Browse Source

Merge branch 'main' into ipex_feature

Abhilash Majumder 1 year ago
parent
commit
6a78b96764

File diff suppressed because it is too large
+ 1 - 1
README.md


File diff suppressed because it is too large
+ 32 - 6
docs/Dataset.md


File diff suppressed because it is too large
+ 13 - 13
docs/LLM_finetuning.md


docs/images/featurebased_FN_.png → docs/images/feature-based_FN_2.png


File diff suppressed because it is too large
+ 1 - 0
examples/Getting_to_know_Llama.ipynb


+ 5 - 1
examples/README.md

@@ -31,4 +31,8 @@ For more in depth information on inference including inference safety checks and
 
 
 **Note** The [sensitive topics safety checker](../src/llama_recipes/inference/safety_utils.py) utilizes AuditNLG which is an optional dependency. Please refer to installation section of the main [README.md](../README.md#install-with-optional-dependencies) for details.
 **Note** The [sensitive topics safety checker](../src/llama_recipes/inference/safety_utils.py) utilizes AuditNLG which is an optional dependency. Please refer to installation section of the main [README.md](../README.md#install-with-optional-dependencies) for details.
 
 
-**Note** The **vLLM** example requires additional dependencies. Please refer to installation section of the main [README.md](../README.md#install-with-optional-dependencies) for details.
+**Note** The **vLLM** example requires additional dependencies. Please refer to installation section of the main [README.md](../README.md#install-with-optional-dependencies) for details.
+
+## Train on custom dataset
+To show how to train a model on a custom dataset we provide an example to generate a custom dataset in [custom_dataset.py](./custom_dataset.py).
+The usage of the custom dataset is further described in the datasets [README](../docs/Dataset.md#training-on-custom-data).

+ 1 - 1
examples/chat_completion/chat_completion.py

@@ -34,7 +34,7 @@ def main(
     enable_azure_content_safety: bool=False, # Enable safety check with Azure content safety api
     enable_azure_content_safety: bool=False, # Enable safety check with Azure content safety api
     enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs
     enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs
     enable_saleforce_content_safety: bool=True, # Enable safety check woth Saleforce safety flan t5
     enable_saleforce_content_safety: bool=True, # Enable safety check woth Saleforce safety flan t5
-    use_fast_kernels: bool = False, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
+    use_fast_kernels: bool = False, # Enable using SDPA from PyTorch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
     **kwargs
     **kwargs
 ):
 ):
     if prompt_file is not None:
     if prompt_file is not None:

+ 91 - 0
examples/custom_dataset.py

@@ -0,0 +1,91 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+# For dataset details visit: https://huggingface.co/datasets/samsum
+
+import copy
+import datasets
+import itertools
+
+from llama_recipes.datasets.utils import Concatenator
+
+
+B_INST, E_INST = "[INST]", "[/INST]"
+B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
+
+def tokenize_dialog(dialog, tokenizer):
+    dialog_tokens = [
+            tokenizer(
+                f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ",
+            )
+            for prompt, answer in zip(dialog[::2], dialog[1::2])
+        ]
+    if len(dialog) % 2:    
+        dialog_tokens += [tokenizer(
+            f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}",
+        )]
+    
+    combined_tokens = {}  
+    for k in dialog_tokens[0].keys():
+        combined_tokens[k] = list(itertools.chain(*(t[k] for t in dialog_tokens)))
+    return combined_tokens
+
+
+def get_custom_dataset(dataset_config, tokenizer, split):
+    dataset = datasets.load_dataset("OpenAssistant/oasst1", split=split)
+    
+    dataset = dataset.map(lambda sample: {
+        "message_id": sample["message_id"],
+        "parent_id": sample["parent_id"],
+        "text": sample["text"],
+        },
+        batched=True,
+        remove_columns=list(dataset.features),)
+    
+    nodes = {}
+    
+    messages = {}
+    root_ids = []
+    
+    for data in dataset:
+        if data["parent_id"]:
+            nodes[data["parent_id"]] = nodes.get(data["parent_id"], []) + [data["message_id"]]
+        else:
+            root_ids.append(data["message_id"])
+        messages[data["message_id"]]=data["text"]
+           
+    def follow(thread, current_id):
+        thread = copy.copy(thread) + [messages[current_id]]
+        if current_id in nodes:
+            new_threads = []
+            for next_id in nodes[current_id]:
+                new_threads += follow(thread, next_id)
+            return new_threads
+        else:
+            return [thread]
+        
+    def get_threads_from_root(root_id):
+        all_threads = []
+        thread = [messages[root_id]]
+        for cid in nodes[root_id]:
+            all_threads += follow(thread, cid)
+        return all_threads
+            
+    dataset = dataset.filter(lambda x: x["message_id"] in root_ids)
+    dataset = dataset.map(lambda x: {"thread": get_threads_from_root(x["message_id"])}, remove_columns=list(dataset.features))
+    dataset = dataset.map(lambda x: {"thread": [i for row in x["thread"] for i in row]}, batched=True)
+    
+    def to_dialog(thread):
+        dialog = []
+        for i, content in enumerate(thread):
+            dialog.append({
+                "role": "user" if i % 2 == 0 else "assistant",
+                "content": content,
+            })
+        return {"dialog": dialog}
+            
+    dataset = dataset.map(lambda x: to_dialog(x["thread"]), remove_columns=list(dataset.features))
+    dataset = dataset.map(lambda x: tokenize_dialog(x["dialog"], tokenizer), remove_columns=list(dataset.features))
+    dataset = dataset.map(Concatenator(), batched=True)
+    
+    return dataset

+ 1 - 7
examples/inference.py

@@ -76,13 +76,7 @@ def main(
             print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")
             print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")
 
 
     tokenizer = LlamaTokenizer.from_pretrained(model_name)
     tokenizer = LlamaTokenizer.from_pretrained(model_name)
-    tokenizer.add_special_tokens(
-        {
-         
-            "pad_token": "<PAD>",
-        }
-    )
-    model.resize_token_embeddings(model.config.vocab_size + 1) 
+    tokenizer.pad_token = tokenizer.eos_token
     
     
     safety_checker = get_safety_checker(enable_azure_content_safety,
     safety_checker = get_safety_checker(enable_azure_content_safety,
                                         enable_sensitive_topics,
                                         enable_sensitive_topics,

+ 3 - 6
examples/quickstart.ipynb

@@ -32,7 +32,7 @@
    "outputs": [],
    "outputs": [],
    "source": [
    "source": [
     "# %%bash\n",
     "# %%bash\n",
-    "# pip install transformers datasets accelerate sentencepiece protobuf==3.20 py7zr scipy peft bitsandbytes fire torch_tb_profiler ipywidgets\n",
+    "# pip install llama-recipes transformers datasets accelerate sentencepiece protobuf==3.20 py7zr scipy peft bitsandbytes fire torch_tb_profiler ipywidgets\n",
     "# TRANSFORM=`python -c \"import transformers;print('/'.join(transformers.__file__.split('/')[:-1])+'/models/llama/convert_llama_weights_to_hf.py')\"`\n",
     "# TRANSFORM=`python -c \"import transformers;print('/'.join(transformers.__file__.split('/')[:-1])+'/models/llama/convert_llama_weights_to_hf.py')\"`\n",
     "# python ${TRANSFORM} --input_dir models --model_size 7B --output_dir models_hf/7B"
     "# python ${TRANSFORM} --input_dir models --model_size 7B --output_dir models_hf/7B"
    ]
    ]
@@ -130,11 +130,8 @@
     }
     }
    ],
    ],
    "source": [
    "source": [
-    "from pathlib import Path\n",
-    "import os\n",
-    "import sys\n",
-    "from utils.dataset_utils import get_preprocessed_dataset\n",
-    "from configs.datasets import samsum_dataset\n",
+    "from llama_recipes.utils.dataset_utils import get_preprocessed_dataset\n",
+    "from llama_recipes.configs.datasets import samsum_dataset\n",
     "\n",
     "\n",
     "train_dataset = get_preprocessed_dataset(tokenizer, samsum_dataset, 'train')"
     "train_dataset = get_preprocessed_dataset(tokenizer, samsum_dataset, 'train')"
    ]
    ]

+ 10 - 1
scripts/spellcheck_conf/wordlist.txt

@@ -1147,4 +1147,13 @@ HuggingFace's
 LoRA
 LoRA
 bitsandbytes
 bitsandbytes
 CLA
 CLA
-dialogs
+dialogs
+OpenAssistant
+oasst1
+oasst
+AdamW
+Autocast
+FN
+GBs
+MLP
+learnable

+ 9 - 1
src/llama_recipes/configs/datasets.py

@@ -25,4 +25,12 @@ class alpaca_dataset:
     dataset: str = "alpaca_dataset"
     dataset: str = "alpaca_dataset"
     train_split: str = "train"
     train_split: str = "train"
     test_split: str = "val"
     test_split: str = "val"
-    data_path: str = "src/llama_recipes/datasets/alpaca_data.json"
+    data_path: str = "src/llama_recipes/datasets/alpaca_data.json"
+    
+    
+@dataclass
+class custom_dataset:
+    dataset: str = "custom_dataset"
+    file: str = "examples/custom_dataset.py"
+    train_split: str = "train"
+    test_split: str = "validation"

+ 1 - 2
src/llama_recipes/configs/fsdp.py

@@ -13,8 +13,7 @@ class fsdp_config:
     sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD
     sharding_strategy: ShardingStrategy = ShardingStrategy.FULL_SHARD
     checkpoint_type: StateDictType = StateDictType.SHARDED_STATE_DICT  # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
     checkpoint_type: StateDictType = StateDictType.SHARDED_STATE_DICT  # alternatively can use SHARDED_STATE_DICT save one file per rank, and can resize the world-size.
     fsdp_activation_checkpointing: bool=True
     fsdp_activation_checkpointing: bool=True
+    fsdp_cpu_offload: bool=False
     pure_bf16: bool = False
     pure_bf16: bool = False
     optimizer: str= "AdamW"
     optimizer: str= "AdamW"
     
     
-    
-    

+ 3 - 3
src/llama_recipes/configs/peft.py

@@ -1,14 +1,14 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 
-from dataclasses import dataclass
-from typing import ClassVar, List
+from dataclasses import dataclass, field
+from typing import List
 
 
 @dataclass
 @dataclass
 class lora_config:
 class lora_config:
      r: int=8
      r: int=8
      lora_alpha: int=32
      lora_alpha: int=32
-     target_modules: ClassVar[List[str]]= ["q_proj", "v_proj"]
+     target_modules: List[str] = field(default_factory=lambda: ["q_proj", "v_proj"])
      bias= "none"
      bias= "none"
      task_type: str= "CAUSAL_LM"
      task_type: str= "CAUSAL_LM"
      lora_dropout: float=0.05
      lora_dropout: float=0.05

+ 3 - 3
src/llama_recipes/datasets/grammar_dataset/grammar_dataset_process.ipynb

@@ -35,10 +35,10 @@
     "  (\" '\", \"'\"),\n",
     "  (\" '\", \"'\"),\n",
     "  (\" ?\", \"?\"),\n",
     "  (\" ?\", \"?\"),\n",
     "  (\" !\", \"!\"),\n",
     "  (\" !\", \"!\"),\n",
-    "  (\" :\", \"!\"),\n",
-    "  (\" ;\", \"!\"),\n",
+    "  (\" :\", \":\"),\n",
+    "  (\" ;\", \";\"),\n",
     "  (\" n't\", \"n't\"),\n",
     "  (\" n't\", \"n't\"),\n",
-    "  (\" v\", \"n't\"),\n",
+    "  (\" v\", \"v\"),\n",
     "  (\"2 0 0 6\", \"2006\"),\n",
     "  (\"2 0 0 6\", \"2006\"),\n",
     "  (\"5 5\", \"55\"),\n",
     "  (\"5 5\", \"55\"),\n",
     "  (\"4 0 0\", \"400\"),\n",
     "  (\"4 0 0\", \"400\"),\n",

+ 2 - 2
src/llama_recipes/datasets/utils.py

@@ -52,7 +52,7 @@ class ConcatDataset(Dataset):
             "labels": [],
             "labels": [],
             }
             }
         
         
