Pārlūkot izejas kodu

Adjust test_samsum_dataset to second model

Matthias Reso 10 mēneši atpakaļ
vecāks
revīzija
b96e435cda
1 mainītis faili ar 19 papildinājumiem un 8 dzēšanām
  1. 19 8
      tests/datasets/test_samsum_datasets.py

+ 19 - 8
tests/datasets/test_samsum_datasets.py

@@ -5,21 +5,31 @@ import pytest
 from functools import partial
 from unittest.mock import patch
 
+EXPECTED_RESULTS = {
+    "meta-llama/Llama-2-7b-hf":{
+        "label": 8432,
+        "pos": 242,
+    },
+    "hsramall/hsramall-7b-hf":{
+        "label": 2250,
+        "pos": 211,
+    },
+}
 
 @pytest.mark.skip_missing_tokenizer
 @patch('llama_recipes.finetuning.train')
-@patch('llama_recipes.finetuning.LlamaTokenizer')
+@patch('llama_recipes.finetuning.AutoTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
-def test_samsum_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer):
+def test_samsum_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer, llama_version):
     from llama_recipes.finetuning import main
 
     setup_tokenizer(tokenizer)
 
     BATCH_SIZE = 8
     kwargs = {
-        "model_name": "meta-llama/Llama-2-7b-hf",
+        "model_name": llama_version,
         "batch_size_training": BATCH_SIZE,
         "val_batch_size": 1,
         "use_peft": False,
@@ -34,6 +44,7 @@ def test_samsum_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker,
     args, kwargs = train.call_args
     train_dataloader = args[1]
     eval_dataloader = args[2]
+    token = args[3]
 
     VAL_SAMPLES = 818
     TRAIN_SAMPLES = 14732
@@ -47,9 +58,9 @@ def test_samsum_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker,
     assert "input_ids" in batch.keys()
     assert "attention_mask" in batch.keys()
 
-    assert batch["labels"][0][268] == -100
-    assert batch["labels"][0][269] == 319
+    assert batch["labels"][0][EXPECTED_RESULTS[llama_version]["pos"]-1] == -100
+    assert batch["labels"][0][EXPECTED_RESULTS[llama_version]["pos"]] == EXPECTED_RESULTS[llama_version]["label"]
 
-    assert batch["input_ids"][0][0] == 1
-    assert batch["labels"][0][-1] == 2
-    assert batch["input_ids"][0][-1] == 2
+    assert batch["input_ids"][0][0] == token.bos_token_id
+    assert batch["labels"][0][-1] == token.eos_token_id
+    assert batch["input_ids"][0][-1] == token.eos_token_id