Browse Source

Merge pull request #3 from albertodepaola/l3p/finetuning_inference_chat_mods

Llama 3 modification for finetuning/local_inference/chat_inference/tests
Hamid Shojanazeri 7 tháng trước cách đây
mục cha
commit
d717be8ad4

+ 21 - 5
recipes/finetuning/datasets/custom_dataset.py

@@ -11,11 +11,27 @@ import itertools
 B_INST, E_INST = "[INST]", "[/INST]"
 
 def tokenize_dialog(dialog, tokenizer):
-    prompt_tokens = [tokenizer.encode(f"{tokenizer.bos_token}{B_INST} {(prompt['content']).strip()} {E_INST}", add_special_tokens=False) for prompt in dialog[::2]]
-    answer_tokens = [tokenizer.encode(f"{answer['content'].strip()} {tokenizer.eos_token}", add_special_tokens=False) for answer in dialog[1::2]]
-    dialog_tokens = list(itertools.chain.from_iterable(zip(prompt_tokens, answer_tokens)))
-    #Add labels, convert prompt token to -100 in order to ignore in loss function
-    labels_tokens = [len(c)*[-100,] if i % 2 == 0 else c for i,c in enumerate(dialog_tokens)]
+    if tokenizer.vocab_size >= 128000:
+        dialog_tokens = tokenizer.apply_chat_template(dialog)
+        dialog_tokens = dialog_tokens[:-4] # Remove generation prompt <|start_header_id|>assistant<|end_header_id|>\n\n
+        eot_indices = [i for i,n in enumerate(dialog_tokens) if n == 128009]
+        labels = copy.copy(dialog_tokens)
+        last_idx = 0
+        for n, idx in enumerate(eot_indices):
+            if n % 2 == 1:
+                last_idx = idx
+            else:
+                labels[last_idx:idx+1] = [-100] * (idx-last_idx+1)
+
+        dialog_tokens = [dialog_tokens]
+        labels_tokens = [labels]
+    else:
+        prompt_tokens = [tokenizer.encode(f"{tokenizer.bos_token}{B_INST} {(prompt['content']).strip()} {E_INST}", add_special_tokens=False) for prompt in dialog[::2]]
+        answer_tokens = [tokenizer.encode(f"{answer['content'].strip()} {tokenizer.eos_token}", add_special_tokens=False) for answer in dialog[1::2]]
+        dialog_tokens = list(itertools.chain.from_iterable(zip(prompt_tokens, answer_tokens)))
+
+        #Add labels, convert prompt token to -100 in order to ignore in loss function
+        labels_tokens = [len(c)*[-100,] if i % 2 == 0 else c for i,c in enumerate(dialog_tokens)]
 
     combined_tokens = {
         "input_ids": list(itertools.chain(*(t for t in dialog_tokens))),

+ 6 - 6
recipes/inference/local_inference/chat_completion/chat_completion.py

@@ -8,9 +8,9 @@ import os
 import sys
 
 import torch
-from transformers import LlamaTokenizer
+from transformers import AutoTokenizer
 
-from llama_recipes.inference.chat_utils import read_dialogs_from_file, format_tokens
+from llama_recipes.inference.chat_utils import read_dialogs_from_file
 from llama_recipes.inference.model_utils import load_model, load_peft_model
 from llama_recipes.inference.safety_utils import get_safety_checker
 from accelerate.utils import is_xpu_available
@@ -65,15 +65,15 @@ def main(
     if peft_model:
         model = load_peft_model(model, peft_model)
 
-    tokenizer = LlamaTokenizer.from_pretrained(model_name)
+    tokenizer = AutoTokenizer.from_pretrained(model_name)
     tokenizer.add_special_tokens(
         {
-         
+
             "pad_token": "<PAD>",
         }
     )
-    
-    chats = format_tokens(dialogs, tokenizer)
+
+    chats = tokenizer.apply_chat_template(dialogs)
 
     with torch.no_grad():
         for idx, chat in enumerate(chats):

+ 5 - 5
recipes/inference/local_inference/inference.py

@@ -10,7 +10,7 @@ import time
 import gradio as gr
 
 import torch
-from transformers import LlamaTokenizer
+from transformers import AutoTokenizer
 
 from llama_recipes.inference.safety_utils import get_safety_checker, AgentType
 from llama_recipes.inference.model_utils import load_model, load_peft_model
@@ -69,17 +69,17 @@ def main(
     else:
         torch.cuda.manual_seed(seed)
     torch.manual_seed(seed)
-    
+
     model = load_model(model_name, quantization, use_fast_kernels)
     if peft_model:
         model = load_peft_model(model, peft_model)
 
     model.eval()
-    
 
-    tokenizer = LlamaTokenizer.from_pretrained(model_name)
+
+    tokenizer = AutoTokenizer.from_pretrained(model_name)
     tokenizer.pad_token = tokenizer.eos_token
-    
+
     batch = tokenizer(user_prompt, padding='max_length', truncation=True, max_length=max_padding_length, return_tensors="pt")
     if is_xpu_available():
         batch = {k: v.to("xpu") for k, v in batch.items()}

+ 1 - 1
requirements.txt

@@ -1,4 +1,4 @@
-torch>=2.0.1
+torch>=2.2
 accelerate
 appdirs
 loralib

+ 3 - 4
src/llama_recipes/data/sampler.py

@@ -20,7 +20,7 @@ class LengthBasedBatchSampler(torch.utils.data.BatchSampler):
         self.shuffle = shuffle
 
     def __iter__(self):
-        ids = np.argsort(self.lengths)
+        ids = np.argsort(self.lengths, kind='mergesort')
         if self.drop_last:
             ids = ids[:len(ids) // self.batch_size * self.batch_size]
 
@@ -47,11 +47,10 @@ class DistributedLengthBasedBatchSampler(torch.utils.data.BatchSampler):
             )
         self.num_replicas = num_replicas
         self.rank = rank
-        
+
     def __iter__(self):
         max_length = len(self.batch_sampler) // self.num_replicas * self.num_replicas
         return islice(self.batch_sampler, self.rank, max_length, self.num_replicas)
-         
+
     def __len__(self):
         return len(self.batch_sampler) // self.num_replicas
-            

+ 7 - 13
src/llama_recipes/finetuning.py

@@ -2,7 +2,6 @@
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 import os
-from pkg_resources import packaging
 
 import dataclasses
 import fire
@@ -18,8 +17,8 @@ from torch.distributed.fsdp import (
 from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
 from torch.optim.lr_scheduler import StepLR
 from transformers import (
+    AutoTokenizer,
     LlamaForCausalLM,
-    LlamaTokenizer,
     LlamaConfig,
 )
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer
@@ -51,7 +50,7 @@ from llama_recipes.utils.train_utils import (
 from accelerate.utils import is_xpu_available
 
 def setup_wandb(train_config, fsdp_config, **kwargs):
-    try: 
+    try:
         import wandb
     except ImportError:
         raise ImportError(
@@ -97,7 +96,7 @@ def main(**kwargs):
 
     if train_config.use_wandb:
         if not train_config.enable_fsdp or rank==0:
-            wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)    
+            wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)
 
     # Load the pre-trained model and setup its configuration
     use_cache = False if train_config.enable_fsdp else None
@@ -108,11 +107,6 @@ def main(**kwargs):
         model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms
         overhead and currently requires latest nightly.
         """
-        v = packaging.version.parse(torch.__version__)
-        verify_latest_nightly = v.is_devrelease and v.dev >= 20230701
-        if not verify_latest_nightly:
-            raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, "
-                            "please install latest nightly.")
         if rank == 0:
             model = LlamaForCausalLM.from_pretrained(
                 train_config.model_name,
@@ -137,7 +131,7 @@ def main(**kwargs):
         )
 
     # Load the tokenizer and add special tokens
-    tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name)
+    tokenizer = AutoTokenizer.from_pretrained(train_config.model_name)
     tokenizer.pad_token_id = tokenizer.eos_token_id
 
     print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
@@ -157,12 +151,12 @@ def main(**kwargs):
         if wandb_run:
             wandb_run.config.update(peft_config)
 
-        
+
     hsdp_device_mesh = None
     if fsdp_config.hsdp and fsdp_config.sharding_strategy == ShardingStrategy.HYBRID_SHARD:
         hsdp_device_mesh = hsdp_device_mesh(replica_group_size=fsdp_config.replica_group_size, sharding_group_size=fsdp_config.sharding_group_size)
         print("HSDP device mesh is ready")
-        
+
     #setting up FSDP if enable_fsdp is enabled
     if train_config.enable_fsdp:
         if not train_config.use_peft and train_config.freeze_layers:
@@ -171,7 +165,7 @@ def main(**kwargs):
 
         mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
         my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)
-        
+
         device_id = 0
         if is_xpu_available():
             device_id = torch.xpu.current_device()

+ 0 - 56
src/llama_recipes/inference/chat_utils.py

@@ -2,62 +2,6 @@
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 import json
-from typing import List, Literal, TypedDict
-
-
-Role = Literal["user", "assistant"]
-
-
-class Message(TypedDict):
-    role: Role
-    content: str
-
-
-Dialog = List[Message]
-
-B_INST, E_INST = "[INST]", "[/INST]"
-B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
-def format_tokens(dialogs, tokenizer):
-    prompt_tokens = []
-    for dialog in dialogs:
-        if dialog[0]["role"] == "system":
-            dialog = [
-            {
-                "role": dialog[1]["role"],
-                "content": B_SYS
-                + dialog[0]["content"]
-                + E_SYS
-                + dialog[1]["content"],
-            }
-        ] + dialog[2:]
-        assert all([msg["role"] == "user" for msg in dialog[::2]]) and all(
-            [msg["role"] == "assistant" for msg in dialog[1::2]]
-        ), (
-            "model only supports 'system','user' and 'assistant' roles, "
-            "starting with user and alternating (u/a/u/a/u...)"
-        )
-        """
-        Please verify that your tokenizer support adding "[INST]", "[/INST]" to your inputs.
-        Here, we are adding it manually.
-        """
-        dialog_tokens: List[int] = sum(
-            [
-                tokenizer.encode(
-                    f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ",
-                ) + [tokenizer.eos_token_id]
-                for prompt, answer in zip(dialog[::2], dialog[1::2])
-            ],
-            [],
-        )
-        assert (
-            dialog[-1]["role"] == "user"
-        ), f"Last message must be from user, got {dialog[-1]['role']}"
-        dialog_tokens += tokenizer.encode(
-            f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}",
-        )
-        prompt_tokens.append(dialog_tokens)
-    return prompt_tokens
-        
 
 def read_dialogs_from_file(file_path):
     with open(file_path, 'r') as file:

+ 1 - 0
src/llama_recipes/utils/config_utils.py

@@ -90,6 +90,7 @@ def get_dataloader_kwargs(train_config, dataset, tokenizer, mode):
                 rank=dist.get_rank(),
                 num_replicas=dist.get_world_size(),
                 shuffle=mode=="train",
+                drop_last=True,
             )
             kwargs["batch_size"] = batch_size
             kwargs["drop_last"] = True

+ 15 - 10
tests/conftest.py

@@ -3,21 +3,26 @@
 
 import pytest
 
-from transformers import LlamaTokenizer
+from transformers import AutoTokenizer
 
 ACCESS_ERROR_MSG = "Could not access tokenizer at 'meta-llama/Llama-2-7b-hf'. Did you log into huggingface hub and provided the correct token?"
+LLAMA_VERSIONS = ["meta-llama/Llama-2-7b-hf", "meta-llama/Llama-3-8b-hf"]
+
+@pytest.fixture(params=LLAMA_VERSIONS)
+def llama_version(request):
+    return request.param
 
 
 @pytest.fixture(scope="module")
-def llama_tokenizer():
-    return LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
+def llama_tokenizer(request):
+    return {k: AutoTokenizer.from_pretrained(k) for k in LLAMA_VERSIONS}
 
 
 @pytest.fixture
-def setup_tokenizer(llama_tokenizer):
+def setup_tokenizer(llama_tokenizer, llama_version):
     def _helper(tokenizer_mock):
         #Align with Llama 2 tokenizer
-        tokenizer_mock.from_pretrained.return_value = llama_tokenizer
+        tokenizer_mock.from_pretrained.return_value = llama_tokenizer[llama_version]
 
     return _helper
 
@@ -27,21 +32,21 @@ def pytest_addoption(parser):
         "--unskip-missing-tokenizer",
         action="store_true",
         default=False, help="disable skip missing tokenizer")
-    
+
 def pytest_configure(config):
     config.addinivalue_line("markers", "skip_missing_tokenizer: skip if tokenizer is unavailable")
 
-    
+
 def pytest_collection_modifyitems(config, items):
     if config.getoption("--unskip-missing-tokenizer"):
         return
-    
+
     try:
-        LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
+        AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
         tokenizer_available = True
     except OSError:
         tokenizer_available = False
-    
+
     skip_missing_tokenizer = pytest.mark.skip(reason=ACCESS_ERROR_MSG)
     for item in items:
         if "skip_missing_tokenizer" in item.keywords and not tokenizer_available:

+ 35 - 20
tests/datasets/test_custom_dataset.py

@@ -6,32 +6,50 @@ from unittest.mock import patch
 
 from transformers import LlamaTokenizer
 
-def check_padded_entry(batch):
+EXPECTED_RESULTS={
+    "meta-llama/Llama-2-7b-hf":{
+        "example_1": "[INST] Who made Berlin [/INST] dunno",
+        "example_2": "[INST] Quiero preparar una pizza de pepperoni, puedes darme los pasos para hacerla? [/INST] Claro!",
+    },
+    "meta-llama/Llama-3-8b-hf":{
+        "example_1": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nWho made Berlin<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\ndunno<|eot_id|><|end_of_text|>",
+        "example_2": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow to start learning guitar and become a master at it?",
+    },
+}
+
+def check_padded_entry(batch, tokenizer):
     seq_len = sum(batch["attention_mask"][0])
     assert seq_len < len(batch["attention_mask"][0])
 
+    if tokenizer.vocab_size >= 128000:
+        END_OF_TEXT_ID = 128009
+    else:
+        END_OF_TEXT_ID = tokenizer.eos_token_id
+
     assert batch["labels"][0][0] == -100
-    assert batch["labels"][0][seq_len-1] == 2
+    assert batch["labels"][0][seq_len-1] == END_OF_TEXT_ID
     assert batch["labels"][0][-1] == -100
-    assert batch["input_ids"][0][0] == 1
-    assert batch["input_ids"][0][-1] == 2
+    assert batch["input_ids"][0][0] == tokenizer.bos_token_id
+    assert batch["input_ids"][0][-1] == tokenizer.eos_token_id
 
 
 @pytest.mark.skip_missing_tokenizer
 @patch('llama_recipes.finetuning.train')
-@patch('llama_recipes.finetuning.LlamaTokenizer')
+@patch('llama_recipes.finetuning.AutoTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
-def test_custom_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer):
+def test_custom_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer, llama_version):
     from llama_recipes.finetuning import main
 
     setup_tokenizer(tokenizer)
 
+    skip_special_tokens = llama_version == "meta-llama/Llama-2-7b-hf"
+
     kwargs = {
         "dataset": "custom_dataset",
-        "model_name": "meta-llama/Llama-2-7b-hf",
-        "custom_dataset.file": "examples/custom_dataset.py",
+        "model_name": llama_version,
+        "custom_dataset.file": "recipes/finetuning/datasets/custom_dataset.py",
         "custom_dataset.train_split": "validation",
         "batch_size_training": 2,
         "val_batch_size": 4,
@@ -53,34 +71,31 @@ def test_custom_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker,
 
     it = iter(eval_dataloader)
     batch = next(it)
-    STRING = tokenizer.decode(batch["input_ids"][0], skip_special_tokens=True)
-    EXPECTED_STRING = "[INST] Who made Berlin [/INST] dunno"
-    assert STRING.startswith(EXPECTED_STRING)
+    STRING = tokenizer.decode(batch["input_ids"][0], skip_special_tokens=skip_special_tokens)
+    assert STRING.startswith(EXPECTED_RESULTS[llama_version]["example_1"])
 
     assert batch["input_ids"].size(0) == 4
     assert set(("labels", "input_ids", "attention_mask")) == set(batch.keys())
 
-    check_padded_entry(batch)
+    check_padded_entry(batch, tokenizer)
 
     it = iter(train_dataloader)
-    for _ in range(5):
-        next(it)
+    next(it)
 
     batch = next(it)
-    STRING = tokenizer.decode(batch["input_ids"][0], skip_special_tokens=True)
-    EXPECTED_STRING = "[INST] How do I initialize a Typescript project using npm and git? [/INST] # Initialize a new NPM project"
-    assert STRING.startswith(EXPECTED_STRING)
+    STRING = tokenizer.decode(batch["input_ids"][0], skip_special_tokens=skip_special_tokens)
+    assert STRING.startswith(EXPECTED_RESULTS[llama_version]["example_2"])
 
     assert batch["input_ids"].size(0) == 2
     assert set(("labels", "input_ids", "attention_mask")) == set(batch.keys())
 
-    check_padded_entry(batch)
+    check_padded_entry(batch, tokenizer)
 
 
 
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
-@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
+@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
 def test_unknown_dataset_error(step_lr, optimizer, tokenizer, get_model, train, mocker):
@@ -90,7 +105,7 @@ def test_unknown_dataset_error(step_lr, optimizer, tokenizer, get_model, train,
 
     kwargs = {
         "dataset": "custom_dataset",
-        "custom_dataset.file": "examples/custom_dataset.py:get_unknown_dataset",
+        "custom_dataset.file": "recipes/finetuning/datasets/custom_dataset.py:get_unknown_dataset",
         "batch_size_training": 1,
         "use_peft": False,
         }

+ 19 - 9
tests/datasets/test_grammar_datasets.py

@@ -4,23 +4,32 @@
 import pytest
 from unittest.mock import patch
 
-from transformers import LlamaTokenizer
 
+EXPECTED_RESULTS = {
+    "meta-llama/Llama-2-7b-hf":{
+        "label": 1152,
+        "pos": 31,
+    },
+    "meta-llama/Llama-3-8b-hf":{
+        "label": 40,
+        "pos": 26,
+    },
+}
 
 @pytest.mark.skip_missing_tokenizer
 @patch('llama_recipes.finetuning.train')
-@patch('llama_recipes.finetuning.LlamaTokenizer')
+@patch('llama_recipes.finetuning.AutoTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
-def test_grammar_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer):
+def test_grammar_dataset(step_lr, optimizer, get_model, tokenizer, train, setup_tokenizer, llama_version):
     from llama_recipes.finetuning import main
 
     setup_tokenizer(tokenizer)
 
     BATCH_SIZE = 8
     kwargs = {
-        "model_name": "meta-llama/Llama-2-7b-hf",
+        "model_name": llama_version,
         "batch_size_training": BATCH_SIZE,
         "val_batch_size": 1,
         "use_peft": False,
@@ -48,9 +57,10 @@ def test_grammar_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker
     assert "input_ids" in batch.keys()
     assert "attention_mask" in batch.keys()
 
-    assert batch["labels"][0][31] == -100
-    assert batch["labels"][0][32] == 1152
+    assert batch["labels"][0][EXPECTED_RESULTS[llama_version]["pos"]-1] == -100
+    assert batch["labels"][0][EXPECTED_RESULTS[llama_version]["pos"]] == EXPECTED_RESULTS[llama_version]["label"]
 
-    assert batch["input_ids"][0][0] == 1
-    assert batch["labels"][0][-1] == 2
-    assert batch["input_ids"][0][-1] == 2
+    token = args[3]
+    assert batch["input_ids"][0][0] == token.bos_token_id
+    assert batch["labels"][0][-1] == token.eos_token_id
+    assert batch["input_ids"][0][-1] == token.eos_token_id

+ 19 - 8
tests/datasets/test_samsum_datasets.py

@@ -5,21 +5,31 @@ import pytest
 from functools import partial
 from unittest.mock import patch
 
+EXPECTED_RESULTS = {
+    "meta-llama/Llama-2-7b-hf":{
+        "label": 8432,
+        "pos": 242,
+    },
+    "meta-llama/Llama-3-8b-hf":{
+        "label": 2250,
+        "pos": 211,
+    },
+}
 
 @pytest.mark.skip_missing_tokenizer
 @patch('llama_recipes.finetuning.train')
-@patch('llama_recipes.finetuning.LlamaTokenizer')
+@patch('llama_recipes.finetuning.AutoTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
-def test_samsum_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer):
+def test_samsum_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer, llama_version):
     from llama_recipes.finetuning import main
 
     setup_tokenizer(tokenizer)
 
     BATCH_SIZE = 8
     kwargs = {
-        "model_name": "meta-llama/Llama-2-7b-hf",
+        "model_name": llama_version,
         "batch_size_training": BATCH_SIZE,
         "val_batch_size": 1,
         "use_peft": False,
@@ -34,6 +44,7 @@ def test_samsum_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker,
     args, kwargs = train.call_args
     train_dataloader = args[1]
     eval_dataloader = args[2]
+    token = args[3]
 
     VAL_SAMPLES = 818
     TRAIN_SAMPLES = 14732
@@ -47,9 +58,9 @@ def test_samsum_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker,
     assert "input_ids" in batch.keys()
     assert "attention_mask" in batch.keys()
 
-    assert batch["labels"][0][268] == -100
-    assert batch["labels"][0][269] == 319
+    assert batch["labels"][0][EXPECTED_RESULTS[llama_version]["pos"]-1] == -100
+    assert batch["labels"][0][EXPECTED_RESULTS[llama_version]["pos"]] == EXPECTED_RESULTS[llama_version]["label"]
 
-    assert batch["input_ids"][0][0] == 1
-    assert batch["labels"][0][-1] == 2
-    assert batch["input_ids"][0][-1] == 2
+    assert batch["input_ids"][0][0] == token.bos_token_id
+    assert batch["labels"][0][-1] == token.eos_token_id
+    assert batch["input_ids"][0][-1] == token.eos_token_id

+ 21 - 11
tests/test_batching.py

@@ -4,20 +4,30 @@
 import pytest
 from unittest.mock import patch
 
+EXPECTED_SAMPLE_NUMBER ={
+    "meta-llama/Llama-2-7b-hf": {
+        "train": 96,
+        "eval": 42,
+    },
+    "meta-llama/Llama-3-8b-hf": {
+        "train": 79,
+        "eval": 34,
+    }
+}
 
 @pytest.mark.skip_missing_tokenizer
 @patch('llama_recipes.finetuning.train')
-@patch('llama_recipes.finetuning.LlamaTokenizer')
+@patch('llama_recipes.finetuning.AutoTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
-def test_packing(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer):
+def test_packing(step_lr, optimizer, get_model, tokenizer, train, setup_tokenizer, llama_version):
     from llama_recipes.finetuning import main
 
     setup_tokenizer(tokenizer)
 
     kwargs = {
-        "model_name": "meta-llama/Llama-2-7b-hf",
+        "model_name": llama_version,
         "batch_size_training": 8,
         "val_batch_size": 1,
         "use_peft": False,
@@ -33,8 +43,8 @@ def test_packing(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_
     train_dataloader = args[1]
     eval_dataloader = args[2]
 
-    assert len(train_dataloader) == 96
-    assert len(eval_dataloader) == 42
+    assert len(train_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["train"]
+    assert len(eval_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["eval"]
 
     batch = next(iter(train_dataloader))
 
@@ -49,7 +59,7 @@ def test_packing(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_
 
 @pytest.mark.skip_missing_tokenizer
 @patch('llama_recipes.finetuning.train')
-@patch('llama_recipes.finetuning.LlamaTokenizer')
+@patch('llama_recipes.finetuning.AutoTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
@@ -57,13 +67,13 @@ def test_packing(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_
 @patch('llama_recipes.finetuning.FSDP')
 @patch('llama_recipes.finetuning.torch.distributed.is_initialized')
 @patch('llama_recipes.utils.config_utils.dist')
-def test_distributed_packing(dist, is_initialized, fsdp, setup, step_lr, optimizer, get_model, tokenizer, train, setup_tokenizer):
+def test_distributed_packing(dist, is_initialized, fsdp, setup, step_lr, optimizer, get_model, tokenizer, train, setup_tokenizer, llama_version):
     import os
     from llama_recipes.finetuning import main
 
     setup_tokenizer(tokenizer)
 
-    rank = 0
+    rank = 1
     os.environ['LOCAL_RANK'] = f'{rank}'
     os.environ['RANK'] = f'{rank}'
     os.environ['WORLD_SIZE'] = '2'
@@ -71,7 +81,7 @@ def test_distributed_packing(dist, is_initialized, fsdp, setup, step_lr, optimiz
     os.environ['MASTER_PORT'] = '12345'
 
     kwargs = {
-        "model_name": "meta-llama/Llama-2-7b-hf",
+        "model_name": llama_version,
         "batch_size_training": 8,
         "val_batch_size": 1,
         "use_peft": False,
@@ -92,5 +102,5 @@ def test_distributed_packing(dist, is_initialized, fsdp, setup, step_lr, optimiz
     train_dataloader = args[1]
     eval_dataloader = args[2]
 
-    assert len(train_dataloader) == 96 //2
-    assert len(eval_dataloader) == 42 //2
+    assert len(train_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["train"] //2
+    assert len(eval_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["eval"] //2

+ 155 - 0
tests/test_chat_completion.py

@@ -0,0 +1,155 @@
+import sys
+from pathlib import Path
+from typing import List, Literal, TypedDict
+from unittest.mock import patch
+
+import pytest
+import torch
+from llama_recipes.inference.chat_utils import read_dialogs_from_file
+
+ROOT_DIR = Path(__file__).parents[1]
+CHAT_COMPLETION_DIR = ROOT_DIR / "recipes/inference/local_inference/chat_completion/"
+
+sys.path = [CHAT_COMPLETION_DIR.as_posix()] + sys.path
+
+Role = Literal["user", "assistant"]
+
+
+class Message(TypedDict):
+    role: Role
+    content: str
+
+
+Dialog = List[Message]
+
+B_INST, E_INST = "[INST]", "[/INST]"
+B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
+
+
+def _encode_header(message, tokenizer):
+    tokens = []
+    tokens.extend(tokenizer.encode("<|start_header_id|>"))
+    tokens.extend(tokenizer.encode(message["role"]))
+    tokens.extend(tokenizer.encode("<|end_header_id|>"))
+    tokens.extend(tokenizer.encode("\n\n"))
+    return tokens
+
+
+def _encode_message(message, tokenizer):
+    tokens = _encode_header(message, tokenizer)
+    tokens.extend(tokenizer.encode(message["content"].strip()))
+    tokens.extend(tokenizer.encode("<|eot_id|>"))
+    return tokens
+
+
+def _format_dialog(dialog, tokenizer):
+    tokens = []
+    tokens.extend(tokenizer.encode("<|begin_of_text|>"))
+    for msg in dialog:
+        tokens.extend(_encode_message(msg, tokenizer))
+    tokens.extend(_encode_header({"role": "assistant", "content": ""}, tokenizer))
+    return tokens
+
+
+def _format_tokens_llama3(dialogs, tokenizer):
+    return [_format_dialog(dialog, tokenizer) for dialog in dialogs]
+
+
+def _format_tokens_llama2(dialogs, tokenizer):
+    prompt_tokens = []
+    for dialog in dialogs:
+        if dialog[0]["role"] == "system":
+            dialog = [
+                {
+                    "role": dialog[1]["role"],
+                    "content": B_SYS
+                    + dialog[0]["content"]
+                    + E_SYS
+                    + dialog[1]["content"],
+                }
+            ] + dialog[2:]
+        assert all([msg["role"] == "user" for msg in dialog[::2]]) and all(
+            [msg["role"] == "assistant" for msg in dialog[1::2]]
+        ), (
+            "model only supports 'system','user' and 'assistant' roles, "
+            "starting with user and alternating (u/a/u/a/u...)"
+        )
+        """
+        Please verify that your tokenizer support adding "[INST]", "[/INST]" to your inputs.
+        Here, we are adding it manually.
+        """
+        dialog_tokens: List[int] = sum(
+            [
+                tokenizer.encode(
+                    f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ",
+                )
+                + [tokenizer.eos_token_id]
+                for prompt, answer in zip(dialog[::2], dialog[1::2])
+            ],
+            [],
+        )
+        assert (
+            dialog[-1]["role"] == "user"
+        ), f"Last message must be from user, got {dialog[-1]['role']}"
+        dialog_tokens += tokenizer.encode(
+            f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}",
+        )
+        prompt_tokens.append(dialog_tokens)
+    return prompt_tokens
+
+
+@pytest.mark.skip_missing_tokenizer
+@patch("chat_completion.AutoTokenizer")
+@patch("chat_completion.load_model")
+def test_chat_completion(
+    load_model, tokenizer, setup_tokenizer, llama_tokenizer, llama_version
+):
+    from chat_completion import main
+
+    setup_tokenizer(tokenizer)
+
+    kwargs = {
+        "prompt_file": (CHAT_COMPLETION_DIR / "chats.json").as_posix(),
+    }
+
+    main(llama_version, **kwargs)
+
+    dialogs = read_dialogs_from_file(kwargs["prompt_file"])
+    format_tokens = (
+        _format_tokens_llama2
+        if llama_version == "meta-llama/Llama-2-7b-hf"
+        else _format_tokens_llama3
+    )
+
+    REF_RESULT = format_tokens(dialogs, llama_tokenizer[llama_version])
+
+    assert all(
+        (
+            load_model.return_value.generate.mock_calls[0 * 4][2]["input_ids"].cpu()
+            == torch.tensor(REF_RESULT[0]).long()
+        ).tolist()
+    )
+    assert all(
+        (
+            load_model.return_value.generate.mock_calls[1 * 4][2]["input_ids"].cpu()
+            == torch.tensor(REF_RESULT[1]).long()
+        ).tolist()
+    )
+    assert all(
+        (
+            load_model.return_value.generate.mock_calls[2 * 4][2]["input_ids"].cpu()
+            == torch.tensor(REF_RESULT[2]).long()
+        ).tolist()
+    )
+    assert all(
+        (
+            load_model.return_value.generate.mock_calls[3 * 4][2]["input_ids"].cpu()
+            == torch.tensor(REF_RESULT[3]).long()
+        ).tolist()
+    )
+    assert all(
+        (
+            load_model.return_value.generate.mock_calls[4 * 4][2]["input_ids"].cpu()
+            == torch.tensor(REF_RESULT[4]).long()
+        ).tolist()
+    )

+ 26 - 19
tests/test_finetuning.py

@@ -21,17 +21,19 @@ def get_fake_dataset():
         "labels":[1],
         }]
 
-
+@patch('llama_recipes.finetuning.torch.cuda.is_available')
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
-@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
+@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.get_preprocessed_dataset')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
-def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
+@pytest.mark.parametrize("cuda_is_available", [True, False])
+def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train, cuda, cuda_is_available):
     kwargs = {"run_validation": False}
 
     get_dataset.return_value = get_fake_dataset()
+    cuda.return_value = cuda_is_available
 
     main(**kwargs)
 
@@ -44,23 +46,26 @@ def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, ge
     assert isinstance(train_dataloader, DataLoader)
     assert eval_dataloader is None
 
-    if torch.cuda.is_available():
+    if cuda_is_available:
         assert get_model.return_value.to.call_count == 1
         assert get_model.return_value.to.call_args.args[0] == "cuda"
     else:
         assert get_model.return_value.to.call_count == 0
 
 
+@patch('llama_recipes.finetuning.torch.cuda.is_available')
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
-@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
+@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.get_preprocessed_dataset')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
-def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
+@pytest.mark.parametrize("cuda_is_available", [True, False])
+def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train, cuda, cuda_is_available):
     kwargs = {"run_validation": True}
 
     get_dataset.return_value = get_fake_dataset()
+    cuda.return_value = cuda_is_available
 
     main(**kwargs)
 
@@ -72,40 +77,42 @@ def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer,
     assert isinstance(train_dataloader, DataLoader)
     assert isinstance(eval_dataloader, DataLoader)
 
-    if torch.cuda.is_available():
+    if cuda_is_available:
         assert get_model.return_value.to.call_count == 1
         assert get_model.return_value.to.call_args.args[0] == "cuda"
     else:
         assert get_model.return_value.to.call_count == 0
 
-
+@patch('llama_recipes.finetuning.torch.cuda.is_available')
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
-@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
+@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.get_preprocessed_dataset')
 @patch('llama_recipes.finetuning.generate_peft_config')
 @patch('llama_recipes.finetuning.get_peft_model')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
-def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, get_dataset, tokenizer, get_model, train):
+@pytest.mark.parametrize("cuda_is_available", [True, False])
+def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, get_dataset, tokenizer, get_model, train, cuda, cuda_is_available):
     kwargs = {"use_peft": True}
 
     get_dataset.return_value = get_fake_dataset()
+    cuda.return_value = cuda_is_available
 
     main(**kwargs)
 
-    if torch.cuda.is_available():
-        assert get_model.return_value.to.call_count == 1
-        assert get_model.return_value.to.call_args.args[0] == "cuda"
+    if cuda_is_available:
+        assert get_peft_model.return_value.to.call_count == 1
+        assert get_peft_model.return_value.to.call_args.args[0] == "cuda"
     else:
-        assert get_model.return_value.to.call_count == 0
-    
+        assert get_peft_model.return_value.to.call_count == 0
+
     assert get_peft_model.return_value.print_trainable_parameters.call_count == 1
 
 
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
-@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
+@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.get_preprocessed_dataset')
 @patch('llama_recipes.finetuning.get_peft_model')
 @patch('llama_recipes.finetuning.StepLR')
@@ -113,11 +120,11 @@ def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer
     kwargs = {"weight_decay": 0.01}
 
     get_dataset.return_value = get_fake_dataset()
-    
+
     model = mocker.MagicMock(name="Model")
     model.parameters.return_value = [torch.ones(1,1)]
 
-    get_model.return_value = model 
+    get_model.return_value = model
 
     main(**kwargs)
 
@@ -134,7 +141,7 @@ def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer
 
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
-@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
+@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.get_preprocessed_dataset')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')