Browse Source

Make packing/padding a training setting

Matthias Reso 1 year ago
parent
commit
a647955fc8

+ 2 - 4
src/llama_recipes/configs/training.py

@@ -11,6 +11,8 @@ class train_config:
     low_cpu_fsdp: bool=False
     run_validation: bool=True
     batch_size_training: int=4
+    batching_strategy: str="packing" #alternative: padding
+    context_length: int=4096
     gradient_accumulation_steps: int=1
     num_epochs: int=3
     num_workers_dataloader: int=1
@@ -34,7 +36,3 @@ class train_config:
     dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP
     save_optimizer: bool=False # will be used if using FSDP
     use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
-
-    
-    
-    

+ 7 - 9
src/llama_recipes/datasets/utils.py

@@ -11,7 +11,7 @@ class Concatenator(object):
     def __init__(self, chunk_size=2048):
         self.chunk_size=chunk_size
         self.residual = {"input_ids": [], "attention_mask": []}
-        
+
     def __call__(self, batch):
         concatenated_samples = {
             k: v + list(chain(*batch[k])) for k, v in self.residual.items()
@@ -44,26 +44,24 @@ class ConcatDataset(Dataset):
     def __init__(self, dataset, chunk_size=4096):
         self.dataset = dataset
         self.chunk_size = chunk_size
-        
+
         self.samples = []
-        
+
         buffer = {
             "input_ids": [],
             "attention_mask": [],
             "labels": [],
             }
-        
+
         for sample in tqdm(self.dataset, desc="Preprocessing dataset", dynamic_ncols=True):
             buffer = {k: v + sample[k] for k,v in buffer.items()}
-            
+
             while len(next(iter(buffer.values()))) > self.chunk_size:
                 self.samples.append({k: v[:self.chunk_size] for k,v in buffer.items()})
                 buffer = {k: v[self.chunk_size:] for k,v in buffer.items()}
-                
+
     def __getitem__(self, idx):
         return self.samples[idx]
-    
+
     def __len__(self):
         return len(self.samples)
-    
-

+ 3 - 6
src/llama_recipes/datasets/grammar_dataset/grammar_dataset.py

@@ -10,8 +10,6 @@ from pathlib import Path
 
 from torch.utils.data import Dataset
 
-from llama_recipes.datasets.utils import ConcatDataset
-
 
 class grammar(Dataset):
     def __init__(
@@ -48,10 +46,10 @@ class grammar(Dataset):
 
         input_ = example_batch["input"]
         target_ = example_batch["target"]
-        
+
         prompt = f"Correct this to standard English: {input_}\n---\nCorrected: {target_}"
         sample = self.tokenizer(prompt)
-        
+
         return sample
 
     def __getitem__(self, index):
@@ -80,6 +78,5 @@ def get_dataset(
         tokenizer=tokenizer,
         csv_name=csv_name,
     )
-    
-    return dataset
 
+    return dataset

+ 1 - 2
src/llama_recipes/datasets/samsum_dataset.py

@@ -5,7 +5,6 @@
 
 import datasets
 
-from llama_recipes.datasets.utils import Concatenator
 
 def get_preprocessed_samsum(dataset_config, tokenizer, split):
     dataset = datasets.load_dataset("samsum", split=split)
@@ -24,7 +23,7 @@ def get_preprocessed_samsum(dataset_config, tokenizer, split):
         }
 
     dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features))
-        
+
     dataset = dataset.map(
         lambda sample: tokenizer(sample["text"]),
         remove_columns=list(dataset.features),

+ 14 - 14
src/llama_recipes/finetuning.py

@@ -17,10 +17,11 @@ from transformers import (
     LlamaForCausalLM,
     LlamaTokenizer,
     LlamaConfig,
-)   
+)
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer
 
 from llama_recipes.configs import fsdp_config, train_config
+from llama_recipes.data.concatenator import ConcatDataset
 from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing
 
 from llama_recipes.utils import fsdp_auto_wrap_policy
@@ -28,7 +29,7 @@ from llama_recipes.utils.config_utils import (
     update_config,
     generate_peft_config,
     generate_dataset_config,
-    get_sampler_kwargs,
+    get_dataloader_kwargs,
 )
 from llama_recipes.utils.dataset_utils import get_preprocessed_dataset
 
@@ -100,25 +101,19 @@ def main(**kwargs):
     if train_config.enable_fsdp and train_config.use_fast_kernels:
         """
         For FSDP and FSDP+PEFT, setting 'use_fast_kernels' will enable
-        using of Flash Attention or Xformer memory-efficient kernels 
+        using of Flash Attention or Xformer memory-efficient kernels
         based on the hardware being used. This would speed up fine-tuning.
         """
         try:
             from optimum.bettertransformer import BetterTransformer
