Browse Source

Feature/length based batch sampling (#206)

Hamid Shojanazeri 1 year ago
parent
commit
373000b2ac

+ 3 - 3
README.md

@@ -137,7 +137,7 @@ Here we make use of Parameter Efficient Methods (PEFT) as described in the next
 
 
 ```bash
 ```bash
 
 
-torchrun --nnodes 1 --nproc_per_node 4  examples/finetuning.py --enable_fsdp --use_peft --peft_method lora --model_name /patht_of_model_folder/7B --pure_bf16 --output_dir Path/to/save/PEFT/model
+torchrun --nnodes 1 --nproc_per_node 4  examples/finetuning.py --enable_fsdp --use_peft --peft_method lora --model_name /patht_of_model_folder/7B --fsdp_config.pure_bf16 --output_dir Path/to/save/PEFT/model
 
 
 ```
 ```
 
 
@@ -148,7 +148,7 @@ Here we use FSDP as discussed in the next section which can be used along with P
 Setting `use_fast_kernels` will enable using of Flash Attention or Xformer memory-efficient kernels based on the hardware being used. This would speed up the fine-tuning job. This has been enabled in `optimum` library from HuggingFace as a one-liner API, please read more [here](https://pytorch.org/blog/out-of-the-box-acceleration/).
 Setting `use_fast_kernels` will enable using of Flash Attention or Xformer memory-efficient kernels based on the hardware being used. This would speed up the fine-tuning job. This has been enabled in `optimum` library from HuggingFace as a one-liner API, please read more [here](https://pytorch.org/blog/out-of-the-box-acceleration/).
 
 
 ```bash
 ```bash
-torchrun --nnodes 1 --nproc_per_node 4  examples/finetuning.py --enable_fsdp --use_peft --peft_method lora --model_name /patht_of_model_folder/7B --pure_bf16 --output_dir Path/to/save/PEFT/model --use_fast_kernels
+torchrun --nnodes 1 --nproc_per_node 4  examples/finetuning.py --enable_fsdp --use_peft --peft_method lora --model_name /patht_of_model_folder/7B --fsdp_config.pure_bf16 --output_dir Path/to/save/PEFT/model --use_fast_kernels
 ```
 ```
 
 
 ### Fine-tuning using FSDP Only
 ### Fine-tuning using FSDP Only
@@ -167,7 +167,7 @@ If you are interested in running full parameter fine-tuning on the 70B model, yo
 
 
 ```bash
 ```bash
 
 
-torchrun --nnodes 1 --nproc_per_node 8 examples/finetuning.py --enable_fsdp --low_cpu_fsdp --pure_bf16 --model_name /patht_of_model_folder/70B --batch_size_training 1 --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned
+torchrun --nnodes 1 --nproc_per_node 8 examples/finetuning.py --enable_fsdp --low_cpu_fsdp --fsdp_config.pure_bf16 --model_name /patht_of_model_folder/70B --batch_size_training 1 --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned
 
 
 ```
 ```
 
 

+ 14 - 2
docs/Dataset.md

@@ -7,6 +7,18 @@ The provided fine tuning script allows you to select between three datasets by p
 * [samsum_dataset](https://huggingface.co/datasets/samsum) contains about 16k messenger-like conversations with summaries.
 * [samsum_dataset](https://huggingface.co/datasets/samsum) contains about 16k messenger-like conversations with summaries.
 * [OpenAssistant/oasst1](https://huggingface.co/datasets/OpenAssistant/oasst1/) contains about 88k messages from assistant-style conversations.
 * [OpenAssistant/oasst1](https://huggingface.co/datasets/OpenAssistant/oasst1/) contains about 88k messages from assistant-style conversations.
 
 
+## Batching Strategies
+Llama-recipes support two strategies to batch requests together.
+The default setting is `packing` which concatenates the tokenized samples into long sequences filling up the context length of the model.
+This is the most compute efficient variant as it avoids any padding and all sequences have the same length.
+Samples at the boundary of the context length are truncated and the remainder of the cut sequence it used as the start of the next long sequence.
+
+If the amount of training data is small this procedure might introduce a lot of noise into the training data which can hurt the prediction performance of the fine-tune model.
+Therefore, we also support a `padding` strategy which does not introduce the addition noise due to truncated sequences.
+The strategy tries to minimize the efficiency loss by batching samples of similar length together so only minimal padding is necessary.
+
+The batching strategy can be selected though the command line parameter `--batching_strategy [packing]/[padding]`.
+
 ## Using custom datasets
 ## Using custom datasets
 
 
 The list of available datasets in llama-recipes is supposed to give users a quick start on training their Llama model.
 The list of available datasets in llama-recipes is supposed to give users a quick start on training their Llama model.
@@ -25,7 +37,7 @@ The `dataset_config` in the above signature will be an instance of llama_recipes
 The split signals wether to return the training or validation dataset.
 The split signals wether to return the training or validation dataset.
 The default function name is `get_custom_dataset` but this can be changes as described below.
 The default function name is `get_custom_dataset` but this can be changes as described below.
 
 
-In order to start a training with the custom dataset we need to set the `--dataset` as well as the `--custom_dataset.file` parameter. 
+In order to start a training with the custom dataset we need to set the `--dataset` as well as the `--custom_dataset.file` parameter.
 ```
 ```
 python -m llama_recipes.finetuning --dataset "custom_dataset" --custom_dataset.file "examples/custom_dataset.py" [TRAINING PARAMETERS]
 python -m llama_recipes.finetuning --dataset "custom_dataset" --custom_dataset.file "examples/custom_dataset.py" [TRAINING PARAMETERS]
 ```
 ```
@@ -35,7 +47,7 @@ python -m llama_recipes.finetuning --dataset "custom_dataset" --custom_dataset.f
 ```
 ```
 This will call the function `get_foo` instead of `get_custom_dataset` when retrieving the dataset.
 This will call the function `get_foo` instead of `get_custom_dataset` when retrieving the dataset.
 
 
-### Adding new dataset 
+### Adding new dataset
 Each dataset has a corresponding configuration (dataclass) in [configs/datasets.py](../src/llama_recipes/configs/datasets.py) which contains the dataset name, training/validation split names, as well as optional parameters like datafiles etc.
 Each dataset has a corresponding configuration (dataclass) in [configs/datasets.py](../src/llama_recipes/configs/datasets.py) which contains the dataset name, training/validation split names, as well as optional parameters like datafiles etc.
 
 
 Additionally, there is a preprocessing function for each dataset in the [datasets](../src/llama_recipes/datasets) folder.
 Additionally, there is a preprocessing function for each dataset in the [datasets](../src/llama_recipes/datasets) folder.

+ 23 - 30
examples/custom_dataset.py

@@ -7,33 +7,27 @@ import copy
 import datasets
 import datasets
 import itertools
 import itertools
 
 
-from llama_recipes.datasets.utils import Concatenator
-
 
 
 B_INST, E_INST = "[INST]", "[/INST]"
 B_INST, E_INST = "[INST]", "[/INST]"
-B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
 
 
 def tokenize_dialog(dialog, tokenizer):
 def tokenize_dialog(dialog, tokenizer):
-    dialog_tokens = [
-            tokenizer(
-                f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ",
-            )
-            for prompt, answer in zip(dialog[::2], dialog[1::2])
-        ]
-    if len(dialog) % 2:    
-        dialog_tokens += [tokenizer(
-            f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}",
-        )]
-    
-    combined_tokens = {}  
-    for k in dialog_tokens[0].keys():
-        combined_tokens[k] = list(itertools.chain(*(t[k] for t in dialog_tokens)))
-    return combined_tokens
+    prompt_tokens = [tokenizer.encode(f"{tokenizer.bos_token}{B_INST} {(prompt['content']).strip()} {E_INST}", add_special_tokens=False) for prompt in dialog[::2]]
+    answer_tokens = [tokenizer.encode(f"{answer['content'].strip()} {tokenizer.eos_token}", add_special_tokens=False) for answer in dialog[1::2]]
+    dialog_tokens = list(itertools.chain.from_iterable(zip(prompt_tokens, answer_tokens)))
+    #Add labels, convert prompt token to -100 in order to ignore in loss function
+    labels_tokens = [len(c)*[-100,] if i % 2 == 0 else c for i,c in enumerate(dialog_tokens)]
+
+    combined_tokens = {
+        "input_ids": list(itertools.chain(*(t for t in dialog_tokens))),
+        "labels": list(itertools.chain(*(t for t in labels_tokens))),
+    }
+
+    return dict(combined_tokens, attention_mask=[1]*len(combined_tokens["input_ids"]))
 
 
 
 
 def get_custom_dataset(dataset_config, tokenizer, split):
 def get_custom_dataset(dataset_config, tokenizer, split):
     dataset = datasets.load_dataset("OpenAssistant/oasst1", split=split)
     dataset = datasets.load_dataset("OpenAssistant/oasst1", split=split)
-    
+
     dataset = dataset.map(lambda sample: {
     dataset = dataset.map(lambda sample: {
         "message_id": sample["message_id"],
         "message_id": sample["message_id"],
         "parent_id": sample["parent_id"],
         "parent_id": sample["parent_id"],
@@ -41,19 +35,19 @@ def get_custom_dataset(dataset_config, tokenizer, split):
         },
         },
         batched=True,
         batched=True,
         remove_columns=list(dataset.features),)
         remove_columns=list(dataset.features),)
-    
+
     nodes = {}
     nodes = {}
-    
+
     messages = {}
     messages = {}
     root_ids = []
     root_ids = []
-    
+
     for data in dataset:
     for data in dataset:
         if data["parent_id"]:
         if data["parent_id"]:
             nodes[data["parent_id"]] = nodes.get(data["parent_id"], []) + [data["message_id"]]
             nodes[data["parent_id"]] = nodes.get(data["parent_id"], []) + [data["message_id"]]
         else:
         else:
             root_ids.append(data["message_id"])
             root_ids.append(data["message_id"])
         messages[data["message_id"]]=data["text"]
         messages[data["message_id"]]=data["text"]
-           
+
     def follow(thread, current_id):
     def follow(thread, current_id):
         thread = copy.copy(thread) + [messages[current_id]]
         thread = copy.copy(thread) + [messages[current_id]]
         if current_id in nodes:
         if current_id in nodes:
@@ -63,18 +57,18 @@ def get_custom_dataset(dataset_config, tokenizer, split):
             return new_threads
             return new_threads
         else:
         else:
             return [thread]
             return [thread]
-        
+
     def get_threads_from_root(root_id):
     def get_threads_from_root(root_id):
         all_threads = []
         all_threads = []
         thread = [messages[root_id]]
         thread = [messages[root_id]]
         for cid in nodes[root_id]:
         for cid in nodes[root_id]:
             all_threads += follow(thread, cid)
             all_threads += follow(thread, cid)
         return all_threads
         return all_threads
-            
+
     dataset = dataset.filter(lambda x: x["message_id"] in root_ids)
     dataset = dataset.filter(lambda x: x["message_id"] in root_ids)
     dataset = dataset.map(lambda x: {"thread": get_threads_from_root(x["message_id"])}, remove_columns=list(dataset.features))
     dataset = dataset.map(lambda x: {"thread": get_threads_from_root(x["message_id"])}, remove_columns=list(dataset.features))
     dataset = dataset.map(lambda x: {"thread": [i for row in x["thread"] for i in row]}, batched=True)
     dataset = dataset.map(lambda x: {"thread": [i for row in x["thread"] for i in row]}, batched=True)
-    
+
     def to_dialog(thread):
     def to_dialog(thread):
         dialog = []
         dialog = []
         for i, content in enumerate(thread):
         for i, content in enumerate(thread):
@@ -83,9 +77,8 @@ def get_custom_dataset(dataset_config, tokenizer, split):
                 "content": content,
                 "content": content,
             })
             })
         return {"dialog": dialog}
         return {"dialog": dialog}
-            
+
     dataset = dataset.map(lambda x: to_dialog(x["thread"]), remove_columns=list(dataset.features))
     dataset = dataset.map(lambda x: to_dialog(x["thread"]), remove_columns=list(dataset.features))
     dataset = dataset.map(lambda x: tokenize_dialog(x["dialog"], tokenizer), remove_columns=list(dataset.features))
     dataset = dataset.map(lambda x: tokenize_dialog(x["dialog"], tokenizer), remove_columns=list(dataset.features))
-    dataset = dataset.map(Concatenator(), batched=True)
-    
-    return dataset
+
+    return dataset

+ 3 - 1
scripts/spellcheck_conf/wordlist.txt

@@ -1157,6 +1157,7 @@ FN
 GBs
 GBs
 MLP
 MLP
 learnable
 learnable
+tokenized
 Colab
 Colab
 GenAI
 GenAI
 Gradio
 Gradio
@@ -1182,4 +1183,5 @@ minnutes
 pdf
 pdf
 quantized
 quantized
 serarch
 serarch
-streamlit
+streamlit
+

+ 0 - 2
src/llama_recipes/configs/datasets.py

@@ -9,7 +9,6 @@ class samsum_dataset:
     dataset: str =  "samsum_dataset"
     dataset: str =  "samsum_dataset"
     train_split: str = "train"
     train_split: str = "train"
     test_split: str = "validation"
     test_split: str = "validation"
-    input_length: int = 2048
     
     
     
     
 @dataclass
 @dataclass
@@ -17,7 +16,6 @@ class grammar_dataset:
     dataset: str = "grammar_dataset"
     dataset: str = "grammar_dataset"
     train_split: str = "src/llama_recipes/datasets/grammar_dataset/gtrain_10k.csv" 
     train_split: str = "src/llama_recipes/datasets/grammar_dataset/gtrain_10k.csv" 
     test_split: str = "src/llama_recipes/datasets/grammar_dataset/grammar_validation.csv"
     test_split: str = "src/llama_recipes/datasets/grammar_dataset/grammar_validation.csv"
-    input_length: int = 2048
 
 
     
     
 @dataclass
 @dataclass

+ 2 - 4
src/llama_recipes/configs/training.py

@@ -11,6 +11,8 @@ class train_config:
     low_cpu_fsdp: bool=False
     low_cpu_fsdp: bool=False
     run_validation: bool=True
     run_validation: bool=True
     batch_size_training: int=4
     batch_size_training: int=4
+    batching_strategy: str="packing" #alternative: padding
+    context_length: int=4096
     gradient_accumulation_steps: int=1
     gradient_accumulation_steps: int=1
     num_epochs: int=3
     num_epochs: int=3
     num_workers_dataloader: int=1
     num_workers_dataloader: int=1
@@ -34,7 +36,3 @@ class train_config:
     dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP
     dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP
     save_optimizer: bool=False # will be used if using FSDP
     save_optimizer: bool=False # will be used if using FSDP
     use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
     use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
-
-    
-    
-    

+ 2 - 0
src/llama_recipes/data/__init__.py

@@ -0,0 +1,2 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

+ 34 - 0
src/llama_recipes/data/concatenator.py

@@ -0,0 +1,34 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+from tqdm import tqdm
+from itertools import chain
+
+from torch.utils.data import Dataset
+
+
+class ConcatDataset(Dataset):
+    def __init__(self, dataset, chunk_size=4096):
+        self.dataset = dataset
+        self.chunk_size = chunk_size
+
+        self.samples = []
+
+        buffer = {
+            "input_ids": [],
+            "attention_mask": [],
+            "labels": [],
+            }
+
+        for sample in tqdm(self.dataset, desc="Preprocessing dataset", dynamic_ncols=True):
+            buffer = {k: v + sample[k] for k,v in buffer.items()}
+
+            while len(next(iter(buffer.values()))) > self.chunk_size:
+                self.samples.append({k: v[:self.chunk_size] for k,v in buffer.items()})
+                buffer = {k: v[self.chunk_size:] for k,v in buffer.items()}
+
+    def __getitem__(self, idx):
+        return self.samples[idx]
+
+    def __len__(self):
+        return len(self.samples)

+ 57 - 0
src/llama_recipes/data/sampler.py

@@ -0,0 +1,57 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import random
+from itertools import islice
+
+import numpy as np
+import torch
+
+
+class LengthBasedBatchSampler(torch.utils.data.BatchSampler):
+    def __init__(self, data_source, batch_size: int, drop_last: bool, shuffle: bool=True) -> None:
+        if isinstance(next(iter(data_source)), dict):
+            first_key = next(iter(next(iter(data_source)).keys()))
+            self.lengths = [len(d[first_key]) for d in data_source]
+        else:
+            self.lengths = [len(d) for d in data_source]
+        self.batch_size = batch_size
+        self.drop_last = drop_last
+        self.shuffle = shuffle
+
+    def __iter__(self):
+        ids = np.argsort(self.lengths)
+        if self.drop_last:
+            ids = ids[:len(ids) // self.batch_size * self.batch_size]
+
+        batches = [ids[i:i+self.batch_size] for i in range(0, len(ids), self.batch_size)]
+
+        if self.shuffle:
+            random.shuffle(batches)
+
+        for b in batches:
+            yield b
+
+    def __len__(self):
+        if self.drop_last:
+            return len(self.lengths) // self.batch_size
+        else:
+            return len(self.lengths) // self.batch_size + (len(self.lengths) % self.batch_size > 0)
+
+
+class DistributedLengthBasedBatchSampler(torch.utils.data.BatchSampler):
+    def __init__(self, data_source, batch_size: int, num_replicas: int, rank: int, shuffle: bool = True, seed: int = 0) -> None:
+        random.seed(seed)
+        self.batch_sampler = LengthBasedBatchSampler(
+            data_source, batch_size=batch_size, drop_last=True, shuffle=shuffle
+            )
+        self.num_replicas = num_replicas
+        self.rank = rank
+        
+    def __iter__(self):
+        max_length = len(self.batch_sampler) // self.num_replicas * self.num_replicas
+        return islice(self.batch_sampler, self.rank, max_length, self.num_replicas)
+         
+    def __len__(self):
+        return len(self.batch_sampler) // self.num_replicas
+            

+ 4 - 14
src/llama_recipes/datasets/alpaca_dataset.py

@@ -24,17 +24,14 @@ PROMPT_DICT = {
 }
 }
 
 
 class InstructionDataset(Dataset):
 class InstructionDataset(Dataset):
-    def __init__(self, dataset_config, tokenizer, partition="train", max_words=30):
+    def __init__(self, dataset_config, tokenizer, partition="train"):
         self.ann = json.load(open(dataset_config.data_path))
         self.ann = json.load(open(dataset_config.data_path))
         if partition == "train":
         if partition == "train":
             self.ann = self.ann
             self.ann = self.ann
         else:
         else:
             self.ann = self.ann[:200]
             self.ann = self.ann[:200]
 
 
-        self.max_words = max_words
-        # tokenizer = Tokenizer(model_path=model_path + "./tokenizer.model")
         self.tokenizer = tokenizer
         self.tokenizer = tokenizer
-        # self.tokenizer1 = tokenizer
 
 
     def __len__(self):
     def __len__(self):
         return len(self.ann)
         return len(self.ann)
@@ -57,22 +54,15 @@ class InstructionDataset(Dataset):
         example = torch.tensor(
         example = torch.tensor(
             example, dtype=torch.int64
             example, dtype=torch.int64
         )
         )
-        padding = self.max_words - example.shape[0]
-        if padding > 0:
-            example = torch.cat((example, torch.zeros(padding, dtype=torch.int64) - 1))
-        elif padding < 0:
-            example = example[: self.max_words]
         labels = copy.deepcopy(example)
         labels = copy.deepcopy(example)
         labels[: len(prompt)] = -1
         labels[: len(prompt)] = -1
         example_mask = example.ge(0)
         example_mask = example.ge(0)
         label_mask = labels.ge(0)
         label_mask = labels.ge(0)
         example[~example_mask] = 0
         example[~example_mask] = 0
         labels[~label_mask] = IGNORE_INDEX
         labels[~label_mask] = IGNORE_INDEX
-        example_mask = example_mask.float()
-        label_mask = label_mask.float()
 
 
         return {
         return {
-            "input_ids": example,
-            "labels": labels,
-            "attention_mask":example_mask,
+            "input_ids": example.tolist(),
+            "labels": labels.tolist(),
+            "attention_mask":example_mask.tolist(),
         }
         }

+ 13 - 18
src/llama_recipes/datasets/grammar_dataset/grammar_dataset.py

@@ -10,8 +10,6 @@ from pathlib import Path
 
 
 from torch.utils.data import Dataset
 from torch.utils.data import Dataset
 
 
-from llama_recipes.datasets.utils import ConcatDataset
-
 
 
 class grammar(Dataset):
 class grammar(Dataset):
     def __init__(
     def __init__(
@@ -48,24 +46,22 @@ class grammar(Dataset):
 
 
         input_ = example_batch["input"]
         input_ = example_batch["input"]
         target_ = example_batch["target"]
         target_ = example_batch["target"]
-        
-        prompt = f"Correct this to standard English: {input_}\n---\nCorrected: {target_}"
-        sample = self.tokenizer(prompt)
-        
-        return sample
-
-    def __getitem__(self, index):
-        sample = self.convert_to_features(self.dataset["train"][index])
-        source_ids = sample["input_ids"]
 
 
-        src_mask = sample["attention_mask"]
+        prompt = f"Correct this to standard English: {input_}\n---\nCorrected: "
+        prompt_ids = self.tokenizer.encode(self.tokenizer.bos_token + prompt, add_special_tokens=False)
+        label_ids = self.tokenizer.encode(target_ + self.tokenizer.eos_token, add_special_tokens=False)
 
 
-        return {
-            "input_ids": source_ids,
-            "attention_mask": src_mask,
-            "labels": source_ids.copy(),
+        sample = {
+            "input_ids": prompt_ids + label_ids,
+            "attention_mask": [1] * len(prompt_ids + label_ids),
+            "labels": [-100] * len(prompt_ids) + label_ids
         }
         }
 
 
+        return sample
+
+    def __getitem__(self, index):
+        return self.convert_to_features(self.dataset["train"][int(index)])
+
 
 
 def get_dataset(
 def get_dataset(
     dataset_config, tokenizer, csv_name=None
     dataset_config, tokenizer, csv_name=None
@@ -80,6 +76,5 @@ def get_dataset(
         tokenizer=tokenizer,
         tokenizer=tokenizer,
         csv_name=csv_name,
         csv_name=csv_name,
     )
     )
-    
-    return ConcatDataset(dataset, chunk_size=dataset_config.input_length)
 
 
+    return dataset

+ 19 - 13
src/llama_recipes/datasets/samsum_dataset.py

@@ -3,31 +3,37 @@
 
 
 # For dataset details visit: https://huggingface.co/datasets/samsum
 # For dataset details visit: https://huggingface.co/datasets/samsum
 
 
+import copy
 import datasets
 import datasets
 
 
-from llama_recipes.datasets.utils import Concatenator
 
 
 def get_preprocessed_samsum(dataset_config, tokenizer, split):
 def get_preprocessed_samsum(dataset_config, tokenizer, split):
     dataset = datasets.load_dataset("samsum", split=split)
     dataset = datasets.load_dataset("samsum", split=split)
 
 
     prompt = (
     prompt = (
-        f"Summarize this dialog:\n{{dialog}}\n---\nSummary:\n{{summary}}{{eos_token}}"
+        f"Summarize this dialog:\n{{dialog}}\n---\nSummary:\n"
     )
     )
 
 
     def apply_prompt_template(sample):
     def apply_prompt_template(sample):
         return {
         return {
-            "text": prompt.format(
-                dialog=sample["dialogue"],
-                summary=sample["summary"],
-                eos_token=tokenizer.eos_token,
-            )
+            "prompt": prompt.format(dialog=sample["dialogue"]),
+            "summary": sample["summary"],
         }
         }
 
 
     dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features))
     dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features))
-        
-    dataset = dataset.map(
-        lambda sample: tokenizer(sample["text"]),
-        batched=True,
-        remove_columns=list(dataset.features),
-    ).map(Concatenator(), batched=True)
+
+    def tokenize_add_label(sample):
+        prompt = tokenizer.encode(tokenizer.bos_token + sample["prompt"], add_special_tokens=False)
+        summary = tokenizer.encode(sample["summary"] +  tokenizer.eos_token, add_special_tokens=False)
+
+        sample = {
+            "input_ids": prompt + summary,
+            "attention_mask" : [1] * (len(prompt) + len(summary)),
+            "labels": [-100] * len(prompt) + summary,
+            }
+
+        return sample
+
+    dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features))
+
     return dataset
     return dataset

