فهرست منبع

Fix invalid labels for context in custom dataset/oasst1

Matthias Reso 1 سال پیش
والد
کامیت
8620ab8ac2
2فایلهای تغییر یافته به همراه45 افزوده شده و 24 حذف شده
  1. 7 8
      examples/custom_dataset.py
  2. 38 16
      tests/datasets/test_custom_dataset.py

+ 7 - 8
examples/custom_dataset.py

@@ -7,23 +7,22 @@ import copy
 import datasets
 import itertools
 
-from llama_recipes.datasets.utils import Concatenator
-
 
 B_INST, E_INST = "[INST]", "[/INST]"
 
 def tokenize_dialog(dialog, tokenizer):
     prompt_tokens = [tokenizer.encode(f"{tokenizer.bos_token}{B_INST} {(prompt['content']).strip()} {E_INST}", add_special_tokens=False) for prompt in dialog[::2]]
     answer_tokens = [tokenizer.encode(f"{answer['content'].strip()} {tokenizer.eos_token}", add_special_tokens=False) for answer in dialog[1::2]]
-    answer_tokens = [{k:v[1:] for k,v in items.items()} for items in answer_tokens]
     dialog_tokens = list(itertools.chain.from_iterable(zip(prompt_tokens, answer_tokens)))
     #Add labels, convert prompt token to -100 in order to ignore in loss function
-    dialog_tokens = [dict(c, labels=len(c["input_ids"])*[-100,]if i % 2 == 0 else c["input_ids"]) for i,c in enumerate(dialog_tokens)]
+    labels_tokens = [len(c)*[-100,] if i % 2 == 0 else c for i,c in enumerate(dialog_tokens)]
+
+    combined_tokens = {
+        "input_ids": list(itertools.chain(*(t for t in dialog_tokens))),
+        "labels": list(itertools.chain(*(t for t in labels_tokens))),
+    }
 
-    combined_tokens = {}
-    for k in dialog_tokens[0].keys():
-        combined_tokens[k] = list(itertools.chain(*(t[k] for t in dialog_tokens)))
-    return combined_tokens
+    return dict(combined_tokens, attention_mask=[1]*len(combined_tokens["input_ids"]))
 
 
 def get_custom_dataset(dataset_config, tokenizer, split):

+ 38 - 16
tests/datasets/test_custom_dataset.py

@@ -4,14 +4,33 @@
 import pytest
 from unittest.mock import patch
 
+from transformers import LlamaTokenizer
+
+def check_padded_entry(batch):
+    seq_len = sum(batch["attention_mask"][0])
+    assert seq_len < len(batch["attention_mask"][0])
+
+    assert batch["labels"][0][0] == -100
+    assert batch["labels"][0][seq_len-1] == 2
+    assert batch["labels"][0][-1] == -100
+    assert batch["input_ids"][0][0] == 1
+    assert batch["input_ids"][0][-1] == 2
+
 
 @patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
-def test_custom_dataset(step_lr, optimizer, get_model, train, mocker):
+def test_custom_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker):
     from llama_recipes.finetuning import main
 
+    #Align with Llama 2 tokenizer
+    tokenizer.from_pretrained.return_value = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")
+    tokenizer.from_pretrained.return_value.add_special_tokens({'bos_token': '<s>', 'eos_token': '</s>'})
+    tokenizer.from_pretrained.return_value.bos_token_id = 1
+    tokenizer.from_pretrained.return_value.eos_token_id = 2
+
     kwargs = {
         "dataset": "custom_dataset",
         "model_name": "decapoda-research/llama-7b-hf", # We use the tokenizer as a surrogate for llama2 tokenizer here
@@ -20,6 +39,7 @@ def test_custom_dataset(step_lr, optimizer, get_model, train, mocker):
         "batch_size_training": 2,
         "val_batch_size": 4,
         "use_peft": False,
+        "batching_strategy": "padding"
         }
 
     main(**kwargs)
@@ -35,28 +55,30 @@ def test_custom_dataset(step_lr, optimizer, get_model, train, mocker):
     assert len(eval_dataloader) == 1120 //2
 
     it = iter(eval_dataloader)
-    STRING = tokenizer.decode(next(it)["input_ids"][0], skip_special_tokens=True)
+    batch = next(it)
+    STRING = tokenizer.decode(batch["input_ids"][0], skip_special_tokens=True)
     EXPECTED_STRING = "[INST] Who made Berlin [/INST] dunno"
     assert STRING.startswith(EXPECTED_STRING)
 
-    # assert next(it)["input_ids"].size(0) == 4
-    # it = iter(train_dataloader)
-    # entry = next(it)
-    # STRING = tokenizer.decode(entry["input_ids"][0], skip_special_tokens=True)
-    # EXPECTED_STRING = "[INST] Напиши функцию на языке swift, которая сортирует массив целых чисел, а затем выводит его на экран [/INST] Вот функция, "
+    assert batch["input_ids"].size(0) == 4
+    assert set(("labels", "input_ids", "attention_mask")) == set(batch.keys())
 
-    # assert STRING.startswith(EXPECTED_STRING)
-    # assert entry["labels"][0][:10].tolist() == 10*[-100]
+    check_padded_entry(batch)
 
-    next(it)
-    next(it)
-    STRING = tokenizer.decode(next(it)["input_ids"][0], skip_special_tokens=True)
-    EXPECTED_STRING = "[INST] Implementa el algoritmo `bubble sort` en C. [/INST] xdxdxd"
+    it = iter(train_dataloader)
+    for _ in range(5):
+        next(it)
+
+    batch = next(it)
+    STRING = tokenizer.decode(batch["input_ids"][0], skip_special_tokens=True)
+    EXPECTED_STRING = "[INST] How do I initialize a Typescript project using npm and git? [/INST] # Initialize a new NPM project"
     assert STRING.startswith(EXPECTED_STRING)
 
-    assert "labels" in next(iter(train_dataloader)).keys()
-    assert "input_ids" in next(iter(train_dataloader)).keys()
-    assert "attention_mask" in next(iter(train_dataloader)).keys()
+    assert batch["input_ids"].size(0) == 2
+    assert set(("labels", "input_ids", "attention_mask")) == set(batch.keys())
+
+    check_padded_entry(batch)
+
 
 
 @patch('llama_recipes.finetuning.train')