Jelajahi Sumber

Feature/length based batch sampling (#206)

Hamid Shojanazeri 1 tahun lalu
induk
melakukan
373000b2ac

+ 3 - 3
README.md

@@ -137,7 +137,7 @@ Here we make use of Parameter Efficient Methods (PEFT) as described in the next
 
 ```bash
 
-torchrun --nnodes 1 --nproc_per_node 4  examples/finetuning.py --enable_fsdp --use_peft --peft_method lora --model_name /patht_of_model_folder/7B --pure_bf16 --output_dir Path/to/save/PEFT/model
+torchrun --nnodes 1 --nproc_per_node 4  examples/finetuning.py --enable_fsdp --use_peft --peft_method lora --model_name /patht_of_model_folder/7B --fsdp_config.pure_bf16 --output_dir Path/to/save/PEFT/model
 
 ```
 
@@ -148,7 +148,7 @@ Here we use FSDP as discussed in the next section which can be used along with P
 Setting `use_fast_kernels` will enable using of Flash Attention or Xformer memory-efficient kernels based on the hardware being used. This would speed up the fine-tuning job. This has been enabled in `optimum` library from HuggingFace as a one-liner API, please read more [here](https://pytorch.org/blog/out-of-the-box-acceleration/).
 
 ```bash
-torchrun --nnodes 1 --nproc_per_node 4  examples/finetuning.py --enable_fsdp --use_peft --peft_method lora --model_name /patht_of_model_folder/7B --pure_bf16 --output_dir Path/to/save/PEFT/model --use_fast_kernels
+torchrun --nnodes 1 --nproc_per_node 4  examples/finetuning.py --enable_fsdp --use_peft --peft_method lora --model_name /patht_of_model_folder/7B --fsdp_config.pure_bf16 --output_dir Path/to/save/PEFT/model --use_fast_kernels
 ```
 
 ### Fine-tuning using FSDP Only
@@ -167,7 +167,7 @@ If you are interested in running full parameter fine-tuning on the 70B model, yo
 
 ```bash
 
-torchrun --nnodes 1 --nproc_per_node 8 examples/finetuning.py --enable_fsdp --low_cpu_fsdp --pure_bf16 --model_name /patht_of_model_folder/70B --batch_size_training 1 --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned
+torchrun --nnodes 1 --nproc_per_node 8 examples/finetuning.py --enable_fsdp --low_cpu_fsdp --fsdp_config.pure_bf16 --model_name /patht_of_model_folder/70B --batch_size_training 1 --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned
 
 ```
 

+ 14 - 2
docs/Dataset.md

@@ -7,6 +7,18 @@ The provided fine tuning script allows you to select between three datasets by p
 * [samsum_dataset](https://huggingface.co/datasets/samsum) contains about 16k messenger-like conversations with summaries.
 * [OpenAssistant/oasst1](https://huggingface.co/datasets/OpenAssistant/oasst1/) contains about 88k messages from assistant-style conversations.
 
+## Batching Strategies
+Llama-recipes support two strategies to batch requests together.
+The default setting is `packing` which concatenates the tokenized samples into long sequences filling up the context length of the model.
+This is the most compute efficient variant as it avoids any padding and all sequences have the same length.
+Samples at the boundary of the context length are truncated and the remainder of the cut sequence it used as the start of the next long sequence.
+
+If the amount of training data is small this procedure might introduce a lot of noise into the training data which can hurt the prediction performance of the fine-tune model.
+Therefore, we also support a `padding` strategy which does not introduce the addition noise due to truncated sequences.
+The strategy tries to minimize the efficiency loss by batching samples of similar length together so only minimal padding is necessary.
+
+The batching strategy can be selected though the command line parameter `--batching_strategy [packing]/[padding]`.
+
 ## Using custom datasets
 
 The list of available datasets in llama-recipes is supposed to give users a quick start on training their Llama model.
@@ -25,7 +37,7 @@ The `dataset_config` in the above signature will be an instance of llama_recipes
 The split signals wether to return the training or validation dataset.
 The default function name is `get_custom_dataset` but this can be changes as described below.
 
-In order to start a training with the custom dataset we need to set the `--dataset` as well as the `--custom_dataset.file` parameter. 
+In order to start a training with the custom dataset we need to set the `--dataset` as well as the `--custom_dataset.file` parameter.
 ```
 python -m llama_recipes.finetuning --dataset "custom_dataset" --custom_dataset.file "examples/custom_dataset.py" [TRAINING PARAMETERS]
 ```
@@ -35,7 +47,7 @@ python -m llama_recipes.finetuning --dataset "custom_dataset" --custom_dataset.f
 ```
 This will call the function `get_foo` instead of `get_custom_dataset` when retrieving the dataset.
 
-### Adding new dataset 
+### Adding new dataset
 Each dataset has a corresponding configuration (dataclass) in [configs/datasets.py](../src/llama_recipes/configs/datasets.py) which contains the dataset name, training/validation split names, as well as optional parameters like datafiles etc.
 
 Additionally, there is a preprocessing function for each dataset in the [datasets](../src/llama_recipes/datasets) folder.

+ 23 - 30
examples/custom_dataset.py

@@ -7,33 +7,27 @@ import copy
 import datasets
 import itertools
 
-from llama_recipes.datasets.utils import Concatenator
-
 
 B_INST, E_INST = "[INST]", "[/INST]"
-B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
 
 def tokenize_dialog(dialog, tokenizer):
-    dialog_tokens = [
-            tokenizer(
-                f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ",
-            )
-            for prompt, answer in zip(dialog[::2], dialog[1::2])
-        ]
-    if len(dialog) % 2:    
-        dialog_tokens += [tokenizer(
-            f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}",
-        )]
-    
-    combined_tokens = {}  
-    for k in dialog_tokens[0].keys():
-        combined_tokens[k] = list(itertools.chain(*(t[k] for t in dialog_tokens)))
-    return combined_tokens
+    prompt_tokens = [tokenizer.encode(f"{tokenizer.bos_token}{B_INST} {(prompt['content']).strip()} {E_INST}", add_special_tokens=False) for prompt in dialog[::2]]
+    answer_tokens = [tokenizer.encode(f"{answer['content'].strip()} {tokenizer.eos_token}", add_special_tokens=False) for answer in dialog[1::2]]
+    dialog_tokens = list(itertools.chain.from_iterable(zip(prompt_tokens, answer_tokens)))
+    #Add labels, convert prompt token to -100 in order to ignore in loss function
+    labels_tokens = [len(c)*[-100,] if i % 2 == 0 else c for i,c in enumerate(dialog_tokens)]
+
+    combined_tokens = {
+        "input_ids": list(itertools.chain(*(t for t in dialog_tokens))),
+        "labels": list(itertools.chain(*(t for t in labels_tokens))),
+    }
+
+    return dict(combined_tokens, attention_mask=[1]*len(combined_tokens["input_ids"]))
 
 
 def get_custom_dataset(dataset_config, tokenizer, split):
     dataset = datasets.load_dataset("OpenAssistant/oasst1", split=split)
-    
+
     dataset = dataset.map(lambda sample: {
         "message_id": sample["message_id"],
         "parent_id": sample["parent_id"],
@@ -41,19 +35,19 @@ def get_custom_dataset(dataset_config, tokenizer, split):
         },
         batched=True,
         remove_columns=list(dataset.features),)
-    
+
     nodes = {}
-    
+
     messages = {}
     root_ids = []
-    
+
     for data in dataset:
         if data["parent_id"]:
             nodes[data["parent_id"]] = nodes.get(data["parent_id"], []) + [data["message_id"]]
         else:
             root_ids.append(data["message_id"])
         messages[data["message_id"]]=data["text"]
-           
+
     def follow(thread, current_id):
         thread = copy.copy(thread) + [messages[current_id]]
         if current_id in nodes:
@@ -63,18 +57,18 @@ def get_custom_dataset(dataset_config, tokenizer, split):
             return new_threads
         else:
             return [thread]
-        
+
     def get_threads_from_root(root_id):
         all_threads = []
         thread = [messages[root_id]]
         for cid in nodes[root_id]:
             all_threads += follow(thread, cid)
         return all_threads
-            
+
     dataset = dataset.filter(lambda x: x["message_id"] in root_ids)
     dataset = dataset.map(lambda x: {"thread": get_threads_from_root(x["message_id"])}, remove_columns=list(dataset.features))
     dataset = dataset.map(lambda x: {"thread": [i for row in x["thread"] for i in row]}, batched=True)
-    
+
     def to_dialog(thread):
         dialog = []
         for i, content in enumerate(thread):
@@ -83,9 +77,8 @@ def get_custom_dataset(dataset_config, tokenizer, split):
                 "content": content,
             })
         return {"dialog": dialog}
-            
+
     dataset = dataset.map(lambda x: to_dialog(x["thread"]), remove_columns=list(dataset.features))
     dataset = dataset.map(lambda x: tokenize_dialog(x["dialog"], tokenizer), remove_columns=list(dataset.features))
-    dataset = dataset.map(Concatenator(), batched=True)
-    
-    return dataset
+
+    return dataset

+ 3 - 1
scripts/spellcheck_conf/wordlist.txt

@@ -1157,6 +1157,7 @@ FN
 GBs
 MLP
 learnable
+tokenized
 Colab
 GenAI
 Gradio
@@ -1182,4 +1183,5 @@ minnutes
 pdf
 quantized
 serarch
-streamlit
+streamlit
+

+ 0 - 2
src/llama_recipes/configs/datasets.py

@@ -9,7 +9,6 @@ class samsum_dataset:
     dataset: str =  "samsum_dataset"
     train_split: str = "train"
     test_split: str = "validation"
-    input_length: int = 2048
     
     
 @dataclass
@@ -17,7 +16,6 @@ class grammar_dataset:
     dataset: str = "grammar_dataset"
     train_split: str = "src/llama_recipes/datasets/grammar_dataset/gtrain_10k.csv" 
     test_split: str = "src/llama_recipes/datasets/grammar_dataset/grammar_validation.csv"
-    input_length: int = 2048
 
     
 @dataclass

+ 2 - 4
src/llama_recipes/configs/training.py

@@ -11,6 +11,8 @@ class train_config:
     low_cpu_fsdp: bool=False
     run_validation: bool=True
     batch_size_training: int=4
+    batching_strategy: str="packing" #alternative: padding
+    context_length: int=4096
     gradient_accumulation_steps: int=1
     num_epochs: int=3
     num_workers_dataloader: int=1
@@ -34,7 +36,3 @@ class train_config:
     dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP
     save_optimizer: bool=False # will be used if using FSDP
     use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
-
-    
-    
-    

+ 2 - 0
src/llama_recipes/data/__init__.py

@@ -0,0 +1,2 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

+ 34 - 0
src/llama_recipes/data/concatenator.py

@@ -0,0 +1,34 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+from tqdm import tqdm
+from itertools import chain
+
+from torch.utils.data import Dataset
+
+
+class ConcatDataset(Dataset):
+    def __init__(self, dataset, chunk_size=4096):
+        self.dataset = dataset
+        self.chunk_size = chunk_size
+
+        self.samples = []
+
+        buffer = {
+            "input_ids": [],
+            "attention_mask": [],
+            "labels": [],
+            }
+
+        for sample in tqdm(self.dataset, desc="Preprocessing dataset", dynamic_ncols=True):
+            buffer = {k: v + sample[k] for k,v in buffer.items()}
+
+            while len(next(iter(buffer.values()))) > self.chunk_size:
+                self.samples.append({k: v[:self.chunk_size] for k,v in buffer.items()})
+                buffer = {k: v[self.chunk_size:] for k,v in buffer.items()}
+
+    def __getitem__(self, idx):
+        return self.samples[idx]
+
+    def __len__(self):
+        return len(self.samples)

+ 57 - 0
src/llama_recipes/data/sampler.py

@@ -0,0 +1,57 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import random
+from itertools import islice
+
+import numpy as np
+import torch
+
+
+class LengthBasedBatchSampler(torch.utils.data.BatchSampler):
+    def __init__(self, data_source, batch_size: int, drop_last: bool, shuffle: bool=True) -> None:
+        if isinstance(next(iter(data_source)), dict):
+            first_key = next(iter(next(iter(data_source)).keys()))
+            self.lengths = [len(d[first_key]) for d in data_source]
+        else:
+            self.lengths = [len(d) for d in data_source]
+        self.batch_size = batch_size
+        self.drop_last = drop_last
+        self.shuffle = shuffle
+
+    def __iter__(self):
+        ids = np.argsort(self.lengths)
+        if self.drop_last:
+            ids = ids[:len(ids) // self.batch_size * self.batch_size]
+
+        batches = [ids[i:i+self.batch_size] for i in range(0, len(ids), self.batch_size)]
+
+        if self.shuffle:
+            random.shuffle(batches)
+
+        for b in batches:
+            yield b
+
+    def __len__(self):
+        if self.drop_last:
+            return len(self.lengths) // self.batch_size
+        else:
+            return len(self.lengths) // self.batch_size + (len(self.lengths) % self.batch_size > 0)
+
+
+class DistributedLengthBasedBatchSampler(torch.utils.data.BatchSampler):
+    def __init__(self, data_source, batch_size: int, num_replicas: int, rank: int, shuffle: bool = True, seed: int = 0) -> None:
+        random.seed(seed)
+        self.batch_sampler = LengthBasedBatchSampler(
+            data_source, batch_size=batch_size, drop_last=True, shuffle=shuffle
+            )
+        self.num_replicas = num_replicas
+        self.rank = rank
+        
+    def __iter__(self):
+        max_length = len(self.batch_sampler) // self.num_replicas * self.num_replicas
+        return islice(self.batch_sampler, self.rank, max_length, self.num_replicas)
+         
+    def __len__(self):
+        return len(self.batch_sampler) // self.num_replicas
+            

+ 4 - 14
src/llama_recipes/datasets/alpaca_dataset.py

@@ -24,17 +24,14 @@ PROMPT_DICT = {
 }
 
 class InstructionDataset(Dataset):
-    def __init__(self, dataset_config, tokenizer, partition="train", max_words=30):
+    def __init__(self, dataset_config, tokenizer, partition="train"):
         self.ann = json.load(open(dataset_config.data_path))
         if partition == "train":
             self.ann = self.ann
         else:
             self.ann = self.ann[:200]
 
-        self.max_words = max_words
-        # tokenizer = Tokenizer(model_path=model_path + "./tokenizer.model")
         self.tokenizer = tokenizer
-        # self.tokenizer1 = tokenizer
 
     def __len__(self):
         return len(self.ann)
@@ -57,22 +54,15 @@ class InstructionDataset(Dataset):
         example = torch.tensor(
             example, dtype=torch.int64
         )
-        padding = self.max_words - example.shape[0]
-        if padding > 0:
-            example = torch.cat((example, torch.zeros(padding, dtype=torch.int64) - 1))
-        elif padding < 0:
-            example = example[: self.max_words]
         labels = copy.deepcopy(example)
         labels[: len(prompt)] = -1
         example_mask = example.ge(0)
         label_mask = labels.ge(0)
         example[~example_mask] = 0
         labels[~label_mask] = IGNORE_INDEX
-        example_mask = example_mask.float()
-        label_mask = label_mask.float()
 
         return {
-            "input_ids": example,
-            "labels": labels,
-            "attention_mask":example_mask,
+            "input_ids": example.tolist(),
+            "labels": labels.tolist(),
+            "attention_mask":example_mask.tolist(),
         }

+ 13 - 18
src/llama_recipes/datasets/grammar_dataset/grammar_dataset.py

@@ -10,8 +10,6 @@ from pathlib import Path
 
 from torch.utils.data import Dataset
 
-from llama_recipes.datasets.utils import ConcatDataset
-
 
 class grammar(Dataset):
     def __init__(
@@ -48,24 +46,22 @@ class grammar(Dataset):
 
         input_ = example_batch["input"]
         target_ = example_batch["target"]
-        
-        prompt = f"Correct this to standard English: {input_}\n---\nCorrected: {target_}"
-        sample = self.tokenizer(prompt)
-        
-        return sample
-
-    def __getitem__(self, index):
-        sample = self.convert_to_features(self.dataset["train"][index])
-        source_ids = sample["input_ids"]
 
-        src_mask = sample["attention_mask"]
+        prompt = f"Correct this to standard English: {input_}\n---\nCorrected: "
+        prompt_ids = self.tokenizer.encode(self.tokenizer.bos_token + prompt, add_special_tokens=False)
+        label_ids = self.tokenizer.encode(target_ + self.tokenizer.eos_token, add_special_tokens=False)
 
-        return {
-            "input_ids": source_ids,
-            "attention_mask": src_mask,
-            "labels": source_ids.copy(),
+        sample = {
+            "input_ids": prompt_ids + label_ids,
+            "attention_mask": [1] * len(prompt_ids + label_ids),
+            "labels": [-100] * len(prompt_ids) + label_ids
         }
 
+        return sample
+
+    def __getitem__(self, index):
+        return self.convert_to_features(self.dataset["train"][int(index)])
+
 
 def get_dataset(
     dataset_config, tokenizer, csv_name=None
@@ -80,6 +76,5 @@ def get_dataset(
         tokenizer=tokenizer,
         csv_name=csv_name,
     )
-    
-    return ConcatDataset(dataset, chunk_size=dataset_config.input_length)
 
+    return dataset

+ 19 - 13
src/llama_recipes/datasets/samsum_dataset.py

@@ -3,31 +3,37 @@
 
 # For dataset details visit: https://huggingface.co/datasets/samsum
 
+import copy
 import datasets
 
-from llama_recipes.datasets.utils import Concatenator
 
 def get_preprocessed_samsum(dataset_config, tokenizer, split):
     dataset = datasets.load_dataset("samsum", split=split)
 
     prompt = (
-        f"Summarize this dialog:\n{{dialog}}\n---\nSummary:\n{{summary}}{{eos_token}}"
+        f"Summarize this dialog:\n{{dialog}}\n---\nSummary:\n"
     )
 
     def apply_prompt_template(sample):
         return {
-            "text": prompt.format(
-                dialog=sample["dialogue"],
-                summary=sample["summary"],
-                eos_token=tokenizer.eos_token,
-            )
+            "prompt": prompt.format(dialog=sample["dialogue"]),
+            "summary": sample["summary"],
         }
 
     dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features))
-        
-    dataset = dataset.map(
-        lambda sample: tokenizer(sample["text"]),
-        batched=True,
-        remove_columns=list(dataset.features),
-    ).map(Concatenator(), batched=True)
+
+    def tokenize_add_label(sample):
+        prompt = tokenizer.encode(tokenizer.bos_token + sample["prompt"], add_special_tokens=False)
+        summary = tokenizer.encode(sample["summary"] +  tokenizer.eos_token, add_special_tokens=False)
+
+        sample = {
+            "input_ids": prompt + summary,
+            "attention_mask" : [1] * (len(prompt) + len(summary)),
+            "labels": [-100] * len(prompt) + summary,
+            }
+
+        return sample
+
+    dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features))
+
     return dataset

+ 0 - 66
src/llama_recipes/datasets/utils.py

@@ -1,66 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
-
-from tqdm import tqdm
-from itertools import chain
-
-from torch.utils.data import Dataset
-
-class Concatenator(object):
-    def __init__(self, chunk_size=2048):
-        self.chunk_size=chunk_size
-        self.residual = {"input_ids": [], "attention_mask": []}
-        
-    def __call__(self, batch):
-        concatenated_samples = {
-            k: v + list(chain(*batch[k])) for k, v in self.residual.items()
-        }
-
-        total_length = len(concatenated_samples[list(concatenated_samples.keys())[0]])
-
-        if total_length >= self.chunk_size:
-            chunk_num = total_length // self.chunk_size
-            result = {
-                k: [
-                    v[i : i + self.chunk_size]
-                    for i in range(0, chunk_num * self.chunk_size, self.chunk_size)
-                ]
-                for k, v in concatenated_samples.items()
-            }
-            self.residual = {
-                k: v[(chunk_num * self.chunk_size) :]
-                for k, v in concatenated_samples.items()
-            }
-        else:
-            result = concatenated_samples
-            self.residual = {k: [] for k in concatenated_samples.keys()}
-
-        result["labels"] = result["input_ids"].copy()
-
-        return result
-
-class ConcatDataset(Dataset):
-    def __init__(self, dataset, chunk_size=4096):
-        self.dataset = dataset
-        self.chunk_size = chunk_size
-        
-        self.samples = []
-        
-        buffer = {
-            "input_ids": [],
-            "attention_mask": [],
-            "labels": [],
-            }
-        
-        for sample in tqdm(self.dataset, desc="Preprocessing dataset", dynamic_ncols=True):
-            buffer = {k: v + sample[k] for k,v in buffer.items()}
-            
-            while len(next(iter(buffer.values()))) > self.chunk_size:
-                self.samples.append({k: v[:self.chunk_size] for k,v in buffer.items()})
-                buffer = {k: v[self.chunk_size:] for k,v in buffer.items()}
-                
-    def __getitem__(self, idx):
-        return self.samples[idx]
-    
-    def __len__(self):
-        return len(self.samples)

+ 25 - 37
src/llama_recipes/finetuning.py

@@ -5,8 +5,8 @@ import os
 from pkg_resources import packaging
 
 import fire
+import random
 import torch
-import torch.distributed as dist
 import torch.optim as optim
 from peft import get_peft_model, prepare_model_for_int8_training
 from torch.distributed.fsdp import (
@@ -14,16 +14,16 @@ from torch.distributed.fsdp import (
 )
 from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
 from torch.optim.lr_scheduler import StepLR
-from torch.utils.data import DistributedSampler
 from transformers import (
     LlamaForCausalLM,
     LlamaTokenizer,
     LlamaConfig,
-    default_data_collator,
 )
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer
 
-from llama_recipes.configs import fsdp_config, train_config
+from llama_recipes.configs import fsdp_config as FSDP_CONFIG
+from llama_recipes.configs import train_config as TRAIN_CONFIG
+from llama_recipes.data.concatenator import ConcatDataset
 from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing
 
 from llama_recipes.utils import fsdp_auto_wrap_policy
@@ -31,6 +31,7 @@ from llama_recipes.utils.config_utils import (
     update_config,
     generate_peft_config,
     generate_dataset_config,
+    get_dataloader_kwargs,
 )
 from llama_recipes.utils.dataset_utils import get_preprocessed_dataset
 
@@ -47,11 +48,13 @@ from llama_recipes.utils.train_utils import (
 
 def main(**kwargs):
     # Update the configuration for the training and sharding process
+    train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG()
     update_config((train_config, fsdp_config), **kwargs)
 
     # Set the seeds for reproducibility
     torch.cuda.manual_seed(train_config.seed)
     torch.manual_seed(train_config.seed)
+    random.seed(train_config.seed)
 
     if train_config.enable_fsdp:
         setup()
@@ -102,14 +105,19 @@ def main(**kwargs):
     if train_config.enable_fsdp and train_config.use_fast_kernels:
         """
         For FSDP and FSDP+PEFT, setting 'use_fast_kernels' will enable
-        using of Flash Attention or Xformer memory-efficient kernels 
+        using of Flash Attention or Xformer memory-efficient kernels
         based on the hardware being used. This would speed up fine-tuning.
         """
         try:
             from optimum.bettertransformer import BetterTransformer
-            model = BetterTransformer.transform(model) 
+            model = BetterTransformer.transform(model)
         except ImportError:
             print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")
+
+    # Load the tokenizer and add special tokens
+    tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name)
+    tokenizer.pad_token_id = tokenizer.eos_token_id
+
     print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
 
     # Prepare the model for int8 training if quantization is enabled
@@ -120,14 +128,6 @@ def main(**kwargs):
     if train_config.enable_fsdp and fsdp_config.pure_bf16:
         model.to(torch.bfloat16)
 
-    # Load the tokenizer and add special tokens
-    tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name)
-    tokenizer.add_special_tokens(
-            {
-
-                "pad_token": "<PAD>",
-            }
-        )
     if train_config.use_peft:
         peft_config = generate_peft_config(train_config, kwargs)
         model = get_peft_model(model, peft_config)
@@ -179,43 +179,31 @@ def main(**kwargs):
     if not train_config.enable_fsdp or rank == 0:
             print(f"--> Validation Set Length = {len(dataset_val)}")
 
-    train_sampler = None
-    val_sampler = None
-    if train_config.enable_fsdp:
-        train_sampler = DistributedSampler(
-            dataset_train,
-            rank=dist.get_rank(),
-            num_replicas=dist.get_world_size(),
-            shuffle=True,
-        )
-        if train_config.run_validation:
-            val_sampler = DistributedSampler(
-                dataset_val,
-                rank=dist.get_rank(),
-                num_replicas=dist.get_world_size(),
-            )
+    if train_config.batching_strategy == "packing":
+        dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
+
+    train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train")
 
     # Create DataLoaders for the training and validation dataset
     train_dataloader = torch.utils.data.DataLoader(
         dataset_train,
-        batch_size=train_config.batch_size_training,
         num_workers=train_config.num_workers_dataloader,
         pin_memory=True,
-        sampler=train_sampler if train_sampler else None,
-        drop_last=True,
-        collate_fn=default_data_collator,
+        **train_dl_kwargs,
     )
 
     eval_dataloader = None
     if train_config.run_validation:
+        if train_config.batching_strategy == "packing":
+            dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)
+
+        val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, tokenizer, "val")
+
         eval_dataloader = torch.utils.data.DataLoader(
             dataset_val,
-            batch_size=train_config.val_batch_size,
             num_workers=train_config.num_workers_dataloader,
             pin_memory=True,
-            sampler=val_sampler if val_sampler else None,
-            drop_last=True,
-            collate_fn=default_data_collator,
+            **val_dl_kwargs,
         )
 
     # Initialize the optimizer and learning rate scheduler

+ 49 - 11
src/llama_recipes/utils/config_utils.py

@@ -3,13 +3,19 @@
 
 import inspect
 from dataclasses import asdict
+
+import torch.distributed as dist
+from torch.utils.data import DistributedSampler
 from peft import (
     LoraConfig,
     AdaptionPromptConfig,
     PrefixTuningConfig,
 )
+from transformers import default_data_collator
+from transformers.data import DataCollatorForSeq2Seq
 
 from llama_recipes.configs import datasets, lora_config, llama_adapter_config, prefix_config, train_config
+from llama_recipes.data.sampler import LengthBasedBatchSampler, DistributedLengthBasedBatchSampler
 from llama_recipes.utils.dataset_utils import DATASET_PREPROC
 
 
@@ -32,31 +38,63 @@ def update_config(config, **kwargs):
                         print(f"Warning: {config_name} does not accept parameter: {k}")
             elif isinstance(config, train_config):
                 print(f"Warning: unknown parameter {k}")
-                        
-                        
+
+
 def generate_peft_config(train_config, kwargs):
     configs = (lora_config, llama_adapter_config, prefix_config)
     peft_configs = (LoraConfig, AdaptionPromptConfig, PrefixTuningConfig)
     names = tuple(c.__name__.rstrip("_config") for c in configs)
-    
+
     assert train_config.peft_method in names, f"Peft config not found: {train_config.peft_method}"
-    
+
     config = configs[names.index(train_config.peft_method)]()
-    
+
     update_config(config, **kwargs)
     params = asdict(config)
     peft_config = peft_configs[names.index(train_config.peft_method)](**params)
-    
+
     return peft_config
 
 
 def generate_dataset_config(train_config, kwargs):
     names = tuple(DATASET_PREPROC.keys())
-        
+
     assert train_config.dataset in names, f"Unknown dataset: {train_config.dataset}"
-    
+
     dataset_config = {k:v for k, v in inspect.getmembers(datasets)}[train_config.dataset]()
-        
+
     update_config(dataset_config, **kwargs)
-    
-    return  dataset_config
+
+    return  dataset_config
+
+
+def get_dataloader_kwargs(train_config, dataset, tokenizer, mode):
+        kwargs = {}
+        batch_size = train_config.batch_size_training if mode=="train" else train_config.val_batch_size
+        if train_config.batching_strategy == "padding":
+            if train_config.enable_fsdp:
+                kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler(
+                    dataset,
+                    batch_size=batch_size,
+                    rank=dist.get_rank(),
+                    num_replicas=dist.get_world_size(),
+                    shuffle=mode=="train",
+                )
+            else:
+                kwargs["batch_sampler"] = LengthBasedBatchSampler(dataset, batch_size, drop_last=True, shuffle=mode=="train")
+            kwargs["collate_fn"] = DataCollatorForSeq2Seq(tokenizer)
+        elif train_config.batching_strategy == "packing":
+            if train_config.enable_fsdp:
+                kwargs["sampler"] = DistributedSampler(
+                dataset,
+                rank=dist.get_rank(),
+                num_replicas=dist.get_world_size(),
+                shuffle=mode=="train",
+            )
+            kwargs["batch_size"] = batch_size
+            kwargs["drop_last"] = True
+            kwargs["collate_fn"] = default_data_collator
+        else:
+            raise ValueError(f"Unknown batching strategy: {train_config.batching_strategy}")
+
+        return kwargs

+ 6 - 6
src/llama_recipes/utils/dataset_utils.py

@@ -33,24 +33,24 @@ def get_custom_dataset(dataset_config, tokenizer, split: str):
         module_path, func_name = dataset_config.file.split(":")
     else:
         module_path, func_name = dataset_config.file, "get_custom_dataset"
-        
+
     if not module_path.endswith(".py"):
         raise ValueError(f"Dataset file {module_path} is not a .py file.")
-    
+
     module_path = Path(module_path)
     if not module_path.is_file():
         raise FileNotFoundError(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.")
-    
+
     module = load_module_from_py_file(module_path.as_posix())
     try:
         return getattr(module, func_name)(dataset_config, tokenizer, split)
     except AttributeError as e:
         print(f"It seems like the given method name ({func_name}) is not present in the dataset .py file ({module_path.as_posix()}).")
         raise e
-    
+
 
 DATASET_PREPROC = {
-    "alpaca_dataset": partial(get_alpaca_dataset, max_words=224),
+    "alpaca_dataset": partial(get_alpaca_dataset),
     "grammar_dataset": get_grammar_dataset,
     "samsum_dataset": get_samsum_dataset,
     "custom_dataset": get_custom_dataset,
@@ -69,7 +69,7 @@ def get_preprocessed_dataset(
             if split == "train"
             else dataset_config.test_split
         )
-    
+
     return DATASET_PREPROC[dataset_config.dataset](
         dataset_config,
         tokenizer,

+ 34 - 34
src/llama_recipes/utils/train_utils.py

@@ -26,7 +26,7 @@ from llama_recipes.utils.memory_utils import MemoryTrace
 def set_tokenizer_params(tokenizer: LlamaTokenizer):
     tokenizer.pad_token_id = 0
     tokenizer.padding_side = "left"
-    
+
 # Converting Bytes to Megabytes
 def byte2mb(x):
     return int(x / 2**20)
@@ -34,7 +34,7 @@ def byte2mb(x):
 def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None):
     """
     Trains the model on the given dataloader
-    
+
     Args:
         model: The model to be trained
         train_dataloader: The dataloader containing the training data
@@ -46,18 +46,18 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         train_config: The training configuration
         eval_dataloader: The dataloader containing the eval data
         tokenizer: tokenizer used in the eval for decoding the predicitons
-    
+
     Returns: results dictionary containing average training and validation perplexity and loss
     """
     # Create a gradient scaler for fp16
     if train_config.use_fp16 and train_config.enable_fsdp:
         scaler = ShardedGradScaler()
     elif train_config.use_fp16 and not train_config.enable_fsdp:
-        scaler = torch.cuda.amp.GradScaler() 
+        scaler = torch.cuda.amp.GradScaler()
     if train_config.enable_fsdp:
         world_size = int(os.environ["WORLD_SIZE"])
     autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext
-    
+
     train_prep = []
     train_loss = []
     val_prep = []
@@ -78,7 +78,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                     if train_config.enable_fsdp:
                         batch[key] = batch[key].to(local_rank)
                     else:
-                        batch[key] = batch[key].to('cuda:0')              
+                        batch[key] = batch[key].to('cuda:0')
                 with autocast():
                     loss = model(**batch).loss
                 loss = loss / gradient_accumulation_steps
@@ -101,9 +101,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
 
                 pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
             pbar.close()
-                
+
         epoch_end_time = time.perf_counter()-epoch_start_time
-        epoch_times.append(epoch_end_time)    
+        epoch_times.append(epoch_end_time)
         # Reducing total_loss across all devices if there's more than one CUDA device
         if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
             dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
@@ -111,10 +111,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         if train_config.enable_fsdp:
             train_epoch_loss = train_epoch_loss/world_size
         train_perplexity = torch.exp(train_epoch_loss)
-        
+
         train_prep.append(train_perplexity)
         train_loss.append(train_epoch_loss)
-        
+
         if train_config.enable_fsdp:
             if rank==0:
                 print(f"Max CUDA memory allocated was {memtrace.peak} GB")
@@ -128,10 +128,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
             print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
             print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
             print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
-        
+
         # Update the learning rate as needed
         lr_scheduler.step()
-          
+
         if train_config.run_validation:
             eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer)
             checkpoint_start_time = time.perf_counter()
@@ -144,23 +144,23 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                             print(f"we are about to save the PEFT modules")
                     else:
                         print(f"we are about to save the PEFT modules")
-                    model.save_pretrained(train_config.output_dir)  
+                    model.save_pretrained(train_config.output_dir)
                     if train_config.enable_fsdp:
-                        if rank==0: 
+                        if rank==0:
                             print(f"PEFT modules are saved in {train_config.output_dir} directory")
                     else:
                         print(f"PEFT modules are saved in {train_config.output_dir} directory")
-                        
+
                 else:
                     if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
-                        
+
                         save_model_checkpoint(
                             model, optimizer, rank, train_config, epoch=epoch
                         )
                     elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
                         print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
                         print("=====================================================")
-                        
+
                         save_model_and_optimizer_sharded(model, rank, train_config)
                         if train_config.save_optimizer:
                             save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
@@ -172,7 +172,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                             model, optimizer, rank, train_config, epoch=epoch
                         )
                         print(" Saving the FSDP model checkpoints and optimizer using FULL_STATE_DICT")
-                        print("=====================================================")                     
+                        print("=====================================================")
                 if train_config.enable_fsdp:
                     dist.barrier()
             checkpoint_end_time = time.perf_counter() - checkpoint_start_time
@@ -196,8 +196,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
     avg_train_prep = sum(train_prep)/len(train_prep)
     avg_train_loss = sum(train_loss)/len(train_loss)
     if train_config.run_validation:
-        avg_eval_prep = sum(val_prep)/len(val_prep) 
-        avg_eval_loss = sum(val_loss)/len(val_loss) 
+        avg_eval_prep = sum(val_prep)/len(val_prep)
+        avg_eval_loss = sum(val_loss)/len(val_loss)
 
     results['avg_train_prep'] = avg_train_prep
     results['avg_train_loss'] = avg_train_loss
@@ -206,27 +206,27 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         results['avg_eval_loss'] = avg_eval_loss
     results["avg_epoch_time"] = avg_epoch_time
     results["avg_checkpoint_time"] = avg_checkpoint_time
-    
+
     #saving the training params including fsdp setting for reference.
     if train_config.enable_fsdp and not train_config.use_peft:
         save_train_params(train_config, fsdp_config, rank)
-        
+
     return results
 
 def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
     """
     Evaluates the model on the given dataloader
-    
+
     Args:
         model: The model to evaluate
         eval_dataloader: The dataloader containing the evaluation data
         local_rank: The rank of the current node in a distributed setting
         tokenizer: The tokenizer used to decode predictions
-    
+
     Returns: eval_ppl, eval_epoch_loss
     """
     if train_config.enable_fsdp:
-        world_size = int(os.environ["WORLD_SIZE"]) 
+        world_size = int(os.environ["WORLD_SIZE"])
     model.eval()
     eval_preds = []
     eval_loss = 0.0  # Initialize evaluation loss
@@ -248,24 +248,24 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
             eval_preds.extend(
                 tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True)
             )
-    
+
     # If there's more than one CUDA device, reduce evaluation loss across all devices
     if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
         dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
-    
+
     # Compute average loss and perplexity
     eval_epoch_loss = eval_loss / len(eval_dataloader)
     if train_config.enable_fsdp:
         eval_epoch_loss = eval_epoch_loss/world_size
     eval_ppl = torch.exp(eval_epoch_loss)
-    
+
     # Print evaluation metrics
     if train_config.enable_fsdp:
         if local_rank==0:
             print(f" {eval_ppl=} {eval_epoch_loss=}")
     else:
         print(f" {eval_ppl=} {eval_epoch_loss=}")
-        
+
     return eval_ppl, eval_epoch_loss
 
 def freeze_transformer_layers(model, num_layer):
@@ -279,8 +279,8 @@ def check_frozen_layers_peft_model(model):
      for i, layer in enumerate(model.base_model.model.model.layers):
             for name, param in layer.named_parameters():
                 print(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}")
-                
-                
+
+
 def setup():
     """Initialize the process group for distributed training"""
     dist.init_process_group("nccl")
@@ -293,7 +293,7 @@ def setup_environ_flags(rank):
     # os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
     # This flag will help with CUDA memory fragmentations that can lead into OOM in some cases.
     # Note this is only availble in PyTorch Nighlies (as of July 30 2023)
-    # os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True' 
+    # os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True'
     if rank == 0:
         print(f"--> Running with torch dist debug set to detail")
 
@@ -338,7 +338,7 @@ def print_model_size(model, config, rank: int = 0) -> None:
 
 def get_policies(cfg, rank):
     """Get the policies for mixed precision and fsdp wrapping"""
-    
+
     verify_bfloat_support = (
     torch.version.cuda
     and torch.cuda.is_bf16_supported()
@@ -374,7 +374,7 @@ def save_train_params(train_config, fsdp_config, rank):
     This will be used by converter script in the inference folder to fetch the HF model name or path.
     It also would be hepful as a log for future references.
     """
-    # Convert the train_config and fsdp_config objects to dictionaries, 
+    # Convert the train_config and fsdp_config objects to dictionaries,
     # converting all values to strings to ensure they can be serialized into a YAML file
     train_config_dict = {k: str(v) for k, v in vars(train_config).items() if not k.startswith('__')}
     fsdp_config_dict = {k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith('__')}

+ 18 - 0
tests/conftest.py

@@ -0,0 +1,18 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import pytest
+
+from transformers import LlamaTokenizer
+
+
+@pytest.fixture
+def setup_tokenizer():
+    def _helper(tokenizer):
+        #Align with Llama 2 tokenizer
+        tokenizer.from_pretrained.return_value = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")
+        tokenizer.from_pretrained.return_value.add_special_tokens({'bos_token': '<s>', 'eos_token': '</s>'})
+        tokenizer.from_pretrained.return_value.bos_token_id = 1
+        tokenizer.from_pretrained.return_value.eos_token_id = 2
+
+    return _helper

+ 40 - 13
tests/datasets/test_custom_dataset.py

@@ -4,21 +4,38 @@
 import pytest
 from unittest.mock import patch
 
+from transformers import LlamaTokenizer
+
+def check_padded_entry(batch):
+    seq_len = sum(batch["attention_mask"][0])
+    assert seq_len < len(batch["attention_mask"][0])
+
+    assert batch["labels"][0][0] == -100
+    assert batch["labels"][0][seq_len-1] == 2
+    assert batch["labels"][0][-1] == -100
+    assert batch["input_ids"][0][0] == 1
+    assert batch["input_ids"][0][-1] == 2
+
 
 @patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
-def test_custom_dataset(step_lr, optimizer, get_model, train, mocker):
+def test_custom_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer):
     from llama_recipes.finetuning import main
 
+    setup_tokenizer(tokenizer)
+
     kwargs = {
         "dataset": "custom_dataset",
         "model_name": "decapoda-research/llama-7b-hf", # We use the tokenizer as a surrogate for llama2 tokenizer here
         "custom_dataset.file": "examples/custom_dataset.py",
         "custom_dataset.train_split": "validation",
         "batch_size_training": 2,
+        "val_batch_size": 4,
         "use_peft": False,
+        "batching_strategy": "padding"
         }
 
     main(**kwargs)
@@ -30,24 +47,34 @@ def test_custom_dataset(step_lr, optimizer, get_model, train, mocker):
     eval_dataloader = args[2]
     tokenizer = args[3]
 
-    assert len(train_dataloader) == 226
-    assert len(eval_dataloader) == 2*226
+    assert len(train_dataloader) == 1120
+    assert len(eval_dataloader) == 1120 //2
+
+    it = iter(eval_dataloader)
+    batch = next(it)
+    STRING = tokenizer.decode(batch["input_ids"][0], skip_special_tokens=True)
+    EXPECTED_STRING = "[INST] Who made Berlin [/INST] dunno"
+    assert STRING.startswith(EXPECTED_STRING)
+
+    assert batch["input_ids"].size(0) == 4
+    assert set(("labels", "input_ids", "attention_mask")) == set(batch.keys())
+
+    check_padded_entry(batch)
 
     it = iter(train_dataloader)
-    STRING = tokenizer.decode(next(it)["input_ids"][0], skip_special_tokens=True)
-    EXPECTED_STRING = "[INST] Напиши функцию на языке swift, которая сортирует массив целых чисел, а затем выводит его на экран [/INST] Вот функция, "
+    for _ in range(5):
+        next(it)
 
+    batch = next(it)
+    STRING = tokenizer.decode(batch["input_ids"][0], skip_special_tokens=True)
+    EXPECTED_STRING = "[INST] How do I initialize a Typescript project using npm and git? [/INST] # Initialize a new NPM project"
     assert STRING.startswith(EXPECTED_STRING)
 
-    next(it)
-    next(it)
-    next(it)
-    STRING = tokenizer.decode(next(it)["input_ids"][0], skip_special_tokens=True)
-    EXPECTED_SUBSTRING_1 = "Therefore you are correct.  [INST] How can L’Hopital’s Rule be"
-    EXPECTED_SUBSTRING_2 = "a circular path around the turn.  [INST] How on earth is that related to L’Hopital’s Rule?"
+    assert batch["input_ids"].size(0) == 2
+    assert set(("labels", "input_ids", "attention_mask")) == set(batch.keys())
+
+    check_padded_entry(batch)
 
-    assert EXPECTED_SUBSTRING_1 in STRING
-    assert EXPECTED_SUBSTRING_2 in STRING
 
 
 @patch('llama_recipes.finetuning.train')

+ 54 - 0
tests/datasets/test_grammar_datasets.py

@@ -0,0 +1,54 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+from unittest.mock import patch
+
+from transformers import LlamaTokenizer
+
+
+@patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaTokenizer')
+@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
+@patch('llama_recipes.finetuning.optim.AdamW')
+@patch('llama_recipes.finetuning.StepLR')
+def test_grammar_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer):
+    from llama_recipes.finetuning import main
+
+    setup_tokenizer(tokenizer)
+
+    BATCH_SIZE = 8
+    kwargs = {
+        "model_name": "decapoda-research/llama-7b-hf",
+        "batch_size_training": BATCH_SIZE,
+        "val_batch_size": 1,
+        "use_peft": False,
+        "dataset": "grammar_dataset",
+        "batching_strategy": "padding",
+        }
+
+    main(**kwargs)
+
+    assert train.call_count == 1
+
+    args, kwargs = train.call_args
+    train_dataloader = args[1]
+    eval_dataloader = args[2]
+
+    VAL_SAMPLES = 2988
+    TRAIN_SAMPLES = 13016
+
+    assert len(train_dataloader) == TRAIN_SAMPLES // BATCH_SIZE
+    assert len(eval_dataloader) == VAL_SAMPLES
+
+    batch = next(iter(train_dataloader))
+
+    assert "labels" in batch.keys()
+    assert "input_ids" in batch.keys()
+    assert "attention_mask" in batch.keys()
+
+    assert batch["labels"][0][29] == -100
+    assert batch["labels"][0][30] == 29871
+
+    assert batch["input_ids"][0][0] == 1
+    assert batch["labels"][0][-1] == 2
+    assert batch["input_ids"][0][-1] == 2

+ 30 - 14
tests/datasets/test_samsum_datasets.py

@@ -1,37 +1,53 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
+from functools import partial
 from unittest.mock import patch
 
 
 @patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
-@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
-def test_custom_dataset(step_lr, optimizer, tokenizer, get_model, train, mocker):
+def test_samsum_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer):
     from llama_recipes.finetuning import main
-        
-    tokenizer.return_value = mocker.MagicMock(side_effect=lambda x: {"input_ids":[len(x)*[0,]], "attention_mask": [len(x)*[0,]]})
-    
-    
+
+    setup_tokenizer(tokenizer)
+
+    BATCH_SIZE = 8
     kwargs = {
-        "batch_size_training": 1,
+        "model_name": "decapoda-research/llama-7b-hf",
+        "batch_size_training": BATCH_SIZE,
+        "val_batch_size": 1,
         "use_peft": False,
         "dataset": "samsum_dataset",
+        "batching_strategy": "padding",
         }
-    
+
     main(**kwargs)
-    
+
     assert train.call_count == 1
-    
+
     args, kwargs = train.call_args
     train_dataloader = args[1]
     eval_dataloader = args[2]
-    
+
     VAL_SAMPLES = 818
     TRAIN_SAMPLES = 14732
-    CONCAT_SIZE = 2048
-    assert len(train_dataloader) == TRAIN_SAMPLES // CONCAT_SIZE
+
+    assert len(train_dataloader) == TRAIN_SAMPLES // BATCH_SIZE
     assert len(eval_dataloader) == VAL_SAMPLES
-    
+
+    batch = next(iter(train_dataloader))
+
+    assert "labels" in batch.keys()
+    assert "input_ids" in batch.keys()
+    assert "attention_mask" in batch.keys()
+
+    assert batch["labels"][0][268] == -100
+    assert batch["labels"][0][269] == 22291
+
+    assert batch["input_ids"][0][0] == 1
+    assert batch["labels"][0][-1] == 2
+    assert batch["input_ids"][0][-1] == 2

+ 94 - 0
tests/test_batching.py

@@ -0,0 +1,94 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import pytest
+from unittest.mock import patch
+
+
+@patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaTokenizer')
+@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
+@patch('llama_recipes.finetuning.optim.AdamW')
+@patch('llama_recipes.finetuning.StepLR')
+def test_packing(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer):
+    from llama_recipes.finetuning import main
+
+    setup_tokenizer(tokenizer)
+
+    kwargs = {
+        "model_name": "decapoda-research/llama-7b-hf",
+        "batch_size_training": 8,
+        "val_batch_size": 1,
+        "use_peft": False,
+        "dataset": "samsum_dataset",
+        "batching_strategy": "packing",
+        }
+
+    main(**kwargs)
+
+    assert train.call_count == 1
+
+    args, kwargs = train.call_args
+    train_dataloader = args[1]
+    eval_dataloader = args[2]
+
+    assert len(train_dataloader) == 96
+    assert len(eval_dataloader) == 42
+
+    batch = next(iter(train_dataloader))
+
+    assert "labels" in batch.keys()
+    assert "input_ids" in batch.keys()
+    assert "attention_mask" in batch.keys()
+
+    assert batch["labels"][0].size(0) == 4096
+    assert batch["input_ids"][0].size(0) == 4096
+    assert batch["attention_mask"][0].size(0) == 4096
+
+
+@patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaTokenizer')
+@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
+@patch('llama_recipes.finetuning.optim.AdamW')
+@patch('llama_recipes.finetuning.StepLR')
+@patch('llama_recipes.finetuning.setup')
+@patch('llama_recipes.finetuning.FSDP')
+@patch('llama_recipes.finetuning.torch.distributed.is_initialized')
+@patch('llama_recipes.utils.config_utils.dist')
+def test_distributed_packing(dist, is_initialized, fsdp, setup, step_lr, optimizer, get_model, tokenizer, train, setup_tokenizer):
+    import os
+    from llama_recipes.finetuning import main
+
+    setup_tokenizer(tokenizer)
+
+    rank = 0
+    os.environ['LOCAL_RANK'] = f'{rank}'
+    os.environ['RANK'] = f'{rank}'
+    os.environ['WORLD_SIZE'] = '2'
+    os.environ['MASTER_ADDR'] = 'localhost'
+    os.environ['MASTER_PORT'] = '12345'
+
+    kwargs = {
+        "model_name": "decapoda-research/llama-7b-hf",
+        "batch_size_training": 8,
+        "val_batch_size": 1,
+        "use_peft": False,
+        "dataset": "samsum_dataset",
+        "batching_strategy": "packing",
+        "enable_fsdp": True
+        }
+
+    is_initialized.return_value = True
+    dist.get_rank.return_value = rank
+    dist.get_world_size.return_value = 2
+
+    main(**kwargs)
+
+    assert train.call_count == 1
+
+    args, kwargs = train.call_args
+    train_dataloader = args[1]
+    eval_dataloader = args[2]
+
+    assert len(train_dataloader) == 96 //2
+    assert len(eval_dataloader) == 42 //2

+ 81 - 34
tests/test_finetuning.py

@@ -1,14 +1,26 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
+import pytest
 from pytest import approx
 from unittest.mock import patch
 
 from torch.nn import Linear
 from torch.optim import AdamW
 from torch.utils.data.dataloader import DataLoader
+from torch.utils.data.sampler import BatchSampler
 
 from llama_recipes.finetuning import main
+from llama_recipes.data.sampler import LengthBasedBatchSampler
+
+
+def get_fake_dataset():
+    return [{
+        "input_ids":[1],
+        "attention_mask":[1],
+        "labels":[1],
+        }]
+
 
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
@@ -18,23 +30,23 @@ from llama_recipes.finetuning import main
 @patch('llama_recipes.finetuning.StepLR')
 def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
     kwargs = {"run_validation": False}
-    
-    get_dataset.return_value = [1]
-    
+
+    get_dataset.return_value = get_fake_dataset()
+
     main(**kwargs)
-    
+
     assert train.call_count == 1
-    
+
     args, kwargs = train.call_args
     train_dataloader = args[1]
     eval_dataloader = args[2]
-    
+
     assert isinstance(train_dataloader, DataLoader)
     assert eval_dataloader is None
-    
+
     assert get_model.return_value.to.call_args.args[0] == "cuda"
-    
-    
+
+
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
@@ -43,21 +55,22 @@ def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, ge
 @patch('llama_recipes.finetuning.StepLR')
 def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
     kwargs = {"run_validation": True}
-    get_dataset.return_value = [1]
-    
+
+    get_dataset.return_value = get_fake_dataset()
+
     main(**kwargs)
-    
+
     assert train.call_count == 1
-    
+
     args, kwargs = train.call_args
     train_dataloader = args[1]
     eval_dataloader = args[2]
     assert isinstance(train_dataloader, DataLoader)
     assert isinstance(eval_dataloader, DataLoader)
-    
+
     assert get_model.return_value.to.call_args.args[0] == "cuda"
-    
-    
+
+
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
@@ -68,15 +81,15 @@ def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer,
 @patch('llama_recipes.finetuning.StepLR')
 def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, get_dataset, tokenizer, get_model, train):
     kwargs = {"use_peft": True}
-    
-    get_dataset.return_value = [1]
-    
+
+    get_dataset.return_value = get_fake_dataset()
+
     main(**kwargs)
-    
+
     assert get_peft_model.return_value.to.call_args.args[0] == "cuda"
     assert get_peft_model.return_value.print_trainable_parameters.call_count == 1
-    
-    
+
+
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
@@ -85,22 +98,56 @@ def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, ge
 @patch('llama_recipes.finetuning.StepLR')
 def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer, get_model, train, mocker):
     kwargs = {"weight_decay": 0.01}
-    
-    get_dataset.return_value = [1]
-    
-    model = mocker.MagicMock(name="model")
-    model.parameters.return_value = Linear(1,1).parameters()
-    get_peft_model.return_value = model 
-    get_peft_model.return_value.print_trainable_parameters=lambda:None
+
+    get_dataset.return_value = get_fake_dataset()
+
+    get_model.return_value = Linear(1,1)
+
     main(**kwargs)
-    
+
     assert train.call_count == 1
-    
+
     args, kwargs = train.call_args
     optimizer = args[4]
-    
+
     print(optimizer.state_dict())
-    
+
     assert isinstance(optimizer, AdamW)
     assert optimizer.state_dict()["param_groups"][0]["weight_decay"] == approx(0.01)
-    
+
+
+@patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
+@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
+@patch('llama_recipes.finetuning.get_preprocessed_dataset')
+@patch('llama_recipes.finetuning.optim.AdamW')
+@patch('llama_recipes.finetuning.StepLR')
+def test_batching_strategy(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
+    kwargs = {"batching_strategy": "packing"}
+
+    get_dataset.return_value = get_fake_dataset()
+
+    main(**kwargs)
+
+    assert train.call_count == 1
+
+    args, kwargs = train.call_args
+    train_dataloader, eval_dataloader = args[1:3]
+    assert isinstance(train_dataloader.batch_sampler, BatchSampler)
+    assert isinstance(eval_dataloader.batch_sampler, BatchSampler)
+
+    kwargs["batching_strategy"] = "padding"
+    train.reset_mock()
+    main(**kwargs)
+
+    assert train.call_count == 1
+
+    args, kwargs = train.call_args
+    train_dataloader, eval_dataloader = args[1:3]
+    assert isinstance(train_dataloader.batch_sampler, LengthBasedBatchSampler)
+    assert isinstance(eval_dataloader.batch_sampler, LengthBasedBatchSampler)
+
+    kwargs["batching_strategy"] = "none"
+
+    with pytest.raises(ValueError):
+        main(**kwargs)

+ 86 - 0
tests/test_sampler.py

@@ -0,0 +1,86 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import random
+import pytest
+
+import torch
+
+from llama_recipes.data.sampler import LengthBasedBatchSampler
+from llama_recipes.data.sampler import DistributedLengthBasedBatchSampler
+
+SAMPLES = 33
+
+@pytest.fixture
+def dataset():
+    random.seed(42)
+    dataset = []
+    def add_samples(ds, n, a, b):
+        for _ in range(n):
+            ds.append(random.randint(a,b) * [1,])
+    add_samples(dataset, SAMPLES // 2, 1,9)
+    add_samples(dataset, (SAMPLES // 2) + (SAMPLES % 2), 10,20)
+    
+    return random.sample(dataset, len(dataset))
+    
+    
+@pytest.mark.parametrize("batch_size, drop_last", [(2, False), (8, False), (2, True), (8, True)])
+def test_batch_sampler_array(dataset, batch_size, drop_last):
+    
+    sampler = LengthBasedBatchSampler(dataset, batch_size, drop_last)
+    
+    EXPECTED_LENGTH = SAMPLES // batch_size if drop_last else (SAMPLES // batch_size) + (SAMPLES % batch_size)
+    
+    all_ids = [i for b in sampler for i in b]
+    assert len(set(all_ids)) == EXPECTED_LENGTH * batch_size if drop_last else len(dataset)
+    
+    assert len(sampler) == EXPECTED_LENGTH
+    is_long = [len(d)>=10 for d in dataset]
+    
+    def check_batch(batch):
+        return all(batch) or not any(batch)
+    
+    assert all(check_batch(is_long[i] for i in b) for b in sampler)
+    
+    
+@pytest.mark.parametrize("batch_size, drop_last", [(2, False), (8, False), (2, True), (8, True)])
+def test_batch_sampler_dict(dataset, batch_size, drop_last):
+    
+    dist_dataset = [{"input_ids": d, "attention_mask": d} for d in dataset]
+    
+    sampler = LengthBasedBatchSampler(dist_dataset, batch_size, drop_last)
+    
+    EXPECTED_LENGTH = SAMPLES // batch_size if drop_last else (SAMPLES // batch_size) + (SAMPLES % batch_size)
+    
+    assert len(sampler) == EXPECTED_LENGTH
+    is_long = [len(d)>=10 for d in dataset]
+    
+    def check_batch(batch):
+        return all(batch) or not any(batch)
+    
+    assert all(check_batch(is_long[i] for i in b) for b in sampler)
+    
+    
+@pytest.mark.parametrize("batch_size", [2, 8])
+def test_dist_batch_sampling(dataset, batch_size):
+    sampler_1 = DistributedLengthBasedBatchSampler(
+        dataset,
+        batch_size=batch_size,
+        rank=0,
+        num_replicas=2,
+        shuffle=False,
+    )
+    sampler_2 = DistributedLengthBasedBatchSampler(
+        dataset,
+        batch_size=batch_size,
+        rank=1,
+        num_replicas=2,
+        shuffle=False,
+    )
+    
+    ids_1 = set(i for b in sampler_1 for i in b)
+    ids_2 = set(i for b in sampler_2 for i in b)
+    
+    assert ids_1.isdisjoint(ids_2)
+    assert len(ids_1)+len(ids_2) > 0
+    assert len(ids_1)+len(ids_2) == len(dataset) // batch_size  *  batch_size