+ 0 - 66
src/llama_recipes/datasets/utils.py

@@ -1,66 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
-
-from tqdm import tqdm
-from itertools import chain
-
-from torch.utils.data import Dataset
-
-class Concatenator(object):
-    def __init__(self, chunk_size=2048):
-        self.chunk_size=chunk_size
-        self.residual = {"input_ids": [], "attention_mask": []}
-        
-    def __call__(self, batch):
-        concatenated_samples = {
-            k: v + list(chain(*batch[k])) for k, v in self.residual.items()
-        }
-
-        total_length = len(concatenated_samples[list(concatenated_samples.keys())[0]])
-
-        if total_length >= self.chunk_size:
-            chunk_num = total_length // self.chunk_size
-            result = {
-                k: [
-                    v[i : i + self.chunk_size]
-                    for i in range(0, chunk_num * self.chunk_size, self.chunk_size)
-                ]
-                for k, v in concatenated_samples.items()
-            }
-            self.residual = {
-                k: v[(chunk_num * self.chunk_size) :]
-                for k, v in concatenated_samples.items()
-            }
-        else:
-            result = concatenated_samples
-            self.residual = {k: [] for k in concatenated_samples.keys()}
-
-        result["labels"] = result["input_ids"].copy()
-
-        return result
-
-class ConcatDataset(Dataset):
-    def __init__(self, dataset, chunk_size=4096):
-        self.dataset = dataset
-        self.chunk_size = chunk_size
-        
-        self.samples = []
-        
-        buffer = {
-            "input_ids": [],
-            "attention_mask": [],
-            "labels": [],
-            }
-        
-        for sample in tqdm(self.dataset, desc="Preprocessing dataset", dynamic_ncols=True):
-            buffer = {k: v + sample[k] for k,v in buffer.items()}
-            
-            while len(next(iter(buffer.values()))) > self.chunk_size:
-                self.samples.append({k: v[:self.chunk_size] for k,v in buffer.items()})
-                buffer = {k: v[self.chunk_size:] for k,v in buffer.items()}
-                
-    def __getitem__(self, idx):
-        return self.samples[idx]
-    
-    def __len__(self):
-        return len(self.samples)

