浏览代码

Merging Readme notes

Beto 7 月之前
父节点
当前提交
57c6001fd0
共有 34 个文件被更改,包括 991 次插入407 次删除
  1. 0 125
      LICENSE
  2. 16 15
      README.md
  3. 0 49
      USE_POLICY.md
  4. 2 1
      recipes/README.md
  5. 21 5
      recipes/finetuning/datasets/custom_dataset.py
  6. 6 6
      recipes/inference/local_inference/chat_completion/chat_completion.py
  7. 4 5
      recipes/inference/local_inference/inference.py
  8. 156 0
      recipes/multilingual/README.md
  9. 52 0
      recipes/multilingual/extend_tokenizer.py
  10. 二进制
      recipes/multilingual/imgs/phase1-eval-loss.png
  11. 二进制
      recipes/multilingual/imgs/phase1-train-loss.png
  12. 二进制
      recipes/multilingual/imgs/phase2-eval-loss.png
  13. 二进制
      recipes/multilingual/imgs/phase2-train-loss.png
  14. 23 0
      recipes/multilingual/prepare_data.py
  15. 22 0
      recipes/multilingual/train_tokenizer.py
  16. 196 0
      recipes/responsible_ai/CodeShieldUsageDemo.ipynb
  17. 14 8
      recipes/responsible_ai/llama_guard/README.md
  18. 28 21
      recipes/responsible_ai/llama_guard/inference.py
  19. 1 1
      requirements.txt
  20. 20 0
      scripts/spellcheck_conf/wordlist.txt
  21. 1 0
      src/llama_recipes/configs/training.py
  22. 3 4
      src/llama_recipes/data/sampler.py
  23. 13 13
      src/llama_recipes/finetuning.py
  24. 0 56
      src/llama_recipes/inference/chat_utils.py
  25. 117 18
      src/llama_recipes/inference/prompt_format_utils.py
  26. 5 3
      src/llama_recipes/inference/safety_utils.py
  27. 1 0
      src/llama_recipes/utils/config_utils.py
  28. 15 10
      tests/conftest.py
  29. 35 20
      tests/datasets/test_custom_dataset.py
  30. 19 9
      tests/datasets/test_grammar_datasets.py
  31. 19 8
      tests/datasets/test_samsum_datasets.py
  32. 21 11
      tests/test_batching.py
  33. 155 0
      tests/test_chat_completion.py
  34. 26 19
      tests/test_finetuning.py

+ 0 - 125
LICENSE

