Explorar o código

Adapt test_custom_dataset to new model

Matthias Reso hai 10 meses
pai
achega
fac41298b0
Modificáronse 1 ficheiros con 19 adicións e 10 borrados
  1. 19 10
      tests/datasets/test_custom_dataset.py

+ 19 - 10
tests/datasets/test_custom_dataset.py

@@ -6,15 +6,26 @@ from unittest.mock import patch
 
 from transformers import LlamaTokenizer
 
-def check_padded_entry(batch):
+EXPECTED_RESULTS={
+    "meta-llama/Llama-2-7b-hf":{
+        "example_1": "[INST] Who made Berlin [/INST] dunno",
+        "example_2": "[INST] Quiero preparar una pizza de pepperoni, puedes darme los pasos para hacerla? [/INST] Claro!",
+    },
+    "hsramall/hsramall-7b-hf":{
+        "example_1": "[INST] こんにちは! [/INST]こんにちは!",
+        "example_2": "[INST] Как появляются деньги в экономике? Я знаю, что центробанк страны обычно регулирует базовую ставку валюты, но",
+    },
+}
+
+def check_padded_entry(batch, tokenizer):
     seq_len = sum(batch["attention_mask"][0])
     assert seq_len < len(batch["attention_mask"][0])
 
     assert batch["labels"][0][0] == -100
-    assert batch["labels"][0][seq_len-1] == 2
+    assert batch["labels"][0][seq_len-1] == tokenizer.eos_token_id
     assert batch["labels"][0][-1] == -100
-    assert batch["input_ids"][0][0] == 1
-    assert batch["input_ids"][0][-1] == 2
+    assert batch["input_ids"][0][0] == tokenizer.bos_token_id
+    assert batch["input_ids"][0][-1] == tokenizer.eos_token_id
 
 
 @pytest.mark.skip_missing_tokenizer
@@ -54,26 +65,24 @@ def test_custom_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker,
     it = iter(eval_dataloader)
     batch = next(it)
     STRING = tokenizer.decode(batch["input_ids"][0], skip_special_tokens=True)
-    EXPECTED_STRING = "[INST] Who made Berlin [/INST] dunno"
-    assert STRING.startswith(EXPECTED_STRING)
+    assert STRING.startswith(EXPECTED_RESULTS[llama_version]["example_1"])
 
     assert batch["input_ids"].size(0) == 4
     assert set(("labels", "input_ids", "attention_mask")) == set(batch.keys())
 
-    check_padded_entry(batch)
+    check_padded_entry(batch, tokenizer)
 
     it = iter(train_dataloader)
     next(it)
 
     batch = next(it)
     STRING = tokenizer.decode(batch["input_ids"][0], skip_special_tokens=True)
-    EXPECTED_STRING = "[INST] Quiero preparar una pizza de pepperoni, puedes darme los pasos para hacerla? [/INST] Claro!"
-    assert STRING.startswith(EXPECTED_STRING)
+    assert STRING.startswith(EXPECTED_RESULTS[llama_version]["example_2"])
 
     assert batch["input_ids"].size(0) == 2
     assert set(("labels", "input_ids", "attention_mask")) == set(batch.keys())
 
-    check_padded_entry(batch)
+    check_padded_entry(batch, tokenizer)