Pārlūkot izejas kodu

Test batching for both llama versions

Matthias Reso 11 mēneši atpakaļ
vecāks
revīzija
17a6d16289
3 mainītis faili ar 37 papildinājumiem un 21 dzēšanām
  1. 1 0
      src/llama_recipes/utils/config_utils.py
  2. 15 10
      tests/conftest.py
  3. 21 11
      tests/test_batching.py

+ 1 - 0
src/llama_recipes/utils/config_utils.py

@@ -90,6 +90,7 @@ def get_dataloader_kwargs(train_config, dataset, tokenizer, mode):
                 rank=dist.get_rank(),
                 num_replicas=dist.get_world_size(),
                 shuffle=mode=="train",
+                drop_last=True,
             )
             kwargs["batch_size"] = batch_size
             kwargs["drop_last"] = True

+ 15 - 10
tests/conftest.py

@@ -3,21 +3,26 @@
 
 import pytest
 
-from transformers import LlamaTokenizer
+from transformers import AutoTokenizer
 
 ACCESS_ERROR_MSG = "Could not access tokenizer at 'meta-llama/Llama-2-7b-hf'. Did you log into huggingface hub and provided the correct token?"
+LLAMA_VERSIONS = ["meta-llama/Llama-2-7b-hf", "hsramall/hsramall-7b-hf"]
+
+@pytest.fixture(params=LLAMA_VERSIONS)
+def llama_version(request):
+    return request.param
 
 
 @pytest.fixture(scope="module")
-def llama_tokenizer():
-    return LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
+def llama_tokenizer(request):
+    return {k: AutoTokenizer.from_pretrained(k) for k in LLAMA_VERSIONS}
 
 
 @pytest.fixture
-def setup_tokenizer(llama_tokenizer):
+def setup_tokenizer(llama_tokenizer, llama_version):
     def _helper(tokenizer_mock):
         #Align with Llama 2 tokenizer
-        tokenizer_mock.from_pretrained.return_value = llama_tokenizer
+        tokenizer_mock.from_pretrained.return_value = llama_tokenizer[llama_version]
 
     return _helper
 
@@ -27,21 +32,21 @@ def pytest_addoption(parser):
         "--unskip-missing-tokenizer",
         action="store_true",
         default=False, help="disable skip missing tokenizer")
-    
+
 def pytest_configure(config):
     config.addinivalue_line("markers", "skip_missing_tokenizer: skip if tokenizer is unavailable")
 
-    
+
 def pytest_collection_modifyitems(config, items):
     if config.getoption("--unskip-missing-tokenizer"):
         return
-    
+
     try:
-        LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
+        AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
         tokenizer_available = True
     except OSError:
         tokenizer_available = False
-    
+
     skip_missing_tokenizer = pytest.mark.skip(reason=ACCESS_ERROR_MSG)
     for item in items:
         if "skip_missing_tokenizer" in item.keywords and not tokenizer_available:

+ 21 - 11
tests/test_batching.py

@@ -4,20 +4,30 @@
 import pytest
 from unittest.mock import patch
 
+EXPECTED_SAMPLE_NUMBER ={
+    "meta-llama/Llama-2-7b-hf": {
+        "train": 96,
+        "eval": 42,
+    },
+    "hsramall/hsramall-7b-hf": {
+        "train": 79,
+        "eval": 34,
+    }
+}
 
 @pytest.mark.skip_missing_tokenizer
 @patch('llama_recipes.finetuning.train')
-@patch('llama_recipes.finetuning.LlamaTokenizer')
+@patch('llama_recipes.finetuning.AutoTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
-def test_packing(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer):
+def test_packing(step_lr, optimizer, get_model, tokenizer, train, setup_tokenizer, llama_version):
     from llama_recipes.finetuning import main
 
     setup_tokenizer(tokenizer)
 
     kwargs = {
-        "model_name": "meta-llama/Llama-2-7b-hf",
+        "model_name": llama_version,
         "batch_size_training": 8,
         "val_batch_size": 1,
         "use_peft": False,
@@ -33,8 +43,8 @@ def test_packing(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_
     train_dataloader = args[1]
     eval_dataloader = args[2]
 
-    assert len(train_dataloader) == 96
-    assert len(eval_dataloader) == 42
+    assert len(train_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["train"]
+    assert len(eval_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["eval"]
 
     batch = next(iter(train_dataloader))
 
@@ -49,7 +59,7 @@ def test_packing(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_
 
 @pytest.mark.skip_missing_tokenizer
 @patch('llama_recipes.finetuning.train')
-@patch('llama_recipes.finetuning.LlamaTokenizer')
+@patch('llama_recipes.finetuning.AutoTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
@@ -57,13 +67,13 @@ def test_packing(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_
 @patch('llama_recipes.finetuning.FSDP')
 @patch('llama_recipes.finetuning.torch.distributed.is_initialized')
 @patch('llama_recipes.utils.config_utils.dist')
-def test_distributed_packing(dist, is_initialized, fsdp, setup, step_lr, optimizer, get_model, tokenizer, train, setup_tokenizer):
+def test_distributed_packing(dist, is_initialized, fsdp, setup, step_lr, optimizer, get_model, tokenizer, train, setup_tokenizer, llama_version):
     import os
     from llama_recipes.finetuning import main
 
     setup_tokenizer(tokenizer)
 
-    rank = 0
+    rank = 1
     os.environ['LOCAL_RANK'] = f'{rank}'
     os.environ['RANK'] = f'{rank}'
     os.environ['WORLD_SIZE'] = '2'
@@ -71,7 +81,7 @@ def test_distributed_packing(dist, is_initialized, fsdp, setup, step_lr, optimiz
     os.environ['MASTER_PORT'] = '12345'
 
     kwargs = {
-        "model_name": "meta-llama/Llama-2-7b-hf",
+        "model_name": llama_version,
         "batch_size_training": 8,
         "val_batch_size": 1,
         "use_peft": False,
@@ -92,5 +102,5 @@ def test_distributed_packing(dist, is_initialized, fsdp, setup, step_lr, optimiz
     train_dataloader = args[1]
     eval_dataloader = args[2]
 
-    assert len(train_dataloader) == 96 //2
-    assert len(eval_dataloader) == 42 //2
+    assert len(train_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["train"] //2
+    assert len(eval_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["eval"] //2