Browse Source

Invalidate context in labels for samsum + grammar

Matthias Reso 1 year ago
parent
commit
653a79e3dd

+ 10 - 12
src/llama_recipes/datasets/grammar_dataset/grammar_dataset.py

@@ -47,22 +47,20 @@ class grammar(Dataset):
         input_ = example_batch["input"]
         target_ = example_batch["target"]
 
-        prompt = f"Correct this to standard English: {input_}\n---\nCorrected: {target_}"
-        sample = self.tokenizer(prompt)
+        prompt = f"Correct this to standard English: {input_}\n---\nCorrected: "
+        prompt_ids = self.tokenizer.encode(self.tokenizer.bos_token + prompt, add_special_tokens=False)
+        label_ids = self.tokenizer.encode(target_ + self.tokenizer.eos_token, add_special_tokens=False)
+
+        sample = {
+            "input_ids": prompt_ids + label_ids,
+            "attention_mask": [1] * len(prompt_ids + label_ids),
+            "labels": [-100] * len(prompt_ids) + label_ids
+        }
 
         return sample
 
     def __getitem__(self, index):
-        sample = self.convert_to_features(self.dataset["train"][int(index)])
-        source_ids = sample["input_ids"]
-
-        src_mask = sample["attention_mask"]
-
-        return {
-            "input_ids": source_ids,
-            "attention_mask": src_mask,
-            "labels": source_ids.copy(),
-        }
+        return self.convert_to_features(self.dataset["train"][int(index)])
 
 
 def get_dataset(

+ 18 - 11
src/llama_recipes/datasets/samsum_dataset.py

@@ -3,6 +3,7 @@
 
 # For dataset details visit: https://huggingface.co/datasets/samsum
 
+import copy
 import datasets
 
 
@@ -10,23 +11,29 @@ def get_preprocessed_samsum(dataset_config, tokenizer, split):
     dataset = datasets.load_dataset("samsum", split=split)
 
     prompt = (
-        f"Summarize this dialog:\n{{dialog}}\n---\nSummary:\n{{summary}}{{eos_token}}"
+        f"Summarize this dialog:\n{{dialog}}\n---\nSummary:\n"
     )
 
     def apply_prompt_template(sample):
         return {
-            "text": prompt.format(
-                dialog=sample["dialogue"],
-                summary=sample["summary"],
-                eos_token=tokenizer.eos_token,
-            )
+            "prompt": prompt.format(dialog=sample["dialogue"]),
+            "summary": sample["summary"],
         }
 
     dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features))
 
-    dataset = dataset.map(
-        lambda sample: tokenizer(sample["text"]),
-        remove_columns=list(dataset.features),
-    )
-    dataset = dataset.map(lambda x: dict(x, labels=x["input_ids"].copy()),remove_columns=list(dataset.features))
+    def tokenize_add_label(sample):
+        prompt = tokenizer.encode(tokenizer.bos_token + sample["prompt"], add_special_tokens=False)
+        summary = tokenizer.encode(sample["summary"] +  tokenizer.eos_token, add_special_tokens=False)
+
+        sample = {
+            "input_ids": prompt + summary,
+            "attention_mask" : [1] * (len(prompt) + len(summary)),
+            "labels": [-100] * len(prompt) + summary,
+            }
+
+        return sample
+
+    dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features))
+
     return dataset

+ 2 - 0
src/llama_recipes/finetuning.py

@@ -5,6 +5,7 @@ import os
 from pkg_resources import packaging
 
 import fire
+import random
 import torch
 import torch.optim as optim
 from peft import get_peft_model, prepare_model_for_int8_training
@@ -51,6 +52,7 @@ def main(**kwargs):
     # Set the seeds for reproducibility
     torch.cuda.manual_seed(train_config.seed)
     torch.manual_seed(train_config.seed)
+    random.seed(train_config.seed)
 
     if train_config.enable_fsdp:
         setup()

+ 30 - 12
tests/datasets/test_grammar_datasets.py

@@ -3,15 +3,23 @@
 
 from unittest.mock import patch
 
+from transformers import LlamaTokenizer
+
 
 @patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
-def test_grammar_dataset(step_lr, optimizer, get_model, train, mocker):
-# def test_samsum_dataset(step_lr, optimizer, tokenizer, get_model, train, mocker):
+def test_grammar_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker):
     from llama_recipes.finetuning import main
