Jelajahi Sumber

Use new chat format in custom dataset

Matthias Reso 10 bulan lalu
induk
melakukan
8b0a233c1a

+ 21 - 5
recipes/finetuning/datasets/custom_dataset.py

@@ -11,11 +11,27 @@ import itertools
 B_INST, E_INST = "[INST]", "[/INST]"
 
 def tokenize_dialog(dialog, tokenizer):
-    prompt_tokens = [tokenizer.encode(f"{tokenizer.bos_token}{B_INST} {(prompt['content']).strip()} {E_INST}", add_special_tokens=False) for prompt in dialog[::2]]
-    answer_tokens = [tokenizer.encode(f"{answer['content'].strip()} {tokenizer.eos_token}", add_special_tokens=False) for answer in dialog[1::2]]
-    dialog_tokens = list(itertools.chain.from_iterable(zip(prompt_tokens, answer_tokens)))
-    #Add labels, convert prompt token to -100 in order to ignore in loss function
-    labels_tokens = [len(c)*[-100,] if i % 2 == 0 else c for i,c in enumerate(dialog_tokens)]
+    if tokenizer.vocab_size >= 128000:
+        dialog_tokens = tokenizer.apply_chat_template(dialog)
+        dialog_tokens = dialog_tokens[:-4] # Remove generation prompt <|start_header_id|>assistant<|end_header_id|>\n\n
+        eot_indices = [i for i,n in enumerate(dialog_tokens) if n == 128009]
+        labels = copy.copy(dialog_tokens)
+        last_idx = 0
+        for n, idx in enumerate(eot_indices):
+            if n % 2 == 1:
+                last_idx = idx
+            else:
+                labels[last_idx:idx+1] = [-100] * (idx-last_idx+1)
+
+        dialog_tokens = [dialog_tokens]
+        labels_tokens = [labels]
+    else:
+        prompt_tokens = [tokenizer.encode(f"{tokenizer.bos_token}{B_INST} {(prompt['content']).strip()} {E_INST}", add_special_tokens=False) for prompt in dialog[::2]]
+        answer_tokens = [tokenizer.encode(f"{answer['content'].strip()} {tokenizer.eos_token}", add_special_tokens=False) for answer in dialog[1::2]]
+        dialog_tokens = list(itertools.chain.from_iterable(zip(prompt_tokens, answer_tokens)))
+
+        #Add labels, convert prompt token to -100 in order to ignore in loss function
+        labels_tokens = [len(c)*[-100,] if i % 2 == 0 else c for i,c in enumerate(dialog_tokens)]
 
     combined_tokens = {
         "input_ids": list(itertools.chain(*(t for t in dialog_tokens))),

+ 12 - 5
tests/datasets/test_custom_dataset.py

@@ -12,8 +12,8 @@ EXPECTED_RESULTS={
         "example_2": "[INST] Quiero preparar una pizza de pepperoni, puedes darme los pasos para hacerla? [/INST] Claro!",
     },
     "hsramall/hsramall-7b-hf":{
-        "example_1": "[INST] こんにちは! [/INST]こんにちは!",
-        "example_2": "[INST] Как появляются деньги в экономике? Я знаю, что центробанк страны обычно регулирует базовую ставку валюты, но",
+        "example_1": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nWho made Berlin<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\ndunno<|eot_id|><|end_of_text|>",
+        "example_2": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow to start learning guitar and become a master at it?",
     },
 }
 
@@ -21,8 +21,13 @@ def check_padded_entry(batch, tokenizer):
     seq_len = sum(batch["attention_mask"][0])
     assert seq_len < len(batch["attention_mask"][0])
 
+    if tokenizer.vocab_size >= 128000:
+        END_OF_TEXT_ID = 128009
+    else:
+        END_OF_TEXT_ID = tokenizer.eos_token_id
+
     assert batch["labels"][0][0] == -100
-    assert batch["labels"][0][seq_len-1] == tokenizer.eos_token_id
+    assert batch["labels"][0][seq_len-1] == END_OF_TEXT_ID
     assert batch["labels"][0][-1] == -100
     assert batch["input_ids"][0][0] == tokenizer.bos_token_id
     assert batch["input_ids"][0][-1] == tokenizer.eos_token_id
@@ -39,6 +44,8 @@ def test_custom_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker,
 
     setup_tokenizer(tokenizer)
 
+    skip_special_tokens = llama_version == "meta-llama/Llama-2-7b-hf"
+
     kwargs = {
         "dataset": "custom_dataset",
         "model_name": llama_version,
@@ -64,7 +71,7 @@ def test_custom_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker,
 
     it = iter(eval_dataloader)
     batch = next(it)
-    STRING = tokenizer.decode(batch["input_ids"][0], skip_special_tokens=True)
+    STRING = tokenizer.decode(batch["input_ids"][0], skip_special_tokens=skip_special_tokens)
     assert STRING.startswith(EXPECTED_RESULTS[llama_version]["example_1"])
 
     assert batch["input_ids"].size(0) == 4
@@ -76,7 +83,7 @@ def test_custom_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker,
     next(it)
 
     batch = next(it)
-    STRING = tokenizer.decode(batch["input_ids"][0], skip_special_tokens=True)
+    STRING = tokenizer.decode(batch["input_ids"][0], skip_special_tokens=skip_special_tokens)
     assert STRING.startswith(EXPECTED_RESULTS[llama_version]["example_2"])
 
     assert batch["input_ids"].size(0) == 2