Browse Source

fix missing labels in datasets

Matthias Reso 1 year ago
parent
commit
10f9367e56

+ 1 - 0
examples/custom_dataset.py

@@ -86,5 +86,6 @@ def get_custom_dataset(dataset_config, tokenizer, split):
             
     dataset = dataset.map(lambda x: to_dialog(x["thread"]), remove_columns=list(dataset.features))
     dataset = dataset.map(lambda x: tokenize_dialog(x["dialog"], tokenizer), remove_columns=list(dataset.features))
+    dataset = dataset.map(lambda x: dict(x, labels=x["input_ids"].copy()), remove_columns=list(dataset.features))
     
     return dataset

+ 1 - 0
src/llama_recipes/datasets/samsum_dataset.py

@@ -29,4 +29,5 @@ def get_preprocessed_samsum(dataset_config, tokenizer, split):
         lambda sample: tokenizer(sample["text"]),
         remove_columns=list(dataset.features),
     )
+    dataset = dataset.map(lambda x: dict(x, labels=x["input_ids"].copy()),remove_columns=list(dataset.features))
     return dataset

+ 2 - 2
src/llama_recipes/utils/config_utils.py

@@ -12,7 +12,7 @@ from peft import (
     PrefixTuningConfig,
 )
 from transformers import default_data_collator
-from transformers.data import DataCollatorWithPadding
+from transformers.data import DataCollatorForSeq2Seq
 
 from llama_recipes.configs import datasets, lora_config, llama_adapter_config, prefix_config, train_config
 from llama_recipes.data.sampler import LengthBasedBatchSampler, DistributedLengthBasedBatchSampler
@@ -81,6 +81,6 @@ def get_sampler_kwargs(train_config, dataset, tokenizer, mode):
             )
         else:
             kwargs["batch_sampler"] = LengthBasedBatchSampler(dataset, batch_size, drop_last=True, shuffle=mode=="train")
-        kwargs["collate_fn"] = DataCollatorWithPadding(tokenizer)
+        kwargs["collate_fn"] = DataCollatorForSeq2Seq(tokenizer)
             
         return kwargs

+ 1 - 1
src/llama_recipes/utils/train_utils.py

@@ -75,7 +75,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                     if train_config.enable_fsdp:
                         batch[key] = batch[key].to(local_rank)
                     else:
-                        batch[key] = batch[key].to('cuda:0')              
+                        batch[key] = batch[key].to('cuda:0')
                 loss = model(**batch).loss
                 loss = loss / gradient_accumulation_steps
                 total_loss += loss.detach().float()

+ 4 - 0
tests/datasets/test_custom_dataset.py

@@ -46,6 +46,10 @@ def test_custom_dataset(step_lr, optimizer, get_model, train, mocker):
     STRING = tokenizer.decode(next(it)["input_ids"][0], skip_special_tokens=True)
     EXPECTED_STRING = "[INST] Implementa el algoritmo `bubble sort` en C. [/INST] xdxdxd"
     assert STRING.startswith(EXPECTED_STRING)
+    
+    assert "labels" in next(iter(train_dataloader)).keys()
+    assert "input_ids" in next(iter(train_dataloader)).keys()
+    assert "attention_mask" in next(iter(train_dataloader)).keys()
 
 
 @patch('llama_recipes.finetuning.train')

+ 9 - 4
tests/datasets/test_samsum_datasets.py

@@ -6,16 +6,18 @@ from unittest.mock import patch
 
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
-@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
+# @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
-def test_samsum_dataset(step_lr, optimizer, tokenizer, get_model, train, mocker):
+def test_samsum_dataset(step_lr, optimizer, get_model, train, mocker):
+# def test_samsum_dataset(step_lr, optimizer, tokenizer, get_model, train, mocker):
     from llama_recipes.finetuning import main
         
-    tokenizer.return_value = mocker.MagicMock(side_effect=lambda x: {"input_ids":[len(x)*[0,]], "attention_mask": [len(x)*[0,]]})
+    # tokenizer.return_value = mocker.MagicMock(side_effect=lambda x: {"input_ids":[len(x)*[0,]], "attention_mask": [len(x)*[0,]]})
     
     BATCH_SIZE = 8
     kwargs = {
+        "model_name": "decapoda-research/llama-7b-hf",
         "batch_size_training": 8,
         "val_batch_size": 1,
         "use_peft": False,
@@ -35,4 +37,7 @@ def test_samsum_dataset(step_lr, optimizer, tokenizer, get_model, train, mocker)
     
     assert len(train_dataloader) == TRAIN_SAMPLES // BATCH_SIZE
     assert len(eval_dataloader) == VAL_SAMPLES
-    
+    
+    assert "labels" in next(iter(train_dataloader)).keys()
+    assert "input_ids" in next(iter(train_dataloader)).keys()
+    assert "attention_mask" in next(iter(train_dataloader)).keys()