-    
+
+    #Align with Llama 2 tokenizer
+    tokenizer.from_pretrained.return_value = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")
+    tokenizer.from_pretrained.return_value.add_special_tokens({'bos_token': '<s>', 'eos_token': '</s>'})
+    tokenizer.from_pretrained.return_value.bos_token_id = 1
+    tokenizer.from_pretrained.return_value.eos_token_id = 2
+
     BATCH_SIZE = 8
     kwargs = {
         "model_name": "decapoda-research/llama-7b-hf",
@@ -19,22 +27,32 @@ def test_grammar_dataset(step_lr, optimizer, get_model, train, mocker):
         "val_batch_size": 1,
         "use_peft": False,
         "dataset": "grammar_dataset",
+        "batching_strategy": "padding",
         }
-    
+
     main(**kwargs)
-    
+
     assert train.call_count == 1
-    
+
     args, kwargs = train.call_args
     train_dataloader = args[1]
     eval_dataloader = args[2]
-    
+
     VAL_SAMPLES = 2988
     TRAIN_SAMPLES = 13016
-    
+
     assert len(train_dataloader) == TRAIN_SAMPLES // BATCH_SIZE
     assert len(eval_dataloader) == VAL_SAMPLES
-    
-    assert "labels" in next(iter(train_dataloader)).keys()
-    assert "input_ids" in next(iter(train_dataloader)).keys()
-    assert "attention_mask" in next(iter(train_dataloader)).keys()
+
+    batch = next(iter(train_dataloader))
+
+    assert "labels" in batch.keys()
+    assert "input_ids" in batch.keys()
+    assert "attention_mask" in batch.keys()
+
+    assert batch["labels"][0][29] == -100
+    assert batch["labels"][0][30] == 29871
+
+    assert batch["input_ids"][0][0] == 1
+    assert batch["labels"][0][-1] == 2
+    assert batch["input_ids"][0][-1] == 2

+ 30 - 15
tests/datasets/test_samsum_datasets.py

@@ -3,18 +3,23 @@
 
 from unittest.mock import patch
 
+from transformers import LlamaTokenizer
+
 
 @patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
-# @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
-def test_samsum_dataset(step_lr, optimizer, get_model, train, mocker):
-# def test_samsum_dataset(step_lr, optimizer, tokenizer, get_model, train, mocker):
+def test_samsum_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker):
     from llama_recipes.finetuning import main
-        
-    # tokenizer.return_value = mocker.MagicMock(side_effect=lambda x: {"input_ids":[len(x)*[0,]], "attention_mask": [len(x)*[0,]]})
-    
+
+    #Align with Llama 2 tokenizer
+    tokenizer.from_pretrained.return_value = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")
+    tokenizer.from_pretrained.return_value.add_special_tokens({'bos_token': '<s>', 'eos_token': '</s>'})
+    tokenizer.from_pretrained.return_value.bos_token_id = 1
+    tokenizer.from_pretrained.return_value.eos_token_id = 2
+
     BATCH_SIZE = 8
     kwargs = {
         "model_name": "decapoda-research/llama-7b-hf",
@@ -22,22 +27,32 @@ def test_samsum_dataset(step_lr, optimizer, get_model, train, mocker):
         "val_batch_size": 1,
         "use_peft": False,
         "dataset": "samsum_dataset",
+        "batching_strategy": "padding",
         }
-    
+
     main(**kwargs)
-    
+
     assert train.call_count == 1
-    
+
     args, kwargs = train.call_args
     train_dataloader = args[1]
     eval_dataloader = args[2]
-    
+
     VAL_SAMPLES = 818
     TRAIN_SAMPLES = 14732
-    
+
     assert len(train_dataloader) == TRAIN_SAMPLES // BATCH_SIZE
     assert len(eval_dataloader) == VAL_SAMPLES
-    
-    assert "labels" in next(iter(train_dataloader)).keys()
-    assert "input_ids" in next(iter(train_dataloader)).keys()
-    assert "attention_mask" in next(iter(train_dataloader)).keys()
+
+    batch = next(iter(train_dataloader))
+
+    assert "labels" in batch.keys()
+    assert "input_ids" in batch.keys()
+    assert "attention_mask" in batch.keys()
+
+    assert batch["labels"][0][268] == -100
+    assert batch["labels"][0][269] == 22291
+
+    assert batch["input_ids"][0][0] == 1
+    assert batch["labels"][0][-1] == 2
+    assert batch["input_ids"][0][-1] == 2