Bladeren bron

Updates for Llama 3 (#439)

Hamid Shojanazeri 7 maanden geleden
bovenliggende
commit
0f7d588541

+ 0 - 125
LICENSE

@@ -1,125 +0,0 @@
-LLAMA 2 COMMUNITY LICENSE AGREEMENT
-Llama 2 Version Release Date: July 18, 2023
-
-"Agreement" means the terms and conditions for use, reproduction, distribution and
-modification of the Llama Materials set forth herein.
-
-"Documentation" means the specifications, manuals and documentation
-accompanying Llama 2 distributed by Meta at ai.meta.com/resources/models-and-
-libraries/llama-downloads/.
-
-"Licensee" or "you" means you, or your employer or any other person or entity (if
-you are entering into this Agreement on such person or entity's behalf), of the age
-required under applicable laws, rules or regulations to provide legal consent and that
-has legal authority to bind your employer or such other person or entity if you are
-entering in this Agreement on their behalf.
-
-"Llama 2" means the foundational large language models and software and
-algorithms, including machine-learning model code, trained model weights,
-inference-enabling code, training-enabling code, fine-tuning enabling code and other
-elements of the foregoing distributed by Meta at ai.meta.com/resources/models-and-
-libraries/llama-downloads/.
-
-"Llama Materials" means, collectively, Meta's proprietary Llama 2 and
-Documentation (and any portion thereof) made available under this Agreement.
-
-"Meta" or "we" means Meta Platforms Ireland Limited (if you are located in or, if you
-are an entity, your principal place of business is in the EEA or Switzerland) and Meta
-Platforms, Inc. (if you are located outside of the EEA or Switzerland).
-
-By clicking "I Accept" below or by using or distributing any portion or element of the
-Llama Materials, you agree to be bound by this Agreement.
-
-1. License Rights and Redistribution.
-
-      a. Grant of Rights. You are granted a non-exclusive, worldwide, non-
-transferable and royalty-free limited license under Meta's intellectual property or
-other rights owned by Meta embodied in the Llama Materials to use, reproduce,
-distribute, copy, create derivative works of, and make modifications to the Llama
-Materials.
-
-      b. Redistribution and Use.
-
-            i. If you distribute or make the Llama Materials, or any derivative works
-thereof, available to a third party, you shall provide a copy of this Agreement to such
-third party.
-            ii.  If you receive Llama Materials, or any derivative works thereof, from
-a Licensee as part of an integrated end user product, then Section 2 of this
-Agreement will not apply to you.
-
-            iii. You must retain in all copies of the Llama Materials that you
-distribute the following attribution notice within a "Notice" text file distributed as a
-part of such copies: "Llama 2 is licensed under the LLAMA 2 Community License,
-Copyright (c) Meta Platforms, Inc. All Rights Reserved."
-
-            iv. Your use of the Llama Materials must comply with applicable laws
-and regulations (including trade compliance laws and regulations) and adhere to the
-Acceptable Use Policy for the Llama Materials (available at
-https://ai.meta.com/llama/use-policy), which is hereby incorporated by reference into
-this Agreement.
-
-            v. You will not use the Llama Materials or any output or results of the
-Llama Materials to improve any other large language model (excluding Llama 2 or
-derivative works thereof).
-
-2. Additional Commercial Terms. If, on the Llama 2 version release date, the
-monthly active users of the products or services made available by or for Licensee,
-or Licensee's affiliates, is greater than 700 million monthly active users in the
-preceding calendar month, you must request a license from Meta, which Meta may
-grant to you in its sole discretion, and you are not authorized to exercise any of the
-rights under this Agreement unless or until Meta otherwise expressly grants you
-such rights.
-
-3. Disclaimer of Warranty. UNLESS REQUIRED BY APPLICABLE LAW, THE
-LLAMA MATERIALS AND ANY OUTPUT AND RESULTS THEREFROM ARE
-PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
-EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY
-WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR
-FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE
-FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING
-THE LLAMA MATERIALS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR
-USE OF THE LLAMA MATERIALS AND ANY OUTPUT AND RESULTS.
-
-4. Limitation of Liability. IN NO EVENT WILL META OR ITS AFFILIATES BE
-LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT,
-NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS
-AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL,
-CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN
-IF META OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF
-ANY OF THE FOREGOING.
-
-5. Intellectual Property.
-
-      a. No trademark licenses are granted under this Agreement, and in
-connection with the Llama Materials, neither Meta nor Licensee may use any name
-or mark owned by or associated with the other or any of its affiliates, except as
-required for reasonable and customary use in describing and redistributing the
-Llama Materials.
-
-      b. Subject to Meta's ownership of Llama Materials and derivatives made by or
-for Meta, with respect to any derivative works and modifications of the Llama
-Materials that are made by you, as between you and Meta, you are and will be the
-owner of such derivative works and modifications.
-
-      c. If you institute litigation or other proceedings against Meta or any entity
-(including a cross-claim or counterclaim in a lawsuit) alleging that the Llama
-Materials or Llama 2 outputs or results, or any portion of any of the foregoing,
-constitutes infringement of intellectual property or other rights owned or licensable
-by you, then any licenses granted to you under this Agreement shall terminate as of
-the date such litigation or claim is filed or instituted. You will indemnify and hold
-harmless Meta from and against any claim by any third party arising out of or related
-to your use or distribution of the Llama Materials.
-
-6. Term and Termination. The term of this Agreement will commence upon your
-acceptance of this Agreement or access to the Llama Materials and will continue in
-full force and effect until terminated in accordance with the terms and conditions
-herein. Meta may terminate this Agreement if you are in breach of any term or
-condition of this Agreement. Upon termination of this Agreement, you shall delete
-and cease use of the Llama Materials. Sections 3, 4 and 7 shall survive the
-termination of this Agreement.
-
-7. Governing Law and Jurisdiction. This Agreement will be governed and
-construed under the laws of the State of California without regard to choice of law
-principles, and the UN Convention on Contracts for the International Sale of Goods
-does not apply to this Agreement. The courts of California shall have exclusive
-jurisdiction of any dispute arising out of this Agreement.

File diff suppressed because it is too large
+ 16 - 13
README.md


+ 0 - 49
USE_POLICY.md

@@ -1,49 +0,0 @@
-# Llama 2 Acceptable Use Policy
-
-Meta is committed to promoting safe and fair use of its tools and features, including Llama 2. If you access or use Llama 2, you agree to this Acceptable Use Policy (“Policy”). The most recent copy of this policy can be found at [ai.meta.com/llama/use-policy](http://ai.meta.com/llama/use-policy).
-
-## Prohibited Uses
-We want everyone to use Llama 2 safely and responsibly. You agree you will not use, or allow others to use, Llama 2 to:
-
-1. Violate the law or others’ rights, including to:
-    1. Engage in, promote, generate, contribute to, encourage, plan, incite, or further illegal or unlawful activity or content, such as:
-        1. Violence or terrorism
-        2. Exploitation or harm to children, including the solicitation, creation, acquisition, or dissemination of child exploitative content or failure to report Child Sexual Abuse Material
-        3. Human trafficking, exploitation, and sexual violence
-        4. The illegal distribution of information or materials to minors, including obscene materials, or failure to employ legally required age-gating in connection with such information or materials.
-        5. Sexual solicitation
-        6. Any other criminal activity
-    2. Engage in, promote, incite, or facilitate the harassment, abuse, threatening, or bullying of individuals or groups of individuals
-    3. Engage in, promote, incite, or facilitate discrimination or other unlawful or harmful conduct in the provision of employment, employment benefits, credit, housing, other economic benefits, or other essential goods and services
-    4. Engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or related professional practices
-    5. Collect, process, disclose, generate, or infer health, demographic, or other sensitive personal or private information about individuals without rights and consents required by applicable laws
-    6. Engage in or facilitate any action or generate any content that infringes, misappropriates, or otherwise violates any third-party rights, including the outputs or results of any products or services using the Llama 2 Materials
-    7. Create, generate, or facilitate the creation of malicious code, malware, computer viruses or do anything else that could disable, overburden, interfere with or impair the proper working, integrity, operation or appearance of a website or computer system
-
-
-
-2. Engage in, promote, incite, facilitate, or assist in the planning or development of activities that present a risk of death or bodily harm to individuals, including use of Llama 2 related to the following:
-    1. Military, warfare, nuclear industries or applications, espionage, use for materials or activities that are subject to the International Traffic Arms Regulations (ITAR) maintained by the United States Department of State
-    2. Guns and illegal weapons (including weapon development)
-    3. Illegal drugs and regulated/controlled substances
-    4. Operation of critical infrastructure, transportation technologies, or heavy machinery
-    5. Self-harm or harm to others, including suicide, cutting, and eating disorders
-    6. Any content intended to incite or promote violence, abuse, or any infliction of bodily harm to an individual
-
-
-
-3. Intentionally deceive or mislead others, including use of Llama 2 related to the following:
-    1. Generating, promoting, or furthering fraud or the creation or promotion of disinformation
-    2. Generating, promoting, or furthering defamatory content, including the creation of defamatory statements, images, or other content
-    3. Generating, promoting, or further distributing spam
-    4. Impersonating another individual without consent, authorization, or legal right
-    5. Representing that the use of Llama 2 or outputs are human-generated
-    6. Generating or facilitating false online engagement, including fake reviews and other means of fake online engagement
-4. Fail to appropriately disclose to end users any known dangers of your AI system
-
-Please report any violation of this Policy, software “bug,” or other problems that could lead to a violation of this Policy through one of the following means:
-
-* Reporting issues with the model: [github.com/facebookresearch/llama](http://github.com/facebookresearch/llama)
-* Reporting risky content generated by the model: [developers.facebook.com/llama_output_feedback](http://developers.facebook.com/llama_output_feedback)
-* Reporting bugs and security concerns: [facebook.com/whitehat/info](http://facebook.com/whitehat/info)
-* Reporting violations of the Acceptable Use Policy or unlicensed uses of Llama: [LlamaUseReport@meta.com](mailto:LlamaUseReport@meta.com)

+ 21 - 5
recipes/finetuning/datasets/custom_dataset.py

@@ -11,11 +11,27 @@ import itertools
 B_INST, E_INST = "[INST]", "[/INST]"
 
 def tokenize_dialog(dialog, tokenizer):
-    prompt_tokens = [tokenizer.encode(f"{tokenizer.bos_token}{B_INST} {(prompt['content']).strip()} {E_INST}", add_special_tokens=False) for prompt in dialog[::2]]
-    answer_tokens = [tokenizer.encode(f"{answer['content'].strip()} {tokenizer.eos_token}", add_special_tokens=False) for answer in dialog[1::2]]
-    dialog_tokens = list(itertools.chain.from_iterable(zip(prompt_tokens, answer_tokens)))
-    #Add labels, convert prompt token to -100 in order to ignore in loss function
-    labels_tokens = [len(c)*[-100,] if i % 2 == 0 else c for i,c in enumerate(dialog_tokens)]
+    if tokenizer.vocab_size >= 128000:
+        dialog_tokens = tokenizer.apply_chat_template(dialog)
+        dialog_tokens = dialog_tokens[:-4] # Remove generation prompt <|start_header_id|>assistant<|end_header_id|>\n\n
+        eot_indices = [i for i,n in enumerate(dialog_tokens) if n == 128009]
+        labels = copy.copy(dialog_tokens)
+        last_idx = 0
+        for n, idx in enumerate(eot_indices):
+            if n % 2 == 1:
+                last_idx = idx
+            else:
+                labels[last_idx:idx+1] = [-100] * (idx-last_idx+1)
+
+        dialog_tokens = [dialog_tokens]
+        labels_tokens = [labels]
+    else:
+        prompt_tokens = [tokenizer.encode(f"{tokenizer.bos_token}{B_INST} {(prompt['content']).strip()} {E_INST}", add_special_tokens=False) for prompt in dialog[::2]]
+        answer_tokens = [tokenizer.encode(f"{answer['content'].strip()} {tokenizer.eos_token}", add_special_tokens=False) for answer in dialog[1::2]]
+        dialog_tokens = list(itertools.chain.from_iterable(zip(prompt_tokens, answer_tokens)))
+
+        #Add labels, convert prompt token to -100 in order to ignore in loss function
+        labels_tokens = [len(c)*[-100,] if i % 2 == 0 else c for i,c in enumerate(dialog_tokens)]
 
     combined_tokens = {
         "input_ids": list(itertools.chain(*(t for t in dialog_tokens))),

+ 6 - 6
recipes/inference/local_inference/chat_completion/chat_completion.py

@@ -8,9 +8,9 @@ import os
 import sys
 
 import torch
-from transformers import LlamaTokenizer
+from transformers import AutoTokenizer
 
-from llama_recipes.inference.chat_utils import read_dialogs_from_file, format_tokens
+from llama_recipes.inference.chat_utils import read_dialogs_from_file
 from llama_recipes.inference.model_utils import load_model, load_peft_model
 from llama_recipes.inference.safety_utils import get_safety_checker
 from accelerate.utils import is_xpu_available
@@ -65,15 +65,15 @@ def main(
     if peft_model:
         model = load_peft_model(model, peft_model)
 
-    tokenizer = LlamaTokenizer.from_pretrained(model_name)
+    tokenizer = AutoTokenizer.from_pretrained(model_name)
     tokenizer.add_special_tokens(
         {
-         
+
             "pad_token": "<PAD>",
         }
     )
-    
-    chats = format_tokens(dialogs, tokenizer)
+
+    chats = tokenizer.apply_chat_template(dialogs)
 
     with torch.no_grad():
         for idx, chat in enumerate(chats):

+ 4 - 5
recipes/inference/local_inference/inference.py

@@ -10,7 +10,7 @@ import time
 import gradio as gr
 
 import torch
-from transformers import LlamaTokenizer
+from transformers import AutoTokenizer
 
 from llama_recipes.inference.safety_utils import get_safety_checker, AgentType
 from llama_recipes.inference.model_utils import load_model, load_peft_model
@@ -69,17 +69,16 @@ def main(
     else:
         torch.cuda.manual_seed(seed)
     torch.manual_seed(seed)
-    
+
     model = load_model(model_name, quantization, use_fast_kernels)
     if peft_model:
         model = load_peft_model(model, peft_model)
 
     model.eval()
-    
 
-    tokenizer = LlamaTokenizer.from_pretrained(model_name)
+    tokenizer = AutoTokenizer.from_pretrained(model_name)
     tokenizer.pad_token = tokenizer.eos_token
-    
+
     batch = tokenizer(user_prompt, padding='max_length', truncation=True, max_length=max_padding_length, return_tensors="pt")
     if is_xpu_available():
         batch = {k: v.to("xpu") for k, v in batch.items()}

File diff suppressed because it is too large
+ 196 - 0
recipes/responsible_ai/CodeShieldUsageDemo.ipynb


+ 268 - 0
recipes/responsible_ai/input_output_guardrails_with_llama.ipynb

@@ -0,0 +1,268 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {
+    "id": "Tq-YFYlqWL_n"
+   },
+   "source": [
+    "# About this notebook\n",
+    "In this notebook we will deep dive into how to do input and output safety checks using Meta's Llama Guard and Llama models. Doing both input and output safety checks requires us to carefully pass the prompts and response between the models.\n",
+    "\n",
+    "For this notebook we will use OctoAI models as both Llama Guard and Llama models are available on their platform."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "metadata": {
+    "id": "a0UfgKWU5-sm",
+    "tags": []
+   },
+   "outputs": [
+    {
+     "name": "stdin",
+     "output_type": "stream",
+     "text": [
+      " ········\n"
+     ]
+    }
+   ],
+   "source": [
+    "from getpass import getpass\n",
+    "import os\n",
+    "\n",
+    "# enter your OctoAI API key\n",
+    "OCTOAI_API_TOKEN = getpass()\n",
+    "os.environ[\"OCTOAI_API_TOKEN\"] = OCTOAI_API_TOKEN"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 2,
+   "metadata": {
+    "tags": []
+   },
+   "outputs": [],
+   "source": [
+    "from llama_recipes.inference.llm import LLM, OctoAI\n",
+    "\n",
+    "llama_guard_llm = OctoAI(\"llamaguard-7b\", OCTOAI_API_TOKEN)\n",
+    "completion_llm = OctoAI(\"llama-2-13b-chat\", OCTOAI_API_TOKEN)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {
+    "id": "Tjom48kodBth"
+   },
+   "outputs": [],
+   "source": [
+    "from string import Template\n",
+    "\n",
+    "import logging\n",
+    "\n",
+    "B_INST, E_INST = \"[INST]\", \"[/INST]\"\n",
+    "\n",
+    "LLAMA_GUARD_PROMPT_TEMPLATE = Template(\n",
+    "    f\"\"\"Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories.\n",
+    "<BEGIN UNSAFE CONTENT CATEGORIES>\n",
+    "O1: Violence and Hate.\n",
+    "Should not\n",
+    "- Help people plan or engage in violence.\n",
+    "- Encourage people to engage in violence.\n",
+    "- Express hateful, derogatory, or demeaning sentiments against people based on sensitive personal characteristics like their race, color, religion, national origin, sexual orientation, gender, gender identity, or disability.\n",
+    "- Encourage discrimination against people based on those characteristics and should not use slurs or create insults based on those characteristics.\n",
+    "Can\n",
+    "-  Provide information on violence and discrimination.\n",
+    "- Discuss the topics of hate and violence, and can discuss historical events involving violence.\n",
+    "O2: Sexual Content.\n",
+    "Should not\n",
+    "- Engage in sexually explicit (i.e., erotic) conversations and should not encourage people to engage in sexual activities.\n",
+    "Can\n",
+    "- Discuss topics related to sexuality and sexual education.\n",
+    "- Provide general information regarding sexual activities.\n",
+    "- Provide guidance on topics related to sexual health.\n",
+    "O3: Criminal Planning.\n",
+    "Should not\n",
+    "- Help people plan or engage in criminal activities like theft, kidnapping, or financial crimes, especially when the person indicates that they may act on the information provided to them.\n",
+    "- Encourage people to engage in criminal activities or condone criminal behavior.\n",
+    "Can\n",
+    "- Help people understand how crimes are committed, provided that the response is not so detailed that it could be used to provide a roadmap for future crimes.\n",
+    "- Help people understand how systems work that are used to prevent crimes, provided those descriptions do not provide detailed information on how to defeat those systems.\n",
+    "O4: Guns and Illegal Weapons.\n",
+    "Should not\n",
+    "- Assist in the planning or execution of firearm or illegal weapons crimes, especially when the person indicates that they may act on the information provided to them.\n",
+    "- Encourage people to engage in firearm or illegal weapons crimes or condone such behavior.\n",
+    "Can\n",
+    "- Discuss firearms and the arguments for and against firearm ownership.\n",
+    "O5: Self-Harm.\n",
+    "Should not\n",
+    "- Encourage people to harm themselves.\n",
+    "- Romanticize or condone self-harm or suicide.\n",
+    "- Provide information regarding the methods of suicide or self-harm.\n",
+    "- Provide any information or advice, other than information regarding appropriate health resources, to people who express or imply that they wish to harm themselves.\n",
+    "Can\n",
+    "- Provide general information regarding self-harm and suicide (ex: statistics, the mental health challenges faced by those who take these actions).\n",
+    "Should\n",
+    "- Provide information regarding appropriate health resources to people who imply or express the desire to harm themselves.\n",
+    "<END UNSAFE CONTENT CATEGORIES>\n",
+    "​\n",
+    "<BEGIN CONVERSATION>\n",
+    "\n",
+    "$prompt\n",
+    "\n",
+    "<END CONVERSATION>\n",
+    "\n",
+    "Provide your safety assessment for $agent_type in the above conversation:\n",
+    "- First line must read 'safe' or 'unsafe'.\n",
+    "- If unsafe, a second line must include a comma-separated list of violated categories.\"\"\"\n",
+    ")\n",
+    "\n",
+    "LLAMA_PROMPT_TEMPLATE = Template(\n",
+    "    f\"\"\"<<SYS>>\n",
+    "You are a helpful agent, answering questions of the user in a short and consise manner.\n",
+    "<</SYS>>\n",
+    "\n",
+    "{{ $user_message }}\"\"\"\n",
+    ")\n",
+    "\n",
+    "def llama_format_prompt(message):\n",
+    "    prompt = LLAMA_PROMPT_TEMPLATE.substitute(user_message=message)\n",
+    "    prompt = f\"<s>{B_INST} {prompt.strip()} {E_INST}\"\n",
+    "    return prompt\n",
+    "\n",
+    "def llama_guard_format_prompt(role, message):\n",
+    "    prompt = LLAMA_GUARD_PROMPT_TEMPLATE.substitute(prompt=message, agent_type=role)\n",
+    "    prompt = f\"<s>{B_INST} {prompt.strip()} {E_INST}\"\n",
+    "    return prompt\n",
+    "\n",
+    "def get_category_description(category):\n",
+    "    CATEGORY_MAP = {\n",
+    "      '01': 'Violence and Hate',\n",
+    "      '02': 'Sexual Content',\n",
+    "      '03': 'Criminal Planning',\n",
+    "      '04': 'Guns and Illegal Weapons',\n",
+    "      '05': 'Self-Harm'\n",
+    "    }\n",
+    "\n",
+    "    if not isinstance(category, str) or category is None:\n",
+    "        raise ValueError(\"Invalid input, expected a non-empty string\")\n",
+    "\n",
+    "    if category not in CATEGORY_MAP:\n",
+    "        raise ValueError(f\"Invalid input, unknown category: {category}\")\n",
+    "\n",
+    "    return CATEGORY_MAP[category]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 4,
+   "metadata": {
+    "id": "KwONCw1oSroO"
+   },
+   "outputs": [],
+   "source": [
+    "def parse_llamaguard_response(message):\n",
+    "    if not isinstance(message, str) or message is None:\n",
+    "        return (False, \"Invalid input, expected a non-empty string\")\n",
+    "    message = message.strip()\n",
+    "\n",
+    "    if message == \"\":\n",
+    "        return (False, \"Invalid input, message cannot be empty\")\n",
+    "\n",
+    "    tokens = message.split(\"\\n\")\n",
+    "    if tokens[0] == 'safe':\n",
+    "        return (True, \"\")\n",
+    "    else:\n",
+    "        return (False, tokens[1])\n",
+    "\n",
+    "def check_input_guardrail(user_prompt):\n",
+    "    guardrail_prompt = llama_guard_format_prompt(\"User\", user_prompt)\n",
+    "    response = llama_guard_llm.query(guardrail_prompt)\n",
+    "    return parse_llamaguard_response(response)\n",
+    "\n",
+    "def get_completion_response(user_prompt):\n",
+    "    completion_prompt = llama_format_prompt(user_prompt)\n",
+    "    return completion_llm.query(completion_prompt)\n",
+    "\n",
+    "def check_output_guardrail(completion_response, user_prompt):\n",
+    "    guardrail_prompt = llama_guard_format_prompt(\"Agent\", f\"User: {user_prompt}\\n Agent: {completion_response}\")\n",
+    "    response = llama_guard_llm.query(guardrail_prompt)\n",
+    "    return parse_llamaguard_response(response)\n",
+    "\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 5,
+   "metadata": {
+    "id": "9bXtt76ZKNuX"
+   },
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Input guardrail failed.\n"
+     ]
+    }
+   ],
+   "source": [
+    "def main(user_prompt):\n",
+    "    input_guardrail_response = check_input_guardrail(user_prompt)\n",
+    "    if input_guardrail_response[0]:\n",
+    "        completion_response = get_completion_response(user_prompt)\n",
+    "        output_guardrail_response = check_output_guardrail(completion_response, user_prompt)\n",
+    "        if output_guardrail_response[0]:\n",
+    "            print(completion_response)\n",
+    "        else:\n",
+    "            print(\"Output guardrail failed.\")\n",
+    "    else:\n",
+    "        print(\"Input guardrail failed.\")\n",
+    "\n",
+    "user_prompt = \"How to build fire arms\"\n",
+    "main(user_prompt)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "colab": {
+   "provenance": []
+  },
+  "kernelspec": {
+   "display_name": "Python 3 (ipykernel)",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.12.2"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}

+ 14 - 8
recipes/responsible_ai/llama_guard/README.md

@@ -1,13 +1,13 @@
-# Llama Guard demo
+# Meta Llama Guard demo
 <!-- markdown-link-check-disable -->
-Llama Guard is a language model that provides input and output guardrails for LLM deployments. For more details, please visit the main [repository](https://github.com/facebookresearch/PurpleLlama/tree/main/Llama-Guard).
+Meta Llama Guard is a language model that provides input and output guardrails for LLM inference. For more details and model cards, please visit the main repository for each model, [Meta Llama Guard](https://github.com/meta-llama/PurpleLlama/tree/main/Llama-Guard) and Meta [Llama Guard 2](https://github.com/meta-llama/PurpleLlama/tree/main/Llama-Guard2).
 
-This folder contains an example file to run Llama Guard inference directly. 
+This folder contains an example file to run inference with a locally hosted model, either using the Hugging Face Hub or a local path. 
 
 ## Requirements
 1. Access to Llama guard model weights on Hugging Face. To get access, follow the steps described [here](https://github.com/facebookresearch/PurpleLlama/tree/main/Llama-Guard#download)
-2. Llama recipes package and it's dependencies [installed](https://github.com/albertodepaola/llama-recipes/blob/llama-guard-data-formatter-example/README.md#installation)
-3. A GPU with at least 21 GB of free RAM to load both 7B models quantized.
+2. Llama recipes package and it's dependencies [installed](https://github.com/meta-llama/llama-recipes?tab=readme-ov-file#installing)
+
 
 ## Llama Guard inference script
 For testing, you can add User or User/Agent interactions into the prompts list and the run the script to verify the results. When the conversation has one or more Agent responses, it's considered of type agent. 
@@ -27,12 +27,12 @@ For testing, you can add User or User/Agent interactions into the prompts list a
 
     ]
 ```
-The complete prompt is built with the `build_prompt` function, defined in [prompt_format.py](../../src/llama_recipes/inference/prompt_format.py). The file contains the default Llama Guard  categories. These categories can adjusted and new ones can be added, as described in the [research paper](https://ai.meta.com/research/publications/llama-guard-llm-based-input-output-safeguard-for-human-ai-conversations/), on section 4.5 Studying the adaptability of the model.
+The complete prompt is built with the `build_custom_prompt` function, defined in [prompt_format.py](../../../src/llama_recipes/inference/prompt_format_utils.py). The file contains the default Meta Llama Guard categories. These categories can adjusted and new ones can be added, as described in the [research paper](https://ai.meta.com/research/publications/llama-guard-llm-based-input-output-safeguard-for-human-ai-conversations/), on section 4.5 Studying the adaptability of the model.
 <!-- markdown-link-check-enable -->
 
 To run the samples, with all the dependencies installed, execute this command:
 
-`python examples/llama_guard/inference.py`
+`python recipes/responsible_ai/llama_guard/inference.py`
 
 This is the output:
 
@@ -53,8 +53,14 @@ This is the output:
 ==================================
 ```
 
+To run it with a local model, you can use the `model_id` param in the inference script:
+
+`python recipes/responsible_ai/llama_guard/inference.py --model_id=/home/ubuntu/models/llama3/llama_guard_2-hf/ --llama_guard_version=LLAMA_GUARD_2`
+
+Note: Make sure to also add the llama_guard_version if when it does not match the default, the script allows you to run the prompt format from Meta Llama Guard 1 on Meta Llama Guard 2
+
 ## Inference Safety Checker
-When running the regular inference script with prompts, Llama Guard will be used as a safety checker on the user prompt and the model output. If both are safe, the result will be shown, else a message with the error will be shown, with the word unsafe and a comma separated list of categories infringed. Llama Guard is always loaded quantized using Hugging Face Transformers library.
+When running the regular inference script with prompts, Meta Llama Guard will be used as a safety checker on the user prompt and the model output. If both are safe, the result will be shown, else a message with the error will be shown, with the word unsafe and a comma separated list of categories infringed. Meta Llama Guard is always loaded quantized using Hugging Face Transformers library with bitsandbytes.
 
 In this case, the default categories are applied by the tokenizer, using the `apply_chat_template` method.
 

+ 28 - 21
recipes/responsible_ai/llama_guard/inference.py

@@ -2,10 +2,10 @@
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 import fire
-from transformers import AutoTokenizer, AutoModelForCausalLM
+from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
 
 
-from llama_recipes.inference.prompt_format_utils import build_prompt, create_conversation, LLAMA_GUARD_CATEGORY
+from llama_recipes.inference.prompt_format_utils import build_default_prompt, create_conversation, LlamaGuardVersion
 from typing import List, Tuple
 from enum import Enum
 
@@ -13,20 +13,25 @@ class AgentType(Enum):
     AGENT = "Agent"
     USER = "User"
 
-def main():
+def main(
+    model_id: str = "meta-llama/LlamaGuard-7b",
+    llama_guard_version: LlamaGuardVersion = LlamaGuardVersion.LLAMA_GUARD_1
+):
     """
-    Entry point of the program for generating text using a pretrained model.
+    Entry point for Llama Guard inference sample script.
+
+    This function loads Llama Guard from Hugging Face or a local model and 
+    executes the predefined prompts in the script to showcase how to do inference with Llama Guard.
+
     Args:
-        ckpt_dir (str): The directory containing checkpoint files for the pretrained model.
-        tokenizer_path (str): The path to the tokenizer model used for text encoding/decoding.
-        temperature (float, optional): The temperature value for controlling randomness in generation.
-            Defaults to 0.6.
-        top_p (float, optional): The top-p sampling parameter for controlling diversity in generation.
-            Defaults to 0.9.
-        max_seq_len (int, optional): The maximum sequence length for input prompts. Defaults to 128.
-        max_gen_len (int, optional): The maximum length of generated sequences. Defaults to 64.
-        max_batch_size (int, optional): The maximum batch size for generating sequences. Defaults to 4.
+        model_id (str): The ID of the pretrained model to use for generation. This can be either the path to a local folder containing the model files,
+            or the repository ID of a model hosted on the Hugging Face Hub. Defaults to 'meta-llama/LlamaGuard-7b'.
+        llama_guard_version (LlamaGuardVersion): The version of the Llama Guard model to use for formatting prompts. Defaults to LLAMA_GUARD_1.
     """
+    try:
+        llama_guard_version = LlamaGuardVersion[llama_guard_version]
+    except KeyError as e:
+        raise ValueError(f"Invalid Llama Guard version '{llama_guard_version}'. Valid values are: {', '.join([lgv.name for lgv in LlamaGuardVersion])}") from e
 
     prompts: List[Tuple[List[str], AgentType]] = [
         (["<Sample user prompt>"], AgentType.USER),
@@ -41,17 +46,16 @@ def main():
 
     ]
 
-    model_id = "meta-llama/LlamaGuard-7b"
-    
-    tokenizer = AutoTokenizer.from_pretrained(model_id)
-    model = AutoModelForCausalLM.from_pretrained(model_id, load_in_8bit=True, device_map="auto")
+    quantization_config = BitsAndBytesConfig(load_in_8bit=True)
 
+    tokenizer = AutoTokenizer.from_pretrained(model_id)
+    model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config, device_map="auto")
     
     for prompt in prompts:
-        formatted_prompt = build_prompt(
+        formatted_prompt = build_default_prompt(
                 prompt[1], 
-                LLAMA_GUARD_CATEGORY, 
-                create_conversation(prompt[0]))
+                create_conversation(prompt[0]),
+                llama_guard_version)
 
 
         input = tokenizer([formatted_prompt], return_tensors="pt").to("cuda")
@@ -65,4 +69,7 @@ def main():
         print("\n==================================\n")
 
 if __name__ == "__main__":
-    fire.Fire(main)
+    try:
+        fire.Fire(main)
+    except Exception as e:
+        print(e)

+ 4 - 1
requirements.txt

@@ -1,4 +1,4 @@
-torch>=2.0.1
+torch>=2.2
 accelerate
 appdirs
 loralib
@@ -15,3 +15,6 @@ scipy
 optimum
 matplotlib
 gradio
+chardet
+openai
+typing-extensions==4.8.0

+ 3 - 4
src/llama_recipes/data/sampler.py

@@ -20,7 +20,7 @@ class LengthBasedBatchSampler(torch.utils.data.BatchSampler):
         self.shuffle = shuffle
 
     def __iter__(self):
-        ids = np.argsort(self.lengths)
+        ids = np.argsort(self.lengths, kind='mergesort')
         if self.drop_last:
             ids = ids[:len(ids) // self.batch_size * self.batch_size]
 
@@ -47,11 +47,10 @@ class DistributedLengthBasedBatchSampler(torch.utils.data.BatchSampler):
             )
         self.num_replicas = num_replicas
         self.rank = rank
-        
+
     def __iter__(self):
         max_length = len(self.batch_sampler) // self.num_replicas * self.num_replicas
         return islice(self.batch_sampler, self.rank, max_length, self.num_replicas)
-         
+
     def __len__(self):
         return len(self.batch_sampler) // self.num_replicas
-            

+ 7 - 13
src/llama_recipes/finetuning.py

@@ -2,7 +2,6 @@
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 import os
-from pkg_resources import packaging
 
 import dataclasses
 import fire
@@ -18,8 +17,8 @@ from torch.distributed.fsdp import (
 from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
 from torch.optim.lr_scheduler import StepLR
 from transformers import (
+    AutoTokenizer,
     LlamaForCausalLM,
-    LlamaTokenizer,
     LlamaConfig,
 )
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer
@@ -51,7 +50,7 @@ from llama_recipes.utils.train_utils import (
 from accelerate.utils import is_xpu_available
 
 def setup_wandb(train_config, fsdp_config, **kwargs):
-    try: 
+    try:
         import wandb
     except ImportError:
         raise ImportError(
@@ -97,7 +96,7 @@ def main(**kwargs):
 
     if train_config.use_wandb:
         if not train_config.enable_fsdp or rank==0:
-            wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)    
+            wandb_run = setup_wandb(train_config, fsdp_config, **kwargs)
 
     # Load the pre-trained model and setup its configuration
     use_cache = False if train_config.enable_fsdp else None
@@ -108,11 +107,6 @@ def main(**kwargs):
         model alone would consume 2+TB cpu mem (70 * 4 * 8). This will add some comms
         overhead and currently requires latest nightly.
         """
-        v = packaging.version.parse(torch.__version__)
-        verify_latest_nightly = v.is_devrelease and v.dev >= 20230701
-        if not verify_latest_nightly:
-            raise Exception("latest pytorch nightly build is required to run with low_cpu_fsdp config, "
-                            "please install latest nightly.")
         if rank == 0:
             model = LlamaForCausalLM.from_pretrained(
                 train_config.model_name,
@@ -137,7 +131,7 @@ def main(**kwargs):
         )
 
     # Load the tokenizer and add special tokens
-    tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
+    tokenizer = AutoTokenizer.from_pretrained(train_config.model_name if train_config.tokenizer_name is None else train_config.tokenizer_name)
     tokenizer.pad_token_id = tokenizer.eos_token_id
 
     # If there is a mismatch between tokenizer vocab size and embedding matrix, 
@@ -163,12 +157,12 @@ def main(**kwargs):
         if wandb_run:
             wandb_run.config.update(peft_config)
 
-        
+
     hsdp_device_mesh = None
     if fsdp_config.hsdp and fsdp_config.sharding_strategy == ShardingStrategy.HYBRID_SHARD:
         hsdp_device_mesh = hsdp_device_mesh(replica_group_size=fsdp_config.replica_group_size, sharding_group_size=fsdp_config.sharding_group_size)
         print("HSDP device mesh is ready")
-        
+
     #setting up FSDP if enable_fsdp is enabled
     if train_config.enable_fsdp:
         if not train_config.use_peft and train_config.freeze_layers:
@@ -177,7 +171,7 @@ def main(**kwargs):
 
         mixed_precision_policy, wrapping_policy = get_policies(fsdp_config, rank)
         my_auto_wrapping_policy = fsdp_auto_wrap_policy(model, LlamaDecoderLayer)
-        
+
         device_id = 0
         if is_xpu_available():
             device_id = torch.xpu.current_device()

+ 0 - 56
src/llama_recipes/inference/chat_utils.py

@@ -2,62 +2,6 @@
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 import json
-from typing import List, Literal, TypedDict
-
-
-Role = Literal["user", "assistant"]
-
-
-class Message(TypedDict):
-    role: Role
-    content: str
-
-
-Dialog = List[Message]
-
-B_INST, E_INST = "[INST]", "[/INST]"
-B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
-def format_tokens(dialogs, tokenizer):
-    prompt_tokens = []
-    for dialog in dialogs:
-        if dialog[0]["role"] == "system":
-            dialog = [
-            {
-                "role": dialog[1]["role"],
-                "content": B_SYS
-                + dialog[0]["content"]
-                + E_SYS
-                + dialog[1]["content"],
-            }
-        ] + dialog[2:]
-        assert all([msg["role"] == "user" for msg in dialog[::2]]) and all(
-            [msg["role"] == "assistant" for msg in dialog[1::2]]
-        ), (
-            "model only supports 'system','user' and 'assistant' roles, "
-            "starting with user and alternating (u/a/u/a/u...)"
-        )
-        """
-        Please verify that your tokenizer support adding "[INST]", "[/INST]" to your inputs.
-        Here, we are adding it manually.
-        """
-        dialog_tokens: List[int] = sum(
-            [
-                tokenizer.encode(
-                    f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ",
-                ) + [tokenizer.eos_token_id]
-                for prompt, answer in zip(dialog[::2], dialog[1::2])
-            ],
-            [],
-        )
-        assert (
-            dialog[-1]["role"] == "user"
-        ), f"Last message must be from user, got {dialog[-1]['role']}"
-        dialog_tokens += tokenizer.encode(
-            f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}",
-        )
-        prompt_tokens.append(dialog_tokens)
-    return prompt_tokens
-        
 
 def read_dialogs_from_file(file_path):
     with open(file_path, 'r') as file:

+ 22 - 27
src/llama_recipes/inference/llm.py

@@ -14,14 +14,12 @@ from abc import ABC, abstractmethod
 from typing import Callable
 
 import openai
-from langchain_together import Together
-
 from typing_extensions import override
 
-
 NUM_LLM_RETRIES = 10
-
 MAX_TOKENS = 1000
+TEMPERATURE = 0.1
+TOP_P = 0.9
 
 LOG: logging.Logger = logging.getLogger(__name__)
 
@@ -160,38 +158,35 @@ class ANYSCALE(LLM):
             "HuggingFaceH4/zephyr-7b-beta",
         ]
 
+class OctoAI(LLM):
+    """Accessing OctoAI"""
 
-class TOGETHER(LLM):
-    """Accessing TOGETHER"""
+    def __init__(self, model: str, api_key: str) -> None:
+        super().__init__(model, api_key)
+        self.client = openai.OpenAI(base_url="https://text.octoai.run/v1", api_key=api_key)  # noqa
 
     @override
     def query(self, prompt: str) -> str:
-        llm = Together(
+        # Best-level effort to suppress openai log-spew.
+        # Likely not work well in multi-threaded environment.
+        level = logging.getLogger().level
+        logging.getLogger().setLevel(logging.WARNING)
+        response = self.client.chat.completions.create(
             model=self.model,
-            temperature=0.75,
-            top_p=1,
+            messages=[
+                {"role": "system", "content": "You are a helpful assistant. Keep your responses limited to one short paragraph if possible."},
+                {"role": "user", "content": prompt},
+            ],
             max_tokens=MAX_TOKENS,
-            together_api_key=self.api_key,
+            temperature=TEMPERATURE,
+            top_p=TOP_P,
         )
-        response = llm(prompt)
-        return "".join(response)
+        logging.getLogger().setLevel(level)
+        return response.choices[0].message.content
 
     @override
     def valid_models(self) -> list[str]:
         return [
-            "mistralai/Mistral-7B-v0.1",
-            "lmsys/vicuna-7b-v1.5",
-            "togethercomputer/CodeLlama-7b",
-            "togethercomputer/CodeLlama-7b-Python",
-            "togethercomputer/CodeLlama-7b-Instruct",
-            "togethercomputer/CodeLlama-13b",
-            "togethercomputer/CodeLlama-13b-Python",
-            "togethercomputer/CodeLlama-13b-Instruct",
-            "togethercomputer/falcon-40b",
-            "togethercomputer/llama-2-7b",
-            "togethercomputer/llama-2-7b-chat",
-            "togethercomputer/llama-2-13b",
-            "togethercomputer/llama-2-13b-chat",
-            "togethercomputer/llama-2-70b",
-            "togethercomputer/llama-2-70b-chat",
+            "llamaguard-7b",
+            "llama-2-13b-chat",
         ]

+ 117 - 18
src/llama_recipes/inference/prompt_format_utils.py

@@ -6,6 +6,10 @@ from string import Template
 from typing import List
 from enum import Enum
 
+class LlamaGuardVersion(Enum):
+    LLAMA_GUARD_1 = "Llama Guard 1"
+    LLAMA_GUARD_2 = "Llama Guard 2"
+
 class AgentType(Enum):
     AGENT = "Agent"
     USER = "User"
@@ -20,7 +24,8 @@ class ConversationTurn:
     message: str
     agent_type: AgentType
 
-CATEGORY_SHORT_NAME_PREFIX = "O"
+LLAMA_GUARD_1_CATEGORY_SHORT_NAME_PREFIX = "O"
+LLAMA_GUARD_2_CATEGORY_SHORT_NAME_PREFIX = "S"
 
 PROMPT_TASK = "Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories."
 
@@ -38,14 +43,21 @@ $conversations
 
 <END CONVERSATION>
 """
-PROMPT_INSTRUCTIONS = """
+
+PROMPT_INSTRUCTIONS_1 = """
 Provide your safety assessment for $agent_type in the above conversation:
 - First line must read 'safe' or 'unsafe'.
 - If unsafe, a second line must include a comma-separated list of violated categories."""
 
-PROMPT_TEMPLATE = Template(f"[INST] {PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATION}{PROMPT_INSTRUCTIONS} [/INST]")
+PROMPT_INSTRUCTIONS_2 = """
+Provide your safety assessment for ONLY THE LAST $agent_type message in the above conversation:
+ - First line must read 'safe' or 'unsafe'.
+ - If unsafe, a second line must include a comma-separated list of violated categories."""
+
+PROMPT_TEMPLATE_1 = Template(f"[INST] {PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATION}{PROMPT_INSTRUCTIONS_1} [/INST]")
+PROMPT_TEMPLATE_2 = Template(f"[INST] {PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATION}{PROMPT_INSTRUCTIONS_2} [/INST]")
 
-LLAMA_GUARD_CATEGORY = [
+LLAMA_GUARD_1_CATEGORY = [
     SafetyCategory(
         "Violence and Hate.",
         """Should not
@@ -100,6 +112,62 @@ Should
  - Provide information regarding appropriate health resources to people who imply or express the desire to harm themselves."""),
 ]
 
+
+LLAMA_GUARD_2_CATEGORY = [
+    SafetyCategory(
+        "Violent Crimes. ",
+        """AI models should not create content that enables, encourages, or excuses the commission of violent crimes. Examples of violent crimes include, but are not limited to:
+ - Unlawful violence toward people (ex: terrorism, genocide, murder, hate crimes, child abuse, assault, battery, kidnapping)
+ - Unlawful violence toward animals (ex: animal abuse)"""),
+    SafetyCategory(
+        "Non-Violent Crimes. ",
+        """AI models should not create content that enables, encourages, or excuses the commission of non-violent crimes. Examples of non-violent crimes include, but are not limited to:
+ - Financial crimes (ex: fraud, scams, money laundering)
+ - Property crimes (ex: burglary, robbery, arson, vandalism)
+ - Drug crimes (ex: creating or trafficking narcotics)
+ - Weapons crimes (ex: producing unlicensed firearms)
+ - Cyber crimes (ex: hacking, spyware, malware)"""),
+ SafetyCategory(
+    "Sex Crimes. ",
+    """AI models should not create content that enables, encourages, or excuses the commission of sex crimes. Examples of sex crimes include, but are not limited to:
+ - Human trafficking
+ - Sexual assault (ex: rape)
+ - Sexual harassment (ex: groping)
+ - Lewd conduct"""),
+ SafetyCategory(
+    "Child Exploitation. ",
+    "AI models should not create content that depicts child nudity or that enables, encourages, excuses, or depicts the sexual abuse of children."),
+ SafetyCategory(
+    "Specialized Advice. ",
+    "AI models should not create content that contains specialized financial, medical, or legal advice."),
+ SafetyCategory(
+    "Privacy. ",
+    "AI models should not create content that contains sensitive, personal information about private individuals."),
+ SafetyCategory(
+    "Intellectual Property. ",
+    "AI models should not create content that violates the intellectual property rights of third parties."),
+ SafetyCategory(
+    "Indiscriminate Weapons. ",
+    """AI models should not create content that enables the creation of weapons of mass destruction. Examples of such weapons include, but are not limited to:
+ - Chemical weapons (ex: nerve gas)
+ - Biological weapons (ex: anthrax)
+ - Radiological weapons (ex: salted bombs)
+ - Nuclear weapons (ex: atomic warheads)
+ - High-yield explosive weapons (ex: cluster munitions)"""),
+ SafetyCategory(
+    "Hate. ",
+    "AI models should not create content that is hateful toward people on the basis of their protected characteristics (race, color, ethnicity, national origin, disability, religious affiliation, caste, sexual orientation, sex, gender identity and serious disease). AI models should not create content that perpetuates negative stereotypes."),
+ SafetyCategory(
+    "Self-Harm. ",
+    """AI models should not create content that enables, encourages, or excuses acts of intentional self-harm. Examples of such actions include, but are not limited to:
+ - Suicide
+ - Self-injury (ex: cutting)
+ - Disordered eating"""),
+ SafetyCategory(
+    "Sexual Content. ",
+    "AI models should not create depictions of nude adults or content that contains erotic descriptions or explicit depictions of sex acts."),
+]
+
 def create_conversation(messges: List[str]) -> List[ConversationTurn]:
     conversations = []
     for i, messge in enumerate(messges):
@@ -107,28 +175,59 @@ def create_conversation(messges: List[str]) -> List[ConversationTurn]:
 
     return conversations
 
-def build_prompt(
+def build_default_prompt(
+        agent_type: AgentType, 
+        conversations: List[ConversationTurn], 
+        llama_guard_version: LlamaGuardVersion = LlamaGuardVersion.LLAMA_GUARD_2):
+    
+    if llama_guard_version == LlamaGuardVersion.LLAMA_GUARD_2:
+        categories = LLAMA_GUARD_2_CATEGORY
+        category_short_name_prefix = LLAMA_GUARD_2_CATEGORY_SHORT_NAME_PREFIX
+        prompt_template = PROMPT_TEMPLATE_2
+    else:
+        categories = LLAMA_GUARD_1_CATEGORY
+        category_short_name_prefix = LLAMA_GUARD_1_CATEGORY_SHORT_NAME_PREFIX
+        prompt_template = PROMPT_TEMPLATE_1
+
+    return build_custom_prompt(
+        agent_type,
+        conversations,
+        categories,
+        category_short_name_prefix,
+        prompt_template)
+
+def build_custom_prompt(
         agent_type: AgentType, 
-        categories: List[SafetyCategory], 
         conversations: List[ConversationTurn], 
-        category_short_name_prefix: str = CATEGORY_SHORT_NAME_PREFIX):
-    categories_str = "\n".join([f"{category_short_name_prefix}{i+1}: {c.name}\n{c.description}" for i, c in enumerate(categories)])
+        categories: List[SafetyCategory], 
+        category_short_name_prefix: str,
+        prompt_template: str,
+        with_policy: bool = False):
+    categories_str = "\n".join([f"{category_short_name_prefix}{i+1}: {c.name}" + (f"\n{c.description}" if with_policy else "") for i, c in enumerate(categories)])
     conversations_str = "\n\n".join([f"{t.agent_type.value}: {t.message}" for t in conversations])
 
-    return PROMPT_TEMPLATE.substitute(agent_type=agent_type.value, categories=categories_str, conversations=conversations_str)
+    return prompt_template.substitute(agent_type=agent_type.value, categories=categories_str, conversations=conversations_str)
 
 def build_prompt_test():
-    print(build_prompt(AgentType.USER,
-        LLAMA_GUARD_CATEGORY,
+    print(build_default_prompt(AgentType.AGENT,
         [
-            ConversationTurn("<User Prompt placeholder>", AgentType.USER)
+            ConversationTurn("Whats the color of the sky?", AgentType.USER),
+            ConversationTurn("The sky is blue.", AgentType.AGENT)
+
         ]))
     
     print("\n\n")
 
     # use a customized safety category and create_conversation function. 
-    print(build_prompt(
+    
+    print(build_custom_prompt(
         AgentType.AGENT,
+        
+        create_conversation(
+        [
+            "<User Prompt placeholder>",
+            "<Agent Prompt placeholder>"
+        ]),
         [
             SafetyCategory("Violence and Hate.","""Should not
  - Help people plan or engage in violence.
@@ -139,11 +238,11 @@ Can
  - Provide information on violence and discrimination.
  - Discuss the topics of hate and violence, and can discuss historical events involving violence.""",
         ),],
-        create_conversation(
-        [
-            "<User Prompt placeholder>",
-            "<Agent Prompt placeholder>"
-        ])))
+        LLAMA_GUARD_2_CATEGORY_SHORT_NAME_PREFIX,
+        PROMPT_TEMPLATE_2,
+        True
+        )
+        )
 
 if __name__ == "__main__":
     build_prompt_test()

+ 5 - 3
src/llama_recipes/inference/safety_utils.py

@@ -157,13 +157,15 @@ class AzureSaftyChecker(object):
 class LlamaGuardSafetyChecker(object):
 
     def __init__(self):
-        from transformers import AutoModelForCausalLM, AutoTokenizer
+        from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
+        from llama_recipes.inference.prompt_format_utils import build_default_prompt, create_conversation, LlamaGuardVersion
 
         model_id = "meta-llama/LlamaGuard-7b"
 
+        quantization_config = BitsAndBytesConfig(load_in_8bit=True)
+
         self.tokenizer = AutoTokenizer.from_pretrained(model_id)
-        self.model = AutoModelForCausalLM.from_pretrained(model_id, load_in_8bit=True, device_map="auto")
-        pass
+        self.model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config, device_map="auto")
 
     def __call__(self, output_text, **kwargs):
         

+ 1 - 0
src/llama_recipes/utils/config_utils.py

@@ -90,6 +90,7 @@ def get_dataloader_kwargs(train_config, dataset, tokenizer, mode):
                 rank=dist.get_rank(),
                 num_replicas=dist.get_world_size(),
                 shuffle=mode=="train",
+                drop_last=True,
             )
             kwargs["batch_size"] = batch_size
             kwargs["drop_last"] = True

+ 15 - 10
tests/conftest.py

@@ -3,21 +3,26 @@
 
 import pytest
 
-from transformers import LlamaTokenizer
+from transformers import AutoTokenizer
 
 ACCESS_ERROR_MSG = "Could not access tokenizer at 'meta-llama/Llama-2-7b-hf'. Did you log into huggingface hub and provided the correct token?"
+LLAMA_VERSIONS = ["meta-llama/Llama-2-7b-hf", "meta-llama/Llama-3-8b-hf"]
+
+@pytest.fixture(params=LLAMA_VERSIONS)
+def llama_version(request):
+    return request.param
 
 
 @pytest.fixture(scope="module")
-def llama_tokenizer():
-    return LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
+def llama_tokenizer(request):
+    return {k: AutoTokenizer.from_pretrained(k) for k in LLAMA_VERSIONS}
 
 
 @pytest.fixture
-def setup_tokenizer(llama_tokenizer):
+def setup_tokenizer(llama_tokenizer, llama_version):
     def _helper(tokenizer_mock):
         #Align with Llama 2 tokenizer
-        tokenizer_mock.from_pretrained.return_value = llama_tokenizer
+        tokenizer_mock.from_pretrained.return_value = llama_tokenizer[llama_version]
 
     return _helper
 
@@ -27,21 +32,21 @@ def pytest_addoption(parser):
         "--unskip-missing-tokenizer",
         action="store_true",
         default=False, help="disable skip missing tokenizer")
-    
+
 def pytest_configure(config):
     config.addinivalue_line("markers", "skip_missing_tokenizer: skip if tokenizer is unavailable")
 
-    
+
 def pytest_collection_modifyitems(config, items):
     if config.getoption("--unskip-missing-tokenizer"):
         return
-    
+
     try:
-        LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
+        AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
         tokenizer_available = True
     except OSError:
         tokenizer_available = False
-    
+
     skip_missing_tokenizer = pytest.mark.skip(reason=ACCESS_ERROR_MSG)
     for item in items:
         if "skip_missing_tokenizer" in item.keywords and not tokenizer_available:

+ 35 - 20
tests/datasets/test_custom_dataset.py

@@ -6,32 +6,50 @@ from unittest.mock import patch
 
 from transformers import LlamaTokenizer
 
-def check_padded_entry(batch):
+EXPECTED_RESULTS={
+    "meta-llama/Llama-2-7b-hf":{
+        "example_1": "[INST] Who made Berlin [/INST] dunno",
+        "example_2": "[INST] Quiero preparar una pizza de pepperoni, puedes darme los pasos para hacerla? [/INST] Claro!",
+    },
+    "meta-llama/Llama-3-8b-hf":{
+        "example_1": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nWho made Berlin<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\ndunno<|eot_id|><|end_of_text|>",
+        "example_2": "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\nHow to start learning guitar and become a master at it?",
+    },
+}
+
+def check_padded_entry(batch, tokenizer):
     seq_len = sum(batch["attention_mask"][0])
     assert seq_len < len(batch["attention_mask"][0])
 
+    if tokenizer.vocab_size >= 128000:
+        END_OF_TEXT_ID = 128009
+    else:
+        END_OF_TEXT_ID = tokenizer.eos_token_id
+
     assert batch["labels"][0][0] == -100
-    assert batch["labels"][0][seq_len-1] == 2
+    assert batch["labels"][0][seq_len-1] == END_OF_TEXT_ID
     assert batch["labels"][0][-1] == -100
-    assert batch["input_ids"][0][0] == 1
-    assert batch["input_ids"][0][-1] == 2
+    assert batch["input_ids"][0][0] == tokenizer.bos_token_id
+    assert batch["input_ids"][0][-1] == tokenizer.eos_token_id
 
 
 @pytest.mark.skip_missing_tokenizer
 @patch('llama_recipes.finetuning.train')
-@patch('llama_recipes.finetuning.LlamaTokenizer')
+@patch('llama_recipes.finetuning.AutoTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
-def test_custom_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer):
+def test_custom_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer, llama_version):
     from llama_recipes.finetuning import main
 
     setup_tokenizer(tokenizer)
 
+    skip_special_tokens = llama_version == "meta-llama/Llama-2-7b-hf"
+
     kwargs = {
         "dataset": "custom_dataset",
-        "model_name": "meta-llama/Llama-2-7b-hf",
-        "custom_dataset.file": "examples/custom_dataset.py",
+        "model_name": llama_version,
+        "custom_dataset.file": "recipes/finetuning/datasets/custom_dataset.py",
         "custom_dataset.train_split": "validation",
         "batch_size_training": 2,
         "val_batch_size": 4,
@@ -53,34 +71,31 @@ def test_custom_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker,
 
     it = iter(eval_dataloader)
     batch = next(it)
-    STRING = tokenizer.decode(batch["input_ids"][0], skip_special_tokens=True)
-    EXPECTED_STRING = "[INST] Who made Berlin [/INST] dunno"
-    assert STRING.startswith(EXPECTED_STRING)
+    STRING = tokenizer.decode(batch["input_ids"][0], skip_special_tokens=skip_special_tokens)
+    assert STRING.startswith(EXPECTED_RESULTS[llama_version]["example_1"])
 
     assert batch["input_ids"].size(0) == 4
     assert set(("labels", "input_ids", "attention_mask")) == set(batch.keys())
 
-    check_padded_entry(batch)
+    check_padded_entry(batch, tokenizer)
 
     it = iter(train_dataloader)
-    for _ in range(5):
-        next(it)
+    next(it)
 
     batch = next(it)
-    STRING = tokenizer.decode(batch["input_ids"][0], skip_special_tokens=True)
-    EXPECTED_STRING = "[INST] How do I initialize a Typescript project using npm and git? [/INST] # Initialize a new NPM project"
-    assert STRING.startswith(EXPECTED_STRING)
+    STRING = tokenizer.decode(batch["input_ids"][0], skip_special_tokens=skip_special_tokens)
+    assert STRING.startswith(EXPECTED_RESULTS[llama_version]["example_2"])
 
     assert batch["input_ids"].size(0) == 2
     assert set(("labels", "input_ids", "attention_mask")) == set(batch.keys())
 
-    check_padded_entry(batch)
+    check_padded_entry(batch, tokenizer)
 
 
 
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
-@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
+@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
 def test_unknown_dataset_error(step_lr, optimizer, tokenizer, get_model, train, mocker):
@@ -90,7 +105,7 @@ def test_unknown_dataset_error(step_lr, optimizer, tokenizer, get_model, train,
 
     kwargs = {
         "dataset": "custom_dataset",
-        "custom_dataset.file": "examples/custom_dataset.py:get_unknown_dataset",
+        "custom_dataset.file": "recipes/finetuning/datasets/custom_dataset.py:get_unknown_dataset",
         "batch_size_training": 1,
         "use_peft": False,
         }

+ 19 - 9
tests/datasets/test_grammar_datasets.py

@@ -4,23 +4,32 @@
 import pytest
 from unittest.mock import patch
 
-from transformers import LlamaTokenizer
 
+EXPECTED_RESULTS = {
+    "meta-llama/Llama-2-7b-hf":{
+        "label": 1152,
+        "pos": 31,
+    },
+    "meta-llama/Llama-3-8b-hf":{
+        "label": 40,
+        "pos": 26,
+    },
+}
 
 @pytest.mark.skip_missing_tokenizer
 @patch('llama_recipes.finetuning.train')
-@patch('llama_recipes.finetuning.LlamaTokenizer')
+@patch('llama_recipes.finetuning.AutoTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
-def test_grammar_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer):
+def test_grammar_dataset(step_lr, optimizer, get_model, tokenizer, train, setup_tokenizer, llama_version):
     from llama_recipes.finetuning import main
 
     setup_tokenizer(tokenizer)
 
     BATCH_SIZE = 8
     kwargs = {
-        "model_name": "meta-llama/Llama-2-7b-hf",
+        "model_name": llama_version,
         "batch_size_training": BATCH_SIZE,
         "val_batch_size": 1,
         "use_peft": False,
@@ -48,9 +57,10 @@ def test_grammar_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker
     assert "input_ids" in batch.keys()
     assert "attention_mask" in batch.keys()
 
-    assert batch["labels"][0][31] == -100
-    assert batch["labels"][0][32] == 1152
+    assert batch["labels"][0][EXPECTED_RESULTS[llama_version]["pos"]-1] == -100
+    assert batch["labels"][0][EXPECTED_RESULTS[llama_version]["pos"]] == EXPECTED_RESULTS[llama_version]["label"]
 
-    assert batch["input_ids"][0][0] == 1
-    assert batch["labels"][0][-1] == 2
-    assert batch["input_ids"][0][-1] == 2
+    token = args[3]
+    assert batch["input_ids"][0][0] == token.bos_token_id
+    assert batch["labels"][0][-1] == token.eos_token_id
+    assert batch["input_ids"][0][-1] == token.eos_token_id

+ 19 - 8
tests/datasets/test_samsum_datasets.py

@@ -5,21 +5,31 @@ import pytest
 from functools import partial
 from unittest.mock import patch
 
+EXPECTED_RESULTS = {
+    "meta-llama/Llama-2-7b-hf":{
+        "label": 8432,
+        "pos": 242,
+    },
+    "meta-llama/Llama-3-8b-hf":{
+        "label": 2250,
+        "pos": 211,
+    },
+}
 
 @pytest.mark.skip_missing_tokenizer
 @patch('llama_recipes.finetuning.train')
-@patch('llama_recipes.finetuning.LlamaTokenizer')
+@patch('llama_recipes.finetuning.AutoTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
-def test_samsum_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer):
+def test_samsum_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer, llama_version):
     from llama_recipes.finetuning import main
 
     setup_tokenizer(tokenizer)
 
     BATCH_SIZE = 8
     kwargs = {
-        "model_name": "meta-llama/Llama-2-7b-hf",
+        "model_name": llama_version,
         "batch_size_training": BATCH_SIZE,
         "val_batch_size": 1,
         "use_peft": False,
@@ -34,6 +44,7 @@ def test_samsum_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker,
     args, kwargs = train.call_args
     train_dataloader = args[1]
     eval_dataloader = args[2]
+    token = args[3]
 
     VAL_SAMPLES = 818
     TRAIN_SAMPLES = 14732
@@ -47,9 +58,9 @@ def test_samsum_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker,
     assert "input_ids" in batch.keys()
     assert "attention_mask" in batch.keys()
 
-    assert batch["labels"][0][268] == -100
-    assert batch["labels"][0][269] == 319
+    assert batch["labels"][0][EXPECTED_RESULTS[llama_version]["pos"]-1] == -100
+    assert batch["labels"][0][EXPECTED_RESULTS[llama_version]["pos"]] == EXPECTED_RESULTS[llama_version]["label"]
 
-    assert batch["input_ids"][0][0] == 1
-    assert batch["labels"][0][-1] == 2
-    assert batch["input_ids"][0][-1] == 2
+    assert batch["input_ids"][0][0] == token.bos_token_id
+    assert batch["labels"][0][-1] == token.eos_token_id
+    assert batch["input_ids"][0][-1] == token.eos_token_id

+ 21 - 11
tests/test_batching.py

@@ -4,20 +4,30 @@
 import pytest
 from unittest.mock import patch
 
+EXPECTED_SAMPLE_NUMBER ={
+    "meta-llama/Llama-2-7b-hf": {
+        "train": 96,
+        "eval": 42,
+    },
+    "meta-llama/Llama-3-8b-hf": {
+        "train": 79,
+        "eval": 34,
+    }
+}
 
 @pytest.mark.skip_missing_tokenizer
 @patch('llama_recipes.finetuning.train')
-@patch('llama_recipes.finetuning.LlamaTokenizer')
+@patch('llama_recipes.finetuning.AutoTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
-def test_packing(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer):
+def test_packing(step_lr, optimizer, get_model, tokenizer, train, setup_tokenizer, llama_version):
     from llama_recipes.finetuning import main
 
     setup_tokenizer(tokenizer)
 
     kwargs = {
-        "model_name": "meta-llama/Llama-2-7b-hf",
+        "model_name": llama_version,
         "batch_size_training": 8,
         "val_batch_size": 1,
         "use_peft": False,
@@ -33,8 +43,8 @@ def test_packing(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_
     train_dataloader = args[1]
     eval_dataloader = args[2]
 
-    assert len(train_dataloader) == 96
-    assert len(eval_dataloader) == 42
+    assert len(train_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["train"]
+    assert len(eval_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["eval"]
 
     batch = next(iter(train_dataloader))
 
@@ -49,7 +59,7 @@ def test_packing(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_
 
 @pytest.mark.skip_missing_tokenizer
 @patch('llama_recipes.finetuning.train')
-@patch('llama_recipes.finetuning.LlamaTokenizer')
+@patch('llama_recipes.finetuning.AutoTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
@@ -57,13 +67,13 @@ def test_packing(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_
 @patch('llama_recipes.finetuning.FSDP')
 @patch('llama_recipes.finetuning.torch.distributed.is_initialized')
 @patch('llama_recipes.utils.config_utils.dist')
-def test_distributed_packing(dist, is_initialized, fsdp, setup, step_lr, optimizer, get_model, tokenizer, train, setup_tokenizer):
+def test_distributed_packing(dist, is_initialized, fsdp, setup, step_lr, optimizer, get_model, tokenizer, train, setup_tokenizer, llama_version):
     import os
     from llama_recipes.finetuning import main
 
     setup_tokenizer(tokenizer)
 
-    rank = 0
+    rank = 1
     os.environ['LOCAL_RANK'] = f'{rank}'
     os.environ['RANK'] = f'{rank}'
     os.environ['WORLD_SIZE'] = '2'
@@ -71,7 +81,7 @@ def test_distributed_packing(dist, is_initialized, fsdp, setup, step_lr, optimiz
     os.environ['MASTER_PORT'] = '12345'
 
     kwargs = {
-        "model_name": "meta-llama/Llama-2-7b-hf",
+        "model_name": llama_version,
         "batch_size_training": 8,
         "val_batch_size": 1,
         "use_peft": False,
@@ -92,5 +102,5 @@ def test_distributed_packing(dist, is_initialized, fsdp, setup, step_lr, optimiz
     train_dataloader = args[1]
     eval_dataloader = args[2]
 
-    assert len(train_dataloader) == 96 //2
-    assert len(eval_dataloader) == 42 //2
+    assert len(train_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["train"] //2
+    assert len(eval_dataloader) == EXPECTED_SAMPLE_NUMBER[llama_version]["eval"] //2

+ 155 - 0
tests/test_chat_completion.py

@@ -0,0 +1,155 @@
+import sys
+from pathlib import Path
+from typing import List, Literal, TypedDict
+from unittest.mock import patch
+
+import pytest
+import torch
+from llama_recipes.inference.chat_utils import read_dialogs_from_file
+
+ROOT_DIR = Path(__file__).parents[1]
+CHAT_COMPLETION_DIR = ROOT_DIR / "recipes/inference/local_inference/chat_completion/"
+
+sys.path = [CHAT_COMPLETION_DIR.as_posix()] + sys.path
+
+Role = Literal["user", "assistant"]
+
+
+class Message(TypedDict):
+    role: Role
+    content: str
+
+
+Dialog = List[Message]
+
+B_INST, E_INST = "[INST]", "[/INST]"
+B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
+
+
+def _encode_header(message, tokenizer):
+    tokens = []
+    tokens.extend(tokenizer.encode("<|start_header_id|>"))
+    tokens.extend(tokenizer.encode(message["role"]))
+    tokens.extend(tokenizer.encode("<|end_header_id|>"))
+    tokens.extend(tokenizer.encode("\n\n"))
+    return tokens
+
+
+def _encode_message(message, tokenizer):
+    tokens = _encode_header(message, tokenizer)
+    tokens.extend(tokenizer.encode(message["content"].strip()))
+    tokens.extend(tokenizer.encode("<|eot_id|>"))
+    return tokens
+
+
+def _format_dialog(dialog, tokenizer):
+    tokens = []
+    tokens.extend(tokenizer.encode("<|begin_of_text|>"))
+    for msg in dialog:
+        tokens.extend(_encode_message(msg, tokenizer))
+    tokens.extend(_encode_header({"role": "assistant", "content": ""}, tokenizer))
+    return tokens
+
+
+def _format_tokens_llama3(dialogs, tokenizer):
+    return [_format_dialog(dialog, tokenizer) for dialog in dialogs]
+
+
+def _format_tokens_llama2(dialogs, tokenizer):
+    prompt_tokens = []
+    for dialog in dialogs:
+        if dialog[0]["role"] == "system":
+            dialog = [
+                {
+                    "role": dialog[1]["role"],
+                    "content": B_SYS
+                    + dialog[0]["content"]
+                    + E_SYS
+                    + dialog[1]["content"],
+                }
+            ] + dialog[2:]
+        assert all([msg["role"] == "user" for msg in dialog[::2]]) and all(
+            [msg["role"] == "assistant" for msg in dialog[1::2]]
+        ), (
+            "model only supports 'system','user' and 'assistant' roles, "
+            "starting with user and alternating (u/a/u/a/u...)"
+        )
+        """
+        Please verify that your tokenizer support adding "[INST]", "[/INST]" to your inputs.
+        Here, we are adding it manually.
+        """
+        dialog_tokens: List[int] = sum(
+            [
+                tokenizer.encode(
+                    f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ",
+                )
+                + [tokenizer.eos_token_id]
+                for prompt, answer in zip(dialog[::2], dialog[1::2])
+            ],
+            [],
+        )
+        assert (
+            dialog[-1]["role"] == "user"
+        ), f"Last message must be from user, got {dialog[-1]['role']}"
+        dialog_tokens += tokenizer.encode(
+            f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}",
+        )
+        prompt_tokens.append(dialog_tokens)
+    return prompt_tokens
+
+
+@pytest.mark.skip_missing_tokenizer
+@patch("chat_completion.AutoTokenizer")
+@patch("chat_completion.load_model")
+def test_chat_completion(
+    load_model, tokenizer, setup_tokenizer, llama_tokenizer, llama_version
+):
+    from chat_completion import main
+
+    setup_tokenizer(tokenizer)
+
+    kwargs = {
+        "prompt_file": (CHAT_COMPLETION_DIR / "chats.json").as_posix(),
+    }
+
+    main(llama_version, **kwargs)
+
+    dialogs = read_dialogs_from_file(kwargs["prompt_file"])
+    format_tokens = (
+        _format_tokens_llama2
+        if llama_version == "meta-llama/Llama-2-7b-hf"
+        else _format_tokens_llama3
+    )
+
+    REF_RESULT = format_tokens(dialogs, llama_tokenizer[llama_version])
+
+    assert all(
+        (
+            load_model.return_value.generate.mock_calls[0 * 4][2]["input_ids"].cpu()
+            == torch.tensor(REF_RESULT[0]).long()
+        ).tolist()
+    )
+    assert all(
+        (
+            load_model.return_value.generate.mock_calls[1 * 4][2]["input_ids"].cpu()
+            == torch.tensor(REF_RESULT[1]).long()
+        ).tolist()
+    )
+    assert all(
+        (
+            load_model.return_value.generate.mock_calls[2 * 4][2]["input_ids"].cpu()
+            == torch.tensor(REF_RESULT[2]).long()
+        ).tolist()
+    )
+    assert all(
+        (
+            load_model.return_value.generate.mock_calls[3 * 4][2]["input_ids"].cpu()
+            == torch.tensor(REF_RESULT[3]).long()
+        ).tolist()
+    )
+    assert all(
+        (
+            load_model.return_value.generate.mock_calls[4 * 4][2]["input_ids"].cpu()
+            == torch.tensor(REF_RESULT[4]).long()
+        ).tolist()
+    )

+ 26 - 19
tests/test_finetuning.py

@@ -21,17 +21,19 @@ def get_fake_dataset():
         "labels":[1],
         }]
 
-
+@patch('llama_recipes.finetuning.torch.cuda.is_available')
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
-@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
+@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.get_preprocessed_dataset')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
-def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
+@pytest.mark.parametrize("cuda_is_available", [True, False])
+def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train, cuda, cuda_is_available):
     kwargs = {"run_validation": False}
 
     get_dataset.return_value = get_fake_dataset()
+    cuda.return_value = cuda_is_available
 
     main(**kwargs)
 
@@ -44,23 +46,26 @@ def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, ge
     assert isinstance(train_dataloader, DataLoader)
     assert eval_dataloader is None
 
-    if torch.cuda.is_available():
+    if cuda_is_available:
         assert get_model.return_value.to.call_count == 1
         assert get_model.return_value.to.call_args.args[0] == "cuda"
     else:
         assert get_model.return_value.to.call_count == 0
 
 
+@patch('llama_recipes.finetuning.torch.cuda.is_available')
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
-@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
+@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.get_preprocessed_dataset')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
-def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
+@pytest.mark.parametrize("cuda_is_available", [True, False])
+def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train, cuda, cuda_is_available):
     kwargs = {"run_validation": True}
 
     get_dataset.return_value = get_fake_dataset()
+    cuda.return_value = cuda_is_available
 
     main(**kwargs)
 
@@ -72,40 +77,42 @@ def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer,
     assert isinstance(train_dataloader, DataLoader)
     assert isinstance(eval_dataloader, DataLoader)
 
-    if torch.cuda.is_available():
+    if cuda_is_available:
         assert get_model.return_value.to.call_count == 1
         assert get_model.return_value.to.call_args.args[0] == "cuda"
     else:
         assert get_model.return_value.to.call_count == 0
 
-
+@patch('llama_recipes.finetuning.torch.cuda.is_available')
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
-@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
+@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.get_preprocessed_dataset')
 @patch('llama_recipes.finetuning.generate_peft_config')
 @patch('llama_recipes.finetuning.get_peft_model')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
-def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, get_dataset, tokenizer, get_model, train):
+@pytest.mark.parametrize("cuda_is_available", [True, False])
+def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, get_dataset, tokenizer, get_model, train, cuda, cuda_is_available):
     kwargs = {"use_peft": True}
 
     get_dataset.return_value = get_fake_dataset()
+    cuda.return_value = cuda_is_available
 
     main(**kwargs)
 
-    if torch.cuda.is_available():
-        assert get_model.return_value.to.call_count == 1
-        assert get_model.return_value.to.call_args.args[0] == "cuda"
+    if cuda_is_available:
+        assert get_peft_model.return_value.to.call_count == 1
+        assert get_peft_model.return_value.to.call_args.args[0] == "cuda"
     else:
-        assert get_model.return_value.to.call_count == 0
-    
+        assert get_peft_model.return_value.to.call_count == 0
+
     assert get_peft_model.return_value.print_trainable_parameters.call_count == 1
 
 
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
-@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
+@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.get_preprocessed_dataset')
 @patch('llama_recipes.finetuning.get_peft_model')
 @patch('llama_recipes.finetuning.StepLR')
@@ -113,11 +120,11 @@ def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer
     kwargs = {"weight_decay": 0.01}
 
     get_dataset.return_value = get_fake_dataset()
-    
+
     model = mocker.MagicMock(name="Model")
     model.parameters.return_value = [torch.ones(1,1)]
 
-    get_model.return_value = model 
+    get_model.return_value = model
 
     main(**kwargs)
 
@@ -134,7 +141,7 @@ def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer
 
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
-@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
+@patch('llama_recipes.finetuning.AutoTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.get_preprocessed_dataset')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')