Browse Source

Allow easier use of custom datasets (#178)

Geeta Chauhan 1 year ago
parent
commit
d172d883b0

+ 1 - 1
README.md

@@ -101,7 +101,7 @@ If you want to dive right into single or multi GPU fine-tuning, run the examples
 All the parameters in the examples and recipes below need to be further tuned to have desired results based on the model, method, data and task at hand.
 All the parameters in the examples and recipes below need to be further tuned to have desired results based on the model, method, data and task at hand.
 
 
 **Note:**
 **Note:**
-* To change the dataset in the commands below pass the `dataset` arg. Current options for dataset are `grammar_dataset`, `alpaca_dataset`and  `samsum_dataset`. A description of the datasets and how to add custom datasets can be found in [Dataset.md](./docs/Dataset.md). For  `grammar_dataset`, `alpaca_dataset` please make sure you use the suggested instructions from [here](./docs/single_gpu.md#how-to-run-with-different-datasets) to set them up.
+* To change the dataset in the commands below pass the `dataset` arg. Current options for integrated dataset are `grammar_dataset`, `alpaca_dataset`and  `samsum_dataset`. A description of how to use your own dataset and how to add custom datasets can be found in [Dataset.md](./docs/Dataset.md#using-custom-datasets). For  `grammar_dataset`, `alpaca_dataset` please make sure you use the suggested instructions from [here](./docs/single_gpu.md#how-to-run-with-different-datasets) to set them up.
 
 
 * Default dataset and other LORA config has been set to `samsum_dataset`.
 * Default dataset and other LORA config has been set to `samsum_dataset`.
 
 

+ 29 - 4
docs/Dataset.md

@@ -6,10 +6,35 @@ The provided fine tuning script allows you to select between three datasets by p
 * [alpaca_dataset](https://github.com/tatsu-lab/stanford_alpaca) provides 52K instruction-response pairs as generated by `text-davinci-003`.
 * [alpaca_dataset](https://github.com/tatsu-lab/stanford_alpaca) provides 52K instruction-response pairs as generated by `text-davinci-003`.
 * [samsum_dataset](https://huggingface.co/datasets/samsum) contains about 16k messenger-like conversations with summaries.
 * [samsum_dataset](https://huggingface.co/datasets/samsum) contains about 16k messenger-like conversations with summaries.
 
 
-## Adding custom datasets
-
-The list of available datasets can easily be extended with custom datasets by following these instructions.
-
+## Using custom datasets
+
+The list of available datasets in llama-recipes is supposed to give users a quick start on training their Llama model.
+To use a custom dataset there are two possible ways.
+The first provides a function returning the dataset in a .py file which can be given to the command line tool.
+This does not involve changing the source code of llama-recipes.
+The second way is targeting contributions which extend llama-recipes as it involves changing the source code.
+
+### Training on custom data
+To supply a custom dataset you need to provide a single .py file which contains a function with the following signature:
+```@python
+def get_custom_dataset(dataset_config, tokenizer, split: str):
+```
+For an example `get_custom_dataset` you can look at the provided datasets in llama_recipes.datasets or [examples/custom_dataset.py](examples/custom_dataset.py).
+The `dataset_config` in the above signature will be an instance of llama_recipes.configs.dataset.custom_dataset with the modifications made through the command line.
+The split signals wether to return the training or validation dataset.
+The default function name is `get_custom_dataset` but this can be changes as described below.
+
+In order to start a training with the custom dataset we need to set the `--dataset` as well as the `--custom_dataset.file` parameter. 
+```
+python -m llama_recipes.finetuning --dataset "custom_dataset" --custom_dataset.file "examples/custom_dataset.py" [TRAINING PARAMETERS]
+```
+To change the function name that is used in the .py you can append the name following a `:` like this:
+```
+python -m llama_recipes.finetuning --dataset "custom_dataset" --custom_dataset.file "examples/custom_dataset.py:get_foo" [TRAINING PARAMETERS]
+```
+This will call the function `get_foo` instead of `get_custom_dataset` when retrieving the dataset.
+
+### Adding new dataset 
 Each dataset has a corresponding configuration (dataclass) in [configs/datasets.py](../src/llama_recipes/configs/datasets.py) which contains the dataset name, training/validation split names, as well as optional parameters like datafiles etc.
 Each dataset has a corresponding configuration (dataclass) in [configs/datasets.py](../src/llama_recipes/configs/datasets.py) which contains the dataset name, training/validation split names, as well as optional parameters like datafiles etc.
 
 
 Additionally, there is a preprocessing function for each dataset in the [datasets](../src/llama_recipes/datasets) folder.
 Additionally, there is a preprocessing function for each dataset in the [datasets](../src/llama_recipes/datasets) folder.

+ 5 - 1
examples/README.md

@@ -31,4 +31,8 @@ For more in depth information on inference including inference safety checks and
 
 
 **Note** The [sensitive topics safety checker](../src/llama_recipes/inference/safety_utils.py) utilizes AuditNLG which is an optional dependency. Please refer to installation section of the main [README.md](../README.md#install-with-optional-dependencies) for details.
 **Note** The [sensitive topics safety checker](../src/llama_recipes/inference/safety_utils.py) utilizes AuditNLG which is an optional dependency. Please refer to installation section of the main [README.md](../README.md#install-with-optional-dependencies) for details.
 
 
-**Note** The **vLLM** example requires additional dependencies. Please refer to installation section of the main [README.md](../README.md#install-with-optional-dependencies) for details.
+**Note** The **vLLM** example requires additional dependencies. Please refer to installation section of the main [README.md](../README.md#install-with-optional-dependencies) for details.
+
+## Train on custom dataset
+To show how to train a model on a custom dataset we provide an example to generate a custom dataset in [custom_dataset.py](./custom_dataset.py).
+The usage of the custom dataset is further described in the datasets [README](../docs/Dataset.md#training-on-custom-data).

+ 33 - 0
examples/custom_dataset.py

@@ -0,0 +1,33 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+# For dataset details visit: https://huggingface.co/datasets/samsum
+
+import datasets
+
+from llama_recipes.datasets.utils import Concatenator
+
+def get_custom_dataset(dataset_config, tokenizer, split):
+    dataset = datasets.load_dataset("samsum", split=split)
+
+    prompt = (
+        f"Summarize this dialog:\n{{dialog}}\n---\nSummary:\n{{summary}}{{eos_token}}"
+    )
+
+    def apply_prompt_template(sample):
+        return {
+            "text": prompt.format(
+                dialog=sample["dialogue"],
+                summary=sample["summary"],
+                eos_token=tokenizer.eos_token,
+            )
+        }
+
+    dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features))
+        
+    dataset = dataset.map(
+        lambda sample: tokenizer(sample["text"]),
+        batched=True,
+        remove_columns=list(dataset.features),
+    ).map(Concatenator(), batched=True)
+    return dataset

+ 9 - 1
src/llama_recipes/configs/datasets.py

@@ -25,4 +25,12 @@ class alpaca_dataset:
     dataset: str = "alpaca_dataset"
     dataset: str = "alpaca_dataset"
     train_split: str = "train"
     train_split: str = "train"
     test_split: str = "val"
     test_split: str = "val"
-    data_path: str = "src/llama_recipes/datasets/alpaca_data.json"
+    data_path: str = "src/llama_recipes/datasets/alpaca_data.json"
+    
+    
+@dataclass
+class custom_dataset:
+    dataset: str = "custom_dataset"
+    file: str = "examples/custom_dataset.py"
+    train_split: str = "train"
+    test_split: str = "validation"

+ 5 - 3
src/llama_recipes/utils/config_utils.py

@@ -42,7 +42,8 @@ def generate_peft_config(train_config, kwargs):
     
     
     assert train_config.peft_method in names, f"Peft config not found: {train_config.peft_method}"
     assert train_config.peft_method in names, f"Peft config not found: {train_config.peft_method}"
     
     
-    config = configs[names.index(train_config.peft_method)]
+    config = configs[names.index(train_config.peft_method)]()
+    
     update_config(config, **kwargs)
     update_config(config, **kwargs)
     params = {k.name: getattr(config, k.name) for k in fields(config)}
     params = {k.name: getattr(config, k.name) for k in fields(config)}
     peft_config = peft_configs[names.index(train_config.peft_method)](**params)
     peft_config = peft_configs[names.index(train_config.peft_method)](**params)
@@ -52,10 +53,11 @@ def generate_peft_config(train_config, kwargs):
 
 
 def generate_dataset_config(train_config, kwargs):
 def generate_dataset_config(train_config, kwargs):
     names = tuple(DATASET_PREPROC.keys())
     names = tuple(DATASET_PREPROC.keys())
-    
+        
     assert train_config.dataset in names, f"Unknown dataset: {train_config.dataset}"
     assert train_config.dataset in names, f"Unknown dataset: {train_config.dataset}"
     
     
-    dataset_config = {k:v for k, v in inspect.getmembers(datasets)}[train_config.dataset]
+    dataset_config = {k:v for k, v in inspect.getmembers(datasets)}[train_config.dataset]()
+        
     update_config(dataset_config, **kwargs)
     update_config(dataset_config, **kwargs)
     
     
     return  dataset_config
     return  dataset_config

+ 38 - 0
src/llama_recipes/utils/dataset_utils.py

@@ -1,7 +1,9 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 
+import importlib
 from functools import partial
 from functools import partial
+from pathlib import Path
 
 
 import torch
 import torch
 
 
@@ -12,10 +14,46 @@ from llama_recipes.datasets import (
 )
 )
 
 
 
 
+def load_module_from_py_file(py_file: str) -> object:
+    """
+    This method loads a module from a py file which is not in the Python path
+    """
+    module_name = Path(py_file).name
+    loader = importlib.machinery.SourceFileLoader(module_name, py_file)
+    spec = importlib.util.spec_from_loader(module_name, loader)
+    module = importlib.util.module_from_spec(spec)
+
+    loader.exec_module(module)
+
+    return module
+
+
+def get_custom_dataset(dataset_config, tokenizer, split: str):
+    if ":" in dataset_config.file:
+        module_path, func_name = dataset_config.file.split(":")
+    else:
+        module_path, func_name = dataset_config.file, "get_custom_dataset"
+        
+    if not module_path.endswith(".py"):
+        raise ValueError(f"Dataset file {module_path} is not a .py file.")
+    
+    module_path = Path(module_path)
+    if not module_path.is_file():
+        raise FileNotFoundError(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.")
+    
+    module = load_module_from_py_file(module_path.as_posix())
+    try:
+        return getattr(module, func_name)(dataset_config, tokenizer, split)
+    except AttributeError as e:
+        print(f"It seems like the given method name ({func_name}) is not present in the dataset .py file ({module_path.as_posix()}).")
+        raise e
+    
+
 DATASET_PREPROC = {
 DATASET_PREPROC = {
     "alpaca_dataset": partial(get_alpaca_dataset, max_words=224),
     "alpaca_dataset": partial(get_alpaca_dataset, max_words=224),
     "grammar_dataset": get_grammar_dataset,
     "grammar_dataset": get_grammar_dataset,
     "samsum_dataset": get_samsum_dataset,
     "samsum_dataset": get_samsum_dataset,
+    "custom_dataset": get_custom_dataset,
 }
 }
 
 
 
 

+ 58 - 0
tests/datasets/test_custom_dataset.py

@@ -0,0 +1,58 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import pytest
+from unittest.mock import patch
+
+
+@patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
+@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
+@patch('llama_recipes.finetuning.optim.AdamW')
+@patch('llama_recipes.finetuning.StepLR')
+def test_custom_dataset(step_lr, optimizer, tokenizer, get_model, train, mocker):
+    from llama_recipes.finetuning import main
+        
+    tokenizer.return_value = mocker.MagicMock(side_effect=lambda x: {"input_ids":[len(x)*[0,]], "attention_mask": [len(x)*[0,]]})
+    
+    kwargs = {
+        "dataset": "custom_dataset",
+        "custom_dataset.file": "examples/custom_dataset.py",
+        "batch_size_training": 1,
+        "use_peft": False,
+        }
+    
+    main(**kwargs)
+    
+    assert train.call_count == 1
+    
+    args, kwargs = train.call_args
+    train_dataloader = args[1]
+    eval_dataloader = args[2]
+    
+    VAL_SAMPLES = 818
+    TRAIN_SAMPLES = 14732
+    CONCAT_SIZE = 2048
+    
+    assert len(train_dataloader) == TRAIN_SAMPLES // CONCAT_SIZE
+    assert len(eval_dataloader) == VAL_SAMPLES
+    
+
+@patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
+@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
+@patch('llama_recipes.finetuning.optim.AdamW')
+@patch('llama_recipes.finetuning.StepLR')
+def test_unknown_dataset_error(step_lr, optimizer, tokenizer, get_model, train, mocker):
+    from llama_recipes.finetuning import main
+        
+    tokenizer.return_value = mocker.MagicMock(side_effect=lambda x: {"input_ids":[len(x)*[0,]], "attention_mask": [len(x)*[0,]]})
+    
+    kwargs = {
+        "dataset": "custom_dataset",
+        "custom_dataset.file": "examples/custom_dataset.py:get_unknown_dataset",
+        "batch_size_training": 1,
+        "use_peft": False,
+        }
+    with pytest.raises(AttributeError):
+        main(**kwargs)

+ 37 - 0
tests/datasets/test_samsum_datasets.py

@@ -0,0 +1,37 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+from unittest.mock import patch
+
+
+@patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
+@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
+@patch('llama_recipes.finetuning.optim.AdamW')
+@patch('llama_recipes.finetuning.StepLR')
+def test_custom_dataset(step_lr, optimizer, tokenizer, get_model, train, mocker):
+    from llama_recipes.finetuning import main
+        
+    tokenizer.return_value = mocker.MagicMock(side_effect=lambda x: {"input_ids":[len(x)*[0,]], "attention_mask": [len(x)*[0,]]})
+    
+    
+    kwargs = {
+        "batch_size_training": 1,
+        "use_peft": False,
+        "dataset": "samsum_dataset",
+        }
+    
+    main(**kwargs)
+    
+    assert train.call_count == 1
+    
+    args, kwargs = train.call_args
+    train_dataloader = args[1]
+    eval_dataloader = args[2]
+    
+    VAL_SAMPLES = 818
+    TRAIN_SAMPLES = 14732
+    CONCAT_SIZE = 2048
+    assert len(train_dataloader) == TRAIN_SAMPLES // CONCAT_SIZE
+    assert len(eval_dataloader) == VAL_SAMPLES
+