Browse Source

Make packing/padding a training setting

Matthias Reso 1 year ago
parent
commit
a647955fc8

+ 2 - 4
src/llama_recipes/configs/training.py

@@ -11,6 +11,8 @@ class train_config:
     low_cpu_fsdp: bool=False
     low_cpu_fsdp: bool=False
     run_validation: bool=True
     run_validation: bool=True
     batch_size_training: int=4
     batch_size_training: int=4
+    batching_strategy: str="packing" #alternative: padding
+    context_length: int=4096
     gradient_accumulation_steps: int=1
     gradient_accumulation_steps: int=1
     num_epochs: int=3
     num_epochs: int=3
     num_workers_dataloader: int=1
     num_workers_dataloader: int=1
@@ -34,7 +36,3 @@ class train_config:
     dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP
     dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP
     save_optimizer: bool=False # will be used if using FSDP
     save_optimizer: bool=False # will be used if using FSDP
     use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
     use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
-
-    
-    
-    

+ 7 - 9
src/llama_recipes/datasets/utils.py

@@ -11,7 +11,7 @@ class Concatenator(object):
     def __init__(self, chunk_size=2048):
     def __init__(self, chunk_size=2048):
         self.chunk_size=chunk_size
         self.chunk_size=chunk_size
         self.residual = {"input_ids": [], "attention_mask": []}
         self.residual = {"input_ids": [], "attention_mask": []}
-        
+
     def __call__(self, batch):
     def __call__(self, batch):
         concatenated_samples = {
         concatenated_samples = {
             k: v + list(chain(*batch[k])) for k, v in self.residual.items()
             k: v + list(chain(*batch[k])) for k, v in self.residual.items()
@@ -44,26 +44,24 @@ class ConcatDataset(Dataset):
     def __init__(self, dataset, chunk_size=4096):
     def __init__(self, dataset, chunk_size=4096):
         self.dataset = dataset
         self.dataset = dataset
         self.chunk_size = chunk_size
         self.chunk_size = chunk_size
-        
+
         self.samples = []
         self.samples = []
-        
+
         buffer = {
         buffer = {
             "input_ids": [],
             "input_ids": [],
             "attention_mask": [],
             "attention_mask": [],
             "labels": [],
             "labels": [],
             }
             }
-        
+
         for sample in tqdm(self.dataset, desc="Preprocessing dataset", dynamic_ncols=True):
         for sample in tqdm(self.dataset, desc="Preprocessing dataset", dynamic_ncols=True):
             buffer = {k: v + sample[k] for k,v in buffer.items()}
             buffer = {k: v + sample[k] for k,v in buffer.items()}
-            
+
             while len(next(iter(buffer.values()))) > self.chunk_size:
             while len(next(iter(buffer.values()))) > self.chunk_size:
                 self.samples.append({k: v[:self.chunk_size] for k,v in buffer.items()})
                 self.samples.append({k: v[:self.chunk_size] for k,v in buffer.items()})
                 buffer = {k: v[self.chunk_size:] for k,v in buffer.items()}
                 buffer = {k: v[self.chunk_size:] for k,v in buffer.items()}
-                
+
     def __getitem__(self, idx):
     def __getitem__(self, idx):
         return self.samples[idx]
         return self.samples[idx]
-    
+
     def __len__(self):
     def __len__(self):
         return len(self.samples)
         return len(self.samples)
-    
-

+ 3 - 6
src/llama_recipes/datasets/grammar_dataset/grammar_dataset.py

@@ -10,8 +10,6 @@ from pathlib import Path
 
 
 from torch.utils.data import Dataset
 from torch.utils.data import Dataset
 
 
-from llama_recipes.datasets.utils import ConcatDataset
-
 
 
 class grammar(Dataset):
 class grammar(Dataset):
     def __init__(
     def __init__(
@@ -48,10 +46,10 @@ class grammar(Dataset):
 
 
         input_ = example_batch["input"]
         input_ = example_batch["input"]
         target_ = example_batch["target"]
         target_ = example_batch["target"]
-        
+
         prompt = f"Correct this to standard English: {input_}\n---\nCorrected: {target_}"
         prompt = f"Correct this to standard English: {input_}\n---\nCorrected: {target_}"
         sample = self.tokenizer(prompt)
         sample = self.tokenizer(prompt)
-        
+
         return sample
         return sample
 
 
     def __getitem__(self, index):
     def __getitem__(self, index):
@@ -80,6 +78,5 @@ def get_dataset(
         tokenizer=tokenizer,
         tokenizer=tokenizer,
         csv_name=csv_name,
         csv_name=csv_name,
     )
     )
-    
-    return dataset
 
 
+    return dataset

+ 1 - 2
src/llama_recipes/datasets/samsum_dataset.py

@@ -5,7 +5,6 @@
 
 
 import datasets
 import datasets
 
 
-from llama_recipes.datasets.utils import Concatenator
 
 
 def get_preprocessed_samsum(dataset_config, tokenizer, split):
 def get_preprocessed_samsum(dataset_config, tokenizer, split):
     dataset = datasets.load_dataset("samsum", split=split)
     dataset = datasets.load_dataset("samsum", split=split)
@@ -24,7 +23,7 @@ def get_preprocessed_samsum(dataset_config, tokenizer, split):
         }
         }
 
 
     dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features))
     dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features))