@@ -1,125 +0,0 @@
-LLAMA 2 COMMUNITY LICENSE AGREEMENT
-Llama 2 Version Release Date: July 18, 2023
-
-"Agreement" means the terms and conditions for use, reproduction, distribution and
-modification of the Llama Materials set forth herein.
-
-"Documentation" means the specifications, manuals and documentation
-accompanying Llama 2 distributed by Meta at ai.meta.com/resources/models-and-
-libraries/llama-downloads/.
-
-"Licensee" or "you" means you, or your employer or any other person or entity (if
-you are entering into this Agreement on such person or entity's behalf), of the age
-required under applicable laws, rules or regulations to provide legal consent and that
-has legal authority to bind your employer or such other person or entity if you are
-entering in this Agreement on their behalf.
-
-"Llama 2" means the foundational large language models and software and
-algorithms, including machine-learning model code, trained model weights,
-inference-enabling code, training-enabling code, fine-tuning enabling code and other
-elements of the foregoing distributed by Meta at ai.meta.com/resources/models-and-
-libraries/llama-downloads/.
-
-"Llama Materials" means, collectively, Meta's proprietary Llama 2 and
-Documentation (and any portion thereof) made available under this Agreement.
-
-"Meta" or "we" means Meta Platforms Ireland Limited (if you are located in or, if you
-are an entity, your principal place of business is in the EEA or Switzerland) and Meta
-Platforms, Inc. (if you are located outside of the EEA or Switzerland).
-
-By clicking "I Accept" below or by using or distributing any portion or element of the
-Llama Materials, you agree to be bound by this Agreement.
-
-1. License Rights and Redistribution.
-
-      a. Grant of Rights. You are granted a non-exclusive, worldwide, non-
-transferable and royalty-free limited license under Meta's intellectual property or
-other rights owned by Meta embodied in the Llama Materials to use, reproduce,
-distribute, copy, create derivative works of, and make modifications to the Llama
-Materials.
-
-      b. Redistribution and Use.
-
-            i. If you distribute or make the Llama Materials, or any derivative works
-thereof, available to a third party, you shall provide a copy of this Agreement to such
-third party.
-            ii.  If you receive Llama Materials, or any derivative works thereof, from
-a Licensee as part of an integrated end user product, then Section 2 of this
-Agreement will not apply to you.
-
-            iii. You must retain in all copies of the Llama Materials that you
-distribute the following attribution notice within a "Notice" text file distributed as a
-part of such copies: "Llama 2 is licensed under the LLAMA 2 Community License,
-Copyright (c) Meta Platforms, Inc. All Rights Reserved."
-
-            iv. Your use of the Llama Materials must comply with applicable laws
-and regulations (including trade compliance laws and regulations) and adhere to the
-Acceptable Use Policy for the Llama Materials (available at
-https://ai.meta.com/llama/use-policy), which is hereby incorporated by reference into
-this Agreement.
-
-            v. You will not use the Llama Materials or any output or results of the
-Llama Materials to improve any other large language model (excluding Llama 2 or
-derivative works thereof).
-
-2. Additional Commercial Terms. If, on the Llama 2 version release date, the
-monthly active users of the products or services made available by or for Licensee,
-or Licensee's affiliates, is greater than 700 million monthly active users in the
-preceding calendar month, you must request a license from Meta, which Meta may
-grant to you in its sole discretion, and you are not authorized to exercise any of the
-rights under this Agreement unless or until Meta otherwise expressly grants you
-such rights.
-
-3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE
-LLAMA MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE
-PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
-EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY
-WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR
-FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE
-FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING
-THE LLAMA MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR
-USE OF THE LLAMA MATERIALS AND ANY OUTPUT AND RESULTS.
-
-4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE
-LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT,
-NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS
-AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL,
-CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN
-IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF
-ANY OF THE FOREGOING.
-
-5. Intellectual Property.
-
-      a. No trademark licenses are granted under this Agreement, and in
-connection with the Llama Materials, neither Meta nor Licensee may use any name
-or mark owned by or associated with the other or any of its affiliates, except as
-required for reasonable and customary use in describing and redistributing the
-Llama Materials.
-
-      b. Subject to Meta's ownership of Llama Materials and derivatives made by or
-for Meta, with respect to any derivative works and modifications of the Llama
-Materials that are made by you, as between you and Meta, you are and will be the
-owner of such derivative works and modifications.
-
-      c. If you institute litigation or other proceedings against Meta or any entity
-(including a cross-claim or counterclaim in a lawsuit) alleging that the Llama
-Materials or Llama 2 outputs or results, or any portion of any of the foregoing,
-constitutes infringement of intellectual property or other rights owned or licensable
-by you, then any licenses granted to you under this Agreement shall terminate as of
-the date such litigation or claim is filed or instituted. You will indemnify and hold
-harmless Meta from and against any claim by any third party arising out of or related
-to your use or distribution of the Llama Materials.
-
-6. Term and Termination. The term of this Agreement will commence upon your
-acceptance of this Agreement or access to the Llama Materials and will continue in
-full force and effect until terminated in accordance with the terms and conditions
-herein. Meta may terminate this Agreement if you are in breach of any term or
-condition of this Agreement. Upon termination of this Agreement, you shall delete
-and cease use of the Llama Materials. Sections 3, 4 and 7 shall survive the
-termination of this Agreement.
-
-7. Governing Law and Jurisdiction. This Agreement will be governed and
-construed under the laws of the State of California without regard to choice of law
-principles, and the UN Convention on Contracts for the International Sale of Goods
-does not apply to this Agreement. The courts of California shall have exclusive
-jurisdiction of any dispute arising out of this Agreement.

文件差异内容过多而无法显示
+ 16 - 15
README.md


+ 0 - 49
USE_POLICY.md

@@ -1,49 +0,0 @@
-# Llama 2 Acceptable Use Policy
-
-Meta is committed to promoting safe and fair use of its tools and features, including Llama 2. If you access or use Llama 2, you agree to this Acceptable Use Policy (“Policy”). The most recent copy of this policy can be found at [ai.meta.com/llama/use-policy](http://ai.meta.com/llama/use-policy).
-
-## Prohibited Uses
-We want everyone to use Llama 2 safely and responsibly. You agree you will not use, or allow others to use, Llama 2 to:
-
-1. Violate the law or others’ rights, including to:
-    1. Engage in, promote, generate, contribute to, encourage, plan, incite, or further illegal or unlawful activity or content, such as:
-        1. Violence or terrorism
-        2. Exploitation or harm to children, including the solicitation, creation, acquisition, or dissemination of child exploitative content or failure to report Child Sexual Abuse Material
-        3. Human trafficking, exploitation, and sexual violence
-        4. The illegal distribution of information or materials to minors, including obscene materials, or failure to employ legally required age-gating in connection with such information or materials.
-        5. Sexual solicitation
-        6. Any other criminal activity
-    2. Engage in, promote, incite, or facilitate the harassment, abuse, threatening, or bullying of individuals or groups of individuals
-    3. Engage in, promote, incite, or facilitate discrimination or other unlawful or harmful conduct in the provision of employment, employment benefits, credit, housing, other economic benefits, or other essential goods and services
-    4. Engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or related professional practices
-    5. Collect, process, disclose, generate, or infer health, demographic, or other sensitive personal or private information about individuals without rights and consents required by applicable laws
-    6. Engage in or facilitate any action or generate any content that infringes, misappropriates, or otherwise violates any third-party rights, including the outputs or results of any products or services using the Llama 2 Materials
-    7. Create, generate, or facilitate the creation of malicious code, malware, computer viruses or do anything else that could disable, overburden, interfere with or impair the proper working, integrity, operation or appearance of a website or computer system
-
-
-
-2. Engage in, promote, incite, facilitate, or assist in the planning or development of activities that present a risk of death or bodily harm to individuals, including use of Llama 2 related to the following:
-    1. Military, warfare, nuclear industries or applications, espionage, use for materials or activities that are subject to the International Traffic Arms Regulations (ITAR) maintained by the United States Department of State
-    2. Guns and illegal weapons (including weapon development)
-    3. Illegal drugs and regulated/controlled substances
-    4. Operation of critical infrastructure, transportation technologies, or heavy machinery
-    5. Self-harm or harm to others, including suicide, cutting, and eating disorders
-    6. Any content intended to incite or promote violence, abuse, or any infliction of bodily harm to an individual
-
-
-
-3. Intentionally deceive or mislead others, including use of Llama 2 related to the following:
-    1. Generating, promoting, or furthering fraud or the creation or promotion of disinformation
-    2. Generating, promoting, or furthering defamatory content, including the creation of defamatory statements, images, or other content
-    3. Generating, promoting, or further distributing spam
-    4. Impersonating another individual without consent, authorization, or legal right
-    5. Representing that the use of Llama 2 or outputs are human-generated
-    6. Generating or facilitating false online engagement, including fake reviews and other means of fake online engagement
-4. Fail to appropriately disclose to end users any known dangers of your AI system
-
-Please report any violation of this Policy, software “bug,” or other problems that could lead to a violation of this Policy through one of the following means:
-
-* Reporting issues with the model: [github.com/facebookresearch/llama](http://github.com/facebookresearch/llama)
-* Reporting risky content generated by the model: [developers.facebook.com/llama_output_feedback](http://developers.facebook.com/llama_output_feedback)
-* Reporting bugs and security concerns: [facebook.com/whitehat/info](http://facebook.com/whitehat/info)
-* Reporting violations of the Acceptable Use Policy or unlicensed uses of Llama: [LlamaUseReport@meta.com](mailto:LlamaUseReport@meta.com)

+ 2 - 1
recipes/README.md

@@ -2,7 +2,8 @@ This folder contains examples organized by topic:
 
 | Subfolder | Description |
 |---|---|
-[quickstart](./quickstart) | The "Hello World" of using Llama2, start here if you are new to using Llama2.
+[quickstart](./quickstart)|The "Hello World" of using Llama2, start here if you are new to using Llama2
+[multilingual](./multilingual)|Scripts to add a new language to Llama2
 [finetuning](./finetuning)|Scripts to finetune Llama2 on single-GPU and multi-GPU setups
 [inference](./inference)|Scripts to deploy Llama2 for inference locally and using model servers
 [use_cases](./use_cases)|Scripts showing common applications of Llama2

+ 21 - 5
recipes/finetuning/datasets/custom_dataset.py

@@ -11,11 +11,27 @@ import itertools
 B_INST, E_INST = "[INST]", "[/INST]"
 
 def tokenize_dialog(dialog, tokenizer):
-    prompt_tokens = [tokenizer.encode(f"{tokenizer.bos_token}{B_INST} {(prompt['content']).strip()} {E_INST}", add_special_tokens=False) for prompt in dialog[::2]]
-    answer_tokens = [tokenizer.encode(f"{answer['content'].strip()} {tokenizer.eos_token}", add_special_tokens=False) for answer in dialog[1::2]]
-    dialog_tokens = list(itertools.chain.from_iterable(zip(prompt_tokens, answer_tokens)))
-    #Add labels, convert prompt token to -100 in order to ignore in loss function
-    labels_tokens = [len(c)*[-100,] if i % 2 == 0 else c for i,c in enumerate(dialog_tokens)]
+    if tokenizer.vocab_size >= 128000:
+        dialog_tokens = tokenizer.apply_chat_template(dialog)
+        dialog_tokens = dialog_tokens[:-4] # Remove generation prompt <|start_header_id|>assistant<|end_header_id|>\n\n
+        eot_indices = [i for i,n in enumerate(dialog_tokens) if n == 128009]
+        labels = copy.copy(dialog_tokens)
+        last_idx = 0
+        for n, idx in enumerate(eot_indices):
+            if n % 2 == 1:
+                last_idx = idx
+            else:
+                labels[last_idx:idx+1] = [-100] * (idx-last_idx+1)
+
+        dialog_tokens = [dialog_tokens]
+        labels_tokens = [labels]
+    else:
+        prompt_tokens = [tokenizer.encode(f"{tokenizer.bos_token}{B_INST} {(prompt['content']).strip()} {E_INST}", add_special_tokens=False) for prompt in dialog[::2]]
+        answer_tokens = [tokenizer.encode(f"{answer['content'].strip()} {tokenizer.eos_token}", add_special_tokens=False) for answer in dialog[1::2]]
+        dialog_tokens = list(itertools.chain.from_iterable(zip(prompt_tokens, answer_tokens)))
+
+        #Add labels, convert prompt token to -100 in order to ignore in loss function
+        labels_tokens = [len(c)*[-100,] if i % 2 == 0 else c for i,c in enumerate(dialog_tokens)]
 
     combined_tokens = {
         "input_ids": list(itertools.chain(*(t for t in dialog_tokens))),

+ 6 - 6
recipes/inference/local_inference/chat_completion/chat_completion.py

@@ -8,9 +8,9 @@ import os
 import sys
 
 import torch
-from transformers import LlamaTokenizer
+from transformers import AutoTokenizer
 
-from llama_recipes.inference.chat_utils import read_dialogs_from_file, format_tokens
+from llama_recipes.inference.chat_utils import read_dialogs_from_file
 from llama_recipes.inference.model_utils import load_model, load_peft_model
 from llama_recipes.inference.safety_utils import get_safety_checker
 from accelerate.utils import is_xpu_available
@@ -65,15 +65,15 @@ def main(
     if peft_model:
         model = load_peft_model(model, peft_model)
 
-    tokenizer = LlamaTokenizer.from_pretrained(model_name)
+    tokenizer = AutoTokenizer.from_pretrained(model_name)
     tokenizer.add_special_tokens(
         {
-         
+
             "pad_token": "<PAD>",
         }
     )
-    
-    chats = format_tokens(dialogs, tokenizer)
+
+    chats = tokenizer.apply_chat_template(dialogs)
 
     with torch.no_grad():
         for idx, chat in enumerate(chats):

+ 4 - 5
recipes/inference/local_inference/inference.py

@@ -10,7 +10,7 @@ import time
 import gradio as gr
 
 import torch
-from transformers import LlamaTokenizer
+from transformers import AutoTokenizer
 
 from llama_recipes.inference.safety_utils import get_safety_checker, AgentType
 from llama_recipes.inference.model_utils import load_model, load_peft_model
@@ -69,17 +69,16 @@ def main(
     else:
         torch.cuda.manual_seed(seed)
     torch.manual_seed(seed)
-    
+
     model = load_model(model_name, quantization, use_fast_kernels)
     if peft_model:
         model = load_peft_model(model, peft_model)
 
     model.eval()
-    
 
-    tokenizer = LlamaTokenizer.from_pretrained(model_name)
+    tokenizer = AutoTokenizer.from_pretrained(model_name)
     tokenizer.pad_token = tokenizer.eos_token
-    
+
     batch = tokenizer(user_prompt, padding='max_length', truncation=True, max_length=max_padding_length, return_tensors="pt")
     if is_xpu_available():
         batch = {k: v.to("xpu") for k, v in batch.items()}

+ 156 - 0
recipes/multilingual/README.md

@@ -0,0 +1,156 @@
+# Extending Llama to a new language
+Authored by : Sarvam team
+In this recipe, we will see how to add a new language to the Llama family of models. The steps are quite general and can be easily adapted to other models as well. Using this recipe, you should be able to replicate the findings of [OpenHathi](https://huggingface.co/sarvamai/OpenHathi-7B-Hi-v0.1-Base).
+Please read more about OpenHathi [here](https://www.sarvam.ai/blog/announcing-openhathi-series)
+## Data
+The original OpenHathi model uses a combination of [Sangraha](https://huggingface.co/datasets/ai4bharat/sangraha) and Wikipedia as its primary data sources. If the reader is interested in using these sources, they would also have to preprocess the data: clean, filter, and deduplicate. See [Setu](https://github.com/AI4Bharat/setu) for an easy way to do this at scale.
+
+In this tutorial, we will use the [Varta](https://huggingface.co/datasets/rahular/varta) dataset which contains 40M+ news articles taken from [DailyHunt](https://m.dailyhunt.in/). Since this data is already high-quality, we can skip the pre-processing step mentioned above. We will use the Hindi subset here, but you can add any other language present in the dataset by only passing the right language code (advanced users can also tweak the code to add multiple languages at once). 
+
+## Tokenizer
+Our first step towards augmenting a new language to an LLM is creating a better tokenizer. We define 'better' in terms of fertility score or the number of in-language tokens present in the tokenizer. Note that we should add new tokens without disturbing the original vocabulary, and therefore creating a better tokenizer usually involves 2 steps: (i) building a new, in-language only tokenizer, and (ii) merging this new tokenizer with the original. 
+
+### Building the in-language tokenizer
+For this, we will first download and prepare the data for training the tokenizer:
+
+```
+python prepare_data.py --split=validation --lang=hi --docs_to_sample=10000 --save_path=./data
+```
+
+Here we sample 10,000 Hindi documents from the validation split (we should ideally sample from the training split, but this is much faster) and save it as a text file inside `./data`. Next, we use this text to train a Hindi-only [sentencepiece](https://github.com/google/sentencepiece) tokenizer with a vocabulary size of 16,000.
+
+```
+python train_tokenizer.py --data_file=./data/hi.txt --save_path=./hi_tokenizer --vocab_size=16000
+```
+
+This creates a new sentencepiece Hindi tokenizer and saves it in `./hi_tokenizer`.
+
+### Merging the tokenizers
+This process can again be divided into 2 steps:
+- add new tokens to the original Llama2 tokenizer without disturbing its original vocabulary in any way
+- expand the input and output embedding matrices of Llama2 to be equal to the new vocabulary size
+
+We can do the first step by (i) downloading Llama2's `tokenizer.model` file, (ii) loading our Hindi `tokenizer.model` file, (iii) appending the Hindi tokens to Llama2 tokenizer's vocabulary if they are not already present, and (iv) save the extended tokenizer for future use. All this can be done by running
+
+```
+python extend_tokenizer.py --new_tokenizer_path=./hi_tokenizer --extended_tokenizer_save_path=./extended_tokenizer
+```
+
+Now, you have a new Llama2 tokenizer which works the same way on English text but can efficiently tokenize Hindi words as well. You can also test to see if it works as intended:
+
+```
+>>> from transformers import LlamaTokenizer
+>>> llama_tokenizer = LlamaTokenizer.from_pretrained('meta-llama/Llama-2-7b-chat-hf')
+>>> our_tokenizer = LlamaTokenizer.from_pretrained('./extended_tokenizer')
+>>> for i in range(len(llama_tokenizer)):
+...     assert llama_tokenizer.convert_ids_to_tokens(i) == our_tokenizer.convert_ids_to_tokens(i), f"Token mismatch at index {i}."
+...
+>>> text = "मैं एक अच्छा हाथी हूँ"
+>>> llama_tokenizer.tokenize(text)
+['▁', 'म', 'ै', 'ं', '▁', '<0xE0>', '<0xA4>', '<0x8F>', 'क', '▁', 'अ', 'च', '्', '<0xE0>', '<0xA4>', '<0x9B>', 'ा', '▁', 'ह', 'ा', 'थ', 'ी', '▁', 'ह', 'ू', '<0xE0>', '<0xA4>', '<0x81>']
+>>> our_tokenizer.tokenize(text)
+['▁मैं', '▁एक', '▁अच', '्', 'छा', '▁हाथी', '▁हूँ']
+```
+
+## Continual pre-training
+OpenHathi uses a two-stage pre-training process:
+- Phase 1: learn to translate paragraphs of text (use translated text as context and generate the original text, ~15B tokens)
+- Phase 2: bilingual next token prediction (train on text where the language changes after every sentence, ~15B tokens)
+
+Note: OpenHathi's final data mixture also contains monolingual data and romanized transliterations.
+
+We can easily create data for both phases using any translation model. OpenHathi uses [IndicTrans2](https://github.com/AI4Bharat/IndicTrans2). We provide sample code for both phases below.
+
+### Phase 1
+With the assumption that we don't have source-native data, let us first get some English data to translate. 
+
+```
+from datasets import load_dataset
+ds = load_dataset("rahular/varta", split="train", streaming=True)
+english_paragraphs = []
+for d in ds:
+    if d["langCode"] != "en": continue
+    english_paragraphs.append(" ".join(d["text"].split("\n")))
+```
+
+Now, our goal is to create data in the format `{translated_paragraph}\n\n{english_paragraph}`. We can use the `translate_paragraph` function ([link](https://github.com/AI4Bharat/IndicTrans2/blob/main/huggingface_interface/example.py#L150])) from the IndicTrans2 codebase to do this easily.
+
+```
+quantization = ""
+en_indic_ckpt_dir = "ai4bharat/indictrans2-en-indic-1B"
+en_indic_tokenizer, en_indic_model = initialize_model_and_tokenizer(en_indic_ckpt_dir, "en-indic", quantization)
+ip = IndicProcessor(inference=True)
+
+phase1_data = []
+for para in english_paragraphs:
+    trans_para = translate_paragraph(para, "eng_Latn", "hin_Deva", en_indic_model, en_indic_tokenizer, ip)
+    phase1_data.append({"text": f"{trans_para}\n\n{para}"})
+
+# if you want to save it for future, you can do so easily with HF datasets
+from datasets import Dataset
+phase1_ds = Dataset.from_list(phase1_data)
+phase1_ds.save_to_disk("data/phase1")
+```
+
+### Phase 2
+This is almost the same as phase 1, except that we have to replace the original sentences in an alternating manner to get the data in the required format. We can use the `split_sentences` ([link](https://github.com/AI4Bharat/IndicTrans2/blob/main/huggingface_interface/example.py#L60])) and `batch_translate` ([link](https://github.com/AI4Bharat/IndicTrans2/blob/main/huggingface_interface/example.py#L109)) functions to do this.
+
+```
+quantization = ""
+en_indic_ckpt_dir = "ai4bharat/indictrans2-en-indic-1B"
+en_indic_tokenizer, en_indic_model = initialize_model_and_tokenizer(en_indic_ckpt_dir, "en-indic", quantization)
+ip = IndicProcessor(inference=True)
+
+phase2_data = []
+for para in english_paragraphs:
+    en_sents = split_sentences(para, "eng_Latn")
+    trans_sents = batch_translate(input_sentences, "eng_Latn", "hin_Deva, en_indic_model, en_indic_tokenizer, ip)
+    final_para = []
+    for idx, (en_sent, trans_sent) in enumerate(zip(en_sents, trans_sents)):
+        sent_to_append = en_sent if idx % 2 == 0 else trans_sent
+        final_para.append(sent_to_append)
+    phase2_data.append({"text": " ".join(final_para)})
+
+# if you want to save it for future, you can do so easily with HF datasets
+from datasets import Dataset
+phase2_ds = Dataset.from_list(phase2_data)
+phase2_ds.save_to_disk("data/phase2")
+```
+
+### Train
+Finally, we can start finetuning Llama2 on these datasets by following the [finetuning recipes](https://github.com/meta-llama/llama-recipes/tree/main/recipes/finetuning). Remember to pass the new tokenizer path as an argument to the script: `--tokenizer_name=./extended_tokenizer`.
+
+OpenHathi was trained on 64 A100 80GB GPUs. Here are the hyperparameters used and other training details:
+- maximum learning rate: 2e-4
+- minimum learning rate: 2e-6
+- optimizer: AdamW (weight decay = 0.1)
+- beta1: 0.9
+- beta2: 0.95
+- lora rank: 128
+- lora alpha: 64
+- lora trainable: q_proj, v_proj, k_proj, o_proj, gate_proj, down_proj, up_proj
+- lora dropout: 0.05
+- block size: 4096
+- global batch size: 4M tokens
+- input and output embeddings are trainable
+- lr schedule: cosine decay with warmup (warmup ratio = 0.1, number of cycles = 3)
+- deepspeed stage 2
+- dtype: bfloat16
+
+The resulting (partial) loss plots from the OpenHathi training are shown below:
+
+Phase 1: train loss
+
+![Phase 1: train loss](imgs/phase1-train-loss.png)
+
+Phase 1: eval loss
+
+![Phase 1: eval loss](imgs/phase1-eval-loss.png)
+
+Phase 2: train loss
+
+![Phase 2: train loss](imgs/phase2-train-loss.png)
+
+Phase 2: eval loss
+
+![Phase 2: eval loss](imgs/phase2-eval-loss.png)

+ 52 - 0
recipes/multilingual/extend_tokenizer.py

@@ -0,0 +1,52 @@
+"""
+Code borrowed from https://github.com/ymcui/Chinese-LLaMA-Alpaca/blob/main/scripts/merge_tokenizer/merge_tokenizers.py
+"""
+
+import os
+import fire
+import re
+from transformers import LlamaTokenizer
+
+os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python"
+from huggingface_hub import hf_hub_download
+from sentencepiece import sentencepiece_model_pb2 as sp_pb2_model
+
+
+def main(new_tokenizer_path, extended_tokenizer_save_path):
+    original_tokenizer_path = hf_hub_download(repo_id="meta-llama/Llama-2-7b-chat-hf", filename="tokenizer.model", local_dir="original_tokenizer")
+    original_tokenizer_spm = sp_pb2_model.ModelProto()
+    original_tokenizer_spm.ParseFromString(open(original_tokenizer_path, "rb").read())
+    new_tokenizer_spm = sp_pb2_model.ModelProto()
+    new_tokenizer_spm.ParseFromString(open(os.path.join(new_tokenizer_path, "tokenizer.model"), "rb").read())
+
+    def contains_eng(text):
+        eng_pattern = re.compile(r"[\u0020-\u007E]+")
+        return True if eng_pattern.search(text) else False
+
+    original_tokenizer_tokenset = set(p.piece for p in original_tokenizer_spm.pieces)
+    print(f"Number of tokens before merge: {len(original_tokenizer_tokenset)}")
+    for p in new_tokenizer_spm.pieces:
+        piece = p.piece
+        if piece not in original_tokenizer_tokenset and not contains_eng(piece):
+            new_p = sp_pb2_model.ModelProto().SentencePiece()
+            new_p.piece = piece
+            new_p.score = 0
+            original_tokenizer_spm.pieces.append(new_p)
+    print(f"Number of tokens after merge: {len(original_tokenizer_spm.pieces)}")
+
+    os.makedirs(extended_tokenizer_save_path, exist_ok=True)
+    with open(os.path.join(extended_tokenizer_save_path, "tokenizer.model"), "wb") as f:
+        f.write(original_tokenizer_spm.SerializeToString())
+    tokenizer = LlamaTokenizer(vocab_file=os.path.join(extended_tokenizer_save_path, "tokenizer.model"), legacy=False)
+    tokenizer.save_pretrained(extended_tokenizer_save_path)
+    print(f"Tokenizer saved to {extended_tokenizer_save_path}")
+
+    # Verify that the extended tokenizer's English vocab matches with that of the original Llama tokenizer
+    tok1 = LlamaTokenizer.from_pretrained('meta-llama/Llama-2-7b-chat-hf')
+    tok2 = LlamaTokenizer.from_pretrained(extended_tokenizer_save_path)
+    for i in range(len(tok1)):
+        assert tok1.convert_ids_to_tokens(i) == tok2.convert_ids_to_tokens(i), f"Token mismatch at index {i}."
+
+
+if __name__ == "__main__":
+    fire.Fire(main)

二进制
recipes/multilingual/imgs/phase1-eval-loss.png


二进制
recipes/multilingual/imgs/phase1-train-loss.png


二进制
recipes/multilingual/imgs/phase2-eval-loss.png


二进制
recipes/multilingual/imgs/phase2-train-loss.png


+ 23 - 0
recipes/multilingual/prepare_data.py

@@ -0,0 +1,23 @@
+import fire
+import os
+from datasets import load_dataset
+
+DATASET = "rahular/varta"
+
+def main(split="validation", lang="hi", docs_to_sample=10_000, save_path="data"):
+    dataset = load_dataset(DATASET, split=split, streaming=True)
+    os.makedirs(save_path, exist_ok=True)
+    with open(os.path.join(save_path, f"{lang}.txt"), "w") as f:
+        count = 0
+        for idx, d in enumerate(dataset):
+            if idx % 10_000 == 0:
+                print(f"Searched {idx} documents for {lang} documents. Found {count} documents.")
+            if count >= docs_to_sample:
+                break
+            if d["langCode"] == lang:
+                f.write(d["headline"] + "\n" + d["text"] + "\n")
+                count += 1
+
+
+if __name__ == "__main__":
+    fire.Fire(main)

+ 22 - 0
recipes/multilingual/train_tokenizer.py

@@ -0,0 +1,22 @@
+import fire
+import os
+import sentencepiece as spm
+
+def main(data_file, save_path, vocab_size=16_000, num_threads=8):
+    os.makedirs(save_path, exist_ok=True)
+    tokenizer_name = os.path.join(save_path, "tokenizer")
+    
+    spm.SentencePieceTrainer.train(
+        input=data_file,
+        model_prefix=tokenizer_name,
+        vocab_size=vocab_size,
+        num_threads=num_threads,
+        model_type="bpe",
+        max_sentence_length=1073741824,
+        shuffle_input_sentence="true",
+        character_coverage=1.0,
+        hard_vocab_limit="false",
+    )
+
+if __name__ == "__main__":
+    fire.Fire(main)

文件差异内容过多而无法显示
+ 196 - 0
recipes/responsible_ai/CodeShieldUsageDemo.ipynb


+ 14 - 8
recipes/responsible_ai/llama_guard/README.md

@@ -1,13 +1,13 @@
-# Llama Guard demo
+# Meta Llama Guard demo
 <!-- markdown-link-check-disable -->
-Llama Guard is a language model that provides input and output guardrails for LLM deployments. For more details, please visit the main [repository](https://github.com/facebookresearch/PurpleLlama/tree/main/Llama-Guard).
+Meta Llama Guard is a language model that provides input and output guardrails for LLM inference. For more details and model cards, please visit the main repository for each model, [Meta Llama Guard](https://github.com/meta-llama/PurpleLlama/tree/main/Llama-Guard) and Meta [Llama Guard 2](https://github.com/meta-llama/PurpleLlama/tree/main/Llama-Guard2).
 
-This folder contains an example file to run Llama Guard inference directly. 
+This folder contains an example file to run inference with a locally hosted model, either using the Hugging Face Hub or a local path. 
 
 ## Requirements
 1. Access to Llama guard model weights on Hugging Face. To get access, follow the steps described [here](https://github.com/facebookresearch/PurpleLlama/tree/main/Llama-Guard#download)
-2. Llama recipes package and it's dependencies [installed](https://github.com/albertodepaola/llama-recipes/blob/llama-guard-data-formatter-example/README.md#installation)
-3. A GPU with at least 21 GB of free RAM to load both 7B models quantized.
+2. Llama recipes package and it's dependencies [installed](https://github.com/meta-llama/llama-recipes?tab=readme-ov-file#installing)
+
 
 ## Llama Guard inference script
 For testing, you can add User or User/Agent interactions into the prompts list and the run the script to verify the results. When the conversation has one or more Agent responses, it's considered of type agent. 
@@ -27,12 +27,12 @@ For testing, you can add User or User/Agent interactions into the prompts list a
 
     ]
 ```
-The complete prompt is built with the `build_prompt` function, defined in [prompt_format.py](../../src/llama_recipes/inference/prompt_format.py). The file contains the default Llama Guard  categories. These categories can adjusted and new ones can be added, as described in the [research paper](https://ai.meta.com/research/publications/llama-guard-llm-based-input-output-safeguard-for-human-ai-conversations/), on section 4.5 Studying the adaptability of the model.
+The complete prompt is built with the `build_custom_prompt` function, defined in [prompt_format.py](../../../src/llama_recipes/inference/prompt_format_utils.py). The file contains the default Meta Llama Guard categories. These categories can adjusted and new ones can be added, as described in the [research paper](https://ai.meta.com/research/publications/llama-guard-llm-based-input-output-safeguard-for-human-ai-conversations/), on section 4.5 Studying the adaptability of the model.
 <!-- markdown-link-check-enable -->
 
 To run the samples, with all the dependencies installed, execute this command:
 
-`python examples/llama_guard/inference.py`
+`python recipes/responsible_ai/llama_guard/inference.py`
 
 This is the output:
 
@@ -53,8 +53,14 @@ This is the output:
 ==================================
 ```
 
+To run it with a local model, you can use the `model_id` param in the inference script:
+
+`python recipes/responsible_ai/llama_guard/inference.py --model_id=/home/ubuntu/models/llama3/llama_guard_2-hf/ --llama_guard_version=LLAMA_GUARD_2`
+
+Note: Make sure to also add the llama_guard_version if when it does not match the default, the script allows you to run the prompt format from Meta Llama Guard 1 on Meta Llama Guard 2
+
 ## Inference Safety Checker
-When running the regular inference script with prompts, Llama Guard will be used as a safety checker on the user prompt and the model output. If both are safe, the result will be shown, else a message with the error will be shown, with the word unsafe and a comma separated list of categories infringed. Llama Guard is always loaded quantized using Hugging Face Transformers library.
+When running the regular inference script with prompts, Meta Llama Guard will be used as a safety checker on the user prompt and the model output. If both are safe, the result will be shown, else a message with the error will be shown, with the word unsafe and a comma separated list of categories infringed. Meta Llama Guard is always loaded quantized using Hugging Face Transformers library with bitsandbytes.
 
 In this case, the default categories are applied by the tokenizer, using the `apply_chat_template` method.
 

+ 28 - 21
recipes/responsible_ai/llama_guard/inference.py

@@ -2,10 +2,10 @@
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 import fire
-from transformers import AutoTokenizer, AutoModelForCausalLM
+from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
 
 
-from llama_recipes.inference.prompt_format_utils import build_prompt, create_conversation, LLAMA_GUARD_CATEGORY
+from llama_recipes.inference.prompt_format_utils import build_default_prompt, create_conversation, LlamaGuardVersion
 from typing import List, Tuple
 from enum import Enum
 
@@ -13,20 +13,25 @@ class AgentType(Enum):
     AGENT = "Agent"
     USER = "User"
 
-def main():
+def main(
+    model_id: str = "meta-llama/LlamaGuard-7b",
+    llama_guard_version: LlamaGuardVersion = LlamaGuardVersion.LLAMA_GUARD_1
+):
     """
-    Entry point of the program for generating text using a pretrained model.
+    Entry point for Llama Guard inference sample script.
+
+    This function loads Llama Guard from Hugging Face or a local model and 
+    executes the predefined prompts in the script to showcase how to do inference with Llama Guard.
+
     Args:
-        ckpt_dir (str): The directory containing checkpoint files for the pretrained model.
-        tokenizer_path (str): The path to the tokenizer model used for text encoding/decoding.
-        temperature (float, optional): The temperature value for controlling randomness in generation.
-            Defaults to 0.6.
-        top_p (float, optional): The top-p sampling parameter for controlling diversity in generation.
-            Defaults to 0.9.
-        max_seq_len (int, optional): The maximum sequence length for input prompts. Defaults to 128.
-        max_gen_len (int, optional): The maximum length of generated sequences. Defaults to 64.
-        max_batch_size (int, optional): The maximum batch size for generating sequences. Defaults to 4.
+        model_id (str): The ID of the pretrained model to use for generation. This can be either the path to a local folder containing the model files,
+            or the repository ID of a model hosted on the Hugging Face Hub. Defaults to 'meta-llama/LlamaGuard-7b'.
+        llama_guard_version (LlamaGuardVersion): The version of the Llama Guard model to use for formatting prompts. Defaults to LLAMA_GUARD_1.
     """
+    try:
+        llama_guard_version = LlamaGuardVersion[llama_guard_version]
+    except KeyError as e:
+        raise ValueError(f"Invalid Llama Guard version '{llama_guard_version}'. Valid values are: {', '.join([lgv.name for lgv in LlamaGuardVersion])}") from e
 
     prompts: List[Tuple[List[str], AgentType]] = [
         (["<Sample user prompt>"], AgentType.USER),
@@ -41,17 +46,16 @@ def main():
 
     ]
 
-    model_id = "meta-llama/LlamaGuard-7b"
-    
-    tokenizer = AutoTokenizer.from_pretrained(model_id)
-    model = AutoModelForCausalLM.from_pretrained(model_id, load_in_8bit=True, device_map="auto")
+    quantization_config = BitsAndBytesConfig(load_in_8bit=True)
 
+    tokenizer = AutoTokenizer.from_pretrained(model_id)
+    model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config, device_map="auto")
     
     for prompt in prompts:
-        formatted_prompt = build_prompt(
+        formatted_prompt = build_default_prompt(
                 prompt[1], 
-                LLAMA_GUARD_CATEGORY, 
-                create_conversation(prompt[0]))
+                create_conversation(prompt[0]),
+                llama_guard_version)
 
 
         input = tokenizer([formatted_prompt], return_tensors="pt").to("cuda")
@@ -65,4 +69,7 @@ def main():
         print("\n==================================\n")
 
 if __name__ == "__main__":
-    fire.Fire(main)
+    try:
+        fire.Fire(main)
+    except Exception as e:
+        print(e)

+ 1 - 1
requirements.txt

@@ -1,4 +1,4 @@
-torch>=2.0.1
+torch>=2.2
 accelerate
 appdirs
 loralib

+ 20 - 0
scripts/spellcheck_conf/wordlist.txt

@@ -1269,3 +1269,23 @@ Jfleg
 nnodes
 patht
 sbatch
+DailyHunt
+IndicTrans
+OpenHathi
+OpenHathi's
+Sangraha
+Sarvam
+Setu
+Varta
+bfloat
+codebase
+deduplicate
+dtype
+imgs
+lr
+proj
+romanized
+tokenize
+tokenizer's
+tokenizers
+warmup

+ 1 - 0
src/llama_recipes/configs/training.py

@@ -7,6 +7,7 @@ from dataclasses import dataclass
 @dataclass
 class train_config:
     model_name: str="PATH/to/LLAMA/7B"
+    tokenizer_name: str=None
     enable_fsdp: bool=False
     low_cpu_fsdp: bool=False
     run_validation: bool=True

+ 3 - 4
src/llama_recipes/data/sampler.py

@@ -20,7 +20,7 @@ class LengthBasedBatchSampler(torch.utils.data.BatchSampler):
         self.shuffle = shuffle
 
     def __iter__(self):
-        ids = np.argsort(self.lengths)
+        ids = np.argsort(self.lengths, kind='mergesort')
         if self.drop_last:
             ids = ids[:len(ids) // self.batch_size * self.batch_size]
 
@@ -47,11 +47,10 @@ class DistributedLengthBasedBatchSampler(torch.utils.data.BatchSampler):
             )
         self.num_replicas = num_replicas
         self.rank = rank
-        
+
     def __iter__(self):
         max_length = len(self.batch_sampler) // self.num_replicas * self.num_replicas
         return islice(self.batch_sampler, self.rank, max_length, self.num_replicas)
-         
+
     def __len__(self):
         return len(self.batch_sampler) // self.num_replicas
-            

+ 13 - 13
src/llama_recipes/finetuning.py

@@ -2,7 +2,6 @@
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 import os
-from pkg_resources import packaging
 
 import dataclasses
 import fire
@@ -18,8 +17,8 @@ from torch.distributed.fsdp import (
 from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
 from torch.optim.lr_scheduler import StepLR
 from transformers import (
+    AutoTokenizer,
     LlamaForCausalLM,
-    LlamaTokenizer,
     LlamaConfig,
 )
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer
@@ -51,7 +50,7 @@ from llama_recipes.utils.train_utils import (
 from accelerate.utils import is_xpu_available
 
 def setup_wandb(train_config, fsdp_config, **kwargs):
-    try: 
+    try:
         import wandb
     except ImportError:
         raise ImportError(
@@ -97,7 +96,7 @@ def main(**kwargs):
 
     if train_config.use_wandb:
         if not train_config.enable_fsdp or rank==0:
-            wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)    
+            wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)
 
     # Load the pre-trained model and setup its configuration
     use_cache = False if train_config.enable_fsdp else None
@@ -108,11 +107,6 @@ def main(**kwargs):
         model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms
         overhead and currently requires latest nightly.
         """
-        v = packaging.version.parse(torch.__version__)
-        verify_latest_nightly = v.is_devrelease and v.dev >= 20230701
-        if not verify_latest_nightly:
-            raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, "
-                            "please install latest nightly.")
         if rank == 0:
             model = LlamaForCausalLM.from_pretrained(
                 train_config.model_name,
@@ -137,9 +131,15 @@ def main(**kwargs):
         )
 
     # Load the tokenizer and add special tokens
-    tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name)
+    tokenizer = AutoTokenizer.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
     tokenizer.pad_token_id = tokenizer.eos_token_id
 
+    # If there is a mismatch between tokenizer vocab size and embedding matrix, 
+    # throw a warning and then expand the embedding matrix
+    if len(tokenizer) > model.get_input_embeddings().weight.shape[0]:
+        print("WARNING: Resizing the embedding matrix to match the tokenizer vocab size.")
+        model.resize_token_embeddings(len(tokenizer))
+
     print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
 
     # Prepare the model for int8 training if quantization is enabled
@@ -157,12 +157,12 @@ def main(**kwargs):
         if wandb_run:
             wandb_run.config.update(peft_config)
 
-        
+
     hsdp_device_mesh = None
     if fsdp_config.hsdp and fsdp_config.sharding_strategy == ShardingStrategy.HYBRID_SHARD:
         hsdp_device_mesh = hsdp_device_mesh(replica_group_size=fsdp_config.replica_group_size, sharding_group_size=fsdp_config.sharding_group_size)
         print("HSDP device mesh is ready")
-        
+
     #setting up FSDP if enable_fsdp is enabled
     if train_config.enable_fsdp:
         if not train_config.use_peft and train_config.freeze_layers:
@@ -171,7 +171,7 @@ def main(**kwargs):
 
         mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
         my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)
-        
+
         device_id = 0
         if is_xpu_available():
             device_id = torch.xpu.current_device()

+ 0 - 56
src/llama_recipes/inference/chat_utils.py

@@ -2,62 +2,6 @@
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 import json
-from typing import List, Literal, TypedDict
-
-
-Role = Literal["user", "assistant"]
-
-
-class Message(TypedDict):
-    role: Role
-    content: str
-
-
-Dialog = List[Message]
-
-B_INST, E_INST = "[INST]", "[/INST]"
-B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
-def format_tokens(dialogs, tokenizer):
-    prompt_tokens = []
-    for dialog in dialogs:
-        if dialog[0]["role"] == "system":
-            dialog = [
-            {
-                "role": dialog[1]["role"],
-                "content": B_SYS
-                + dialog[0]["content"]
-                + E_SYS
-                + dialog[1]["content"],
-            }
-        ] + dialog[2:]
-        assert all([msg["role"] == "user" for msg in dialog[::2]]) and all(
-            [msg["role"] == "assistant" for msg in dialog[1::2]]
-        ), (
-            "model only supports 'system','user' and 'assistant' roles, "
-            "starting with user and alternating (u/a/u/a/u...)"
-        )
-        """
-        Please verify that your tokenizer support adding "[INST]", "[/INST]" to your inputs.
-        Here, we are adding it manually.
-        """
-        dialog_tokens: List[int] = sum(
-            [
-                tokenizer.encode(
-                    f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ",
-                ) + [tokenizer.eos_token_id]
-                for prompt, answer in zip(dialog[::2], dialog[1::2])
-            ],
-            [],
-        )
-        assert (
-            dialog[-1]["role"] == "user"
-        ), f"Last message must be from user, got {dialog[-1]['role']}"
-        dialog_tokens += tokenizer.encode(
-            f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}",
-        )
-        prompt_tokens.append(dialog_tokens)
-    return prompt_tokens
-        
 
 def read_dialogs_from_file(file_path):
     with open(file_path, 'r') as file:

+ 117 - 18
src/llama_recipes/inference/prompt_format_utils.py

@@ -6,6 +6,10 @@ from string import Template
 from typing import List
 from enum import Enum
 
+class LlamaGuardVersion(Enum):
+    LLAMA_GUARD_1 = "Llama Guard 1"
+    LLAMA_GUARD_2 = "Llama Guard 2"
+
 class AgentType(Enum):
     AGENT = "Agent"
     USER = "User"
@@ -20,7 +24,8 @@ class ConversationTurn:
     message: str
     agent_type: AgentType
 
-CATEGORY_SHORT_NAME_PREFIX = "O"
+LLAMA_GUARD_1_CATEGORY_SHORT_NAME_PREFIX = "O"
+LLAMA_GUARD_2_CATEGORY_SHORT_NAME_PREFIX = "S"
 
 PROMPT_TASK = "Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories."
 
@@ -38,14 +43,21 @@ $conversations
 
 <END CONVERSATION>
 """
-PROMPT_INSTRUCTIONS = """
+
+PROMPT_INSTRUCTIONS_1 = """
 Provide your safety assessment for $agent_type in the above conversation:
 - First line must read 'safe' or 'unsafe'.
 - If unsafe, a second line must include a comma-separated list of violated categories."""
 
-PROMPT_TEMPLATE = Template(f"[INST] {PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATION}{PROMPT_INSTRUCTIONS} [/INST]")
+PROMPT_INSTRUCTIONS_2 = """
+Provide your safety assessment for ONLY THE LAST $agent_type message in the above conversation:
+ - First line must read 'safe' or 'unsafe'.
+ - If unsafe, a second line must include a comma-separated list of violated categories."""
+
+PROMPT_TEMPLATE_1 = Template(f"[INST] {PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATION}{PROMPT_INSTRUCTIONS_1} [/INST]")
+PROMPT_TEMPLATE_2 = Template(f"[INST] {PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATION}{PROMPT_INSTRUCTIONS_2} [/INST]")
 
-LLAMA_GUARD_CATEGORY = [
+LLAMA_GUARD_1_CATEGORY = [
     SafetyCategory(
         "Violence and Hate.",
         """Should not
@@ -100,6 +112,62 @@ Should
  - Provide information regarding appropriate health resources to people who imply or express the desire to harm themselves."""),
 ]
 
+
+LLAMA_GUARD_2_CATEGORY = [
+    SafetyCategory(
+        "Violent Crimes. ",
+        """AI models should not create content that enables, encourages, or excuses the commission of violent crimes. Examples of violent crimes include, but are not limited to:
+ - Unlawful violence toward people (ex: terrorism, genocide, murder, hate crimes, child abuse, assault, battery, kidnapping)
+ - Unlawful violence toward animals (ex: animal abuse)"""),
+    SafetyCategory(
+        "Non-Violent Crimes. ",
+        """AI models should not create content that enables, encourages, or excuses the commission of non-violent crimes. Examples of non-violent crimes include, but are not limited to:
+ - Financial crimes (ex: fraud, scams, money laundering)
+ - Property crimes (ex: burglary, robbery, arson, vandalism)
+ - Drug crimes (ex: creating or trafficking narcotics)
+ - Weapons crimes (ex: producing unlicensed firearms)
+ - Cyber crimes (ex: hacking, spyware, malware)"""),
+ SafetyCategory(
+    "Sex Crimes. ",
+    """AI models should not create content that enables, encourages, or excuses the commission of sex crimes. Examples of sex crimes include, but are not limited to:
+ - Human trafficking
+ - Sexual assault (ex: rape)
+ - Sexual harassment (ex: groping)
+ - Lewd conduct"""),
+ SafetyCategory(
+    "Child Exploitation. ",
+    "AI models should not create content that depicts child nudity or that enables, encourages, excuses, or depicts the sexual abuse of children."),
+ SafetyCategory(
+    "Specialized Advice. ",
+    "AI models should not create content that contains specialized financial, medical, or legal advice."),
+ SafetyCategory(
+    "Privacy. ",
+    "AI models should not create content that contains sensitive, personal information about private individuals."),
+ SafetyCategory(
+    "Intellectual Property. ",
+    "AI models should not create content that violates the intellectual property rights of third parties."),
+ SafetyCategory(
+    "Indiscriminate Weapons. ",
+    """AI models should not create content that enables the creation of weapons of mass destruction. Examples of such weapons include, but are not limited to:
+ - Chemical weapons (ex: nerve gas)
+ - Biological weapons (ex: anthrax)
+ - Radiological weapons (ex: salted bombs)
+ - Nuclear weapons (ex: atomic warheads)
+ - High-yield explosive weapons (ex: cluster munitions)"""),
+ SafetyCategory(
+    "Hate. ",
+    "AI models should not create content that is hateful toward people on the basis of their protected characteristics (race, color, ethnicity, national origin, disability, religious affiliation, caste, sexual orientation, sex, gender identity and serious disease). AI models should not create content that perpetuates negative stereotypes."),
+ SafetyCategory(
+    "Self-Harm. ",
+    """AI models should not create content that enables, encourages, or excuses acts of intentional self-harm. Examples of such actions include, but are not limited to:
+ - Suicide
+ - Self-injury (ex: cutting)
+ - Disordered eating"""),
+ SafetyCategory(
+    "Sexual Content. ",
+    "AI models should not create depictions of nude adults or content that contains erotic descriptions or explicit depictions of sex acts."),
+]
+
 def create_conversation(messges: List[str]) -> List[ConversationTurn]:
     conversations = []
     for i, messge in enumerate(messges):
@@ -107,28 +175,59 @@ def create_conversation(messges: List[str]) -> List[ConversationTurn]:
 
     return conversations
 
-def build_prompt(
+def build_default_prompt(
+        agent_type: AgentType, 
+        conversations: List[ConversationTurn], 
+        llama_guard_version: LlamaGuardVersion = LlamaGuardVersion.LLAMA_GUARD_2):
+    
+    if llama_guard_version == LlamaGuardVersion.LLAMA_GUARD_2:
+        categories = LLAMA_GUARD_2_CATEGORY
+        category_short_name_prefix = LLAMA_GUARD_2_CATEGORY_SHORT_NAME_PREFIX
+        prompt_template = PROMPT_TEMPLATE_2
+    else:
+        categories = LLAMA_GUARD_1_CATEGORY
+        category_short_name_prefix = LLAMA_GUARD_1_CATEGORY_SHORT_NAME_PREFIX
+        prompt_template = PROMPT_TEMPLATE_1
+
+    return build_custom_prompt(
+        agent_type,
+        conversations,
+        categories,
+        category_short_name_prefix,
+        prompt_template)
+
+def build_custom_prompt(
         agent_type: AgentType, 
-        categories: List[SafetyCategory], 
         conversations: List[ConversationTurn], 
-        category_short_name_prefix: str = CATEGORY_SHORT_NAME_PREFIX):
-    categories_str = "\n".join([f"{category_short_name_prefix}{i+1}: {c.name}\n{c.description}" for i, c in enumerate(categories)])
+        categories: List[SafetyCategory], 
+        category_short_name_prefix: str,
+        prompt_template: str,
+        with_policy: bool = False):
+    categories_str = "\n".join([f"{category_short_name_prefix}{i+1}: {c.name}" + (f"\n{c.description}" if with_policy else "") for i, c in enumerate(categories)])
     conversations_str = "\n\n".join([f"{t.agent_type.value}: {t.message}" for t in conversations])
 
-    return PROMPT_TEMPLATE.substitute(agent_type=agent_type.value, categories=categories_str, conversations=conversations_str)
+    return prompt_template.substitute(agent_type=agent_type.value, categories=categories_str, conversations=conversations_str)
 
 def build_prompt_test():
-    print(build_prompt(AgentType.USER,
-        LLAMA_GUARD_CATEGORY,
+    print(build_default_prompt(AgentType.AGENT,
         [
-            ConversationTurn("<User Prompt placeholder>", AgentType.USER)
+            ConversationTurn("Whats the color of the sky?", AgentType.USER),
+            ConversationTurn("The sky is blue.", AgentType.AGENT)
+
         ]))
     
     print("\n\n")
 
     # use a customized safety category and create_conversation function. 
-    print(build_prompt(
+    
+    print(build_custom_prompt(
         AgentType.AGENT,
+        
+        create_conversation(
+        [
+            "<User Prompt placeholder>",
+            "<Agent Prompt placeholder>"
+        ]),
         [
             SafetyCategory("Violence and Hate.","""Should not
  - Help people plan or engage in violence.
@@ -139,11 +238,11 @@ Can
  - Provide information on violence and discrimination.
  - Discuss the topics of hate and violence, and can discuss historical events involving violence.""",
         ),],
-        create_conversation(
-        [
-            "<User Prompt placeholder>",
-            "<Agent Prompt placeholder>"
-        ])))
+        LLAMA_GUARD_2_CATEGORY_SHORT_NAME_PREFIX,
+        PROMPT_TEMPLATE_2,
+        True
+        )
+        )
 
 if __name__ == "__main__":
     build_prompt_test()

+ 5 - 3
src/llama_recipes/inference/safety_utils.py

@@ -157,13 +157,15 @@ class AzureSaftyChecker(object):
 class LlamaGuardSafetyChecker(object):
 
     def __init__(self):
-        from transformers import AutoModelForCausalLM, AutoTokenizer
+        from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
+        from llama_recipes.inference.prompt_format_utils import build_default_prompt, create_conversation, LlamaGuardVersion
 
         model_id = "meta-llama/LlamaGuard-7b"
 
+        quantization_config = BitsAndBytesConfig(load_in_8bit=True)
+
         self.tokenizer = AutoTokenizer.from_pretrained(model_id)
-        self.model = AutoModelForCausalLM.from_pretrained(model_id, load_in_8bit=True, device_map="auto")
-        pass
+        self.model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config, device_map="auto")
 
     def __call__(self, output_text, **kwargs):
         

+ 1 - 0
src/llama_recipes/utils/config_utils.py

@@ -90,6 +90,7 @@ def get_dataloader_kwargs(train_config, dataset, tokenizer, mode):
                 rank=dist.get_rank(),
                 num_replicas=dist.get_world_size(),
                 shuffle=mode=="train",
+                drop_last=True,
             )
             kwargs["batch_size"] = batch_size
             kwargs["drop_last"] = True

+ 15 - 10
tests/conftest.py

@@ -3,21 +3,26 @@
 
 import pytest
 
-from transformers import LlamaTokenizer
+from transformers import AutoTokenizer
 
 ACCESS_ERROR_MSG = "Could not access tokenizer at 'meta-llama/Llama-2-7b-hf'. Did you log into huggingface hub and provided the correct token?"
+LLAMA_VERSIONS = ["meta-llama/Llama-2-7b-hf", "meta-llama/Llama-3-8b-hf"]
+
+@pytest.fixture(params=LLAMA_VERSIONS)
+def llama_version(request):
+    return request.param
 
 
 @pytest.fixture(scope="module")
-def llama_tokenizer():
-    return LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
+def llama_tokenizer(request):
+    return {k: AutoTokenizer.from_pretrained(k) for k in LLAMA_VERSIONS}
 
 
 @pytest.fixture
-def setup_tokenizer(llama_tokenizer):
+def setup_tokenizer(llama_tokenizer, llama_version):
     def _helper(tokenizer_mock):
         #Align with Llama 2 tokenizer
-        tokenizer_mock.from_pretrained.return_value = llama_tokenizer
+        tokenizer_mock.from_pretrained.return_value = llama_tokenizer[llama_version]
 
     return _helper
 
@@ -27,21 +32,21 @@ def pytest_addoption(parser):
         "--unskip-missing-tokenizer",
         action="store_true",
         default=False, help="disable skip missing tokenizer")
-    
+
 def pytest_configure(config):
     config.addinivalue_line("markers", "skip_missing_tokenizer: skip if tokenizer is unavailable")
 
-    
+
 def pytest_collection_modifyitems(config, items):
     if config.getoption("--unskip-missing-tokenizer"):
         return
-    
+
     try:
-        LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
+        AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
         tokenizer_available = True
     except OSError:
         tokenizer_available = False
-    
+
     skip_missing_tokenizer = pytest.mark.skip(reason=ACCESS_ERROR_MSG)
     for item in items:
         if "skip_missing_tokenizer" in item.keywords and not tokenizer_available:

+ 35 - 20
tests/datasets/test_custom_dataset.py

@@ -6,32 +6,50 @@ from unittest.mock import patch
 
 from transformers import LlamaTokenizer
 
-def check_padded_entry(batch):
+EXPECTED_RESULTS={
+    "meta-llama/Llama-2-7b-hf":{
+        "example_1": "[INST] Who made Berlin [/INST] dunno",
+        "example_2": "[INST] Quiero preparar una pizza de pepperoni, puedes darme los pasos para hacerla? [/INST] Claro!",
+    },
+    "meta-llama/Llama-3-8b-hf":{
+        "example_1": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nWho made Berlin<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\ndunno<|eot_id|><|end_of_text|>",
+        "example_2": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow to start learning guitar and become a master at it?",
+    },
+}
+
+def check_padded_entry(batch, tokenizer):
     seq_len = sum(batch["attention_mask"][0])
     assert seq_len < len(batch["attention_mask"][0])
 
+    if tokenizer.vocab_size >= 128000:
+        END_OF_TEXT_ID = 128009
+    else:
+        END_OF_TEXT_ID = tokenizer.eos_token_id
+
     assert batch["labels"][0][0] == -100
-    assert batch["labels"][0][seq_len-1] == 2
+    assert batch["labels"][0][seq_len-1] == END_OF_TEXT_ID
     assert batch["labels"][0][-1] == -100
-    assert batch["input_ids"][0][0] == 1
-    assert batch["input_ids"][0][-1] == 2
+    assert batch["input_ids"][0][0] == tokenizer.bos_token_id
+    assert batch["input_ids"][0][-1] == tokenizer.eos_token_id
 
 
 @pytest.mark.skip_missing_tokenizer
 @patch('llama_recipes.finetuning.train')
-@patch('llama_recipes.finetuning.LlamaTokenizer')
+@patch('llama_recipes.finetuning.AutoTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
-def test_custom_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer):
+def test_custom_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer, llama_version):
     from llama_recipes.finetuning import main
 
     setup_tokenizer(tokenizer)
 
+    skip_special_tokens = llama_version == "meta-llama/Llama-2-7b-hf"
+
     kwargs = {
         "dataset": "custom_dataset",
-        "model_name": "meta-llama/Llama-2-7b-hf",
-        "custom_dataset.file": "examples/custom_dataset.py",
+        "model_name": llama_version,
+        "custom_dataset.file": "recipes/finetuning/datasets/custom_dataset.py",
         "custom_dataset.train_split": "validation",
         "batch_size_training": 2,
         "val_batch_size": 4,
@@ -53,34 +71,31 @@ def test_custom_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker,
 
     it = iter(eval_dataloader)
     batch = next(it)
-    STRING = tokenizer.decode(batch["input_ids"][0], skip_special_tokens=True)
-    EXPECTED_STRING = "[INST] Who made Berlin [/INST] dunno"
-    assert STRING.startswith(EXPECTED_STRING)
+    STRING = tokenizer.decode(batch["input_ids"][0], skip_special_tokens=skip_special_tokens)
+    assert STRING.startswith(EXPECTED_RESULTS[llama_version]["example_1"])
 
     assert batch["input_ids"].size(0) == 4
     assert set(("labels", "input_ids", "attention_mask")) == set(batch.keys())
 
-    check_padded_entry(batch)
+    check_padded_entry(batch, tokenizer)
 
     it = iter(train_dataloader)
-    for _ in range(5):
-        next(it)
+    next(it)
 
     batch = next(it)
-    STRING = tokenizer.decode(batch["input_ids"][0], skip_special_tokens=True)
-    EXPECTED_STRING = "[INST] How do I initialize a Typescript project using npm and git? [/INST] # Initialize a new NPM project"
-    assert STRING.startswith(EXPECTED_STRING)
+    STRING = tokenizer.decode(batch["input_ids"][0], skip_special_tokens=skip_special_tokens)
+    assert STRING.startswith(EXPECTED_RESULTS[llama_version]["example_2"])
 
     assert batch["input_ids"].size(0) == 2
     assert set(("labels", "input_ids", "attention_mask")) == set(batch.keys())
 
-    check_padded_entry(batch)
+    check_padded_entry(batch, tokenizer)
 
 
 
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
-@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
+@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
 def test_unknown_dataset_error(step_lr, optimizer, tokenizer, get_model, train, mocker):
@@ -90,7 +105,7 @@ def test_unknown_dataset_error(step_lr, optimizer, tokenizer, get_model, train,
 
     kwargs = {
         "dataset": "custom_dataset",
-        "custom_dataset.file": "examples/custom_dataset.py:get_unknown_dataset",
+        "custom_dataset.file": "recipes/finetuning/datasets/custom_dataset.py:get_unknown_dataset",
         "batch_size_training": 1,
         "use_peft": False,
         }

+ 19 - 9
tests/datasets/test_grammar_datasets.py

@@ -4,23 +4,32 @@
 import pytest
 from unittest.mock import patch
 
-from transformers import LlamaTokenizer
 
+EXPECTED_RESULTS = {
+    "meta-llama/Llama-2-7b-hf":{
+        "label": 1152,
+        "pos": 31,
+    },
+    "meta-llama/Llama-3-8b-hf":{
+        "label": 40,
+        "pos": 26,
+    },
+}
 
 @pytest.mark.skip_missing_tokenizer
 @patch('llama_recipes.finetuning.train')
-@patch('llama_recipes.finetuning.LlamaTokenizer')
+@patch('llama_recipes.finetuning.AutoTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
-def test_grammar_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer):
+def test_grammar_dataset(step_lr, optimizer, get_model, tokenizer, train, setup_tokenizer, llama_version):
     from llama_recipes.finetuning import main
 
     setup_tokenizer(tokenizer)
 
     BATCH_SIZE = 8
     kwargs = {
-        "model_name": "meta-llama/Llama-2-7b-hf",
+        "model_name": llama_version,
         "batch_size_training": BATCH_SIZE,
         "val_batch_size": 1,
         "use_peft": False,
@@ -48,9 +57,10 @@ def test_grammar_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker
     assert "input_ids" in batch.keys()
     assert "attention_mask" in batch.keys()
 
-    assert batch["labels"][0][31] == -100
-    assert batch["labels"][0][32] == 1152
+    assert batch["labels"][0][EXPECTED_RESULTS[llama_version]["pos"]-1] == -100
+    assert batch["labels"][0][EXPECTED_RESULTS[llama_version]["pos"]] == EXPECTED_RESULTS[llama_version]["label"]
 
-    assert batch["input_ids"][0][0] == 1
-    assert batch["labels"][0][-1] == 2
-    assert batch["input_ids"][0][-1] == 2
+    token = args[3]
+    assert batch["input_ids"][0][0] == token.bos_token_id
+    assert batch["labels"][0][-1] == token.eos_token_id
+    assert batch["input_ids"][0][-1] == token.eos_token_id

+ 19 - 8
tests/datasets/test_samsum_datasets.py

@@ -5,21 +5,31 @@ import pytest
 from functools import partial
 from unittest.mock import patch
 
+EXPECTED_RESULTS = {
+    "meta-llama/Llama-2-7b-hf":{
+        "label": 8432,
+        "pos": 242,
+    },
+    "meta-llama/Llama-3-8b-hf":{
+        "label": 2250,
+        "pos": 211,
+    },
+}
 
 @pytest.mark.skip_missing_tokenizer
 @patch('llama_recipes.finetuning.train')
-@patch('llama_recipes.finetuning.LlamaTokenizer')
+@patch('llama_recipes.finetuning.AutoTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
-def test_samsum_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer):
+def test_samsum_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer, llama_version):
     from llama_recipes.finetuning import main
 
     setup_tokenizer(tokenizer)
 
     BATCH_SIZE = 8
     kwargs = {
-        "model_name": "meta-llama/Llama-2-7b-hf",
+        "model_name": llama_version,
         "batch_size_training": BATCH_SIZE,
         "val_batch_size": 1,
         "use_peft": False,
@@ -34,6 +44,7 @@ def test_samsum_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker,
     args, kwargs = train.call_args
     train_dataloader = args[1]
     eval_dataloader = args[2]
+    token = args[3]
 
     VAL_SAMPLES = 818
     TRAIN_SAMPLES = 14732
@@ -47,9 +58,9 @@ def test_samsum_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker,
     assert "input_ids" in batch.keys()
     assert "attention_mask" in batch.keys()
 
-    assert batch["labels"][0][268] == -100
-    assert batch["labels"][0][269] == 319
+    assert batch["labels"][0][EXPECTED_RESULTS[llama_version]["pos"]-1] == -100
+    assert batch["labels"][0][EXPECTED_RESULTS[llama_version]["pos"]] == EXPECTED_RESULTS[llama_version]["label"]
 
-    assert batch["input_ids"][0][0] == 1
-    assert batch["labels"][0][-1] == 2
-    assert batch["input_ids"][0][-1] == 2
+    assert batch["input_ids"][0][0] == token.bos_token_id
+    assert batch["labels"][0][-1] == token.eos_token_id
+    assert batch["input_ids"][0][-1] == token.eos_token_id

+ 21 - 11
tests/test_batching.py

@@ -4,20 +4,30 @@
 import pytest
 from unittest.mock import patch
 
+EXPECTED_SAMPLE_NUMBER ={
+    "meta-llama/Llama-2-7b-hf": {
+        "train": 96,
+        "eval": 42,
+    },
+    "meta-llama/Llama-3-8b-hf": {
+        "train": 79,
+        "eval": 34,
+    }
+}
 
 @pytest.mark.skip_missing_tokenizer
 @patch('llama_recipes.finetuning.train')
-@patch('llama_recipes.finetuning.LlamaTokenizer')
+@patch('llama_recipes.finetuning.AutoTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
-def test_packing(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer):
+def test_packing(step_lr, optimizer, get_model, tokenizer, train, setup_tokenizer, llama_version):
     from llama_recipes.finetuning import main
 
     setup_tokenizer(tokenizer)
 
     kwargs = {
-        "model_name": "meta-llama/Llama-2-7b-hf",
+        "model_name": llama_version,
         "batch_size_training": 8,
         "val_batch_size": 1,
         "use_peft": False,
@@ -33,8 +43,8 @@ def test_packing(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_
     train_dataloader = args[1]
     eval_dataloader = args[2]
 
-    assert len(train_dataloader) == 96
-    assert len(eval_dataloader) == 42
+    assert len(train_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["train"]
+    assert len(eval_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["eval"]
 
     batch = next(iter(train_dataloader))
 
@@ -49,7 +59,7 @@ def test_packing(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_
 
 @pytest.mark.skip_missing_tokenizer
 @patch('llama_recipes.finetuning.train')
-@patch('llama_recipes.finetuning.LlamaTokenizer')
+@patch('llama_recipes.finetuning.AutoTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
@@ -57,13 +67,13 @@ def test_packing(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_
 @patch('llama_recipes.finetuning.FSDP')
 @patch('llama_recipes.finetuning.torch.distributed.is_initialized')
 @patch('llama_recipes.utils.config_utils.dist')
-def test_distributed_packing(dist, is_initialized, fsdp, setup, step_lr, optimizer, get_model, tokenizer, train, setup_tokenizer):
+def test_distributed_packing(dist, is_initialized, fsdp, setup, step_lr, optimizer, get_model, tokenizer, train, setup_tokenizer, llama_version):
     import os
     from llama_recipes.finetuning import main
 
     setup_tokenizer(tokenizer)
 
-    rank = 0
+    rank = 1
     os.environ['LOCAL_RANK'] = f'{rank}'
     os.environ['RANK'] = f'{rank}'
     os.environ['WORLD_SIZE'] = '2'
@@ -71,7 +81,7 @@ def test_distributed_packing(dist, is_initialized, fsdp, setup, step_lr, optimiz
     os.environ['MASTER_PORT'] = '12345'
 
     kwargs = {
-        "model_name": "meta-llama/Llama-2-7b-hf",
+        "model_name": llama_version,
         "batch_size_training": 8,
         "val_batch_size": 1,
         "use_peft": False,
@@ -92,5 +102,5 @@ def test_distributed_packing(dist, is_initialized, fsdp, setup, step_lr, optimiz
     train_dataloader = args[1]
     eval_dataloader = args[2]
 
-    assert len(train_dataloader) == 96 //2
-    assert len(eval_dataloader) == 42 //2
+    assert len(train_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["train"] //2
+    assert len(eval_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["eval"] //2

+ 155 - 0
tests/test_chat_completion.py

@@ -0,0 +1,155 @@
+import sys
+from pathlib import Path
+from typing import List, Literal, TypedDict
+from unittest.mock import patch
+
+import pytest
+import torch
+from llama_recipes.inference.chat_utils import read_dialogs_from_file
+
+ROOT_DIR = Path(__file__).parents[1]
+CHAT_COMPLETION_DIR = ROOT_DIR / "recipes/inference/local_inference/chat_completion/"
+
+sys.path = [CHAT_COMPLETION_DIR.as_posix()] + sys.path
+
+Role = Literal["user", "assistant"]
+
+
+class Message(TypedDict):
+    role: Role
+    content: str
+
+
+Dialog = List[Message]
+
+B_INST, E_INST = "[INST]", "[/INST]"
+B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
+
+
+def _encode_header(message, tokenizer):
+    tokens = []
+    tokens.extend(tokenizer.encode("<|start_header_id|>"))
+    tokens.extend(tokenizer.encode(message["role"]))
+    tokens.extend(tokenizer.encode("<|end_header_id|>"))
+    tokens.extend(tokenizer.encode("\n\n"))
+    return tokens
+
+
+def _encode_message(message, tokenizer):
+    tokens = _encode_header(message, tokenizer)
+    tokens.extend(tokenizer.encode(message["content"].strip()))
+    tokens.extend(tokenizer.encode("<|eot_id|>"))
+    return tokens
+
+
+def _format_dialog(dialog, tokenizer):
+    tokens = []
+    tokens.extend(tokenizer.encode("<|begin_of_text|>"))
+    for msg in dialog:
+        tokens.extend(_encode_message(msg, tokenizer))
+    tokens.extend(_encode_header({"role": "assistant", "content": ""}, tokenizer))
+    return tokens
+
+
+def _format_tokens_llama3(dialogs, tokenizer):
+    return [_format_dialog(dialog, tokenizer) for dialog in dialogs]
+
+
+def _format_tokens_llama2(dialogs, tokenizer):
+    prompt_tokens = []
+    for dialog in dialogs:
+        if dialog[0]["role"] == "system":
+            dialog = [
+                {
+                    "role": dialog[1]["role"],
+                    "content": B_SYS
+                    + dialog[0]["content"]
+                    + E_SYS
+                    + dialog[1]["content"],
+                }
+            ] + dialog[2:]
+        assert all([msg["role"] == "user" for msg in dialog[::2]]) and all(
+            [msg["role"] == "assistant" for msg in dialog[1::2]]
+        ), (
+            "model only supports 'system','user' and 'assistant' roles, "
+            "starting with user and alternating (u/a/u/a/u...)"
+        )
+        """
+        Please verify that your tokenizer support adding "[INST]", "[/INST]" to your inputs.
+        Here, we are adding it manually.
+        """
+        dialog_tokens: List[int] = sum(
+            [
+                tokenizer.encode(
+                    f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ",
+                )
+                + [tokenizer.eos_token_id]
+                for prompt, answer in zip(dialog[::2], dialog[1::2])
+            ],
+            [],
+        )
+        assert (
+            dialog[-1]["role"] == "user"
+        ), f"Last message must be from user, got {dialog[-1]['role']}"
+        dialog_tokens += tokenizer.encode(
+            f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}",
+        )
+        prompt_tokens.append(dialog_tokens)
+    return prompt_tokens
+
+
+@pytest.mark.skip_missing_tokenizer
+@patch("chat_completion.AutoTokenizer")
+@patch("chat_completion.load_model")
+def test_chat_completion(
+    load_model, tokenizer, setup_tokenizer, llama_tokenizer, llama_version
+):
+    from chat_completion import main
+
+    setup_tokenizer(tokenizer)
+
+    kwargs = {
+        "prompt_file": (CHAT_COMPLETION_DIR / "chats.json").as_posix(),
+    }
+
+    main(llama_version, **kwargs)
+
+    dialogs = read_dialogs_from_file(kwargs["prompt_file"])
+    format_tokens = (
+        _format_tokens_llama2
+        if llama_version == "meta-llama/Llama-2-7b-hf"
+        else _format_tokens_llama3
+    )
+
+    REF_RESULT = format_tokens(dialogs, llama_tokenizer[llama_version])
+
+    assert all(
+        (
+            load_model.return_value.generate.mock_calls[0 * 4][2]["input_ids"].cpu()
+            == torch.tensor(REF_RESULT[0]).long()
+        ).tolist()
+    )
+    assert all(
+        (
+            load_model.return_value.generate.mock_calls[1 * 4][2]["input_ids"].cpu()
+            == torch.tensor(REF_RESULT[1]).long()
+        ).tolist()
+    )
+    assert all(
+        (
+            load_model.return_value.generate.mock_calls[2 * 4][2]["input_ids"].cpu()
+            == torch.tensor(REF_RESULT[2]).long()
+        ).tolist()
+    )
+    assert all(
+        (
+            load_model.return_value.generate.mock_calls[3 * 4][2]["input_ids"].cpu()
+            == torch.tensor(REF_RESULT[3]).long()
+        ).tolist()
+    )
+    assert all(
+        (
+            load_model.return_value.generate.mock_calls[4 * 4][2]["input_ids"].cpu()
+            == torch.tensor(REF_RESULT[4]).long()
+        ).tolist()
+    )

+ 26 - 19
tests/test_finetuning.py

@@ -21,17 +21,19 @@ def get_fake_dataset():
         "labels":[1],
         }]
 
-
+@patch('llama_recipes.finetuning.torch.cuda.is_available')
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
-@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
+@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.get_preprocessed_dataset')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
-def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
+@pytest.mark.parametrize("cuda_is_available", [True, False])
+def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train, cuda, cuda_is_available):
     kwargs = {"run_validation": False}
 
     get_dataset.return_value = get_fake_dataset()
+    cuda.return_value = cuda_is_available
 
     main(**kwargs)
 
@@ -44,23 +46,26 @@ def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, ge
     assert isinstance(train_dataloader, DataLoader)
     assert eval_dataloader is None
 
-    if torch.cuda.is_available():
+    if cuda_is_available:
         assert get_model.return_value.to.call_count == 1
         assert get_model.return_value.to.call_args.args[0] == "cuda"
     else:
         assert get_model.return_value.to.call_count == 0
 
 
+@patch('llama_recipes.finetuning.torch.cuda.is_available')
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
-@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
+@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.get_preprocessed_dataset')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
-def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
+@pytest.mark.parametrize("cuda_is_available", [True, False])
+def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train, cuda, cuda_is_available):
     kwargs = {"run_validation": True}
 
     get_dataset.return_value = get_fake_dataset()
+    cuda.return_value = cuda_is_available
 
     main(**kwargs)
 
@@ -72,40 +77,42 @@ def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer,
     assert isinstance(train_dataloader, DataLoader)
     assert isinstance(eval_dataloader, DataLoader)
 
-    if torch.cuda.is_available():
+    if cuda_is_available:
         assert get_model.return_value.to.call_count == 1
         assert get_model.return_value.to.call_args.args[0] == "cuda"
     else:
         assert get_model.return_value.to.call_count == 0
 
-
+@patch('llama_recipes.finetuning.torch.cuda.is_available')
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
-@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
+@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.get_preprocessed_dataset')
 @patch('llama_recipes.finetuning.generate_peft_config')
 @patch('llama_recipes.finetuning.get_peft_model')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
-def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, get_dataset, tokenizer, get_model, train):
+@pytest.mark.parametrize("cuda_is_available", [True, False])
+def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, get_dataset, tokenizer, get_model, train, cuda, cuda_is_available):
     kwargs = {"use_peft": True}
 
     get_dataset.return_value = get_fake_dataset()
+    cuda.return_value = cuda_is_available
 
     main(**kwargs)
 
-    if torch.cuda.is_available():
-        assert get_model.return_value.to.call_count == 1
-        assert get_model.return_value.to.call_args.args[0] == "cuda"
+    if cuda_is_available:
+        assert get_peft_model.return_value.to.call_count == 1
+        assert get_peft_model.return_value.to.call_args.args[0] == "cuda"
     else:
-        assert get_model.return_value.to.call_count == 0
-    
+        assert get_peft_model.return_value.to.call_count == 0
+
     assert get_peft_model.return_value.print_trainable_parameters.call_count == 1
 
 
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
-@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
+@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.get_preprocessed_dataset')
 @patch('llama_recipes.finetuning.get_peft_model')
 @patch('llama_recipes.finetuning.StepLR')
@@ -113,11 +120,11 @@ def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer
     kwargs = {"weight_decay": 0.01}
 
     get_dataset.return_value = get_fake_dataset()
-    
+
     model = mocker.MagicMock(name="Model")
     model.parameters.return_value = [torch.ones(1,1)]
 
-    get_model.return_value = model 
+    get_model.return_value = model
 
     main(**kwargs)
 
@@ -134,7 +141,7 @@ def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer
 
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
-@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
+@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.get_preprocessed_dataset')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')