+ 25 - 37
src/llama_recipes/finetuning.py

@@ -5,8 +5,8 @@ import os
 from pkg_resources import packaging
 from pkg_resources import packaging
 
 
 import fire
 import fire
+import random
 import torch
 import torch
-import torch.distributed as dist
 import torch.optim as optim
 import torch.optim as optim
 from peft import get_peft_model, prepare_model_for_int8_training
 from peft import get_peft_model, prepare_model_for_int8_training
 from torch.distributed.fsdp import (
 from torch.distributed.fsdp import (
@@ -14,16 +14,16 @@ from torch.distributed.fsdp import (
 )
 )
 from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
 from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
 from torch.optim.lr_scheduler import StepLR
 from torch.optim.lr_scheduler import StepLR
-from torch.utils.data import DistributedSampler
 from transformers import (
 from transformers import (
     LlamaForCausalLM,
     LlamaForCausalLM,
     LlamaTokenizer,
     LlamaTokenizer,
     LlamaConfig,
     LlamaConfig,
-    default_data_collator,
 )
 )
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer
 
 
-from llama_recipes.configs import fsdp_config, train_config
+from llama_recipes.configs import fsdp_config as FSDP_CONFIG
+from llama_recipes.configs import train_config as TRAIN_CONFIG
+from llama_recipes.data.concatenator import ConcatDataset
 from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing
 from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing
 
 
 from llama_recipes.utils import fsdp_auto_wrap_policy
 from llama_recipes.utils import fsdp_auto_wrap_policy
@@ -31,6 +31,7 @@ from llama_recipes.utils.config_utils import (
     update_config,
     update_config,
     generate_peft_config,
     generate_peft_config,
     generate_dataset_config,
     generate_dataset_config,
+    get_dataloader_kwargs,
 )
 )
 from llama_recipes.utils.dataset_utils import get_preprocessed_dataset
 from llama_recipes.utils.dataset_utils import get_preprocessed_dataset
 
 
