Browse Source

Give an explicit error message if custom datset function is not found

Matthias Reso 1 year ago
parent
commit
26b9b7dbb2
2 changed files with 27 additions and 4 deletions
  1. 5 2
      src/llama_recipes/utils/dataset_utils.py
  2. 22 2
      tests/datasets/test_custom_dataset.py

+ 5 - 2
src/llama_recipes/utils/dataset_utils.py

@@ -42,8 +42,11 @@ def get_custom_dataset(dataset_config, tokenizer, split: str):
         raise FileNotFoundError(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.")
     
     module = load_module_from_py_file(module_path.as_posix())
-    
-    return getattr(module, func_name)(dataset_config, tokenizer, split)
+    try:
+        return getattr(module, func_name)(dataset_config, tokenizer, split)
+    except AttributeError as e:
+        print(f"It seems like the given method name ({func_name}) is not present in the dataset .py file ({module_path.as_posix()}).")
+        raise e
     
 
 DATASET_PREPROC = {

+ 22 - 2
tests/datasets/test_custom_dataset.py

@@ -1,6 +1,7 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
+import pytest
 from unittest.mock import patch
 
 
@@ -16,7 +17,7 @@ def test_custom_dataset(step_lr, optimizer, tokenizer, get_model, train, mocker)
     
     kwargs = {
         "dataset": "custom_dataset",
-        "custom_dataset.file": "examples/custom_dataset.py:get_preprocessed_samsum",
+        "custom_dataset.file": "examples/custom_dataset.py",
         "batch_size_training": 1,
         "use_peft": False,
         }
@@ -35,4 +36,23 @@ def test_custom_dataset(step_lr, optimizer, tokenizer, get_model, train, mocker)
     
     assert len(train_dataloader) == TRAIN_SAMPLES // CONCAT_SIZE
     assert len(eval_dataloader) == VAL_SAMPLES
-    
+    
+
+@patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
+@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
+@patch('llama_recipes.finetuning.optim.AdamW')
+@patch('llama_recipes.finetuning.StepLR')
+def test_unknown_dataset_error(step_lr, optimizer, tokenizer, get_model, train, mocker):
+    from llama_recipes.finetuning import main
+        
+    tokenizer.return_value = mocker.MagicMock(side_effect=lambda x: {"input_ids":[len(x)*[0,]], "attention_mask": [len(x)*[0,]]})
+    
+    kwargs = {
+        "dataset": "custom_dataset",
+        "custom_dataset.file": "examples/custom_dataset.py:get_unknown_dataset",
+        "batch_size_training": 1,
+        "use_peft": False,
+        }
+    with pytest.raises(AttributeError):
+        main(**kwargs)