|
@@ -6,15 +6,26 @@ from unittest.mock import patch
|
|
|
|
|
|
from transformers import LlamaTokenizer
|
|
from transformers import LlamaTokenizer
|
|
|
|
|
|
-def check_padded_entry(batch):
|
|
|
|
|
|
+EXPECTED_RESULTS={
|
|
|
|
+ "meta-llama/Llama-2-7b-hf":{
|
|
|
|
+ "example_1": "[INST] Who made Berlin [/INST] dunno",
|
|
|
|
+ "example_2": "[INST] Quiero preparar una pizza de pepperoni, puedes darme los pasos para hacerla? [/INST] Claro!",
|
|
|
|
+ },
|
|
|
|
+ "hsramall/hsramall-7b-hf":{
|
|
|
|
+ "example_1": "[INST] こんにちは! [/INST]こんにちは!",
|
|
|
|
+ "example_2": "[INST] Как появляются деньги в экономике? Я знаю, что центробанк страны обычно регулирует базовую ставку валюты, но",
|
|
|
|
+ },
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+def check_padded_entry(batch, tokenizer):
|
|
seq_len = sum(batch["attention_mask"][0])
|
|
seq_len = sum(batch["attention_mask"][0])
|
|
assert seq_len < len(batch["attention_mask"][0])
|
|
assert seq_len < len(batch["attention_mask"][0])
|
|
|
|
|
|
assert batch["labels"][0][0] == -100
|
|
assert batch["labels"][0][0] == -100
|
|
- assert batch["labels"][0][seq_len-1] == 2
|
|
|
|
|
|
+ assert batch["labels"][0][seq_len-1] == tokenizer.eos_token_id
|
|
assert batch["labels"][0][-1] == -100
|
|
assert batch["labels"][0][-1] == -100
|
|
- assert batch["input_ids"][0][0] == 1
|
|
|
|
- assert batch["input_ids"][0][-1] == 2
|
|
|
|
|
|
+ assert batch["input_ids"][0][0] == tokenizer.bos_token_id
|
|
|
|
+ assert batch["input_ids"][0][-1] == tokenizer.eos_token_id
|
|
|
|
|
|
|
|
|
|
@pytest.mark.skip_missing_tokenizer
|
|
@pytest.mark.skip_missing_tokenizer
|
|
@@ -54,26 +65,24 @@ def test_custom_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker,
|
|
it = iter(eval_dataloader)
|
|
it = iter(eval_dataloader)
|
|
batch = next(it)
|
|
batch = next(it)
|
|
STRING = tokenizer.decode(batch["input_ids"][0], skip_special_tokens=True)
|
|
STRING = tokenizer.decode(batch["input_ids"][0], skip_special_tokens=True)
|
|
- EXPECTED_STRING = "[INST] Who made Berlin [/INST] dunno"
|
|
|
|
- assert STRING.startswith(EXPECTED_STRING)
|
|
|
|
|
|
+ assert STRING.startswith(EXPECTED_RESULTS[llama_version]["example_1"])
|
|
|
|
|
|
assert batch["input_ids"].size(0) == 4
|
|
assert batch["input_ids"].size(0) == 4
|
|
assert set(("labels", "input_ids", "attention_mask")) == set(batch.keys())
|
|
assert set(("labels", "input_ids", "attention_mask")) == set(batch.keys())
|
|
|
|
|
|
- check_padded_entry(batch)
|
|
|
|
|
|
+ check_padded_entry(batch, tokenizer)
|
|
|
|
|
|
it = iter(train_dataloader)
|
|
it = iter(train_dataloader)
|
|
next(it)
|
|
next(it)
|
|
|
|
|
|
batch = next(it)
|
|
batch = next(it)
|
|
STRING = tokenizer.decode(batch["input_ids"][0], skip_special_tokens=True)
|
|
STRING = tokenizer.decode(batch["input_ids"][0], skip_special_tokens=True)
|
|
- EXPECTED_STRING = "[INST] Quiero preparar una pizza de pepperoni, puedes darme los pasos para hacerla? [/INST] Claro!"
|
|
|
|
- assert STRING.startswith(EXPECTED_STRING)
|
|
|
|
|
|
+ assert STRING.startswith(EXPECTED_RESULTS[llama_version]["example_2"])
|
|
|
|
|
|
assert batch["input_ids"].size(0) == 2
|
|
assert batch["input_ids"].size(0) == 2
|
|
assert set(("labels", "input_ids", "attention_mask")) == set(batch.keys())
|
|
assert set(("labels", "input_ids", "attention_mask")) == set(batch.keys())
|
|
|
|
|
|
- check_padded_entry(batch)
|
|
|
|
|
|
+ check_padded_entry(batch, tokenizer)
|
|
|
|
|
|
|
|
|
|
|
|
|