@@ -47,11 +48,13 @@ from llama_recipes.utils.train_utils import (
 
 
 def main(**kwargs):
 def main(**kwargs):
     # Update the configuration for the training and sharding process
     # Update the configuration for the training and sharding process
+    train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG()
     update_config((train_config, fsdp_config), **kwargs)
     update_config((train_config, fsdp_config), **kwargs)
 
 
     # Set the seeds for reproducibility
     # Set the seeds for reproducibility
     torch.cuda.manual_seed(train_config.seed)
     torch.cuda.manual_seed(train_config.seed)
     torch.manual_seed(train_config.seed)
     torch.manual_seed(train_config.seed)
+    random.seed(train_config.seed)
 
 
     if train_config.enable_fsdp:
     if train_config.enable_fsdp:
         setup()
         setup()
@@ -102,14 +105,19 @@ def main(**kwargs):
     if train_config.enable_fsdp and train_config.use_fast_kernels:
     if train_config.enable_fsdp and train_config.use_fast_kernels:
         """
         """
         For FSDP and FSDP+PEFT, setting 'use_fast_kernels' will enable
         For FSDP and FSDP+PEFT, setting 'use_fast_kernels' will enable
-        using of Flash Attention or Xformer memory-efficient kernels 
+        using of Flash Attention or Xformer memory-efficient kernels
         based on the hardware being used. This would speed up fine-tuning.
         based on the hardware being used. This would speed up fine-tuning.
         """
         """
         try:
         try:
             from optimum.bettertransformer import BetterTransformer
             from optimum.bettertransformer import BetterTransformer
-            model = BetterTransformer.transform(model) 
+            model = BetterTransformer.transform(model)
         except ImportError:
         except ImportError:
             print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")
             print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")
+
+    # Load the tokenizer and add special tokens
+    tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name)
+    tokenizer.pad_token_id = tokenizer.eos_token_id
+
     print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
     print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
 
 
     # Prepare the model for int8 training if quantization is enabled
     # Prepare the model for int8 training if quantization is enabled
@@ -120,14 +128,6 @@ def main(**kwargs):
     if train_config.enable_fsdp and fsdp_config.pure_bf16:
     if train_config.enable_fsdp and fsdp_config.pure_bf16:
         model.to(torch.bfloat16)
         model.to(torch.bfloat16)
 
 
-    # Load the tokenizer and add special tokens
-    tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name)
-    tokenizer.add_special_tokens(
-            {
-
-                "pad_token": "<PAD>",
-            }
-        )
     if train_config.use_peft:
     if train_config.use_peft:
         peft_config = generate_peft_config(train_config, kwargs)
         peft_config = generate_peft_config(train_config, kwargs)
         model = get_peft_model(model, peft_config)
         model = get_peft_model(model, peft_config)
@@ -179,43 +179,31 @@ def main(**kwargs):
     if not train_config.enable_fsdp or rank == 0:
     if not train_config.enable_fsdp or rank == 0:
             print(f"--> Validation Set Length = {len(dataset_val)}")
             print(f"--> Validation Set Length = {len(dataset_val)}")
 
 
-    train_sampler = None
-    val_sampler = None
-    if train_config.enable_fsdp:
-        train_sampler = DistributedSampler(
-            dataset_train,
-            rank=dist.get_rank(),
-            num_replicas=dist.get_world_size(),
-            shuffle=True,
-        )
-        if train_config.run_validation:
-            val_sampler = DistributedSampler(
-                dataset_val,
-                rank=dist.get_rank(),
-                num_replicas=dist.get_world_size(),
-            )
+    if train_config.batching_strategy == "packing":
+        dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
+
+    train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train")
 
 
     # Create DataLoaders for the training and validation dataset
     # Create DataLoaders for the training and validation dataset
     train_dataloader = torch.utils.data.DataLoader(
     train_dataloader = torch.utils.data.DataLoader(
         dataset_train,
         dataset_train,
-        batch_size=train_config.batch_size_training,
         num_workers=train_config.num_workers_dataloader,
         num_workers=train_config.num_workers_dataloader,
         pin_memory=True,
         pin_memory=True,
-        sampler=train_sampler if train_sampler else None,
-        drop_last=True,
-        collate_fn=default_data_collator,
+        **train_dl_kwargs,
     )
     )
 
 
     eval_dataloader = None
     eval_dataloader = None
     if train_config.run_validation:
     if train_config.run_validation:
+        if train_config.batching_strategy == "packing":
+            dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)
+
+        val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, tokenizer, "val")
+
         eval_dataloader = torch.utils.data.DataLoader(
         eval_dataloader = torch.utils.data.DataLoader(
             dataset_val,
             dataset_val,
-            batch_size=train_config.val_batch_size,
             num_workers=train_config.num_workers_dataloader,
             num_workers=train_config.num_workers_dataloader,
             pin_memory=True,
             pin_memory=True,
-            sampler=val_sampler if val_sampler else None,
-            drop_last=True,
-            collate_fn=default_data_collator,
+            **val_dl_kwargs,
         )
         )
 
 
     # Initialize the optimizer and learning rate scheduler
     # Initialize the optimizer and learning rate scheduler

+ 49 - 11
src/llama_recipes/utils/config_utils.py

@@ -3,13 +3,19 @@
 
 
 import inspect
 import inspect
 from dataclasses import asdict
 from dataclasses import asdict
+
+import torch.distributed as dist
+from torch.utils.data import DistributedSampler
 from peft import (
 from peft import (
     LoraConfig,
     LoraConfig,
     AdaptionPromptConfig,
     AdaptionPromptConfig,
     PrefixTuningConfig,
     PrefixTuningConfig,
 )
 )
+from transformers import default_data_collator
+from transformers.data import DataCollatorForSeq2Seq
 
 
 from llama_recipes.configs import datasets, lora_config, llama_adapter_config, prefix_config, train_config
 from llama_recipes.configs import datasets, lora_config, llama_adapter_config, prefix_config, train_config
+from llama_recipes.data.sampler import LengthBasedBatchSampler, DistributedLengthBasedBatchSampler
 from llama_recipes.utils.dataset_utils import DATASET_PREPROC
 from llama_recipes.utils.dataset_utils import DATASET_PREPROC
 
 
 
 
@@ -32,31 +38,63 @@ def update_config(config, **kwargs):
                         print(f"Warning: {config_name} does not accept parameter: {k}")
                         print(f"Warning: {config_name} does not accept parameter: {k}")
             elif isinstance(config, train_config):
             elif isinstance(config, train_config):
                 print(f"Warning: unknown parameter {k}")
                 print(f"Warning: unknown parameter {k}")
-                        
-                        
+
+
 def generate_peft_config(train_config, kwargs):
 def generate_peft_config(train_config, kwargs):
     configs = (lora_config, llama_adapter_config, prefix_config)
     configs = (lora_config, llama_adapter_config, prefix_config)
     peft_configs = (LoraConfig, AdaptionPromptConfig, PrefixTuningConfig)
     peft_configs = (LoraConfig, AdaptionPromptConfig, PrefixTuningConfig)
     names = tuple(c.__name__.rstrip("_config") for c in configs)
     names = tuple(c.__name__.rstrip("_config") for c in configs)
-    
+
     assert train_config.peft_method in names, f"Peft config not found: {train_config.peft_method}"
     assert train_config.peft_method in names, f"Peft config not found: {train_config.peft_method}"
-    
+
     config = configs[names.index(train_config.peft_method)]()
     config = configs[names.index(train_config.peft_method)]()
-    
+
     update_config(config, **kwargs)
     update_config(config, **kwargs)
     params = asdict(config)
     params = asdict(config)
     peft_config = peft_configs[names.index(train_config.peft_method)](**params)
     peft_config = peft_configs[names.index(train_config.peft_method)](**params)
-    
+
     return peft_config
     return peft_config
 
 
 
 
 def generate_dataset_config(train_config, kwargs):
 def generate_dataset_config(train_config, kwargs):
     names = tuple(DATASET_PREPROC.keys())
     names = tuple(DATASET_PREPROC.keys())
-        
+
     assert train_config.dataset in names, f"Unknown dataset: {train_config.dataset}"
     assert train_config.dataset in names, f"Unknown dataset: {train_config.dataset}"
-    
+
     dataset_config = {k:v for k, v in inspect.getmembers(datasets)}[train_config.dataset]()
     dataset_config = {k:v for k, v in inspect.getmembers(datasets)}[train_config.dataset]()
-        
+
     update_config(dataset_config, **kwargs)
     update_config(dataset_config, **kwargs)
-    
-    return  dataset_config
+
+    return  dataset_config
+
+
+def get_dataloader_kwargs(train_config, dataset, tokenizer, mode):
+        kwargs = {}
+        batch_size = train_config.batch_size_training if mode=="train" else train_config.val_batch_size
+        if train_config.batching_strategy == "padding":
+            if train_config.enable_fsdp:
+                kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler(
+                    dataset,
+                    batch_size=batch_size,
+                    rank=dist.get_rank(),
+                    num_replicas=dist.get_world_size(),
+                    shuffle=mode=="train",
+                )
+            else:
+                kwargs["batch_sampler"] = LengthBasedBatchSampler(dataset, batch_size, drop_last=True, shuffle=mode=="train")
+            kwargs["collate_fn"] = DataCollatorForSeq2Seq(tokenizer)
+        elif train_config.batching_strategy == "packing":
+            if train_config.enable_fsdp:
+                kwargs["sampler"] = DistributedSampler(
+                dataset,
+                rank=dist.get_rank(),
+                num_replicas=dist.get_world_size(),
+                shuffle=mode=="train",
+            )
+            kwargs["batch_size"] = batch_size
+            kwargs["drop_last"] = True
+            kwargs["collate_fn"] = default_data_collator
+        else:
+            raise ValueError(f"Unknown batching strategy: {train_config.batching_strategy}")
+
+        return kwargs

