Browse Source

Use OpenAssistent/oasst1 dataset for custom dataset example (#180)

Geeta Chauhan 1 year ago
parent
commit
279f4d4a0b
5 changed files with 116 additions and 42 deletions
  1. 1 1
      README.md
  2. 2 1
      docs/Dataset.md
  3. 78 20
      examples/custom_dataset.py
  4. 4 1
      scripts/spellcheck_conf/wordlist.txt
  5. 31 19
      tests/datasets/test_custom_dataset.py

File diff suppressed because it is too large
+ 1 - 1
README.md


File diff suppressed because it is too large
+ 2 - 1
docs/Dataset.md


+ 78 - 20
examples/custom_dataset.py

@@ -3,31 +3,89 @@
 
 # For dataset details visit: https://huggingface.co/datasets/samsum
 
+import copy
 import datasets
+import itertools
 
 from llama_recipes.datasets.utils import Concatenator
 
-def get_custom_dataset(dataset_config, tokenizer, split):
-    dataset = datasets.load_dataset("samsum", split=split)
-
-    prompt = (
-        f"Summarize this dialog:\n{{dialog}}\n---\nSummary:\n{{summary}}{{eos_token}}"
-    )
-
-    def apply_prompt_template(sample):
-        return {
-            "text": prompt.format(
-                dialog=sample["dialogue"],
-                summary=sample["summary"],
-                eos_token=tokenizer.eos_token,
+
+B_INST, E_INST = "[INST]", "[/INST]"
+B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
+
+def tokenize_dialog(dialog, tokenizer):
+    dialog_tokens = [
+            tokenizer(
+                f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ",
             )
-        }
+            for prompt, answer in zip(dialog[::2], dialog[1::2])
+        ]
+    if len(dialog) % 2:    
+        dialog_tokens += [tokenizer(
+            f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}",
+        )]
+    
+    combined_tokens = {}  
+    for k in dialog_tokens[0].keys():
+        combined_tokens[k] = list(itertools.chain(*(t[k] for t in dialog_tokens)))
+    return combined_tokens
 
-    dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features))
-        
-    dataset = dataset.map(
-        lambda sample: tokenizer(sample["text"]),
+
+def get_custom_dataset(dataset_config, tokenizer, split):
+    dataset = datasets.load_dataset("OpenAssistant/oasst1", split=split)
+    
+    dataset = dataset.map(lambda sample: {
+        "message_id": sample["message_id"],
+        "parent_id": sample["parent_id"],
+        "text": sample["text"],
+        },
         batched=True,
-        remove_columns=list(dataset.features),
-    ).map(Concatenator(), batched=True)
+        remove_columns=list(dataset.features),)
+    
+    nodes = {}
+    
+    messages = {}
+    root_ids = []
+    
+    for data in dataset:
+        if data["parent_id"]:
+            nodes[data["parent_id"]] = nodes.get(data["parent_id"], []) + [data["message_id"]]
+        else:
+            root_ids.append(data["message_id"])
+        messages[data["message_id"]]=data["text"]
+           
+    def follow(thread, current_id):
+        thread = copy.copy(thread) + [messages[current_id]]
+        if current_id in nodes:
+            new_threads = []
+            for next_id in nodes[current_id]:
+                new_threads += follow(thread, next_id)
+            return new_threads
+        else:
+            return [thread]
+        
+    def get_threads_from_root(root_id):
+        all_threads = []
+        thread = [messages[root_id]]
+        for cid in nodes[root_id]:
+            all_threads += follow(thread, cid)
+        return all_threads
+            
+    dataset = dataset.filter(lambda x: x["message_id"] in root_ids)
+    dataset = dataset.map(lambda x: {"thread": get_threads_from_root(x["message_id"])}, remove_columns=list(dataset.features))
+    dataset = dataset.map(lambda x: {"thread": [i for row in x["thread"] for i in row]}, batched=True)
+    
+    def to_dialog(thread):
+        dialog = []
+        for i, content in enumerate(thread):
+            dialog.append({
+                "role": "user" if i % 2 == 0 else "assistant",
+                "content": content,
+            })
+        return {"dialog": dialog}
+            
+    dataset = dataset.map(lambda x: to_dialog(x["thread"]), remove_columns=list(dataset.features))
+    dataset = dataset.map(lambda x: tokenize_dialog(x["dialog"], tokenizer), remove_columns=list(dataset.features))
+    dataset = dataset.map(Concatenator(), batched=True)
+    
     return dataset