-        
+
     dataset = dataset.map(
     dataset = dataset.map(
         lambda sample: tokenizer(sample["text"]),
         lambda sample: tokenizer(sample["text"]),
         remove_columns=list(dataset.features),
         remove_columns=list(dataset.features),

+ 14 - 14
src/llama_recipes/finetuning.py

@@ -17,10 +17,11 @@ from transformers import (
     LlamaForCausalLM,
     LlamaForCausalLM,
     LlamaTokenizer,
     LlamaTokenizer,
     LlamaConfig,
     LlamaConfig,
-)   
+)
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer
 
 
 from llama_recipes.configs import fsdp_config, train_config
 from llama_recipes.configs import fsdp_config, train_config
+from llama_recipes.data.concatenator import ConcatDataset
 from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing
 from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing
 
 
 from llama_recipes.utils import fsdp_auto_wrap_policy
 from llama_recipes.utils import fsdp_auto_wrap_policy
@@ -28,7 +29,7 @@ from llama_recipes.utils.config_utils import (
     update_config,
     update_config,
     generate_peft_config,
     generate_peft_config,
     generate_dataset_config,
     generate_dataset_config,
-    get_sampler_kwargs,
+    get_dataloader_kwargs,
 )
 )
 from llama_recipes.utils.dataset_utils import get_preprocessed_dataset
 from llama_recipes.utils.dataset_utils import get_preprocessed_dataset
 
 
@@ -100,25 +101,19 @@ def main(**kwargs):
     if train_config.enable_fsdp and train_config.use_fast_kernels:
     if train_config.enable_fsdp and train_config.use_fast_kernels:
         """
         """
         For FSDP and FSDP+PEFT, setting 'use_fast_kernels' will enable
         For FSDP and FSDP+PEFT, setting 'use_fast_kernels' will enable
-        using of Flash Attention or Xformer memory-efficient kernels 
+        using of Flash Attention or Xformer memory-efficient kernels
         based on the hardware being used. This would speed up fine-tuning.
         based on the hardware being used. This would speed up fine-tuning.
         """
         """
         try:
         try:
             from optimum.bettertransformer import BetterTransformer
             from optimum.bettertransformer import BetterTransformer
-            model = BetterTransformer.transform(model) 
+            model = BetterTransformer.transform(model)
         except ImportError:
         except ImportError:
             print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")
             print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")
-    
+
     # Load the tokenizer and add special tokens
     # Load the tokenizer and add special tokens
     tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name)
     tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name)
-    tokenizer.add_special_tokens(
-            {
+    tokenizer.pad_token_id = tokenizer.eos_token_id
 
 
-                "pad_token": "<PAD>",
-            }
-        )
-    model.resize_token_embeddings(model.config.vocab_size + 1) 
-    
     print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
     print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
 
 
     # Prepare the model for int8 training if quantization is enabled
     # Prepare the model for int8 training if quantization is enabled
