Browse Source

Remove deprecated pytest_cmdline_preparse

Matthias Reso 1 year ago
parent
commit
147aaa29bc

+ 18 - 20
tests/conftest.py

@@ -7,16 +7,10 @@ from transformers import LlamaTokenizer
 
 ACCESS_ERROR_MSG = "Could not access tokenizer at 'meta-llama/Llama-2-7b-hf'. Did you log into huggingface hub and provided the correct token?"
 
-unskip_missing_tokenizer = False
 
 @pytest.fixture(scope="module")
 def llama_tokenizer():
-    try:
-        return LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
-    except OSError as e:
-        if unskip_missing_tokenizer:
-            raise e
-    return None
+    return LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
 
 
 @pytest.fixture
@@ -28,23 +22,27 @@ def setup_tokenizer(llama_tokenizer):
     return _helper
 
 
-@pytest.fixture(autouse=True)
-def skip_if_tokenizer_is_missing(request, llama_tokenizer):
-    if request.node.get_closest_marker("skip_missing_tokenizer") and not unskip_missing_tokenizer:
-        if llama_tokenizer is None:
-            pytest.skip(ACCESS_ERROR_MSG)
-
-
 def pytest_addoption(parser):
     parser.addoption(
         "--unskip-missing-tokenizer",
         action="store_true",
         default=False, help="disable skip missing tokenizer")
+    
+def pytest_configure(config):
+    config.addinivalue_line("markers", "skip_missing_tokenizer: skip if tokenizer is unavailable")
 
-
-@pytest.hookimpl(tryfirst=True)
-def pytest_cmdline_preparse(config, args):
-    if "--unskip-missing-tokenizer" not in args:
+    
+def pytest_collection_modifyitems(config, items):
+    if config.getoption("--unskip-missing-tokenizer"):
         return
-    global unskip_missing_tokenizer
-    unskip_missing_tokenizer = True
+    
+    try:
+        LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
+        tokenizer_available = True
+    except OSError:
+        tokenizer_available = False
+    
+    skip_missing_tokenizer = pytest.mark.skip(reason=ACCESS_ERROR_MSG)
+    for item in items:
+        if "skip_missing_tokenizer" in item.keywords and not tokenizer_available:
+            item.add_marker(skip_missing_tokenizer)

+ 1 - 1
tests/datasets/test_custom_dataset.py

@@ -17,7 +17,7 @@ def check_padded_entry(batch):
     assert batch["input_ids"][0][-1] == 2
 
 
-@pytest.mark.skip_missing_tokenizer()
+@pytest.mark.skip_missing_tokenizer
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')

+ 1 - 1
tests/datasets/test_grammar_datasets.py

@@ -7,7 +7,7 @@ from unittest.mock import patch
 from transformers import LlamaTokenizer
 
 
-@pytest.mark.skip_missing_tokenizer()
+@pytest.mark.skip_missing_tokenizer
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')

+ 1 - 1
tests/datasets/test_samsum_datasets.py

@@ -6,7 +6,7 @@ from functools import partial
 from unittest.mock import patch
 
 
-@pytest.mark.skip_missing_tokenizer()
+@pytest.mark.skip_missing_tokenizer
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')

+ 2 - 2
tests/test_batching.py

@@ -5,7 +5,7 @@ import pytest
 from unittest.mock import patch
 
 
-@pytest.mark.skip_missing_tokenizer()
+@pytest.mark.skip_missing_tokenizer
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
@@ -47,7 +47,7 @@ def test_packing(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_
     assert batch["attention_mask"][0].size(0) == 4096
 
 
-@pytest.mark.skip_missing_tokenizer()
+@pytest.mark.skip_missing_tokenizer
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')