-            model = BetterTransformer.transform(model) 
+            model = BetterTransformer.transform(model)
         except ImportError:
             print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")
-    
+
     # Load the tokenizer and add special tokens
     tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name)
-    tokenizer.add_special_tokens(
-            {
+    tokenizer.pad_token_id = tokenizer.eos_token_id
 
-                "pad_token": "<PAD>",
-            }
-        )
-    model.resize_token_embeddings(model.config.vocab_size + 1) 
-    
     print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
 
     # Prepare the model for int8 training if quantization is enabled
@@ -180,8 +175,11 @@ def main(**kwargs):
     if not train_config.enable_fsdp or rank == 0:
             print(f"--> Validation Set Length = {len(dataset_val)}")
 
-    train_dl_kwargs = get_sampler_kwargs(train_config, dataset_train, tokenizer, "train")
-    val_dl_kwargs = get_sampler_kwargs(train_config, dataset_val, tokenizer, "val")
+    train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train")
+    val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, tokenizer, "val")
+
+    if train_config.batching_strategy == "packing":
+        dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
 
     # Create DataLoaders for the training and validation dataset
     train_dataloader = torch.utils.data.DataLoader(
@@ -193,6 +191,8 @@ def main(**kwargs):
 
     eval_dataloader = None
     if train_config.run_validation:
+        if train_config.batching_strategy == "packing":
+            dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)
         eval_dataloader = torch.utils.data.DataLoader(
             dataset_val,
             num_workers=train_config.num_workers_dataloader,

+ 31 - 17
src/llama_recipes/utils/config_utils.py

@@ -38,49 +38,63 @@ def update_config(config, **kwargs):
                         print(f"Warning: {config_name} does not accept parameter: {k}")
             elif isinstance(config, train_config):
                 print(f"Warning: unknown parameter {k}")
-                        
-                        
+
+
 def generate_peft_config(train_config, kwargs):
     configs = (lora_config, llama_adapter_config, prefix_config)
     peft_configs = (LoraConfig, AdaptionPromptConfig, PrefixTuningConfig)
     names = tuple(c.__name__.rstrip("_config") for c in configs)
-    
+
     assert train_config.peft_method in names, f"Peft config not found: {train_config.peft_method}"
-    
+
     config = configs[names.index(train_config.peft_method)]()
-    
+
     update_config(config, **kwargs)
     params = asdict(config)
     peft_config = peft_configs[names.index(train_config.peft_method)](**params)
-    
+
     return peft_config
 
 
 def generate_dataset_config(train_config, kwargs):
     names = tuple(DATASET_PREPROC.keys())
-        
+
     assert train_config.dataset in names, f"Unknown dataset: {train_config.dataset}"
-    
+
     dataset_config = {k:v for k, v in inspect.getmembers(datasets)}[train_config.dataset]()
-        
+
     update_config(dataset_config, **kwargs)
-    
+
     return  dataset_config
 
 
-def get_sampler_kwargs(train_config, dataset, tokenizer, mode):
+def get_dataloader_kwargs(train_config, dataset, tokenizer, mode):
         kwargs = {}
         batch_size = train_config.batch_size_training if mode=="train" else train_config.val_batch_size
-        if train_config.enable_fsdp:
-            kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler(
+        if train_config.batching_strategy == "padding":
+            if train_config.enable_fsdp:
+                kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler(
+                    dataset,
+                    batch_size=batch_size,
+                    rank=dist.get_rank(),
+                    num_replicas=dist.get_world_size(),
+                    shuffle=mode=="train",
+                )
+            else:
+                kwargs["batch_sampler"] = LengthBasedBatchSampler(dataset, batch_size, drop_last=True, shuffle=mode=="train")
+            kwargs["collate_fn"] = DataCollatorForSeq2Seq(tokenizer)
+        elif train_config.batching_strategy == "packing":
+            if train_config.enable_fsdp:
+                kwargs["batch_sampler"] = DistributedSampler(
                 dataset,
-                batch_size=batch_size,
                 rank=dist.get_rank(),
                 num_replicas=dist.get_world_size(),
                 shuffle=mode=="train",
             )
+            kwargs["batch_size"] = batch_size
+            kwargs["drop_last"] = True
+            kwargs["collate_fn"] = default_data_collator
         else:
-            kwargs["batch_sampler"] = LengthBasedBatchSampler(dataset, batch_size, drop_last=True, shuffle=mode=="train")
-        kwargs["collate_fn"] = DataCollatorForSeq2Seq(tokenizer)
-            
+            raise ValueError(f"Unknown batching strategy: {train_config.batching_strategy}")
+
         return kwargs

+ 65 - 26
tests/test_finetuning.py

@@ -1,14 +1,17 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
+import pytest
 from pytest import approx
 from unittest.mock import patch
 
 from torch.nn import Linear
 from torch.optim import AdamW
 from torch.utils.data.dataloader import DataLoader
+from torch.utils.data.sampler import BatchSampler
 
 from llama_recipes.finetuning import main
+from llama_recipes.data.sampler import LengthBasedBatchSampler
 
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
@@ -18,23 +21,23 @@ from llama_recipes.finetuning import main
 @patch('llama_recipes.finetuning.StepLR')
 def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
     kwargs = {"run_validation": False}
-    
+
     get_dataset.return_value = [[1]]
-    
+
     main(**kwargs)
-    
+
     assert train.call_count == 1
-    
+
     args, kwargs = train.call_args
     train_dataloader = args[1]
     eval_dataloader = args[2]
-    
+
     assert isinstance(train_dataloader, DataLoader)
     assert eval_dataloader is None
-    
+
     assert get_model.return_value.to.call_args.args[0] == "cuda"
-    
-    
+
+
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
@@ -44,20 +47,20 @@ def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, ge
 def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
     kwargs = {"run_validation": True}
     get_dataset.return_value = [[1]]
-    
+
     main(**kwargs)
-    
+
     assert train.call_count == 1
-    
+
     args, kwargs = train.call_args
     train_dataloader = args[1]
     eval_dataloader = args[2]
     assert isinstance(train_dataloader, DataLoader)
     assert isinstance(eval_dataloader, DataLoader)
-    
+
     assert get_model.return_value.to.call_args.args[0] == "cuda"
-    
-    
+
+
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
@@ -68,15 +71,15 @@ def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer,
 @patch('llama_recipes.finetuning.StepLR')
 def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, get_dataset, tokenizer, get_model, train):
     kwargs = {"use_peft": True}
-    
+
     get_dataset.return_value = [[1]]
-    
+
     main(**kwargs)
-    
+
     assert get_peft_model.return_value.to.call_args.args[0] == "cuda"
     assert get_peft_model.return_value.print_trainable_parameters.call_count == 1
-    
-    
+
+
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
@@ -85,20 +88,56 @@ def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, ge
 @patch('llama_recipes.finetuning.StepLR')
 def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer, get_model, train):
     kwargs = {"weight_decay": 0.01}