@@ -180,8 +175,11 @@ def main(**kwargs):
     if not train_config.enable_fsdp or rank == 0:
     if not train_config.enable_fsdp or rank == 0:
             print(f"--> Validation Set Length = {len(dataset_val)}")
             print(f"--> Validation Set Length = {len(dataset_val)}")
 
 
-    train_dl_kwargs = get_sampler_kwargs(train_config, dataset_train, tokenizer, "train")
-    val_dl_kwargs = get_sampler_kwargs(train_config, dataset_val, tokenizer, "val")
+    train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train")
+    val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, tokenizer, "val")
+
+    if train_config.batching_strategy == "packing":
+        dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
 
 
     # Create DataLoaders for the training and validation dataset
     # Create DataLoaders for the training and validation dataset
     train_dataloader = torch.utils.data.DataLoader(
     train_dataloader = torch.utils.data.DataLoader(
@@ -193,6 +191,8 @@ def main(**kwargs):
 
 
     eval_dataloader = None
     eval_dataloader = None
     if train_config.run_validation:
     if train_config.run_validation:
+        if train_config.batching_strategy == "packing":
+            dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)
         eval_dataloader = torch.utils.data.DataLoader(
         eval_dataloader = torch.utils.data.DataLoader(
             dataset_val,
             dataset_val,
             num_workers=train_config.num_workers_dataloader,
             num_workers=train_config.num_workers_dataloader,

+ 31 - 17
src/llama_recipes/utils/config_utils.py

@@ -38,49 +38,63 @@ def update_config(config, **kwargs):
                         print(f"Warning: {config_name} does not accept parameter: {k}")
                         print(f"Warning: {config_name} does not accept parameter: {k}")
             elif isinstance(config, train_config):
             elif isinstance(config, train_config):
                 print(f"Warning: unknown parameter {k}")
                 print(f"Warning: unknown parameter {k}")
-                        
-                        
+
+
 def generate_peft_config(train_config, kwargs):
 def generate_peft_config(train_config, kwargs):
     configs = (lora_config, llama_adapter_config, prefix_config)
     configs = (lora_config, llama_adapter_config, prefix_config)
     peft_configs = (LoraConfig, AdaptionPromptConfig, PrefixTuningConfig)
     peft_configs = (LoraConfig, AdaptionPromptConfig, PrefixTuningConfig)
     names = tuple(c.__name__.rstrip("_config") for c in configs)
     names = tuple(c.__name__.rstrip("_config") for c in configs)
-    
+
     assert train_config.peft_method in names, f"Peft config not found: {train_config.peft_method}"
     assert train_config.peft_method in names, f"Peft config not found: {train_config.peft_method}"
-    
+
     config = configs[names.index(train_config.peft_method)]()
     config = configs[names.index(train_config.peft_method)]()
-    
+
     update_config(config, **kwargs)
     update_config(config, **kwargs)
     params = asdict(config)
     params = asdict(config)
     peft_config = peft_configs[names.index(train_config.peft_method)](**params)
     peft_config = peft_configs[names.index(train_config.peft_method)](**params)
-    
+
     return peft_config
     return peft_config
 
 
 
 
 def generate_dataset_config(train_config, kwargs):
 def generate_dataset_config(train_config, kwargs):
     names = tuple(DATASET_PREPROC.keys())
     names = tuple(DATASET_PREPROC.keys())
-        
+
     assert train_config.dataset in names, f"Unknown dataset: {train_config.dataset}"
     assert train_config.dataset in names, f"Unknown dataset: {train_config.dataset}"
-    
+
     dataset_config = {k:v for k, v in inspect.getmembers(datasets)}[train_config.dataset]()
     dataset_config = {k:v for k, v in inspect.getmembers(datasets)}[train_config.dataset]()
-        
+
     update_config(dataset_config, **kwargs)
     update_config(dataset_config, **kwargs)
-    
+
     return  dataset_config
     return  dataset_config
 
 
 
 
-def get_sampler_kwargs(train_config, dataset, tokenizer, mode):
+def get_dataloader_kwargs(train_config, dataset, tokenizer, mode):
         kwargs = {}
         kwargs = {}
         batch_size = train_config.batch_size_training if mode=="train" else train_config.val_batch_size
         batch_size = train_config.batch_size_training if mode=="train" else train_config.val_batch_size
-        if train_config.enable_fsdp:
-            kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler(
+        if train_config.batching_strategy == "padding":
+            if train_config.enable_fsdp:
+                kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler(
+                    dataset,
+                    batch_size=batch_size,
+                    rank=dist.get_rank(),
+                    num_replicas=dist.get_world_size(),
+                    shuffle=mode=="train",
+                )
+            else:
+                kwargs["batch_sampler"] = LengthBasedBatchSampler(dataset, batch_size, drop_last=True, shuffle=mode=="train")
+            kwargs["collate_fn"] = DataCollatorForSeq2Seq(tokenizer)
+        elif train_config.batching_strategy == "packing":
+            if train_config.enable_fsdp:
+                kwargs["batch_sampler"] = DistributedSampler(
                 dataset,
                 dataset,
-                batch_size=batch_size,
                 rank=dist.get_rank(),
                 rank=dist.get_rank(),
                 num_replicas=dist.get_world_size(),
                 num_replicas=dist.get_world_size(),
                 shuffle=mode=="train",
                 shuffle=mode=="train",
             )
             )
+            kwargs["batch_size"] = batch_size
+            kwargs["drop_last"] = True
+            kwargs["collate_fn"] = default_data_collator
         else:
         else:
-            kwargs["batch_sampler"] = LengthBasedBatchSampler(dataset, batch_size, drop_last=True, shuffle=mode=="train")
-        kwargs["collate_fn"] = DataCollatorForSeq2Seq(tokenizer)
-            
+            raise ValueError(f"Unknown batching strategy: {train_config.batching_strategy}")
+
         return kwargs
         return kwargs

+ 65 - 26
tests/test_finetuning.py

@@ -1,14 +1,17 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 
+import pytest
 from pytest import approx
 from pytest import approx
 from unittest.mock import patch
 from unittest.mock import patch
 
 
 from torch.nn import Linear
 from torch.nn import Linear
 from torch.optim import AdamW
 from torch.optim import AdamW
 from torch.utils.data.dataloader import DataLoader
 from torch.utils.data.dataloader import DataLoader
+from torch.utils.data.sampler import BatchSampler
 
 
 from llama_recipes.finetuning import main
 from llama_recipes.finetuning import main
+from llama_recipes.data.sampler import LengthBasedBatchSampler
 
 
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
@@ -18,23 +21,23 @@ from llama_recipes.finetuning import main
 @patch('llama_recipes.finetuning.StepLR')
 @patch('llama_recipes.finetuning.StepLR')
 def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
 def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
     kwargs = {"run_validation": False}
     kwargs = {"run_validation": False}
-    
+
     get_dataset.return_value = [[1]]
     get_dataset.return_value = [[1]]
-    
+
     main(**kwargs)
     main(**kwargs)
-    
+
     assert train.call_count == 1
     assert train.call_count == 1
-    
+
     args, kwargs = train.call_args
     args, kwargs = train.call_args
     train_dataloader = args[1]
     train_dataloader = args[1]
     eval_dataloader = args[2]
     eval_dataloader = args[2]
-    
+
     assert isinstance(train_dataloader, DataLoader)
     assert isinstance(train_dataloader, DataLoader)
     assert eval_dataloader is None
     assert eval_dataloader is None
-    
+
     assert get_model.return_value.to.call_args.args[0] == "cuda"
     assert get_model.return_value.to.call_args.args[0] == "cuda"
-    
-    
+
+
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
@@ -44,20 +47,20 @@ def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, ge
 def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
 def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
     kwargs = {"run_validation": True}
     kwargs = {"run_validation": True}
     get_dataset.return_value = [[1]]
     get_dataset.return_value = [[1]]
-    
+
     main(**kwargs)
     main(**kwargs)
-    
+
     assert train.call_count == 1
     assert train.call_count == 1
-    
+
     args, kwargs = train.call_args
     args, kwargs = train.call_args
     train_dataloader = args[1]
     train_dataloader = args[1]
     eval_dataloader = args[2]
     eval_dataloader = args[2]
     assert isinstance(train_dataloader, DataLoader)
     assert isinstance(train_dataloader, DataLoader)
     assert isinstance(eval_dataloader, DataLoader)
     assert isinstance(eval_dataloader, DataLoader)
-    
+
     assert get_model.return_value.to.call_args.args[0] == "cuda"
     assert get_model.return_value.to.call_args.args[0] == "cuda"
-    
-    
+
+
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
@@ -68,15 +71,15 @@ def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer,
 @patch('llama_recipes.finetuning.StepLR')
 @patch('llama_recipes.finetuning.StepLR')
 def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, get_dataset, tokenizer, get_model, train):
 def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, get_dataset, tokenizer, get_model, train):
     kwargs = {"use_peft": True}
     kwargs = {"use_peft": True}
-    
+
     get_dataset.return_value = [[1]]
     get_dataset.return_value = [[1]]
-    
+
     main(**kwargs)
     main(**kwargs)
-    
+
     assert get_peft_model.return_value.to.call_args.args[0] == "cuda"
     assert get_peft_model.return_value.to.call_args.args[0] == "cuda"
     assert get_peft_model.return_value.print_trainable_parameters.call_count == 1
     assert get_peft_model.return_value.print_trainable_parameters.call_count == 1
-    
-    
+
+
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
@@ -85,20 +88,56 @@ def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, ge
 @patch('llama_recipes.finetuning.StepLR')
 @patch('llama_recipes.finetuning.StepLR')
 def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer, get_model, train):
 def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer, get_model, train):
     kwargs = {"weight_decay": 0.01}
     kwargs = {"weight_decay": 0.01}