+ 4 - 1
scripts/spellcheck_conf/wordlist.txt

@@ -1147,4 +1147,7 @@ HuggingFace's
 LoRA
 bitsandbytes
 CLA
-dialogs
+dialogs
+OpenAssistant
+oasst1
+oasst

+ 31 - 19
tests/datasets/test_custom_dataset.py

@@ -7,36 +7,48 @@ from unittest.mock import patch
 
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
-@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
-def test_custom_dataset(step_lr, optimizer, tokenizer, get_model, train, mocker):
+def test_custom_dataset(step_lr, optimizer, get_model, train, mocker):
     from llama_recipes.finetuning import main
-        
-    tokenizer.return_value = mocker.MagicMock(side_effect=lambda x: {"input_ids":[len(x)*[0,]], "attention_mask": [len(x)*[0,]]})
-    
+
     kwargs = {
         "dataset": "custom_dataset",
+        "model_name": "decapoda-research/llama-7b-hf", # We use the tokenizer as a surrogate for llama2 tokenizer here
         "custom_dataset.file": "examples/custom_dataset.py",
-        "batch_size_training": 1,
+        "custom_dataset.train_split": "validation",
+        "batch_size_training": 2,
         "use_peft": False,
         }
-    
+
     main(**kwargs)
-    
+
     assert train.call_count == 1
-    
+
     args, kwargs = train.call_args
     train_dataloader = args[1]
     eval_dataloader = args[2]
-    
-    VAL_SAMPLES = 818
-    TRAIN_SAMPLES = 14732
-    CONCAT_SIZE = 2048
-    
-    assert len(train_dataloader) == TRAIN_SAMPLES // CONCAT_SIZE
-    assert len(eval_dataloader) == VAL_SAMPLES
-    
+    tokenizer = args[3]
+
+    assert len(train_dataloader) == 226
+    assert len(eval_dataloader) == 2*226
+
+    it = iter(train_dataloader)
+    STRING = tokenizer.decode(next(it)["input_ids"][0], skip_special_tokens=True)
+    EXPECTED_STRING = "[INST] Напиши функцию на языке swift, которая сортирует массив целых чисел, а затем выводит его на экран [/INST] Вот функция, "
+
+    assert STRING.startswith(EXPECTED_STRING)
+
+    next(it)
+    next(it)
+    next(it)
+    STRING = tokenizer.decode(next(it)["input_ids"][0], skip_special_tokens=True)
+    EXPECTED_SUBSTRING_1 = "Therefore you are correct.  [INST] How can L’Hopital’s Rule be"
+    EXPECTED_SUBSTRING_2 = "a circular path around the turn.  [INST] How on earth is that related to L’Hopital’s Rule?"
+
+    assert EXPECTED_SUBSTRING_1 in STRING
+    assert EXPECTED_SUBSTRING_2 in STRING
+
 
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
@@ -45,9 +57,9 @@ def test_custom_dataset(step_lr, optimizer, tokenizer, get_model, train, mocker)
 @patch('llama_recipes.finetuning.StepLR')
 def test_unknown_dataset_error(step_lr, optimizer, tokenizer, get_model, train, mocker):
     from llama_recipes.finetuning import main
-        
+
     tokenizer.return_value = mocker.MagicMock(side_effect=lambda x: {"input_ids":[len(x)*[0,]], "attention_mask": [len(x)*[0,]]})
-    
+
     kwargs = {
         "dataset": "custom_dataset",
         "custom_dataset.file": "examples/custom_dataset.py:get_unknown_dataset",