+ 6 - 6
src/llama_recipes/utils/dataset_utils.py

@@ -33,24 +33,24 @@ def get_custom_dataset(dataset_config, tokenizer, split: str):
         module_path, func_name = dataset_config.file.split(":")
         module_path, func_name = dataset_config.file.split(":")
     else:
     else:
         module_path, func_name = dataset_config.file, "get_custom_dataset"
         module_path, func_name = dataset_config.file, "get_custom_dataset"
-        
+
     if not module_path.endswith(".py"):
     if not module_path.endswith(".py"):
         raise ValueError(f"Dataset file {module_path} is not a .py file.")
         raise ValueError(f"Dataset file {module_path} is not a .py file.")
-    
+
     module_path = Path(module_path)
     module_path = Path(module_path)
     if not module_path.is_file():
     if not module_path.is_file():
         raise FileNotFoundError(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.")
         raise FileNotFoundError(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.")
-    
+
     module = load_module_from_py_file(module_path.as_posix())
     module = load_module_from_py_file(module_path.as_posix())
     try:
     try:
         return getattr(module, func_name)(dataset_config, tokenizer, split)
         return getattr(module, func_name)(dataset_config, tokenizer, split)
     except AttributeError as e:
     except AttributeError as e:
         print(f"It seems like the given method name ({func_name}) is not present in the dataset .py file ({module_path.as_posix()}).")
         print(f"It seems like the given method name ({func_name}) is not present in the dataset .py file ({module_path.as_posix()}).")
         raise e
         raise e
-    
+
 
 
 DATASET_PREPROC = {
 DATASET_PREPROC = {
-    "alpaca_dataset": partial(get_alpaca_dataset, max_words=224),
+    "alpaca_dataset": partial(get_alpaca_dataset),
     "grammar_dataset": get_grammar_dataset,
     "grammar_dataset": get_grammar_dataset,
     "samsum_dataset": get_samsum_dataset,
     "samsum_dataset": get_samsum_dataset,
     "custom_dataset": get_custom_dataset,
     "custom_dataset": get_custom_dataset,
@@ -69,7 +69,7 @@ def get_preprocessed_dataset(
             if split == "train"
             if split == "train"
             else dataset_config.test_split
             else dataset_config.test_split
         )
         )
-    
+
     return DATASET_PREPROC[dataset_config.dataset](
     return DATASET_PREPROC[dataset_config.dataset](
         dataset_config,
         dataset_config,
         tokenizer,
         tokenizer,

+ 34 - 34
src/llama_recipes/utils/train_utils.py

@@ -26,7 +26,7 @@ from llama_recipes.utils.memory_utils import MemoryTrace
 def set_tokenizer_params(tokenizer: LlamaTokenizer):
 def set_tokenizer_params(tokenizer: LlamaTokenizer):
     tokenizer.pad_token_id = 0
     tokenizer.pad_token_id = 0
     tokenizer.padding_side = "left"
     tokenizer.padding_side = "left"
-    
+
 # Converting Bytes to Megabytes
 # Converting Bytes to Megabytes
 def byte2mb(x):
 def byte2mb(x):
     return int(x / 2**20)
     return int(x / 2**20)
@@ -34,7 +34,7 @@ def byte2mb(x):
 def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None):
 def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None):
     """
     """
     Trains the model on the given dataloader
     Trains the model on the given dataloader
-    
+
     Args:
     Args:
         model: The model to be trained
         model: The model to be trained
         train_dataloader: The dataloader containing the training data
         train_dataloader: The dataloader containing the training data
@@ -46,18 +46,18 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         train_config: The training configuration
         train_config: The training configuration
         eval_dataloader: The dataloader containing the eval data
         eval_dataloader: The dataloader containing the eval data
         tokenizer: tokenizer used in the eval for decoding the predicitons
         tokenizer: tokenizer used in the eval for decoding the predicitons
-    
+
     Returns: results dictionary containing average training and validation perplexity and loss
     Returns: results dictionary containing average training and validation perplexity and loss
     """
     """
     # Create a gradient scaler for fp16
     # Create a gradient scaler for fp16
     if train_config.use_fp16 and train_config.enable_fsdp:
     if train_config.use_fp16 and train_config.enable_fsdp:
         scaler = ShardedGradScaler()
         scaler = ShardedGradScaler()
     elif train_config.use_fp16 and not train_config.enable_fsdp:
     elif train_config.use_fp16 and not train_config.enable_fsdp:
-        scaler = torch.cuda.amp.GradScaler() 
+        scaler = torch.cuda.amp.GradScaler()
     if train_config.enable_fsdp:
     if train_config.enable_fsdp:
         world_size = int(os.environ["WORLD_SIZE"])
         world_size = int(os.environ["WORLD_SIZE"])
     autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext
     autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext
-    
+
     train_prep = []
     train_prep = []
     train_loss = []
     train_loss = []
     val_prep = []
     val_prep = []
@@ -78,7 +78,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                     if train_config.enable_fsdp:
                     if train_config.enable_fsdp:
                         batch[key] = batch[key].to(local_rank)
                         batch[key] = batch[key].to(local_rank)
                     else:
                     else:
-                        batch[key] = batch[key].to('cuda:0')              
+                        batch[key] = batch[key].to('cuda:0')
                 with autocast():
                 with autocast():
                     loss = model(**batch).loss
                     loss = model(**batch).loss
                 loss = loss / gradient_accumulation_steps
                 loss = loss / gradient_accumulation_steps
@@ -101,9 +101,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
 
 
                 pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
                 pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
             pbar.close()
             pbar.close()
-                
+
         epoch_end_time = time.perf_counter()-epoch_start_time
         epoch_end_time = time.perf_counter()-epoch_start_time
-        epoch_times.append(epoch_end_time)    
+        epoch_times.append(epoch_end_time)
         # Reducing total_loss across all devices if there's more than one CUDA device
         # Reducing total_loss across all devices if there's more than one CUDA device
         if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
         if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
             dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
             dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
@@ -111,10 +111,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         if train_config.enable_fsdp:
         if train_config.enable_fsdp:
             train_epoch_loss = train_epoch_loss/world_size
             train_epoch_loss = train_epoch_loss/world_size
         train_perplexity = torch.exp(train_epoch_loss)
         train_perplexity = torch.exp(train_epoch_loss)
-        
+
         train_prep.append(train_perplexity)
         train_prep.append(train_perplexity)
         train_loss.append(train_epoch_loss)
         train_loss.append(train_epoch_loss)
-        
+
         if train_config.enable_fsdp:
         if train_config.enable_fsdp:
             if rank==0:
             if rank==0:
                 print(f"Max CUDA memory allocated was {memtrace.peak} GB")
                 print(f"Max CUDA memory allocated was {memtrace.peak} GB")
@@ -128,10 +128,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
             print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
             print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
             print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
             print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
             print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
             print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
-        
+
         # Update the learning rate as needed
         # Update the learning rate as needed
         lr_scheduler.step()
         lr_scheduler.step()
-          
+
         if train_config.run_validation:
         if train_config.run_validation:
             eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer)
             eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer)
             checkpoint_start_time = time.perf_counter()
             checkpoint_start_time = time.perf_counter()
@@ -144,23 +144,23 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                             print(f"we are about to save the PEFT modules")
                             print(f"we are about to save the PEFT modules")
                     else:
                     else:
                         print(f"we are about to save the PEFT modules")
                         print(f"we are about to save the PEFT modules")
-                    model.save_pretrained(train_config.output_dir)  
+                    model.save_pretrained(train_config.output_dir)
                     if train_config.enable_fsdp:
                     if train_config.enable_fsdp:
-                        if rank==0: 
+                        if rank==0:
                             print(f"PEFT modules are saved in {train_config.output_dir} directory")
                             print(f"PEFT modules are saved in {train_config.output_dir} directory")
                     else:
                     else:
                         print(f"PEFT modules are saved in {train_config.output_dir} directory")
                         print(f"PEFT modules are saved in {train_config.output_dir} directory")
-                        
+
                 else:
                 else:
                     if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
                     if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
-                        
+
                         save_model_checkpoint(
                         save_model_checkpoint(
                             model, optimizer, rank, train_config, epoch=epoch
                             model, optimizer, rank, train_config, epoch=epoch
                         )
                         )
                     elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
                     elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
                         print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
                         print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
                         print("=====================================================")
                         print("=====================================================")
-                        
+
                         save_model_and_optimizer_sharded(model, rank, train_config)
                         save_model_and_optimizer_sharded(model, rank, train_config)
                         if train_config.save_optimizer:
                         if train_config.save_optimizer:
                             save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
                             save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
@@ -172,7 +172,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                             model, optimizer, rank, train_config, epoch=epoch
                             model, optimizer, rank, train_config, epoch=epoch
                         )
                         )
                         print(" Saving the FSDP model checkpoints and optimizer using FULL_STATE_DICT")
                         print(" Saving the FSDP model checkpoints and optimizer using FULL_STATE_DICT")
-                        print("=====================================================")                     
+                        print("=====================================================")
                 if train_config.enable_fsdp:
                 if train_config.enable_fsdp:
                     dist.barrier()
                     dist.barrier()
             checkpoint_end_time = time.perf_counter() - checkpoint_start_time
             checkpoint_end_time = time.perf_counter() - checkpoint_start_time
