Prechádzať zdrojové kódy

[WIP] Preprocess oasst dataset

Matthias Reso 1 rok pred
rodič
commit
32cf7ad459
1 zmenil súbory, kde vykonal 76 pridanie a 19 odobranie
  1. 76 19
      examples/custom_dataset.py

+ 76 - 19
examples/custom_dataset.py

@@ -3,31 +3,88 @@
 
 # For dataset details visit: https://huggingface.co/datasets/samsum
 
+import copy
 import datasets
 
 from llama_recipes.datasets.utils import Concatenator
 
 def get_custom_dataset(dataset_config, tokenizer, split):
-    dataset = datasets.load_dataset("samsum", split=split)
+    dataset = datasets.load_dataset("OpenAssistant/oasst1", split=split)
+    
+    dataset = dataset.map(lambda sample: {
+        "message_id": sample["message_id"],
+        "parent_id": sample["parent_id"],
+        "text": sample["text"],
+        },
+        batched=True,
+        remove_columns=list(dataset.features),)
+        
+    # print(ids[0])
+    
+    p2c = {}
+    
+    ids2text = {}
+    root_ids = []
+    
+    for data in dataset:
+        if data["parent_id"]:
+            p2c[data["parent_id"]] = p2c.get(data["parent_id"], []) + [data["message_id"]]
+        else:
+            root_ids.append(data["message_id"])
+        ids2text[data["message_id"]]=data["text"]
+           
+    def follow(thread, current_id):
+        thread = copy.copy(thread) + [ids2text[current_id]]
+        if current_id in p2c:
+            new_threads = []
+            for next_id in p2c[current_id]:
+                new_threads += follow(thread, next_id)
+            return new_threads
+        else:
+            return [thread]
+        
+        
+    def get_threads_from_root(root_id):
+        all_threads = []
+        thread = [ids2text[root_id]]
+        for cid in p2c[root_id]:
+            all_threads += follow(thread, cid)
+        return all_threads
+        
+        
+    # all_threads = []
+    # for rid in root_ids:
+        
+            
+    dataset = dataset.filter(lambda x: x["message_id"] in root_ids)
+    dataset = dataset.map(lambda x: {"thread": get_threads_from_root(x["message_id"])}, remove_columns=list(dataset.features))
+    dataset = dataset.map(lambda x: {"thread": [i for row in x["thread"] for i in row]}, batched=True)
+            
+    print(len(dataset))
+    from pprint import pprint
+    pprint(dataset[:10])
+    
+    return dataset
+    # threads={}
 
-    prompt = (
-        f"Summarize this dialog:\n{{dialog}}\n---\nSummary:\n{{summary}}{{eos_token}}"
-    )
+    # prompt = (
+    #     f"Summarize this dialog:\n{{dialog}}\n---\nSummary:\n{{summary}}{{eos_token}}"
+    # )
 
-    def apply_prompt_template(sample):
-        return {
-            "text": prompt.format(
-                dialog=sample["dialogue"],
-                summary=sample["summary"],
-                eos_token=tokenizer.eos_token,
-            )
-        }
+    # def apply_prompt_template(sample):
+    #     return {
+    #         "text": prompt.format(
+    #             dialog=sample["dialogue"],
+    #             summary=sample["summary"],
+    #             eos_token=tokenizer.eos_token,
+    #         )
+    #     }
 
-    dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features))
+    # dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features))
         
-    dataset = dataset.map(
-        lambda sample: tokenizer(sample["text"]),
-        batched=True,
-        remove_columns=list(dataset.features),
-    ).map(Concatenator(), batched=True)
-    return dataset
+    # dataset = dataset.map(
+    #     lambda sample: tokenizer(sample["text"]),
+    #     batched=True,
+    #     remove_columns=list(dataset.features),
+    # ).map(Concatenator(), batched=True)
+    # return dataset