-    
+
     get_dataset.return_value = [[1]]
     get_dataset.return_value = [[1]]
-    
+
     get_peft_model.return_value = Linear(1,1)
     get_peft_model.return_value = Linear(1,1)
     get_peft_model.return_value.print_trainable_parameters=lambda:None
     get_peft_model.return_value.print_trainable_parameters=lambda:None
     main(**kwargs)
     main(**kwargs)
-    
+
     assert train.call_count == 1
     assert train.call_count == 1
-    
+
     args, kwargs = train.call_args
     args, kwargs = train.call_args
     optimizer = args[4]
     optimizer = args[4]
-    
+
     print(optimizer.state_dict())
     print(optimizer.state_dict())
-    
+
     assert isinstance(optimizer, AdamW)
     assert isinstance(optimizer, AdamW)
     assert optimizer.state_dict()["param_groups"][0]["weight_decay"] == approx(0.01)
     assert optimizer.state_dict()["param_groups"][0]["weight_decay"] == approx(0.01)
-    
+
+
+@patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
+@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
+@patch('llama_recipes.finetuning.get_preprocessed_dataset')
+@patch('llama_recipes.finetuning.optim.AdamW')
+@patch('llama_recipes.finetuning.StepLR')
+def test_batching_strategy(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
+    kwargs = {"batching_strategy": "packing"}
+
+    get_dataset.return_value = [[1]]
+
+    main(**kwargs)
+
+    assert train.call_count == 1
+
+    args, kwargs = train.call_args
+    train_dataloader, eval_dataloader = args[1:3]
+    assert isinstance(train_dataloader.batch_sampler, BatchSampler)
+    assert isinstance(eval_dataloader.batch_sampler, BatchSampler)
+
+    kwargs["batching_strategy"] = "padding"
+    train.reset_mock()
+    main(**kwargs)
+
+    assert train.call_count == 1
+
+    args, kwargs = train.call_args
+    train_dataloader, eval_dataloader = args[1:3]
+    assert isinstance(train_dataloader.batch_sampler, LengthBasedBatchSampler)
+    assert isinstance(eval_dataloader.batch_sampler, LengthBasedBatchSampler)
+
+    kwargs["batching_strategy"] = "none"
+
+    with pytest.raises(ValueError):
+        main(**kwargs)