@@ -196,8 +196,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
     avg_train_prep = sum(train_prep)/len(train_prep)
     avg_train_prep = sum(train_prep)/len(train_prep)
     avg_train_loss = sum(train_loss)/len(train_loss)
     avg_train_loss = sum(train_loss)/len(train_loss)
     if train_config.run_validation:
     if train_config.run_validation:
-        avg_eval_prep = sum(val_prep)/len(val_prep) 
-        avg_eval_loss = sum(val_loss)/len(val_loss) 
+        avg_eval_prep = sum(val_prep)/len(val_prep)
+        avg_eval_loss = sum(val_loss)/len(val_loss)
 
 
     results['avg_train_prep'] = avg_train_prep
     results['avg_train_prep'] = avg_train_prep
     results['avg_train_loss'] = avg_train_loss
     results['avg_train_loss'] = avg_train_loss
@@ -206,27 +206,27 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         results['avg_eval_loss'] = avg_eval_loss
         results['avg_eval_loss'] = avg_eval_loss
     results["avg_epoch_time"] = avg_epoch_time
     results["avg_epoch_time"] = avg_epoch_time
     results["avg_checkpoint_time"] = avg_checkpoint_time
     results["avg_checkpoint_time"] = avg_checkpoint_time
-    
+
     #saving the training params including fsdp setting for reference.
     #saving the training params including fsdp setting for reference.
     if train_config.enable_fsdp and not train_config.use_peft:
     if train_config.enable_fsdp and not train_config.use_peft:
         save_train_params(train_config, fsdp_config, rank)
         save_train_params(train_config, fsdp_config, rank)
-        
+
     return results
     return results
 
 
 def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
 def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
     """
     """
     Evaluates the model on the given dataloader
     Evaluates the model on the given dataloader
-    
+
     Args:
     Args:
         model: The model to evaluate
         model: The model to evaluate
         eval_dataloader: The dataloader containing the evaluation data
         eval_dataloader: The dataloader containing the evaluation data
         local_rank: The rank of the current node in a distributed setting
         local_rank: The rank of the current node in a distributed setting
         tokenizer: The tokenizer used to decode predictions
         tokenizer: The tokenizer used to decode predictions
-    
+
     Returns: eval_ppl, eval_epoch_loss
     Returns: eval_ppl, eval_epoch_loss
     """
     """
     if train_config.enable_fsdp:
     if train_config.enable_fsdp:
-        world_size = int(os.environ["WORLD_SIZE"]) 
+        world_size = int(os.environ["WORLD_SIZE"])
     model.eval()
     model.eval()
     eval_preds = []
     eval_preds = []
     eval_loss = 0.0  # Initialize evaluation loss
     eval_loss = 0.0  # Initialize evaluation loss
@@ -248,24 +248,24 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
             eval_preds.extend(
             eval_preds.extend(
                 tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True)
                 tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True)
             )
             )
-    
+
     # If there's more than one CUDA device, reduce evaluation loss across all devices
     # If there's more than one CUDA device, reduce evaluation loss across all devices
     if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
     if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
         dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
         dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
-    
+
     # Compute average loss and perplexity
     # Compute average loss and perplexity
     eval_epoch_loss = eval_loss / len(eval_dataloader)
     eval_epoch_loss = eval_loss / len(eval_dataloader)
     if train_config.enable_fsdp:
     if train_config.enable_fsdp:
         eval_epoch_loss = eval_epoch_loss/world_size
         eval_epoch_loss = eval_epoch_loss/world_size
     eval_ppl = torch.exp(eval_epoch_loss)
     eval_ppl = torch.exp(eval_epoch_loss)
-    
+
     # Print evaluation metrics
     # Print evaluation metrics
     if train_config.enable_fsdp:
     if train_config.enable_fsdp:
         if local_rank==0:
         if local_rank==0:
             print(f" {eval_ppl=} {eval_epoch_loss=}")
             print(f" {eval_ppl=} {eval_epoch_loss=}")
     else:
     else:
         print(f" {eval_ppl=} {eval_epoch_loss=}")
         print(f" {eval_ppl=} {eval_epoch_loss=}")
-        
+
     return eval_ppl, eval_epoch_loss
     return eval_ppl, eval_epoch_loss
 
 
 def freeze_transformer_layers(model, num_layer):
 def freeze_transformer_layers(model, num_layer):
@@ -279,8 +279,8 @@ def check_frozen_layers_peft_model(model):
      for i, layer in enumerate(model.base_model.model.model.layers):
      for i, layer in enumerate(model.base_model.model.model.layers):
             for name, param in layer.named_parameters():
             for name, param in layer.named_parameters():
                 print(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}")
                 print(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}")
-                
-                
+
+
 def setup():
 def setup():
     """Initialize the process group for distributed training"""
     """Initialize the process group for distributed training"""
     dist.init_process_group("nccl")
     dist.init_process_group("nccl")
@@ -293,7 +293,7 @@ def setup_environ_flags(rank):
     # os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
     # os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
     # This flag will help with CUDA memory fragmentations that can lead into OOM in some cases.
     # This flag will help with CUDA memory fragmentations that can lead into OOM in some cases.
     # Note this is only availble in PyTorch Nighlies (as of July 30 2023)
     # Note this is only availble in PyTorch Nighlies (as of July 30 2023)
-    # os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True' 
+    # os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True'
     if rank == 0:
     if rank == 0:
         print(f"--> Running with torch dist debug set to detail")
         print(f"--> Running with torch dist debug set to detail")
 
 
@@ -338,7 +338,7 @@ def print_model_size(model, config, rank: int = 0) -> None:
 
 
 def get_policies(cfg, rank):
 def get_policies(cfg, rank):
     """Get the policies for mixed precision and fsdp wrapping"""
     """Get the policies for mixed precision and fsdp wrapping"""
-    
+
     verify_bfloat_support = (
     verify_bfloat_support = (
     torch.version.cuda
     torch.version.cuda
     and torch.cuda.is_bf16_supported()
     and torch.cuda.is_bf16_supported()
@@ -374,7 +374,7 @@ def save_train_params(train_config, fsdp_config, rank):
     This will be used by converter script in the inference folder to fetch the HF model name or path.
     This will be used by converter script in the inference folder to fetch the HF model name or path.
     It also would be hepful as a log for future references.
     It also would be hepful as a log for future references.
     """
     """
-    # Convert the train_config and fsdp_config objects to dictionaries, 
+    # Convert the train_config and fsdp_config objects to dictionaries,
     # converting all values to strings to ensure they can be serialized into a YAML file
     # converting all values to strings to ensure they can be serialized into a YAML file
     train_config_dict = {k: str(v) for k, v in vars(train_config).items() if not k.startswith('__')}
     train_config_dict = {k: str(v) for k, v in vars(train_config).items() if not k.startswith('__')}
     fsdp_config_dict = {k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith('__')}
     fsdp_config_dict = {k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith('__')}

+ 18 - 0
tests/conftest.py

@@ -0,0 +1,18 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import pytest
+
+from transformers import LlamaTokenizer
+
+
+@pytest.fixture
+def setup_tokenizer():
+    def _helper(tokenizer):
+        #Align with Llama 2 tokenizer
+        tokenizer.from_pretrained.return_value = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")
+        tokenizer.from_pretrained.return_value.add_special_tokens({'bos_token': '<s>', 'eos_token': '</s>'})
+        tokenizer.from_pretrained.return_value.bos_token_id = 1
+        tokenizer.from_pretrained.return_value.eos_token_id = 2
+
+    return _helper

+ 40 - 13
tests/datasets/test_custom_dataset.py

@@ -4,21 +4,38 @@
 import pytest
 import pytest
 from unittest.mock import patch
 from unittest.mock import patch
 
 
+from transformers import LlamaTokenizer
+
+def check_padded_entry(batch):
+    seq_len = sum(batch["attention_mask"][0])
+    assert seq_len < len(batch["attention_mask"][0])
+
+    assert batch["labels"][0][0] == -100
+    assert batch["labels"][0][seq_len-1] == 2
+    assert batch["labels"][0][-1] == -100
+    assert batch["input_ids"][0][0] == 1
+    assert batch["input_ids"][0][-1] == 2
+
 
 
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
 @patch('llama_recipes.finetuning.StepLR')
-def test_custom_dataset(step_lr, optimizer, get_model, train, mocker):
+def test_custom_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer):
     from llama_recipes.finetuning import main
     from llama_recipes.finetuning import main
 
 
+    setup_tokenizer(tokenizer)
+
     kwargs = {
     kwargs = {
         "dataset": "custom_dataset",
         "dataset": "custom_dataset",
         "model_name": "decapoda-research/llama-7b-hf", # We use the tokenizer as a surrogate for llama2 tokenizer here
         "model_name": "decapoda-research/llama-7b-hf", # We use the tokenizer as a surrogate for llama2 tokenizer here
         "custom_dataset.file": "examples/custom_dataset.py",
         "custom_dataset.file": "examples/custom_dataset.py",
         "custom_dataset.train_split": "validation",
         "custom_dataset.train_split": "validation",
         "batch_size_training": 2,
         "batch_size_training": 2,
+        "val_batch_size": 4,
         "use_peft": False,
         "use_peft": False,
+        "batching_strategy": "padding"
         }
         }
 
 
     main(**kwargs)
     main(**kwargs)
@@ -30,24 +47,34 @@ def test_custom_dataset(step_lr, optimizer, get_model, train, mocker):
     eval_dataloader = args[2]
     eval_dataloader = args[2]
     tokenizer = args[3]
     tokenizer = args[3]
 
 
-    assert len(train_dataloader) == 226
-    assert len(eval_dataloader) == 2*226
+    assert len(train_dataloader) == 1120
+    assert len(eval_dataloader) == 1120 //2
+
+    it = iter(eval_dataloader)
+    batch = next(it)
+    STRING = tokenizer.decode(batch["input_ids"][0], skip_special_tokens=True)
+    EXPECTED_STRING = "[INST] Who made Berlin [/INST] dunno"
+    assert STRING.startswith(EXPECTED_STRING)
+
+    assert batch["input_ids"].size(0) == 4
+    assert set(("labels", "input_ids", "attention_mask")) == set(batch.keys())
+
+    check_padded_entry(batch)
 
 
     it = iter(train_dataloader)
     it = iter(train_dataloader)
-    STRING = tokenizer.decode(next(it)["input_ids"][0], skip_special_tokens=True)
-    EXPECTED_STRING = "[INST] Напиши функцию на языке swift, которая сортирует массив целых чисел, а затем выводит его на экран [/INST] Вот функция, "
+    for _ in range(5):
+        next(it)
 
 
+    batch = next(it)
+    STRING = tokenizer.decode(batch["input_ids"][0], skip_special_tokens=True)
+    EXPECTED_STRING = "[INST] How do I initialize a Typescript project using npm and git? [/INST] # Initialize a new NPM project"
     assert STRING.startswith(EXPECTED_STRING)
     assert STRING.startswith(EXPECTED_STRING)
 
 
-    next(it)
-    next(it)
-    next(it)
-    STRING = tokenizer.decode(next(it)["input_ids"][0], skip_special_tokens=True)
-    EXPECTED_SUBSTRING_1 = "Therefore you are correct.  [INST] How can L’Hopital’s Rule be"
-    EXPECTED_SUBSTRING_2 = "a circular path around the turn.  [INST] How on earth is that related to L’Hopital’s Rule?"
+    assert batch["input_ids"].size(0) == 2
+    assert set(("labels", "input_ids", "attention_mask")) == set(batch.keys())
+
+    check_padded_entry(batch)
 
 
-    assert EXPECTED_SUBSTRING_1 in STRING
-    assert EXPECTED_SUBSTRING_2 in STRING
 
 
 
 
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.train')

+ 54 - 0
tests/datasets/test_grammar_datasets.py

@@ -0,0 +1,54 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+from unittest.mock import patch
+
+from transformers import LlamaTokenizer
+
+
+@patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaTokenizer')
+@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
+@patch('llama_recipes.finetuning.optim.AdamW')
+@patch('llama_recipes.finetuning.StepLR')
+def test_grammar_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer):
+    from llama_recipes.finetuning import main
+
+    setup_tokenizer(tokenizer)
+
+    BATCH_SIZE = 8
+    kwargs = {
+        "model_name": "decapoda-research/llama-7b-hf",
+        "batch_size_training": BATCH_SIZE,
+        "val_batch_size": 1,
+        "use_peft": False,
+        "dataset": "grammar_dataset",
+        "batching_strategy": "padding",
+        }
+
+    main(**kwargs)
+
+    assert train.call_count == 1
+
+    args, kwargs = train.call_args
+    train_dataloader = args[1]
+    eval_dataloader = args[2]
+
+    VAL_SAMPLES = 2988
+    TRAIN_SAMPLES = 13016
+
+    assert len(train_dataloader) == TRAIN_SAMPLES // BATCH_SIZE
+    assert len(eval_dataloader) == VAL_SAMPLES
+
+    batch = next(iter(train_dataloader))
+
+    assert "labels" in batch.keys()
+    assert "input_ids" in batch.keys()
+    assert "attention_mask" in batch.keys()
+
+    assert batch["labels"][0][29] == -100
+    assert batch["labels"][0][30] == 29871
+
+    assert batch["input_ids"][0][0] == 1
+    assert batch["labels"][0][-1] == 2
+    assert batch["input_ids"][0][-1] == 2

