Parcourir la source

Added mechanism to supply a custom dataset without changing source files

Matthias Reso il y a 1 an
Parent
commit
33c7738dcc

+ 33 - 0
examples/custom_dataset.py

@@ -0,0 +1,33 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+# For dataset details visit: https://huggingface.co/datasets/samsum
+
+import datasets
+
+from llama_recipes.datasets.utils import Concatenator
+
+def get_preprocessed_samsum(dataset_config, tokenizer, split):
+    dataset = datasets.load_dataset("samsum", split=split)
+
+    prompt = (
+        f"Summarize this dialog:\n{{dialog}}\n---\nSummary:\n{{summary}}{{eos_token}}"
+    )
+
+    def apply_prompt_template(sample):
+        return {
+            "text": prompt.format(
+                dialog=sample["dialogue"],
+                summary=sample["summary"],
+                eos_token=tokenizer.eos_token,
+            )
+        }
+
+    dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features))
+        
+    dataset = dataset.map(
+        lambda sample: tokenizer(sample["text"]),
+        batched=True,
+        remove_columns=list(dataset.features),
+    ).map(Concatenator(), batched=True)
+    return dataset

+ 9 - 1
src/llama_recipes/configs/datasets.py

@@ -25,4 +25,12 @@ class alpaca_dataset:
     dataset: str = "alpaca_dataset"
     train_split: str = "train"
     test_split: str = "val"
-    data_path: str = "src/llama_recipes/datasets/alpaca_data.json"
+    data_path: str = "src/llama_recipes/datasets/alpaca_data.json"
+    
+    
+@dataclass
+class custom_dataset:
+    dataset: str = "custom_dataset"
+    file: str = "examples/custom_dataset.py"
+    train_split: str = "train"
+    test_split: str = "validation"

+ 5 - 3
src/llama_recipes/utils/config_utils.py

@@ -42,7 +42,8 @@ def generate_peft_config(train_config, kwargs):
     
     assert train_config.peft_method in names, f"Peft config not found: {train_config.peft_method}"
     
-    config = configs[names.index(train_config.peft_method)]
+    config = configs[names.index(train_config.peft_method)]()
+    
     update_config(config, **kwargs)
     params = {k.name: getattr(config, k.name) for k in fields(config)}
     peft_config = peft_configs[names.index(train_config.peft_method)](**params)
@@ -52,10 +53,11 @@ def generate_peft_config(train_config, kwargs):
 
 def generate_dataset_config(train_config, kwargs):
     names = tuple(DATASET_PREPROC.keys())
-    
+        
     assert train_config.dataset in names, f"Unknown dataset: {train_config.dataset}"
     
-    dataset_config = {k:v for k, v in inspect.getmembers(datasets)}[train_config.dataset]
+    dataset_config = {k:v for k, v in inspect.getmembers(datasets)}[train_config.dataset]()
+        
     update_config(dataset_config, **kwargs)
     
     return  dataset_config

+ 35 - 0
src/llama_recipes/utils/dataset_utils.py

@@ -1,7 +1,9 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
+import importlib
 from functools import partial
+from pathlib import Path
 
 import torch
 
@@ -12,10 +14,43 @@ from llama_recipes.datasets import (
 )
 
 
+def load_module_from_py_file(py_file: str) -> object:
+    """
+    This method loads a module from a py file which is not in the Python path
+    """
+    module_name = Path(py_file).name
+    loader = importlib.machinery.SourceFileLoader(module_name, py_file)
+    spec = importlib.util.spec_from_loader(module_name, loader)
+    module = importlib.util.module_from_spec(spec)
+
+    loader.exec_module(module)
+
+    return module
+
+
+def get_custom_dataset(dataset_config, tokenizer, split: str):
+    if ":" in dataset_config.file:
+        module_path, func_name = dataset_config.file.split(":")
+    else:
+        module_path, func_name = dataset_config.file, "get_custom_dataset"
+        
+    if not module_path.endswith(".py"):
+        raise ValueError(f"Dataset file {module_path} is not a .py file.")
+    
+    module_path = Path(module_path)
+    if not module_path.is_file():
+        raise FileNotFoundError(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.")
+    
+    module = load_module_from_py_file(module_path.as_posix())
+    
+    return getattr(module, func_name)(dataset_config, tokenizer, split)
+    
+
 DATASET_PREPROC = {
     "alpaca_dataset": partial(get_alpaca_dataset, max_words=224),
     "grammar_dataset": get_grammar_dataset,
     "samsum_dataset": get_samsum_dataset,
+    "custom_dataset": get_custom_dataset,
 }
 
 

+ 38 - 0
tests/test_custom_dataset.py

@@ -0,0 +1,38 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+from unittest.mock import patch
+
+
+@patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
+@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
+@patch('llama_recipes.finetuning.optim.AdamW')
+@patch('llama_recipes.finetuning.StepLR')
+def test_custom_dataset(step_lr, optimizer, tokenizer, get_model, train, mocker):
+    from llama_recipes.finetuning import main
+        
+    tokenizer.return_value = mocker.MagicMock(side_effect=lambda x: {"input_ids":[len(x)*[0,]], "attention_mask": [len(x)*[0,]]})
+    
+    kwargs = {
+        "dataset": "custom_dataset",
+        "custom_dataset.file": "examples/custom_dataset.py:get_preprocessed_samsum",
+        "batch_size_training": 1,
+        "use_peft": False,
+        }
+    
+    main(**kwargs)
+    
+    assert train.call_count == 1
+    
+    args, kwargs = train.call_args
+    train_dataloader = args[1]
+    eval_dataloader = args[2]
+    
+    VAL_SAMPLES = 818
+    TRAIN_SAMPLES = 14732
+    CONCAT_SIZE = 2048
+    
+    assert len(train_dataloader) == TRAIN_SAMPLES // CONCAT_SIZE
+    assert len(eval_dataloader) == VAL_SAMPLES
+    

+ 37 - 0
tests/test_samsum_datasets.py

@@ -0,0 +1,37 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+from unittest.mock import patch
+
+
+@patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
+@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
+@patch('llama_recipes.finetuning.optim.AdamW')
+@patch('llama_recipes.finetuning.StepLR')
+def test_custom_dataset(step_lr, optimizer, tokenizer, get_model, train, mocker):
+    from llama_recipes.finetuning import main
+        
+    tokenizer.return_value = mocker.MagicMock(side_effect=lambda x: {"input_ids":[len(x)*[0,]], "attention_mask": [len(x)*[0,]]})
+    
+    
+    kwargs = {
+        "batch_size_training": 1,
+        "use_peft": False,
+        "dataset": "samsum_dataset",
+        }
+    
+    main(**kwargs)
+    
+    assert train.call_count == 1
+    
+    args, kwargs = train.call_args
+    train_dataloader = args[1]
+    eval_dataloader = args[2]
+    
+    VAL_SAMPLES = 818
+    TRAIN_SAMPLES = 14732
+    CONCAT_SIZE = 2048
+    assert len(train_dataloader) == TRAIN_SAMPLES // CONCAT_SIZE
+    assert len(eval_dataloader) == VAL_SAMPLES
+