-    
+
     get_dataset.return_value = [[1]]
-    
+
     get_peft_model.return_value = Linear(1,1)
     get_peft_model.return_value.print_trainable_parameters=lambda:None
     main(**kwargs)
-    
+
     assert train.call_count == 1
-    
+
     args, kwargs = train.call_args
     optimizer = args[4]
-    
+
     print(optimizer.state_dict())
-    
+
     assert isinstance(optimizer, AdamW)
     assert optimizer.state_dict()["param_groups"][0]["weight_decay"] == approx(0.01)
-    
+
+
+@patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
+@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
+@patch('llama_recipes.finetuning.get_preprocessed_dataset')
+@patch('llama_recipes.finetuning.optim.AdamW')
+@patch('llama_recipes.finetuning.StepLR')
+def test_batching_strategy(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
+    kwargs = {"batching_strategy": "packing"}
+
+    get_dataset.return_value = [[1]]
+
+    main(**kwargs)
+
+    assert train.call_count == 1
+
+    args, kwargs = train.call_args
+    train_dataloader, eval_dataloader = args[1:3]
+    assert isinstance(train_dataloader.batch_sampler, BatchSampler)
+    assert isinstance(eval_dataloader.batch_sampler, BatchSampler)
+
+    kwargs["batching_strategy"] = "padding"
+    train.reset_mock()
+    main(**kwargs)
+
+    assert train.call_count == 1
+
+    args, kwargs = train.call_args
+    train_dataloader, eval_dataloader = args[1:3]
+    assert isinstance(train_dataloader.batch_sampler, LengthBasedBatchSampler)
+    assert isinstance(eval_dataloader.batch_sampler, LengthBasedBatchSampler)
+
+    kwargs["batching_strategy"] = "none"
+
+    with pytest.raises(ValueError):
+        main(**kwargs)