+ 30 - 14
tests/datasets/test_samsum_datasets.py

@@ -1,37 +1,53 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 
+from functools import partial
 from unittest.mock import patch
 from unittest.mock import patch
 
 
 
 
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
-@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
 @patch('llama_recipes.finetuning.StepLR')
-def test_custom_dataset(step_lr, optimizer, tokenizer, get_model, train, mocker):
+def test_samsum_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer):
     from llama_recipes.finetuning import main
     from llama_recipes.finetuning import main
-        
-    tokenizer.return_value = mocker.MagicMock(side_effect=lambda x: {"input_ids":[len(x)*[0,]], "attention_mask": [len(x)*[0,]]})
-    
-    
+
+    setup_tokenizer(tokenizer)
+
+    BATCH_SIZE = 8
     kwargs = {
     kwargs = {
-        "batch_size_training": 1,
+        "model_name": "decapoda-research/llama-7b-hf",
+        "batch_size_training": BATCH_SIZE,
+        "val_batch_size": 1,
         "use_peft": False,
         "use_peft": False,
         "dataset": "samsum_dataset",
         "dataset": "samsum_dataset",
+        "batching_strategy": "padding",
         }
         }
-    
+
     main(**kwargs)
     main(**kwargs)
-    
+
     assert train.call_count == 1
     assert train.call_count == 1
-    
+
     args, kwargs = train.call_args
     args, kwargs = train.call_args
     train_dataloader = args[1]
     train_dataloader = args[1]
     eval_dataloader = args[2]
     eval_dataloader = args[2]
-    
+
     VAL_SAMPLES = 818
     VAL_SAMPLES = 818
     TRAIN_SAMPLES = 14732
     TRAIN_SAMPLES = 14732
-    CONCAT_SIZE = 2048
-    assert len(train_dataloader) == TRAIN_SAMPLES // CONCAT_SIZE
+
+    assert len(train_dataloader) == TRAIN_SAMPLES // BATCH_SIZE
     assert len(eval_dataloader) == VAL_SAMPLES
     assert len(eval_dataloader) == VAL_SAMPLES
-    
+
+    batch = next(iter(train_dataloader))
+
+    assert "labels" in batch.keys()
+    assert "input_ids" in batch.keys()
+    assert "attention_mask" in batch.keys()
+
+    assert batch["labels"][0][268] == -100
+    assert batch["labels"][0][269] == 22291
+
+    assert batch["input_ids"][0][0] == 1
+    assert batch["labels"][0][-1] == 2
+    assert batch["input_ids"][0][-1] == 2

+ 94 - 0
tests/test_batching.py

@@ -0,0 +1,94 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import pytest
+from unittest.mock import patch
+
+
+@patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaTokenizer')
+@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
+@patch('llama_recipes.finetuning.optim.AdamW')
+@patch('llama_recipes.finetuning.StepLR')
+def test_packing(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer):
+    from llama_recipes.finetuning import main
+
+    setup_tokenizer(tokenizer)
+
+    kwargs = {
+        "model_name": "decapoda-research/llama-7b-hf",
+        "batch_size_training": 8,
+        "val_batch_size": 1,
+        "use_peft": False,
+        "dataset": "samsum_dataset",
+        "batching_strategy": "packing",
+        }
+
+    main(**kwargs)
+
+    assert train.call_count == 1
+
+    args, kwargs = train.call_args
+    train_dataloader = args[1]
+    eval_dataloader = args[2]
+
+    assert len(train_dataloader) == 96
+    assert len(eval_dataloader) == 42
+
+    batch = next(iter(train_dataloader))
+
+    assert "labels" in batch.keys()
+    assert "input_ids" in batch.keys()
+    assert "attention_mask" in batch.keys()
+
+    assert batch["labels"][0].size(0) == 4096
+    assert batch["input_ids"][0].size(0) == 4096
+    assert batch["attention_mask"][0].size(0) == 4096
+
+
+@patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaTokenizer')
+@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
+@patch('llama_recipes.finetuning.optim.AdamW')
+@patch('llama_recipes.finetuning.StepLR')
+@patch('llama_recipes.finetuning.setup')
+@patch('llama_recipes.finetuning.FSDP')
+@patch('llama_recipes.finetuning.torch.distributed.is_initialized')
+@patch('llama_recipes.utils.config_utils.dist')
+def test_distributed_packing(dist, is_initialized, fsdp, setup, step_lr, optimizer, get_model, tokenizer, train, setup_tokenizer):
+    import os
+    from llama_recipes.finetuning import main
+
+    setup_tokenizer(tokenizer)
+
+    rank = 0
+    os.environ['LOCAL_RANK'] = f'{rank}'
+    os.environ['RANK'] = f'{rank}'
+    os.environ['WORLD_SIZE'] = '2'
+    os.environ['MASTER_ADDR'] = 'localhost'
+    os.environ['MASTER_PORT'] = '12345'
+
+    kwargs = {
+        "model_name": "decapoda-research/llama-7b-hf",
+        "batch_size_training": 8,
+        "val_batch_size": 1,
+        "use_peft": False,
+        "dataset": "samsum_dataset",
+        "batching_strategy": "packing",
+        "enable_fsdp": True
+        }
+
+    is_initialized.return_value = True
+    dist.get_rank.return_value = rank
+    dist.get_world_size.return_value = 2
+
+    main(**kwargs)
+
+    assert train.call_count == 1
+
+    args, kwargs = train.call_args
+    train_dataloader = args[1]
+    eval_dataloader = args[2]
+
+    assert len(train_dataloader) == 96 //2
+    assert len(eval_dataloader) == 42 //2

+ 81 - 34
tests/test_finetuning.py

@@ -1,14 +1,26 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 
+import pytest
 from pytest import approx
 from pytest import approx
 from unittest.mock import patch
 from unittest.mock import patch
 
 
 from torch.nn import Linear
 from torch.nn import Linear
 from torch.optim import AdamW
 from torch.optim import AdamW
 from torch.utils.data.dataloader import DataLoader
 from torch.utils.data.dataloader import DataLoader