-        for sample in tqdm(self.dataset, desc="Preprocessing dataset"):
+        for sample in tqdm(self.dataset, desc="Preprocessing dataset", dynamic_ncols=True):
             buffer = {k: v + sample[k] for k,v in buffer.items()}
             buffer = {k: v + sample[k] for k,v in buffer.items()}
             
             
             while len(next(iter(buffer.values()))) > self.chunk_size:
             while len(next(iter(buffer.values()))) > self.chunk_size:
@@ -63,4 +63,4 @@ class ConcatDataset(Dataset):
         return self.samples[idx]
         return self.samples[idx]
     
     
     def __len__(self):
     def __len__(self):
-        return len(self.samples)
+        return len(self.samples)

+ 4 - 1
src/llama_recipes/finetuning.py

@@ -12,6 +12,7 @@ from peft import get_peft_model, prepare_model_for_int8_training
 from torch.distributed.fsdp import (
 from torch.distributed.fsdp import (
     FullyShardedDataParallel as FSDP,
     FullyShardedDataParallel as FSDP,
 )
 )
+from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
 from torch.optim.lr_scheduler import StepLR
 from torch.optim.lr_scheduler import StepLR
 from torch.utils.data import DistributedSampler
 from torch.utils.data import DistributedSampler
 from transformers import (
 from transformers import (
@@ -150,6 +151,7 @@ def main(**kwargs):
         model = FSDP(
         model = FSDP(
             model,
             model,
             auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
             auto_wrap_policy= my_auto_wrapping_policy if train_config.use_peft else wrapping_policy,
+            cpu_offload=CPUOffload(offload_params=True) if fsdp_config.fsdp_cpu_offload else None,
             mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
             mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
             sharding_strategy=fsdp_config.sharding_strategy,
             sharding_strategy=fsdp_config.sharding_strategy,
             device_id=torch.xpu.current_device() if is_xpu_available() else torch.cuda.current_device(),
             device_id=torch.xpu.current_device() if is_xpu_available() else torch.cuda.current_device(),
@@ -233,12 +235,13 @@ def main(**kwargs):
             momentum_dtype=torch.bfloat16,
             momentum_dtype=torch.bfloat16,
             variance_dtype=torch.bfloat16,
             variance_dtype=torch.bfloat16,
             use_kahan_summation=False,
             use_kahan_summation=False,
+            weight_decay=train_config.weight_decay,
         )
         )
     else:
     else:
         optimizer = optim.AdamW(
         optimizer = optim.AdamW(
             model.parameters(),
             model.parameters(),
             lr=train_config.lr,
             lr=train_config.lr,
-            weight_decay=0.0,
+            weight_decay=train_config.weight_decay,
         )
         )
     scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
     scheduler = StepLR(optimizer, step_size=1, gamma=train_config.gamma)
 
 

+ 2 - 2
src/llama_recipes/inference/chat_utils.py

@@ -44,7 +44,7 @@ def format_tokens(dialogs, tokenizer):
             [
             [
                 tokenizer.encode(
                 tokenizer.encode(
                     f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ",
                     f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ",
-                )
+                ) + [tokenizer.eos_token_id]
                 for prompt, answer in zip(dialog[::2], dialog[1::2])
                 for prompt, answer in zip(dialog[::2], dialog[1::2])
             ],
             ],
             [],
             [],
@@ -62,4 +62,4 @@ def format_tokens(dialogs, tokenizer):
 def read_dialogs_from_file(file_path):
 def read_dialogs_from_file(file_path):
     with open(file_path, 'r') as file:
     with open(file_path, 'r') as file:
         dialogs = json.load(file)
         dialogs = json.load(file)
-    return dialogs
+    return dialogs

+ 7 - 6
src/llama_recipes/utils/config_utils.py

@@ -2,8 +2,7 @@
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 
 import inspect
 import inspect
-from dataclasses import fields
-
+from dataclasses import asdict
 from peft import (
 from peft import (
     LoraConfig,
     LoraConfig,
     AdaptionPromptConfig,
     AdaptionPromptConfig,
@@ -42,9 +41,10 @@ def generate_peft_config(train_config, kwargs):
     
     
     assert train_config.peft_method in names, f"Peft config not found: {train_config.peft_method}"
     assert train_config.peft_method in names, f"Peft config not found: {train_config.peft_method}"
     
     
-    config = configs[names.index(train_config.peft_method)]
+    config = configs[names.index(train_config.peft_method)]()
+    
     update_config(config, **kwargs)
     update_config(config, **kwargs)
-    params = {k.name: getattr(config, k.name) for k in fields(config)}
+    params = asdict(config)
     peft_config = peft_configs[names.index(train_config.peft_method)](**params)
     peft_config = peft_configs[names.index(train_config.peft_method)](**params)
     
     
     return peft_config
     return peft_config
@@ -52,10 +52,11 @@ def generate_peft_config(train_config, kwargs):
 
 
 def generate_dataset_config(train_config, kwargs):
 def generate_dataset_config(train_config, kwargs):
     names = tuple(DATASET_PREPROC.keys())
     names = tuple(DATASET_PREPROC.keys())
-    
+        
     assert train_config.dataset in names, f"Unknown dataset: {train_config.dataset}"
     assert train_config.dataset in names, f"Unknown dataset: {train_config.dataset}"
     
     
-    dataset_config = {k:v for k, v in inspect.getmembers(datasets)}[train_config.dataset]
+    dataset_config = {k:v for k, v in inspect.getmembers(datasets)}[train_config.dataset]()
+        
     update_config(dataset_config, **kwargs)
     update_config(dataset_config, **kwargs)
     
     
     return  dataset_config
     return  dataset_config

+ 38 - 0
src/llama_recipes/utils/dataset_utils.py

@@ -1,7 +1,9 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 
+import importlib
 from functools import partial
 from functools import partial
+from pathlib import Path
 
 
 import torch
 import torch
 
 
@@ -12,10 +14,46 @@ from llama_recipes.datasets import (
 )
 )
 
 
 
 
+def load_module_from_py_file(py_file: str) -> object:
+    """
+    This method loads a module from a py file which is not in the Python path
+    """
+    module_name = Path(py_file).name
+    loader = importlib.machinery.SourceFileLoader(module_name, py_file)
+    spec = importlib.util.spec_from_loader(module_name, loader)
+    module = importlib.util.module_from_spec(spec)
+
+    loader.exec_module(module)
+
+    return module
+
+
+def get_custom_dataset(dataset_config, tokenizer, split: str):
+    if ":" in dataset_config.file:
+        module_path, func_name = dataset_config.file.split(":")
+    else:
+        module_path, func_name = dataset_config.file, "get_custom_dataset"
+        
+    if not module_path.endswith(".py"):
+        raise ValueError(f"Dataset file {module_path} is not a .py file.")
+    
+    module_path = Path(module_path)
+    if not module_path.is_file():
+        raise FileNotFoundError(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.")
+    
+    module = load_module_from_py_file(module_path.as_posix())
+    try:
+        return getattr(module, func_name)(dataset_config, tokenizer, split)
+    except AttributeError as e:
+        print(f"It seems like the given method name ({func_name}) is not present in the dataset .py file ({module_path.as_posix()}).")
+        raise e
+    
+
 DATASET_PREPROC = {
 DATASET_PREPROC = {
     "alpaca_dataset": partial(get_alpaca_dataset, max_words=224),
     "alpaca_dataset": partial(get_alpaca_dataset, max_words=224),
     "grammar_dataset": get_grammar_dataset,
     "grammar_dataset": get_grammar_dataset,
     "samsum_dataset": get_samsum_dataset,
     "samsum_dataset": get_samsum_dataset,
+    "custom_dataset": get_custom_dataset,
 }
 }
 
 
 
 

+ 21 - 13
src/llama_recipes/utils/train_utils.py

@@ -4,6 +4,7 @@
 import os
 import os
 import time
 import time
 import yaml
 import yaml
+from contextlib import nullcontext
 from pathlib import Path
 from pathlib import Path
 from pkg_resources import packaging
 from pkg_resources import packaging
 
 
@@ -56,7 +57,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
     elif train_config.use_fp16 and not train_config.enable_fsdp:
     elif train_config.use_fp16 and not train_config.enable_fsdp:
         scaler = torch.cuda.amp.GradScaler() 
         scaler = torch.cuda.amp.GradScaler() 
     if train_config.enable_fsdp:
     if train_config.enable_fsdp:
-        world_size = int(os.environ["WORLD_SIZE"]) 
+        world_size = int(os.environ["WORLD_SIZE"])
+    autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext
+    
     train_prep = []
     train_prep = []
     train_loss = []
     train_loss = []
     val_prep = []
     val_prep = []
@@ -71,17 +74,21 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
             model.train()
             model.train()
             total_loss = 0.0
             total_loss = 0.0
             total_length = len(train_dataloader)//gradient_accumulation_steps
             total_length = len(train_dataloader)//gradient_accumulation_steps
-            pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch}", total=total_length)
+            pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True)
             for step, batch in enumerate(train_dataloader):
             for step, batch in enumerate(train_dataloader):
                 for key in batch.keys():
                 for key in batch.keys():
                     if train_config.enable_fsdp:
                     if train_config.enable_fsdp:
-                        batch[key] = batch[key].to(local_rank)
+                        if is_xpu_available():
+                            batch[key] = batch[key].to(torch.device(f"xpu:{local_rank}"))
+                        else:
+                            batch[key] = batch[key].to(local_rank)
                     else:
                     else:
                         if is_xpu_available():
                         if is_xpu_available():
                             batch[key] = batch[key].to('xpu:0')
                             batch[key] = batch[key].to('xpu:0')
                         else:
                         else:
                             batch[key] = batch[key].to('cuda:0')              
                             batch[key] = batch[key].to('cuda:0')              
-                loss = model(**batch).loss
+                with autocast():
+                    loss = model(**batch).loss
                 loss = loss / gradient_accumulation_steps
                 loss = loss / gradient_accumulation_steps
                 total_loss += loss.detach().float()
                 total_loss += loss.detach().float()
                 if train_config.use_fp16:
                 if train_config.use_fp16:
@@ -91,16 +98,17 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                         scaler.step(optimizer)
                         scaler.step(optimizer)
                         scaler.update()
                         scaler.update()
                         optimizer.zero_grad()
                         optimizer.zero_grad()
-                        pbar.update(step//gradient_accumulation_steps)
+                        pbar.update(1)
                 else:
                 else:
                     # regular backpropagation when fp16 is not used
                     # regular backpropagation when fp16 is not used
                     loss.backward()
                     loss.backward()
                     if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
                     if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
                         optimizer.step()
                         optimizer.step()
                         optimizer.zero_grad()
                         optimizer.zero_grad()
-                        pbar.update(step//gradient_accumulation_steps)
-                
-                pbar.set_description(f"Training Epoch: {epoch}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
+                        pbar.update(1)
+
+                pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
+            pbar.close()
                 
                 
         epoch_end_time = time.perf_counter()-epoch_start_time
         epoch_end_time = time.perf_counter()-epoch_start_time
         epoch_times.append(epoch_end_time)    
         epoch_times.append(epoch_end_time)    
@@ -195,16 +203,16 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                 best_val_loss = eval_epoch_loss
                 best_val_loss = eval_epoch_loss
                 if train_config.enable_fsdp:
                 if train_config.enable_fsdp:
                     if rank==0:
                     if rank==0:
-                        print(f"best eval loss on epoch {epoch} is {best_val_loss}")
+                        print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
                 else:
                 else:
-                    print(f"best eval loss on epoch {epoch} is {best_val_loss}")
+                    print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
             val_loss.append(best_val_loss)
             val_loss.append(best_val_loss)
             val_prep.append(eval_ppl)
             val_prep.append(eval_ppl)
         if train_config.enable_fsdp:
         if train_config.enable_fsdp:
             if rank==0:
             if rank==0:
-                print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epcoh time {epoch_end_time}s")
+                print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
         else:
         else:
-            print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epcoh time {epoch_end_time}s")
+            print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
     avg_epoch_time = sum(epoch_times)/ len(epoch_times)
     avg_epoch_time = sum(epoch_times)/ len(epoch_times)
     avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times) if len(checkpoint_times) > 0 else 0
     avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times) if len(checkpoint_times) > 0 else 0
     avg_train_prep = sum(train_prep)/len(train_prep)
     avg_train_prep = sum(train_prep)/len(train_prep)
@@ -245,7 +253,7 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
     eval_preds = []
     eval_preds = []
     eval_loss = 0.0  # Initialize evaluation loss
     eval_loss = 0.0  # Initialize evaluation loss
     with MemoryTrace() as memtrace:
     with MemoryTrace() as memtrace:
-        for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch")):
+        for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch", dynamic_ncols=True)):
             for key in batch.keys():
             for key in batch.keys():
                 if train_config.enable_fsdp:
                 if train_config.enable_fsdp:
                     batch[key] = batch[key].to(local_rank)
                     batch[key] = batch[key].to(local_rank)

+ 70 - 0
tests/datasets/test_custom_dataset.py

@@ -0,0 +1,70 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import pytest
+from unittest.mock import patch
+
+
+@patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
+@patch('llama_recipes.finetuning.optim.AdamW')
+@patch('llama_recipes.finetuning.StepLR')
+def test_custom_dataset(step_lr, optimizer, get_model, train, mocker):
+    from llama_recipes.finetuning import main
+
+    kwargs = {
+        "dataset": "custom_dataset",
+        "model_name": "decapoda-research/llama-7b-hf", # We use the tokenizer as a surrogate for llama2 tokenizer here
+        "custom_dataset.file": "examples/custom_dataset.py",
+        "custom_dataset.train_split": "validation",
+        "batch_size_training": 2,
+        "use_peft": False,
+        }
+
+    main(**kwargs)
+
+    assert train.call_count == 1
+
+    args, kwargs = train.call_args
+    train_dataloader = args[1]
+    eval_dataloader = args[2]
+    tokenizer = args[3]
+
+    assert len(train_dataloader) == 226
+    assert len(eval_dataloader) == 2*226
+
+    it = iter(train_dataloader)
+    STRING = tokenizer.decode(next(it)["input_ids"][0], skip_special_tokens=True)
+    EXPECTED_STRING = "[INST] Напиши функцию на языке swift, которая сортирует массив целых чисел, а затем выводит его на экран [/INST] Вот функция, "
+
+    assert STRING.startswith(EXPECTED_STRING)
+
+    next(it)
+    next(it)
+    next(it)
+    STRING = tokenizer.decode(next(it)["input_ids"][0], skip_special_tokens=True)
+    EXPECTED_SUBSTRING_1 = "Therefore you are correct.  [INST] How can L’Hopital’s Rule be"
+    EXPECTED_SUBSTRING_2 = "a circular path around the turn.  [INST] How on earth is that related to L’Hopital’s Rule?"
+
+    assert EXPECTED_SUBSTRING_1 in STRING
+    assert EXPECTED_SUBSTRING_2 in STRING
+
+
+@patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
+@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
+@patch('llama_recipes.finetuning.optim.AdamW')
+@patch('llama_recipes.finetuning.StepLR')
+def test_unknown_dataset_error(step_lr, optimizer, tokenizer, get_model, train, mocker):
+    from llama_recipes.finetuning import main
+
+    tokenizer.return_value = mocker.MagicMock(side_effect=lambda x: {"input_ids":[len(x)*[0,]], "attention_mask": [len(x)*[0,]]})
+
+    kwargs = {
+        "dataset": "custom_dataset",
+        "custom_dataset.file": "examples/custom_dataset.py:get_unknown_dataset",
+        "batch_size_training": 1,
+        "use_peft": False,
+        }
+    with pytest.raises(AttributeError):
+        main(**kwargs)

+ 37 - 0
tests/datasets/test_samsum_datasets.py

@@ -0,0 +1,37 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+from unittest.mock import patch
+
+
+@patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
+@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
+@patch('llama_recipes.finetuning.optim.AdamW')
+@patch('llama_recipes.finetuning.StepLR')
+def test_custom_dataset(step_lr, optimizer, tokenizer, get_model, train, mocker):
+    from llama_recipes.finetuning import main
+        
+    tokenizer.return_value = mocker.MagicMock(side_effect=lambda x: {"input_ids":[len(x)*[0,]], "attention_mask": [len(x)*[0,]]})
+    
+    
+    kwargs = {
+        "batch_size_training": 1,
+        "use_peft": False,
+        "dataset": "samsum_dataset",
+        }
+    
+    main(**kwargs)
+    
+    assert train.call_count == 1
+    
+    args, kwargs = train.call_args
+    train_dataloader = args[1]
+    eval_dataloader = args[2]
+    
+    VAL_SAMPLES = 818
+    TRAIN_SAMPLES = 14732
+    CONCAT_SIZE = 2048
+    assert len(train_dataloader) == TRAIN_SAMPLES // CONCAT_SIZE
+    assert len(eval_dataloader) == VAL_SAMPLES
+    

+ 36 - 2
tests/test_finetuning.py

@@ -1,6 +1,11 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+from pytest import approx
 from unittest.mock import patch
 from unittest.mock import patch
-import importlib
 
 
+from torch.nn import Linear
+from torch.optim import AdamW
 from torch.utils.data.dataloader import DataLoader
 from torch.utils.data.dataloader import DataLoader
 
 
 from llama_recipes.finetuning import main
 from llama_recipes.finetuning import main
@@ -69,4 +74,33 @@ def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, ge
     main(**kwargs)
     main(**kwargs)
     
     
     assert get_peft_model.return_value.to.call_args.args[0] == "cuda"
     assert get_peft_model.return_value.to.call_args.args[0] == "cuda"
-    assert get_peft_model.return_value.print_trainable_parameters.call_count == 1
+    assert get_peft_model.return_value.print_trainable_parameters.call_count == 1
+    
+    
+@patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
+@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
+@patch('llama_recipes.finetuning.get_preprocessed_dataset')
+@patch('llama_recipes.finetuning.get_peft_model')
+@patch('llama_recipes.finetuning.StepLR')
+def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer, get_model, train, mocker):
+    kwargs = {"weight_decay": 0.01}
+    
+    get_dataset.return_value = [1]
+    
+    model = mocker.MagicMock(name="model")
+    model.parameters.return_value = Linear(1,1).parameters()
+    get_peft_model.return_value = model 
+    get_peft_model.return_value.print_trainable_parameters=lambda:None
+    main(**kwargs)
+    
+    assert train.call_count == 1
+    
+    args, kwargs = train.call_args
+    optimizer = args[4]
+    
+    print(optimizer.state_dict())
+    
+    assert isinstance(optimizer, AdamW)
+    assert optimizer.state_dict()["param_groups"][0]["weight_decay"] == approx(0.01)
+    

+ 21 - 5
tests/test_train_utils.py

@@ -1,14 +1,22 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+from unittest.mock import patch
+
 import torch
 import torch
 
 
 from llama_recipes.utils.train_utils import train
 from llama_recipes.utils.train_utils import train
 
 
-def test_gradient_accumulation(mocker):
-    # import sys
-    # sys.path.append('/home/ubuntu/llama-recipes/')
+@patch("llama_recipes.utils.train_utils.MemoryTrace")
+@patch("llama_recipes.utils.train_utils.nullcontext")
+@patch("llama_recipes.utils.train_utils.torch.cuda.amp.GradScaler")
+@patch("llama_recipes.utils.train_utils.torch.cuda.amp.autocast")
+def test_gradient_accumulation(autocast, scaler, nullcontext, mem_trace, mocker):
     
     
     model = mocker.MagicMock(name="model")
     model = mocker.MagicMock(name="model")
     model().loss.__truediv__().detach.return_value = torch.tensor(1)
     model().loss.__truediv__().detach.return_value = torch.tensor(1)
-    batch = {"input": torch.zeros(1)}
+    mock_tensor = mocker.MagicMock(name="tensor")
+    batch = {"input": mock_tensor}
     train_dataloader = [batch, batch, batch, batch, batch]
     train_dataloader = [batch, batch, batch, batch, batch]
     eval_dataloader = None
     eval_dataloader = None
     tokenizer = mocker.MagicMock()
     tokenizer = mocker.MagicMock()
@@ -34,7 +42,13 @@ def test_gradient_accumulation(mocker):
     assert optimizer.zero_grad.call_count == 5
     assert optimizer.zero_grad.call_count == 5
     optimizer.zero_grad.reset_mock()
     optimizer.zero_grad.reset_mock()
     
     
+    assert nullcontext.call_count == 5
+    nullcontext.reset_mock()
+    
+    assert autocast.call_count == 0
+    
     gradient_accumulation_steps = 2
     gradient_accumulation_steps = 2
+    train_config.use_fp16 = True
     train(
     train(
         model,
         model,
         train_dataloader,
         train_dataloader,
@@ -45,4 +59,6 @@ def test_gradient_accumulation(mocker):
         gradient_accumulation_steps,
         gradient_accumulation_steps,
         train_config,
         train_config,
     )
     )
-    assert optimizer.zero_grad.call_count == 3
+    assert optimizer.zero_grad.call_count == 3
+    assert nullcontext.call_count == 0
+    assert autocast.call_count == 5