Browse Source

Finish implementation oasst preprocessing for of custom dataset

Matthias Reso 1 year ago
parent
commit
dc507b4e55
2 changed files with 55 additions and 53 deletions
  1. 45 44
      examples/custom_dataset.py
  2. 10 9
      tests/datasets/test_custom_dataset.py

+ 45 - 44
examples/custom_dataset.py

@@ -5,9 +5,32 @@
 
 import copy
 import datasets
+import itertools
 
 from llama_recipes.datasets.utils import Concatenator
 
+
+B_INST, E_INST = "[INST]", "[/INST]"
+B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
+
+def tokenize_dialog(dialog, tokenizer):
+    dialog_tokens = [
+            tokenizer(
+                f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ",
+            )
+            for prompt, answer in zip(dialog[::2], dialog[1::2])
+        ]
+    if len(dialog) % 2:    
+        dialog_tokens += [tokenizer(
+            f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}",
+        )]
+    
+    combined_tokens = {}  
+    for k in dialog_tokens[0].keys():
+        combined_tokens[k] = [list(itertools.chain(*(t[k] for t in dialog_tokens)))]
+    return combined_tokens
+
+
 def get_custom_dataset(dataset_config, tokenizer, split):
     dataset = datasets.load_dataset("OpenAssistant/oasst1", split=split)
     
@@ -18,73 +41,51 @@ def get_custom_dataset(dataset_config, tokenizer, split):
         },
         batched=True,
         remove_columns=list(dataset.features),)
-        
-    # print(ids[0])
     
-    p2c = {}
+    nodes = {}
     
-    ids2text = {}
+    messages = {}
     root_ids = []
     
     for data in dataset:
         if data["parent_id"]:
-            p2c[data["parent_id"]] = p2c.get(data["parent_id"], []) + [data["message_id"]]
+            nodes[data["parent_id"]] = nodes.get(data["parent_id"], []) + [data["message_id"]]
         else:
             root_ids.append(data["message_id"])
-        ids2text[data["message_id"]]=data["text"]
+        messages[data["message_id"]]=data["text"]
            
     def follow(thread, current_id):
-        thread = copy.copy(thread) + [ids2text[current_id]]
-        if current_id in p2c:
+        thread = copy.copy(thread) + [messages[current_id]]
+        if current_id in nodes:
             new_threads = []
-            for next_id in p2c[current_id]:
+            for next_id in nodes[current_id]:
                 new_threads += follow(thread, next_id)
             return new_threads
         else:
             return [thread]
         
-        
     def get_threads_from_root(root_id):
         all_threads = []
-        thread = [ids2text[root_id]]
-        for cid in p2c[root_id]:
+        thread = [messages[root_id]]
+        for cid in nodes[root_id]:
             all_threads += follow(thread, cid)
         return all_threads
-        
-        
-    # all_threads = []
-    # for rid in root_ids:
-        
             
     dataset = dataset.filter(lambda x: x["message_id"] in root_ids)
     dataset = dataset.map(lambda x: {"thread": get_threads_from_root(x["message_id"])}, remove_columns=list(dataset.features))
     dataset = dataset.map(lambda x: {"thread": [i for row in x["thread"] for i in row]}, batched=True)
+    
+    def to_dialog(thread):
+        dialog = []
+        for i, content in enumerate(thread):
+            dialog.append({
+                "role": "user" if i % 2 == 0 else "assistant",
+                "content": content,
+            })
+        return {"dialog": dialog}
             
-    print(len(dataset))
-    from pprint import pprint
-    pprint(dataset[:10])
+    dataset = dataset.map(lambda x: to_dialog(x["thread"]), remove_columns=list(dataset.features))
+    dataset = dataset.map(lambda x: tokenize_dialog(x["dialog"], tokenizer), remove_columns=list(dataset.features))
+    dataset = dataset.map(Concatenator(), batched=True)
     
-    return dataset
-    # threads={}
-
-    # prompt = (
-    #     f"Summarize this dialog:\n{{dialog}}\n---\nSummary:\n{{summary}}{{eos_token}}"
-    # )
-
-    # def apply_prompt_template(sample):
-    #     return {
-    #         "text": prompt.format(
-    #             dialog=sample["dialogue"],
-    #             summary=sample["summary"],
-    #             eos_token=tokenizer.eos_token,
-    #         )
-    #     }
-
-    # dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features))
-        
-    # dataset = dataset.map(
-    #     lambda sample: tokenizer(sample["text"]),
-    #     batched=True,
-    #     remove_columns=list(dataset.features),
-    # ).map(Concatenator(), batched=True)
-    # return dataset
+    return dataset

+ 10 - 9
tests/datasets/test_custom_dataset.py

@@ -7,17 +7,16 @@ from unittest.mock import patch
 
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
-@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
-def test_custom_dataset(step_lr, optimizer, tokenizer, get_model, train, mocker):
+def test_custom_dataset(step_lr, optimizer, get_model, train, mocker):
     from llama_recipes.finetuning import main
-        
-    tokenizer.return_value = mocker.MagicMock(side_effect=lambda x: {"input_ids":[len(x)*[0,]], "attention_mask": [len(x)*[0,]]})
     
     kwargs = {
         "dataset": "custom_dataset",
+        "model_name": "decapoda-research/llama-7b-hf", # We use the tokenizer as a surrogate for llama2 tokenizer here
         "custom_dataset.file": "examples/custom_dataset.py",
+        "custom_dataset.train_split": "validation",
         "batch_size_training": 1,
         "use_peft": False,
         }
@@ -29,13 +28,15 @@ def test_custom_dataset(step_lr, optimizer, tokenizer, get_model, train, mocker)
     args, kwargs = train.call_args
     train_dataloader = args[1]
     eval_dataloader = args[2]
+    tokenizer = args[3]
+    
+    assert len(train_dataloader) == 2241
+    assert len(eval_dataloader) == 2241
     
-    VAL_SAMPLES = 818
-    TRAIN_SAMPLES = 14732
-    CONCAT_SIZE = 2048
+    STRING = tokenizer.decode(next(iter(train_dataloader))["input_ids"][0], skip_special_tokens=True)
+    EXPECTED_STRING = "[INST] Напиши функцию на языке swift, которая сортирует массив целых чисел, а затем выводит его на экран [/INST] Вот функция, "
     
-    assert len(train_dataloader) == TRAIN_SAMPLES // CONCAT_SIZE
-    assert len(eval_dataloader) == VAL_SAMPLES
+    assert STRING.startswith(EXPECTED_STRING)
     
 
 @patch('llama_recipes.finetuning.train')