conftest.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748
  1. # Copyright (c) Meta Platforms, Inc. and affiliates.
  2. # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
  3. import pytest
  4. from transformers import LlamaTokenizer
  5. ACCESS_ERROR_MSG = "Could not access tokenizer at 'meta-llama/Llama-2-7b-hf'. Did you log into huggingface hub and provided the correct token?"
  6. @pytest.fixture(scope="module")
  7. def llama_tokenizer():
  8. return LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
  9. @pytest.fixture
  10. def setup_tokenizer(llama_tokenizer):
  11. def _helper(tokenizer_mock):
  12. #Align with Llama 2 tokenizer
  13. tokenizer_mock.from_pretrained.return_value = llama_tokenizer
  14. return _helper
  15. def pytest_addoption(parser):
  16. parser.addoption(
  17. "--unskip-missing-tokenizer",
  18. action="store_true",
  19. default=False, help="disable skip missing tokenizer")
  20. def pytest_configure(config):
  21. config.addinivalue_line("markers", "skip_missing_tokenizer: skip if tokenizer is unavailable")
  22. def pytest_collection_modifyitems(config, items):
  23. if config.getoption("--unskip-missing-tokenizer"):
  24. return
  25. try:
  26. LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
  27. tokenizer_available = True
  28. except OSError:
  29. tokenizer_available = False
  30. skip_missing_tokenizer = pytest.mark.skip(reason=ACCESS_ERROR_MSG)
  31. for item in items:
  32. if "skip_missing_tokenizer" in item.keywords and not tokenizer_available:
  33. item.add_marker(skip_missing_tokenizer)