Explorar o código

Merge branch 'main' into l3p/llama_guard

Beto hai 7 meses
pai
achega
6f53f26e05

+ 2 - 1
recipes/README.md

@@ -2,7 +2,8 @@ This folder contains examples organized by topic:
 
 | Subfolder | Description |
 |---|---|
-[quickstart](./quickstart) | The "Hello World" of using Llama2, start here if you are new to using Llama2.
+[quickstart](./quickstart)|The "Hello World" of using Llama2, start here if you are new to using Llama2
+[multilingual](./multilingual)|Scripts to add a new language to Llama2
 [finetuning](./finetuning)|Scripts to finetune Llama2 on single-GPU and multi-GPU setups
 [inference](./inference)|Scripts to deploy Llama2 for inference locally and using model servers
 [use_cases](./use_cases)|Scripts showing common applications of Llama2

+ 21 - 5
recipes/finetuning/datasets/custom_dataset.py

@@ -11,11 +11,27 @@ import itertools
 B_INST, E_INST = "[INST]", "[/INST]"
 
 def tokenize_dialog(dialog, tokenizer):
-    prompt_tokens = [tokenizer.encode(f"{tokenizer.bos_token}{B_INST} {(prompt['content']).strip()} {E_INST}", add_special_tokens=False) for prompt in dialog[::2]]
-    answer_tokens = [tokenizer.encode(f"{answer['content'].strip()} {tokenizer.eos_token}", add_special_tokens=False) for answer in dialog[1::2]]
-    dialog_tokens = list(itertools.chain.from_iterable(zip(prompt_tokens, answer_tokens)))
-    #Add labels, convert prompt token to -100 in order to ignore in loss function
-    labels_tokens = [len(c)*[-100,] if i % 2 == 0 else c for i,c in enumerate(dialog_tokens)]
+    if tokenizer.vocab_size >= 128000:
+        dialog_tokens = tokenizer.apply_chat_template(dialog)
+        dialog_tokens = dialog_tokens[:-4] # Remove generation prompt <|start_header_id|>assistant<|end_header_id|>\n\n
+        eot_indices = [i for i,n in enumerate(dialog_tokens) if n == 128009]
+        labels = copy.copy(dialog_tokens)
+        last_idx = 0
+        for n, idx in enumerate(eot_indices):
+            if n % 2 == 1:
+                last_idx = idx
+            else:
+                labels[last_idx:idx+1] = [-100] * (idx-last_idx+1)
+
+        dialog_tokens = [dialog_tokens]
+        labels_tokens = [labels]
+    else:
+        prompt_tokens = [tokenizer.encode(f"{tokenizer.bos_token}{B_INST} {(prompt['content']).strip()} {E_INST}", add_special_tokens=False) for prompt in dialog[::2]]
+        answer_tokens = [tokenizer.encode(f"{answer['content'].strip()} {tokenizer.eos_token}", add_special_tokens=False) for answer in dialog[1::2]]
+        dialog_tokens = list(itertools.chain.from_iterable(zip(prompt_tokens, answer_tokens)))
+
+        #Add labels, convert prompt token to -100 in order to ignore in loss function
+        labels_tokens = [len(c)*[-100,] if i % 2 == 0 else c for i,c in enumerate(dialog_tokens)]
 
     combined_tokens = {
         "input_ids": list(itertools.chain(*(t for t in dialog_tokens))),

+ 6 - 6
recipes/inference/local_inference/chat_completion/chat_completion.py

@@ -8,9 +8,9 @@ import os
 import sys
 
 import torch
-from transformers import LlamaTokenizer
+from transformers import AutoTokenizer
 
-from llama_recipes.inference.chat_utils import read_dialogs_from_file, format_tokens
+from llama_recipes.inference.chat_utils import read_dialogs_from_file
 from llama_recipes.inference.model_utils import load_model, load_peft_model
 from llama_recipes.inference.safety_utils import get_safety_checker
 from accelerate.utils import is_xpu_available
@@ -65,15 +65,15 @@ def main(
     if peft_model:
         model = load_peft_model(model, peft_model)
 
-    tokenizer = LlamaTokenizer.from_pretrained(model_name)
+    tokenizer = AutoTokenizer.from_pretrained(model_name)
     tokenizer.add_special_tokens(
         {
-         
+
             "pad_token": "<PAD>",
         }
     )
-    
-    chats = format_tokens(dialogs, tokenizer)
+
+    chats = tokenizer.apply_chat_template(dialogs)
 
     with torch.no_grad():
         for idx, chat in enumerate(chats):

+ 2 - 3
recipes/inference/local_inference/inference.py

@@ -10,7 +10,7 @@ import time
 import gradio as gr
 
 import torch
-from transformers import LlamaTokenizer, AutoTokenizer
+from transformers import AutoTokenizer
 
 from llama_recipes.inference.safety_utils import get_safety_checker, AgentType
 from llama_recipes.inference.model_utils import load_model, load_peft_model
@@ -69,13 +69,12 @@ def main(
     else:
         torch.cuda.manual_seed(seed)
     torch.manual_seed(seed)
-    
+
     model = load_model(model_name, quantization, use_fast_kernels)
     if peft_model:
         model = load_peft_model(model, peft_model)
 
     model.eval()
-    
 
     tokenizer = AutoTokenizer.from_pretrained(model_name)
     tokenizer.pad_token = tokenizer.eos_token

+ 156 - 0
recipes/multilingual/README.md

@@ -0,0 +1,156 @@
+# Extending Llama to a new language
+Authored by : Sarvam team
+In this recipe, we will see how to add a new language to the Llama family of models. The steps are quite general and can be easily adapted to other models as well. Using this recipe, you should be able to replicate the findings of [OpenHathi](https://huggingface.co/sarvamai/OpenHathi-7B-Hi-v0.1-Base).
+Please read more about OpenHathi [here](https://www.sarvam.ai/blog/announcing-openhathi-series)
+## Data
+The original OpenHathi model uses a combination of [Sangraha](https://huggingface.co/datasets/ai4bharat/sangraha) and Wikipedia as its primary data sources. If the reader is interested in using these sources, they would also have to preprocess the data: clean, filter, and deduplicate. See [Setu](https://github.com/AI4Bharat/setu) for an easy way to do this at scale.
+
+In this tutorial, we will use the [Varta](https://huggingface.co/datasets/rahular/varta) dataset which contains 40M+ news articles taken from [DailyHunt](https://m.dailyhunt.in/). Since this data is already high-quality, we can skip the pre-processing step mentioned above. We will use the Hindi subset here, but you can add any other language present in the dataset by only passing the right language code (advanced users can also tweak the code to add multiple languages at once). 
+
+## Tokenizer
+Our first step towards augmenting a new language to an LLM is creating a better tokenizer. We define 'better' in terms of fertility score or the number of in-language tokens present in the tokenizer. Note that we should add new tokens without disturbing the original vocabulary, and therefore creating a better tokenizer usually involves 2 steps: (i) building a new, in-language only tokenizer, and (ii) merging this new tokenizer with the original. 
+
+### Building the in-language tokenizer
+For this, we will first download and prepare the data for training the tokenizer:
+
+```
+python prepare_data.py --split=validation --lang=hi --docs_to_sample=10000 --save_path=./data
+```
+
+Here we sample 10,000 Hindi documents from the validation split (we should ideally sample from the training split, but this is much faster) and save it as a text file inside `./data`. Next, we use this text to train a Hindi-only [sentencepiece](https://github.com/google/sentencepiece) tokenizer with a vocabulary size of 16,000.
+
+```
+python train_tokenizer.py --data_file=./data/hi.txt --save_path=./hi_tokenizer --vocab_size=16000
+```
+
+This creates a new sentencepiece Hindi tokenizer and saves it in `./hi_tokenizer`.
+
+### Merging the tokenizers
+This process can again be divided into 2 steps:
+- add new tokens to the original Llama2 tokenizer without disturbing its original vocabulary in any way
+- expand the input and output embedding matrices of Llama2 to be equal to the new vocabulary size
+
+We can do the first step by (i) downloading Llama2's `tokenizer.model` file, (ii) loading our Hindi `tokenizer.model` file, (iii) appending the Hindi tokens to Llama2 tokenizer's vocabulary if they are not already present, and (iv) save the extended tokenizer for future use. All this can be done by running
+
+```
+python extend_tokenizer.py --new_tokenizer_path=./hi_tokenizer --extended_tokenizer_save_path=./extended_tokenizer
+```
+
+Now, you have a new Llama2 tokenizer which works the same way on English text but can efficiently tokenize Hindi words as well. You can also test to see if it works as intended:
+
+```
+>>> from transformers import LlamaTokenizer
+>>> llama_tokenizer = LlamaTokenizer.from_pretrained('meta-llama/Llama-2-7b-chat-hf')
+>>> our_tokenizer = LlamaTokenizer.from_pretrained('./extended_tokenizer')
+>>> for i in range(len(llama_tokenizer)):
+...     assert llama_tokenizer.convert_ids_to_tokens(i) == our_tokenizer.convert_ids_to_tokens(i), f"Token mismatch at index {i}."
+...
+>>> text = "मैं एक अच्छा हाथी हूँ"
+>>> llama_tokenizer.tokenize(text)
+['▁', 'म', 'ै', 'ं', '▁', '<0xE0>', '<0xA4>', '<0x8F>', 'क', '▁', 'अ', 'च', '्', '<0xE0>', '<0xA4>', '<0x9B>', 'ा', '▁', 'ह', 'ा', 'थ', 'ी', '▁', 'ह', 'ू', '<0xE0>', '<0xA4>', '<0x81>']
+>>> our_tokenizer.tokenize(text)
+['▁मैं', '▁एक', '▁अच', '्', 'छा', '▁हाथी', '▁हूँ']
+```
+
+## Continual pre-training
+OpenHathi uses a two-stage pre-training process:
+- Phase 1: learn to translate paragraphs of text (use translated text as context and generate the original text, ~15B tokens)
+- Phase 2: bilingual next token prediction (train on text where the language changes after every sentence, ~15B tokens)
+
+Note: OpenHathi's final data mixture also contains monolingual data and romanized transliterations.
+
+We can easily create data for both phases using any translation model. OpenHathi uses [IndicTrans2](https://github.com/AI4Bharat/IndicTrans2). We provide sample code for both phases below.
+
+### Phase 1
+With the assumption that we don't have source-native data, let us first get some English data to translate. 
+
+```
+from datasets import load_dataset
+ds = load_dataset("rahular/varta", split="train", streaming=True)
+english_paragraphs = []
+for d in ds:
+    if d["langCode"] != "en": continue
+    english_paragraphs.append(" ".join(d["text"].split("\n")))
+```
+
+Now, our goal is to create data in the format `{translated_paragraph}\n\n{english_paragraph}`. We can use the `translate_paragraph` function ([link](https://github.com/AI4Bharat/IndicTrans2/blob/main/huggingface_interface/example.py#L150])) from the IndicTrans2 codebase to do this easily.
+
+```
+quantization = ""
+en_indic_ckpt_dir = "ai4bharat/indictrans2-en-indic-1B"
+en_indic_tokenizer, en_indic_model = initialize_model_and_tokenizer(en_indic_ckpt_dir, "en-indic", quantization)
+ip = IndicProcessor(inference=True)
+
+phase1_data = []
+for para in english_paragraphs:
+    trans_para = translate_paragraph(para, "eng_Latn", "hin_Deva", en_indic_model, en_indic_tokenizer, ip)
+    phase1_data.append({"text": f"{trans_para}\n\n{para}"})
+
+# if you want to save it for future, you can do so easily with HF datasets
+from datasets import Dataset
+phase1_ds = Dataset.from_list(phase1_data)
+phase1_ds.save_to_disk("data/phase1")
+```
+
+### Phase 2
+This is almost the same as phase 1, except that we have to replace the original sentences in an alternating manner to get the data in the required format. We can use the `split_sentences` ([link](https://github.com/AI4Bharat/IndicTrans2/blob/main/huggingface_interface/example.py#L60])) and `batch_translate` ([link](https://github.com/AI4Bharat/IndicTrans2/blob/main/huggingface_interface/example.py#L109)) functions to do this.
+
+```
+quantization = ""
+en_indic_ckpt_dir = "ai4bharat/indictrans2-en-indic-1B"
+en_indic_tokenizer, en_indic_model = initialize_model_and_tokenizer(en_indic_ckpt_dir, "en-indic", quantization)
+ip = IndicProcessor(inference=True)
+
+phase2_data = []
+for para in english_paragraphs:
+    en_sents = split_sentences(para, "eng_Latn")
+    trans_sents = batch_translate(input_sentences, "eng_Latn", "hin_Deva, en_indic_model, en_indic_tokenizer, ip)
+    final_para = []
+    for idx, (en_sent, trans_sent) in enumerate(zip(en_sents, trans_sents)):
+        sent_to_append = en_sent if idx % 2 == 0 else trans_sent
+        final_para.append(sent_to_append)
+    phase2_data.append({"text": " ".join(final_para)})
+
+# if you want to save it for future, you can do so easily with HF datasets
+from datasets import Dataset
+phase2_ds = Dataset.from_list(phase2_data)
+phase2_ds.save_to_disk("data/phase2")
+```
+
+### Train
+Finally, we can start finetuning Llama2 on these datasets by following the [finetuning recipes](https://github.com/meta-llama/llama-recipes/tree/main/recipes/finetuning). Remember to pass the new tokenizer path as an argument to the script: `--tokenizer_name=./extended_tokenizer`.
+
+OpenHathi was trained on 64 A100 80GB GPUs. Here are the hyperparameters used and other training details:
+- maximum learning rate: 2e-4
+- minimum learning rate: 2e-6
+- optimizer: AdamW (weight decay = 0.1)
+- beta1: 0.9
+- beta2: 0.95
+- lora rank: 128
+- lora alpha: 64
+- lora trainable: q_proj, v_proj, k_proj, o_proj, gate_proj, down_proj, up_proj
+- lora dropout: 0.05
+- block size: 4096
+- global batch size: 4M tokens
+- input and output embeddings are trainable
+- lr schedule: cosine decay with warmup (warmup ratio = 0.1, number of cycles = 3)
+- deepspeed stage 2
+- dtype: bfloat16
+
+The resulting (partial) loss plots from the OpenHathi training are shown below:
+
+Phase 1: train loss
+
+![Phase 1: train loss](imgs/phase1-train-loss.png)
+
+Phase 1: eval loss
+
+![Phase 1: eval loss](imgs/phase1-eval-loss.png)
+
+Phase 2: train loss
+
+![Phase 2: train loss](imgs/phase2-train-loss.png)
+
+Phase 2: eval loss
+
+![Phase 2: eval loss](imgs/phase2-eval-loss.png)

+ 52 - 0
recipes/multilingual/extend_tokenizer.py

@@ -0,0 +1,52 @@
+"""
+Code borrowed from https://github.com/ymcui/Chinese-LLaMA-Alpaca/blob/main/scripts/merge_tokenizer/merge_tokenizers.py
+"""
+
+import os
+import fire
+import re
+from transformers import LlamaTokenizer
+
+os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
+from huggingface_hub import hf_hub_download
+from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model
+
+
+def main(new_tokenizer_path, extended_tokenizer_save_path):
+    original_tokenizer_path = hf_hub_download(repo_id="meta-llama/Llama-2-7b-chat-hf", filename="tokenizer.model", local_dir="original_tokenizer")
+    original_tokenizer_spm = sp_pb2_model.ModelProto()
+    original_tokenizer_spm.ParseFromString(open(original_tokenizer_path, "rb").read())
+    new_tokenizer_spm = sp_pb2_model.ModelProto()
+    new_tokenizer_spm.ParseFromString(open(os.path.join(new_tokenizer_path, "tokenizer.model"), "rb").read())
+
+    def contains_eng(text):
+        eng_pattern = re.compile(r"[\u0020-\u007E]+")
+        return True if eng_pattern.search(text) else False
+
+    original_tokenizer_tokenset = set(p.piece for p in original_tokenizer_spm.pieces)
+    print(f"Number of tokens before merge: {len(original_tokenizer_tokenset)}")
+    for p in new_tokenizer_spm.pieces:
+        piece = p.piece
+        if piece not in original_tokenizer_tokenset and not contains_eng(piece):
+            new_p = sp_pb2_model.ModelProto().SentencePiece()
+            new_p.piece = piece
+            new_p.score = 0
+            original_tokenizer_spm.pieces.append(new_p)
+    print(f"Number of tokens after merge: {len(original_tokenizer_spm.pieces)}")
+
+    os.makedirs(extended_tokenizer_save_path, exist_ok=True)
+    with open(os.path.join(extended_tokenizer_save_path, "tokenizer.model"), "wb") as f:
+        f.write(original_tokenizer_spm.SerializeToString())
+    tokenizer = LlamaTokenizer(vocab_file=os.path.join(extended_tokenizer_save_path, "tokenizer.model"), legacy=False)
+    tokenizer.save_pretrained(extended_tokenizer_save_path)
+    print(f"Tokenizer saved to {extended_tokenizer_save_path}")
+
+    # Verify that the extended tokenizer's English vocab matches with that of the original Llama tokenizer
+    tok1 = LlamaTokenizer.from_pretrained('meta-llama/Llama-2-7b-chat-hf')
+    tok2 = LlamaTokenizer.from_pretrained(extended_tokenizer_save_path)
+    for i in range(len(tok1)):
+        assert tok1.convert_ids_to_tokens(i) == tok2.convert_ids_to_tokens(i), f"Token mismatch at index {i}."
+
+
+if __name__ == "__main__":
+    fire.Fire(main)

BIN=BIN
recipes/multilingual/imgs/phase1-eval-loss.png


BIN=BIN
recipes/multilingual/imgs/phase1-train-loss.png


BIN=BIN
recipes/multilingual/imgs/phase2-eval-loss.png


BIN=BIN
recipes/multilingual/imgs/phase2-train-loss.png


+ 23 - 0
recipes/multilingual/prepare_data.py

@@ -0,0 +1,23 @@
+import fire
+import os
+from datasets import load_dataset
+
+DATASET = "rahular/varta"
+
+def main(split="validation", lang="hi", docs_to_sample=10_000, save_path="data"):
+    dataset = load_dataset(DATASET, split=split, streaming=True)
+    os.makedirs(save_path, exist_ok=True)
+    with open(os.path.join(save_path, f"{lang}.txt"), "w") as f:
+        count = 0
+        for idx, d in enumerate(dataset):
+            if idx % 10_000 == 0:
+                print(f"Searched {idx} documents for {lang} documents. Found {count} documents.")
+            if count >= docs_to_sample:
+                break
+            if d["langCode"] == lang:
+                f.write(d["headline"] + "\n" + d["text"] + "\n")
+                count += 1
+
+
+if __name__ == "__main__":
+    fire.Fire(main)

+ 22 - 0
recipes/multilingual/train_tokenizer.py

@@ -0,0 +1,22 @@
+import fire
+import os
+import sentencepiece as spm
+
+def main(data_file, save_path, vocab_size=16_000, num_threads=8):
+    os.makedirs(save_path, exist_ok=True)
+    tokenizer_name = os.path.join(save_path, "tokenizer")
+    
+    spm.SentencePieceTrainer.train(
+        input=data_file,
+        model_prefix=tokenizer_name,
+        vocab_size=vocab_size,
+        num_threads=num_threads,
+        model_type="bpe",
+        max_sentence_length=1073741824,
+        shuffle_input_sentence="true",
+        character_coverage=1.0,
+        hard_vocab_limit="false",
+    )
+
+if __name__ == "__main__":
+    fire.Fire(main)

A diferenza do arquivo foi suprimida porque é demasiado grande
+ 196 - 0
recipes/responsible_ai/CodeShieldUsageDemo.ipynb


+ 1 - 1
requirements.txt

@@ -1,4 +1,4 @@
-torch>=2.0.1
+torch>=2.2
 accelerate
 appdirs
 loralib

+ 20 - 0
scripts/spellcheck_conf/wordlist.txt

@@ -1269,3 +1269,23 @@ Jfleg
 nnodes
 patht
 sbatch
+DailyHunt
+IndicTrans
+OpenHathi
+OpenHathi's
+Sangraha
+Sarvam
+Setu
+Varta
+bfloat
+codebase
+deduplicate
+dtype
+imgs
+lr
+proj
+romanized
+tokenize
+tokenizer's
+tokenizers
+warmup

+ 1 - 0
src/llama_recipes/configs/training.py

@@ -7,6 +7,7 @@ from dataclasses import dataclass
 @dataclass
 class train_config:
     model_name: str="PATH/to/LLAMA/7B"
+    tokenizer_name: str=None
     enable_fsdp: bool=False
     low_cpu_fsdp: bool=False
     run_validation: bool=True

+ 3 - 4
src/llama_recipes/data/sampler.py

@@ -20,7 +20,7 @@ class LengthBasedBatchSampler(torch.utils.data.BatchSampler):
         self.shuffle = shuffle
 
     def __iter__(self):
-        ids = np.argsort(self.lengths)
+        ids = np.argsort(self.lengths, kind='mergesort')
         if self.drop_last:
             ids = ids[:len(ids) // self.batch_size * self.batch_size]
 
@@ -47,11 +47,10 @@ class DistributedLengthBasedBatchSampler(torch.utils.data.BatchSampler):
             )
         self.num_replicas = num_replicas
         self.rank = rank
-        
+
     def __iter__(self):
         max_length = len(self.batch_sampler) // self.num_replicas * self.num_replicas
         return islice(self.batch_sampler, self.rank, max_length, self.num_replicas)
-         
+
     def __len__(self):
         return len(self.batch_sampler) // self.num_replicas
-            

+ 13 - 13
src/llama_recipes/finetuning.py

@@ -2,7 +2,6 @@
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 import os
-from pkg_resources import packaging
 
 import dataclasses
 import fire
@@ -18,8 +17,8 @@ from torch.distributed.fsdp import (
 from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
 from torch.optim.lr_scheduler import StepLR
 from transformers import (
+    AutoTokenizer,
     LlamaForCausalLM,
-    LlamaTokenizer,
     LlamaConfig,
 )
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer
@@ -51,7 +50,7 @@ from llama_recipes.utils.train_utils import (
 from accelerate.utils import is_xpu_available
 
 def setup_wandb(train_config, fsdp_config, **kwargs):
-    try: 
+    try:
         import wandb
     except ImportError:
         raise ImportError(
@@ -97,7 +96,7 @@ def main(**kwargs):
 
     if train_config.use_wandb:
         if not train_config.enable_fsdp or rank==0:
-            wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)    
+            wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)
 
     # Load the pre-trained model and setup its configuration
     use_cache = False if train_config.enable_fsdp else None
@@ -108,11 +107,6 @@ def main(**kwargs):
         model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms
         overhead and currently requires latest nightly.
         """
-        v = packaging.version.parse(torch.__version__)
-        verify_latest_nightly = v.is_devrelease and v.dev >= 20230701
-        if not verify_latest_nightly:
-            raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, "
-                            "please install latest nightly.")
         if rank == 0:
             model = LlamaForCausalLM.from_pretrained(
                 train_config.model_name,
@@ -137,9 +131,15 @@ def main(**kwargs):
         )
 
     # Load the tokenizer and add special tokens
-    tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name)
+    tokenizer = AutoTokenizer.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
     tokenizer.pad_token_id = tokenizer.eos_token_id
 
+    # If there is a mismatch between tokenizer vocab size and embedding matrix, 
+    # throw a warning and then expand the embedding matrix
+    if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
+        print("WARNING: Resizing the embedding matrix to match the tokenizer vocab size.")
+        model.resize_token_embeddings(len(tokenizer))
+
     print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
 
     # Prepare the model for int8 training if quantization is enabled
@@ -157,12 +157,12 @@ def main(**kwargs):
         if wandb_run:
             wandb_run.config.update(peft_config)
 
-        
+
     hsdp_device_mesh = None
     if fsdp_config.hsdp and fsdp_config.sharding_strategy == ShardingStrategy.HYBRID_SHARD:
         hsdp_device_mesh = hsdp_device_mesh(replica_group_size=fsdp_config.replica_group_size, sharding_group_size=fsdp_config.sharding_group_size)
         print("HSDP device mesh is ready")
-        
+
     #setting up FSDP if enable_fsdp is enabled
     if train_config.enable_fsdp:
         if not train_config.use_peft and train_config.freeze_layers:
@@ -171,7 +171,7 @@ def main(**kwargs):
 
         mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
         my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)
-        
+
         device_id = 0
         if is_xpu_available():
             device_id = torch.xpu.current_device()

+ 0 - 56
src/llama_recipes/inference/chat_utils.py

@@ -2,62 +2,6 @@
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 import json
-from typing import List, Literal, TypedDict
-
-
-Role = Literal["user", "assistant"]
-
-
-class Message(TypedDict):
-    role: Role
-    content: str
-
-
-Dialog = List[Message]
-
-B_INST, E_INST = "[INST]", "[/INST]"
-B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
-def format_tokens(dialogs, tokenizer):
-    prompt_tokens = []
-    for dialog in dialogs:
-        if dialog[0]["role"] == "system":
-            dialog = [
-            {
-                "role": dialog[1]["role"],
-                "content": B_SYS
-                + dialog[0]["content"]
-                + E_SYS
-                + dialog[1]["content"],
-            }
-        ] + dialog[2:]
-        assert all([msg["role"] == "user" for msg in dialog[::2]]) and all(
-            [msg["role"] == "assistant" for msg in dialog[1::2]]
-        ), (
-            "model only supports 'system','user' and 'assistant' roles, "
-            "starting with user and alternating (u/a/u/a/u...)"
-        )
-        """
-        Please verify that your tokenizer support adding "[INST]", "[/INST]" to your inputs.
-        Here, we are adding it manually.
-        """
-        dialog_tokens: List[int] = sum(
-            [
-                tokenizer.encode(
-                    f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ",
-                ) + [tokenizer.eos_token_id]
-                for prompt, answer in zip(dialog[::2], dialog[1::2])
-            ],
-            [],
-        )
-        assert (
-            dialog[-1]["role"] == "user"
-        ), f"Last message must be from user, got {dialog[-1]['role']}"
-        dialog_tokens += tokenizer.encode(
-            f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}",
-        )
-        prompt_tokens.append(dialog_tokens)
-    return prompt_tokens
-        
 
 def read_dialogs_from_file(file_path):
     with open(file_path, 'r') as file:

+ 1 - 0
src/llama_recipes/utils/config_utils.py

@@ -90,6 +90,7 @@ def get_dataloader_kwargs(train_config, dataset, tokenizer, mode):
                 rank=dist.get_rank(),
                 num_replicas=dist.get_world_size(),
                 shuffle=mode=="train",
+                drop_last=True,
             )
             kwargs["batch_size"] = batch_size
             kwargs["drop_last"] = True

+ 15 - 10
tests/conftest.py

@@ -3,21 +3,26 @@
 
 import pytest
 
-from transformers import LlamaTokenizer
+from transformers import AutoTokenizer
 
 ACCESS_ERROR_MSG = "Could not access tokenizer at 'meta-llama/Llama-2-7b-hf'. Did you log into huggingface hub and provided the correct token?"
+LLAMA_VERSIONS = ["meta-llama/Llama-2-7b-hf", "meta-llama/Llama-3-8b-hf"]
+
+@pytest.fixture(params=LLAMA_VERSIONS)
+def llama_version(request):
+    return request.param
 
 
 @pytest.fixture(scope="module")
-def llama_tokenizer():
-    return LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
+def llama_tokenizer(request):
+    return {k: AutoTokenizer.from_pretrained(k) for k in LLAMA_VERSIONS}
 
 
 @pytest.fixture
-def setup_tokenizer(llama_tokenizer):
+def setup_tokenizer(llama_tokenizer, llama_version):
     def _helper(tokenizer_mock):
         #Align with Llama 2 tokenizer
-        tokenizer_mock.from_pretrained.return_value = llama_tokenizer
+        tokenizer_mock.from_pretrained.return_value = llama_tokenizer[llama_version]
 
     return _helper
 
@@ -27,21 +32,21 @@ def pytest_addoption(parser):
         "--unskip-missing-tokenizer",
         action="store_true",
         default=False, help="disable skip missing tokenizer")
-    
+
 def pytest_configure(config):
     config.addinivalue_line("markers", "skip_missing_tokenizer: skip if tokenizer is unavailable")
 
-    
+
 def pytest_collection_modifyitems(config, items):
     if config.getoption("--unskip-missing-tokenizer"):
         return
-    
+
     try:
-        LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
+        AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
         tokenizer_available = True
     except OSError:
         tokenizer_available = False
-    
+
     skip_missing_tokenizer = pytest.mark.skip(reason=ACCESS_ERROR_MSG)
     for item in items:
         if "skip_missing_tokenizer" in item.keywords and not tokenizer_available:

+ 35 - 20
tests/datasets/test_custom_dataset.py

@@ -6,32 +6,50 @@ from unittest.mock import patch
 
 from transformers import LlamaTokenizer
 
-def check_padded_entry(batch):
+EXPECTED_RESULTS={
+    "meta-llama/Llama-2-7b-hf":{
+        "example_1": "[INST] Who made Berlin [/INST] dunno",
+        "example_2": "[INST] Quiero preparar una pizza de pepperoni, puedes darme los pasos para hacerla? [/INST] Claro!",
+    },
+    "meta-llama/Llama-3-8b-hf":{
+        "example_1": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nWho made Berlin<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\ndunno<|eot_id|><|end_of_text|>",
+        "example_2": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow to start learning guitar and become a master at it?",
+    },
+}
+
+def check_padded_entry(batch, tokenizer):
     seq_len = sum(batch["attention_mask"][0])
     assert seq_len < len(batch["attention_mask"][0])
 
+    if tokenizer.vocab_size >= 128000:
+        END_OF_TEXT_ID = 128009
+    else:
+        END_OF_TEXT_ID = tokenizer.eos_token_id
+
     assert batch["labels"][0][0] == -100
-    assert batch["labels"][0][seq_len-1] == 2
+    assert batch["labels"][0][seq_len-1] == END_OF_TEXT_ID
     assert batch["labels"][0][-1] == -100
-    assert batch["input_ids"][0][0] == 1
-    assert batch["input_ids"][0][-1] == 2
+    assert batch["input_ids"][0][0] == tokenizer.bos_token_id
+    assert batch["input_ids"][0][-1] == tokenizer.eos_token_id
 
 
 @pytest.mark.skip_missing_tokenizer
 @patch('llama_recipes.finetuning.train')
-@patch('llama_recipes.finetuning.LlamaTokenizer')
+@patch('llama_recipes.finetuning.AutoTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
-def test_custom_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer):
+def test_custom_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer, llama_version):
     from llama_recipes.finetuning import main
 
     setup_tokenizer(tokenizer)
 
+    skip_special_tokens = llama_version == "meta-llama/Llama-2-7b-hf"
+
     kwargs = {
         "dataset": "custom_dataset",
-        "model_name": "meta-llama/Llama-2-7b-hf",
-        "custom_dataset.file": "examples/custom_dataset.py",
+        "model_name": llama_version,
+        "custom_dataset.file": "recipes/finetuning/datasets/custom_dataset.py",
         "custom_dataset.train_split": "validation",
         "batch_size_training": 2,
         "val_batch_size": 4,
@@ -53,34 +71,31 @@ def test_custom_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker,
 
     it = iter(eval_dataloader)
     batch = next(it)
-    STRING = tokenizer.decode(batch["input_ids"][0], skip_special_tokens=True)
-    EXPECTED_STRING = "[INST] Who made Berlin [/INST] dunno"
-    assert STRING.startswith(EXPECTED_STRING)
+    STRING = tokenizer.decode(batch["input_ids"][0], skip_special_tokens=skip_special_tokens)
+    assert STRING.startswith(EXPECTED_RESULTS[llama_version]["example_1"])
 
     assert batch["input_ids"].size(0) == 4
     assert set(("labels", "input_ids", "attention_mask")) == set(batch.keys())
 
-    check_padded_entry(batch)
+    check_padded_entry(batch, tokenizer)
 
     it = iter(train_dataloader)
-    for _ in range(5):
-        next(it)
+    next(it)
 
     batch = next(it)
-    STRING = tokenizer.decode(batch["input_ids"][0], skip_special_tokens=True)
-    EXPECTED_STRING = "[INST] How do I initialize a Typescript project using npm and git? [/INST] # Initialize a new NPM project"
-    assert STRING.startswith(EXPECTED_STRING)
+    STRING = tokenizer.decode(batch["input_ids"][0], skip_special_tokens=skip_special_tokens)
+    assert STRING.startswith(EXPECTED_RESULTS[llama_version]["example_2"])
 
     assert batch["input_ids"].size(0) == 2
     assert set(("labels", "input_ids", "attention_mask")) == set(batch.keys())
 
-    check_padded_entry(batch)
+    check_padded_entry(batch, tokenizer)
 
 
 
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
-@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
+@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
 def test_unknown_dataset_error(step_lr, optimizer, tokenizer, get_model, train, mocker):
@@ -90,7 +105,7 @@ def test_unknown_dataset_error(step_lr, optimizer, tokenizer, get_model, train,
 
     kwargs = {
         "dataset": "custom_dataset",
-        "custom_dataset.file": "examples/custom_dataset.py:get_unknown_dataset",
+        "custom_dataset.file": "recipes/finetuning/datasets/custom_dataset.py:get_unknown_dataset",
         "batch_size_training": 1,
         "use_peft": False,
         }

+ 19 - 9
tests/datasets/test_grammar_datasets.py

@@ -4,23 +4,32 @@
 import pytest
 from unittest.mock import patch
 
-from transformers import LlamaTokenizer
 
+EXPECTED_RESULTS = {
+    "meta-llama/Llama-2-7b-hf":{
+        "label": 1152,
+        "pos": 31,
+    },
+    "meta-llama/Llama-3-8b-hf":{
+        "label": 40,
+        "pos": 26,
+    },
+}
 
 @pytest.mark.skip_missing_tokenizer
 @patch('llama_recipes.finetuning.train')
-@patch('llama_recipes.finetuning.LlamaTokenizer')
+@patch('llama_recipes.finetuning.AutoTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
-def test_grammar_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer):
+def test_grammar_dataset(step_lr, optimizer, get_model, tokenizer, train, setup_tokenizer, llama_version):
     from llama_recipes.finetuning import main
 
     setup_tokenizer(tokenizer)
 
     BATCH_SIZE = 8
     kwargs = {
-        "model_name": "meta-llama/Llama-2-7b-hf",
+        "model_name": llama_version,
         "batch_size_training": BATCH_SIZE,
         "val_batch_size": 1,
         "use_peft": False,
@@ -48,9 +57,10 @@ def test_grammar_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker
     assert "input_ids" in batch.keys()
     assert "attention_mask" in batch.keys()
 
-    assert batch["labels"][0][31] == -100
-    assert batch["labels"][0][32] == 1152
+    assert batch["labels"][0][EXPECTED_RESULTS[llama_version]["pos"]-1] == -100
+    assert batch["labels"][0][EXPECTED_RESULTS[llama_version]["pos"]] == EXPECTED_RESULTS[llama_version]["label"]
 
-    assert batch["input_ids"][0][0] == 1
-    assert batch["labels"][0][-1] == 2
-    assert batch["input_ids"][0][-1] == 2
+    token = args[3]
+    assert batch["input_ids"][0][0] == token.bos_token_id
+    assert batch["labels"][0][-1] == token.eos_token_id
+    assert batch["input_ids"][0][-1] == token.eos_token_id

+ 19 - 8
tests/datasets/test_samsum_datasets.py

@@ -5,21 +5,31 @@ import pytest
 from functools import partial
 from unittest.mock import patch
 
+EXPECTED_RESULTS = {
+    "meta-llama/Llama-2-7b-hf":{
+        "label": 8432,
+        "pos": 242,
+    },
+    "meta-llama/Llama-3-8b-hf":{
+        "label": 2250,
+        "pos": 211,
+    },
+}
 
 @pytest.mark.skip_missing_tokenizer
 @patch('llama_recipes.finetuning.train')
-@patch('llama_recipes.finetuning.LlamaTokenizer')
+@patch('llama_recipes.finetuning.AutoTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
-def test_samsum_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer):
+def test_samsum_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer, llama_version):
     from llama_recipes.finetuning import main
 
     setup_tokenizer(tokenizer)
 
     BATCH_SIZE = 8
     kwargs = {
-        "model_name": "meta-llama/Llama-2-7b-hf",
+        "model_name": llama_version,
         "batch_size_training": BATCH_SIZE,
         "val_batch_size": 1,
         "use_peft": False,
@@ -34,6 +44,7 @@ def test_samsum_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker,
     args, kwargs = train.call_args
     train_dataloader = args[1]
     eval_dataloader = args[2]
+    token = args[3]
 
     VAL_SAMPLES = 818
     TRAIN_SAMPLES = 14732
@@ -47,9 +58,9 @@ def test_samsum_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker,
     assert "input_ids" in batch.keys()
     assert "attention_mask" in batch.keys()
 
-    assert batch["labels"][0][268] == -100
-    assert batch["labels"][0][269] == 319
+    assert batch["labels"][0][EXPECTED_RESULTS[llama_version]["pos"]-1] == -100
+    assert batch["labels"][0][EXPECTED_RESULTS[llama_version]["pos"]] == EXPECTED_RESULTS[llama_version]["label"]
 
-    assert batch["input_ids"][0][0] == 1
-    assert batch["labels"][0][-1] == 2
-    assert batch["input_ids"][0][-1] == 2
+    assert batch["input_ids"][0][0] == token.bos_token_id
+    assert batch["labels"][0][-1] == token.eos_token_id
+    assert batch["input_ids"][0][-1] == token.eos_token_id

+ 21 - 11
tests/test_batching.py

@@ -4,20 +4,30 @@
 import pytest
 from unittest.mock import patch
 
+EXPECTED_SAMPLE_NUMBER ={
+    "meta-llama/Llama-2-7b-hf": {
+        "train": 96,
+        "eval": 42,
+    },
+    "meta-llama/Llama-3-8b-hf": {
+        "train": 79,
+        "eval": 34,
+    }
+}
 
 @pytest.mark.skip_missing_tokenizer
 @patch('llama_recipes.finetuning.train')
-@patch('llama_recipes.finetuning.LlamaTokenizer')
+@patch('llama_recipes.finetuning.AutoTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
-def test_packing(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer):
+def test_packing(step_lr, optimizer, get_model, tokenizer, train, setup_tokenizer, llama_version):
     from llama_recipes.finetuning import main
 
     setup_tokenizer(tokenizer)
 
     kwargs = {
-        "model_name": "meta-llama/Llama-2-7b-hf",
+        "model_name": llama_version,
         "batch_size_training": 8,
         "val_batch_size": 1,
         "use_peft": False,
@@ -33,8 +43,8 @@ def test_packing(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_
     train_dataloader = args[1]
     eval_dataloader = args[2]
 
-    assert len(train_dataloader) == 96
-    assert len(eval_dataloader) == 42
+    assert len(train_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["train"]
+    assert len(eval_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["eval"]
 
     batch = next(iter(train_dataloader))
 
@@ -49,7 +59,7 @@ def test_packing(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_
 
 @pytest.mark.skip_missing_tokenizer
 @patch('llama_recipes.finetuning.train')
-@patch('llama_recipes.finetuning.LlamaTokenizer')
+@patch('llama_recipes.finetuning.AutoTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
@@ -57,13 +67,13 @@ def test_packing(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_
 @patch('llama_recipes.finetuning.FSDP')
 @patch('llama_recipes.finetuning.torch.distributed.is_initialized')
 @patch('llama_recipes.utils.config_utils.dist')
-def test_distributed_packing(dist, is_initialized, fsdp, setup, step_lr, optimizer, get_model, tokenizer, train, setup_tokenizer):
+def test_distributed_packing(dist, is_initialized, fsdp, setup, step_lr, optimizer, get_model, tokenizer, train, setup_tokenizer, llama_version):
     import os
     from llama_recipes.finetuning import main
 
     setup_tokenizer(tokenizer)
 
-    rank = 0
+    rank = 1
     os.environ['LOCAL_RANK'] = f'{rank}'
     os.environ['RANK'] = f'{rank}'
     os.environ['WORLD_SIZE'] = '2'
@@ -71,7 +81,7 @@ def test_distributed_packing(dist, is_initialized, fsdp, setup, step_lr, optimiz
     os.environ['MASTER_PORT'] = '12345'
 
     kwargs = {
-        "model_name": "meta-llama/Llama-2-7b-hf",
+        "model_name": llama_version,
         "batch_size_training": 8,
         "val_batch_size": 1,
         "use_peft": False,
@@ -92,5 +102,5 @@ def test_distributed_packing(dist, is_initialized, fsdp, setup, step_lr, optimiz
     train_dataloader = args[1]
     eval_dataloader = args[2]
 
-    assert len(train_dataloader) == 96 //2
-    assert len(eval_dataloader) == 42 //2
+    assert len(train_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["train"] //2
+    assert len(eval_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["eval"] //2

+ 155 - 0
tests/test_chat_completion.py

@@ -0,0 +1,155 @@
+import sys
+from pathlib import Path
+from typing import List, Literal, TypedDict
+from unittest.mock import patch
+
+import pytest
+import torch
+from llama_recipes.inference.chat_utils import read_dialogs_from_file
+
+ROOT_DIR = Path(__file__).parents[1]
+CHAT_COMPLETION_DIR = ROOT_DIR / "recipes/inference/local_inference/chat_completion/"
+
+sys.path = [CHAT_COMPLETION_DIR.as_posix()] + sys.path
+
+Role = Literal["user", "assistant"]
+
+
+class Message(TypedDict):
+    role: Role
+    content: str
+
+
+Dialog = List[Message]
+
+B_INST, E_INST = "[INST]", "[/INST]"
+B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
+
+
+def _encode_header(message, tokenizer):
+    tokens = []
+    tokens.extend(tokenizer.encode("<|start_header_id|>"))
+    tokens.extend(tokenizer.encode(message["role"]))
+    tokens.extend(tokenizer.encode("<|end_header_id|>"))
+    tokens.extend(tokenizer.encode("\n\n"))
+    return tokens
+
+
+def _encode_message(message, tokenizer):
+    tokens = _encode_header(message, tokenizer)
+    tokens.extend(tokenizer.encode(message["content"].strip()))
+    tokens.extend(tokenizer.encode("<|eot_id|>"))
+    return tokens
+
+
+def _format_dialog(dialog, tokenizer):
+    tokens = []
+    tokens.extend(tokenizer.encode("<|begin_of_text|>"))
+    for msg in dialog:
+        tokens.extend(_encode_message(msg, tokenizer))
+    tokens.extend(_encode_header({"role": "assistant", "content": ""}, tokenizer))
+    return tokens
+
+
+def _format_tokens_llama3(dialogs, tokenizer):
+    return [_format_dialog(dialog, tokenizer) for dialog in dialogs]
+
+
+def _format_tokens_llama2(dialogs, tokenizer):
+    prompt_tokens = []
+    for dialog in dialogs:
+        if dialog[0]["role"] == "system":
+            dialog = [
+                {
+                    "role": dialog[1]["role"],
+                    "content": B_SYS
+                    + dialog[0]["content"]
+                    + E_SYS
+                    + dialog[1]["content"],
+                }
+            ] + dialog[2:]
+        assert all([msg["role"] == "user" for msg in dialog[::2]]) and all(
+            [msg["role"] == "assistant" for msg in dialog[1::2]]
+        ), (
+            "model only supports 'system','user' and 'assistant' roles, "
+            "starting with user and alternating (u/a/u/a/u...)"
+        )
+        """
+        Please verify that your tokenizer support adding "[INST]", "[/INST]" to your inputs.
+        Here, we are adding it manually.
+        """
+        dialog_tokens: List[int] = sum(
+            [
+                tokenizer.encode(
+                    f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ",
+                )
+                + [tokenizer.eos_token_id]
+                for prompt, answer in zip(dialog[::2], dialog[1::2])
+            ],
+            [],
+        )
+        assert (
+            dialog[-1]["role"] == "user"
+        ), f"Last message must be from user, got {dialog[-1]['role']}"
+        dialog_tokens += tokenizer.encode(
+            f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}",
+        )
+        prompt_tokens.append(dialog_tokens)
+    return prompt_tokens
+
+
+@pytest.mark.skip_missing_tokenizer
+@patch("chat_completion.AutoTokenizer")
+@patch("chat_completion.load_model")
+def test_chat_completion(
+    load_model, tokenizer, setup_tokenizer, llama_tokenizer, llama_version
+):
+    from chat_completion import main
+
+    setup_tokenizer(tokenizer)
+
+    kwargs = {
+        "prompt_file": (CHAT_COMPLETION_DIR / "chats.json").as_posix(),
+    }
+
+    main(llama_version, **kwargs)
+
+    dialogs = read_dialogs_from_file(kwargs["prompt_file"])
+    format_tokens = (
+        _format_tokens_llama2
+        if llama_version == "meta-llama/Llama-2-7b-hf"
+        else _format_tokens_llama3
+    )
+
+    REF_RESULT = format_tokens(dialogs, llama_tokenizer[llama_version])
+
+    assert all(
+        (
+            load_model.return_value.generate.mock_calls[0 * 4][2]["input_ids"].cpu()
+            == torch.tensor(REF_RESULT[0]).long()
+        ).tolist()
+    )
+    assert all(
+        (
+            load_model.return_value.generate.mock_calls[1 * 4][2]["input_ids"].cpu()
+            == torch.tensor(REF_RESULT[1]).long()
+        ).tolist()
+    )
+    assert all(
+        (
+            load_model.return_value.generate.mock_calls[2 * 4][2]["input_ids"].cpu()
+            == torch.tensor(REF_RESULT[2]).long()
+        ).tolist()
+    )
+    assert all(
+        (
+            load_model.return_value.generate.mock_calls[3 * 4][2]["input_ids"].cpu()
+            == torch.tensor(REF_RESULT[3]).long()
+        ).tolist()
+    )
+    assert all(
+        (
+            load_model.return_value.generate.mock_calls[4 * 4][2]["input_ids"].cpu()
+            == torch.tensor(REF_RESULT[4]).long()
+        ).tolist()
+    )

+ 26 - 19
tests/test_finetuning.py

@@ -21,17 +21,19 @@ def get_fake_dataset():
         "labels":[1],
         }]
 
-
+@patch('llama_recipes.finetuning.torch.cuda.is_available')
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
-@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
+@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.get_preprocessed_dataset')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
-def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
+@pytest.mark.parametrize("cuda_is_available", [True, False])
+def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train, cuda, cuda_is_available):
     kwargs = {"run_validation": False}
 
     get_dataset.return_value = get_fake_dataset()
+    cuda.return_value = cuda_is_available
 
     main(**kwargs)
 
@@ -44,23 +46,26 @@ def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, ge
     assert isinstance(train_dataloader, DataLoader)
     assert eval_dataloader is None
 
-    if torch.cuda.is_available():
+    if cuda_is_available:
         assert get_model.return_value.to.call_count == 1
         assert get_model.return_value.to.call_args.args[0] == "cuda"
     else:
         assert get_model.return_value.to.call_count == 0
 
 
+@patch('llama_recipes.finetuning.torch.cuda.is_available')
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
-@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
+@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.get_preprocessed_dataset')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
-def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
+@pytest.mark.parametrize("cuda_is_available", [True, False])
+def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train, cuda, cuda_is_available):
     kwargs = {"run_validation": True}
 
     get_dataset.return_value = get_fake_dataset()
+    cuda.return_value = cuda_is_available
 
     main(**kwargs)
 
@@ -72,40 +77,42 @@ def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer,
     assert isinstance(train_dataloader, DataLoader)
     assert isinstance(eval_dataloader, DataLoader)
 
-    if torch.cuda.is_available():
+    if cuda_is_available:
         assert get_model.return_value.to.call_count == 1
         assert get_model.return_value.to.call_args.args[0] == "cuda"
     else:
         assert get_model.return_value.to.call_count == 0
 
-
+@patch('llama_recipes.finetuning.torch.cuda.is_available')
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
-@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
+@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.get_preprocessed_dataset')
 @patch('llama_recipes.finetuning.generate_peft_config')
 @patch('llama_recipes.finetuning.get_peft_model')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
-def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, get_dataset, tokenizer, get_model, train):
+@pytest.mark.parametrize("cuda_is_available", [True, False])
+def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, get_dataset, tokenizer, get_model, train, cuda, cuda_is_available):
     kwargs = {"use_peft": True}
 
     get_dataset.return_value = get_fake_dataset()
+    cuda.return_value = cuda_is_available
 
     main(**kwargs)
 
-    if torch.cuda.is_available():
-        assert get_model.return_value.to.call_count == 1
-        assert get_model.return_value.to.call_args.args[0] == "cuda"
+    if cuda_is_available:
+        assert get_peft_model.return_value.to.call_count == 1
+        assert get_peft_model.return_value.to.call_args.args[0] == "cuda"
     else:
-        assert get_model.return_value.to.call_count == 0
-    
+        assert get_peft_model.return_value.to.call_count == 0
+
     assert get_peft_model.return_value.print_trainable_parameters.call_count == 1
 
 
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
-@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
+@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.get_preprocessed_dataset')
 @patch('llama_recipes.finetuning.get_peft_model')
 @patch('llama_recipes.finetuning.StepLR')
@@ -113,11 +120,11 @@ def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer
     kwargs = {"weight_decay": 0.01}
 
     get_dataset.return_value = get_fake_dataset()
-    
+
     model = mocker.MagicMock(name="Model")
     model.parameters.return_value = [torch.ones(1,1)]
 
-    get_model.return_value = model 
+    get_model.return_value = model
 
     main(**kwargs)
 
@@ -134,7 +141,7 @@ def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer
 
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
-@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
+@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.get_preprocessed_dataset')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')