+from torch.utils.data.sampler import BatchSampler
 
 
 from llama_recipes.finetuning import main
 from llama_recipes.finetuning import main
+from llama_recipes.data.sampler import LengthBasedBatchSampler
+
+
+def get_fake_dataset():
+    return [{
+        "input_ids":[1],
+        "attention_mask":[1],
+        "labels":[1],
+        }]
+
 
 
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
@@ -18,23 +30,23 @@ from llama_recipes.finetuning import main
 @patch('llama_recipes.finetuning.StepLR')
 @patch('llama_recipes.finetuning.StepLR')
 def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
 def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
     kwargs = {"run_validation": False}
     kwargs = {"run_validation": False}
-    
-    get_dataset.return_value = [1]
-    
+
+    get_dataset.return_value = get_fake_dataset()
+
     main(**kwargs)
     main(**kwargs)
-    
+
     assert train.call_count == 1
     assert train.call_count == 1
-    
+
     args, kwargs = train.call_args
     args, kwargs = train.call_args
     train_dataloader = args[1]
     train_dataloader = args[1]
     eval_dataloader = args[2]
     eval_dataloader = args[2]
-    
+
     assert isinstance(train_dataloader, DataLoader)
     assert isinstance(train_dataloader, DataLoader)
     assert eval_dataloader is None
     assert eval_dataloader is None
-    
+
     assert get_model.return_value.to.call_args.args[0] == "cuda"
     assert get_model.return_value.to.call_args.args[0] == "cuda"
-    
-    
+
+
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
@@ -43,21 +55,22 @@ def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, ge
 @patch('llama_recipes.finetuning.StepLR')
 @patch('llama_recipes.finetuning.StepLR')
 def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
 def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
     kwargs = {"run_validation": True}
     kwargs = {"run_validation": True}
-    get_dataset.return_value = [1]
-    
+
+    get_dataset.return_value = get_fake_dataset()
+
     main(**kwargs)
     main(**kwargs)
-    
+
     assert train.call_count == 1
     assert train.call_count == 1
-    
+
     args, kwargs = train.call_args
     args, kwargs = train.call_args
     train_dataloader = args[1]
     train_dataloader = args[1]
     eval_dataloader = args[2]
     eval_dataloader = args[2]
     assert isinstance(train_dataloader, DataLoader)
     assert isinstance(train_dataloader, DataLoader)
     assert isinstance(eval_dataloader, DataLoader)
     assert isinstance(eval_dataloader, DataLoader)
-    
+
     assert get_model.return_value.to.call_args.args[0] == "cuda"
     assert get_model.return_value.to.call_args.args[0] == "cuda"
-    
-    
+
+
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
@@ -68,15 +81,15 @@ def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer,
 @patch('llama_recipes.finetuning.StepLR')
 @patch('llama_recipes.finetuning.StepLR')
 def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, get_dataset, tokenizer, get_model, train):
 def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, get_dataset, tokenizer, get_model, train):
     kwargs = {"use_peft": True}
     kwargs = {"use_peft": True}
-    
-    get_dataset.return_value = [1]
-    
+
+    get_dataset.return_value = get_fake_dataset()
+
     main(**kwargs)
     main(**kwargs)
-    
+
     assert get_peft_model.return_value.to.call_args.args[0] == "cuda"
     assert get_peft_model.return_value.to.call_args.args[0] == "cuda"
     assert get_peft_model.return_value.print_trainable_parameters.call_count == 1
     assert get_peft_model.return_value.print_trainable_parameters.call_count == 1
-    
-    
+
+
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
@@ -85,22 +98,56 @@ def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, ge
 @patch('llama_recipes.finetuning.StepLR')
 @patch('llama_recipes.finetuning.StepLR')
 def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer, get_model, train, mocker):
 def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer, get_model, train, mocker):
     kwargs = {"weight_decay": 0.01}
     kwargs = {"weight_decay": 0.01}
-    
-    get_dataset.return_value = [1]
-    
-    model = mocker.MagicMock(name="model")
-    model.parameters.return_value = Linear(1,1).parameters()
-    get_peft_model.return_value = model 
-    get_peft_model.return_value.print_trainable_parameters=lambda:None
+
+    get_dataset.return_value = get_fake_dataset()
+
+    get_model.return_value = Linear(1,1)
+
     main(**kwargs)
     main(**kwargs)
-    
+
     assert train.call_count == 1
     assert train.call_count == 1
-    
+
     args, kwargs = train.call_args
     args, kwargs = train.call_args
     optimizer = args[4]
     optimizer = args[4]
-    
+
     print(optimizer.state_dict())
     print(optimizer.state_dict())
-    
+
     assert isinstance(optimizer, AdamW)
     assert isinstance(optimizer, AdamW)
     assert optimizer.state_dict()["param_groups"][0]["weight_decay"] == approx(0.01)
     assert optimizer.state_dict()["param_groups"][0]["weight_decay"] == approx(0.01)
-    
+
+
+@patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
+@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
+@patch('llama_recipes.finetuning.get_preprocessed_dataset')
+@patch('llama_recipes.finetuning.optim.AdamW')
+@patch('llama_recipes.finetuning.StepLR')
+def test_batching_strategy(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
+    kwargs = {"batching_strategy": "packing"}
+
+    get_dataset.return_value = get_fake_dataset()
+
+    main(**kwargs)
+
+    assert train.call_count == 1
+
+    args, kwargs = train.call_args
+    train_dataloader, eval_dataloader = args[1:3]
+    assert isinstance(train_dataloader.batch_sampler, BatchSampler)
+    assert isinstance(eval_dataloader.batch_sampler, BatchSampler)
+
+    kwargs["batching_strategy"] = "padding"
+    train.reset_mock()
+    main(**kwargs)
+
+    assert train.call_count == 1
+
+    args, kwargs = train.call_args
+    train_dataloader, eval_dataloader = args[1:3]
+    assert isinstance(train_dataloader.batch_sampler, LengthBasedBatchSampler)
+    assert isinstance(eval_dataloader.batch_sampler, LengthBasedBatchSampler)
+
+    kwargs["batching_strategy"] = "none"
+
+    with pytest.raises(ValueError):
+        main(**kwargs)

+ 86 - 0
tests/test_sampler.py

@@ -0,0 +1,86 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import random
+import pytest
+
+import torch
+
+from llama_recipes.data.sampler import LengthBasedBatchSampler
+from llama_recipes.data.sampler import DistributedLengthBasedBatchSampler
+
+SAMPLES = 33
+
+@pytest.fixture
+def dataset():
+    random.seed(42)
+    dataset = []
+    def add_samples(ds, n, a, b):
+        for _ in range(n):
+            ds.append(random.randint(a,b) * [1,])
+    add_samples(dataset, SAMPLES // 2, 1,9)
+    add_samples(dataset, (SAMPLES // 2) + (SAMPLES % 2), 10,20)
+    
+    return random.sample(dataset, len(dataset))
+    
+    
+@pytest.mark.parametrize("batch_size, drop_last", [(2, False), (8, False), (2, True), (8, True)])
+def test_batch_sampler_array(dataset, batch_size, drop_last):
+    
+    sampler = LengthBasedBatchSampler(dataset, batch_size, drop_last)
+    
+    EXPECTED_LENGTH = SAMPLES // batch_size if drop_last else (SAMPLES // batch_size) + (SAMPLES % batch_size)
+    
+    all_ids = [i for b in sampler for i in b]
+    assert len(set(all_ids)) == EXPECTED_LENGTH * batch_size if drop_last else len(dataset)
+    
+    assert len(sampler) == EXPECTED_LENGTH
+    is_long = [len(d)>=10 for d in dataset]
+    
+    def check_batch(batch):
+        return all(batch) or not any(batch)
+    
+    assert all(check_batch(is_long[i] for i in b) for b in sampler)
+    
+    
+@pytest.mark.parametrize("batch_size, drop_last", [(2, False), (8, False), (2, True), (8, True)])
+def test_batch_sampler_dict(dataset, batch_size, drop_last):
+    
+    dist_dataset = [{"input_ids": d, "attention_mask": d} for d in dataset]
+    
+    sampler = LengthBasedBatchSampler(dist_dataset, batch_size, drop_last)
+    
+    EXPECTED_LENGTH = SAMPLES // batch_size if drop_last else (SAMPLES // batch_size) + (SAMPLES % batch_size)
+    
+    assert len(sampler) == EXPECTED_LENGTH
+    is_long = [len(d)>=10 for d in dataset]
+    
+    def check_batch(batch):
+        return all(batch) or not any(batch)
+    
+    assert all(check_batch(is_long[i] for i in b) for b in sampler)
+    
+    
+@pytest.mark.parametrize("batch_size", [2, 8])
+def test_dist_batch_sampling(dataset, batch_size):
+    sampler_1 = DistributedLengthBasedBatchSampler(
+        dataset,
+        batch_size=batch_size,
+        rank=0,
+        num_replicas=2,
+        shuffle=False,
+    )
+    sampler_2 = DistributedLengthBasedBatchSampler(
+        dataset,
+        batch_size=batch_size,
+        rank=1,
+        num_replicas=2,
+        shuffle=False,
+    )
+    
+    ids_1 = set(i for b in sampler_1 for i in b)
+    ids_2 = set(i for b in sampler_2 for i in b)
+    
+    assert ids_1.isdisjoint(ids_2)
+    assert len(ids_1)+len(ids_2) > 0
+    assert len(ids_1)+len(ids_2) == len(dataset) // batch_size  *  batch_size