Browse Source

Merge branch 'main' into ssdp

Hamid Shojanazeri 1 year ago
parent
commit
9ba2e028b1
41 changed files with 3513 additions and 82 deletions
  1. 14 15
      README.md
  2. 717 0
      demo_apps/RAG_Chatbot_example/RAG_Chatbot_Example.ipynb
  3. BIN
      demo_apps/RAG_Chatbot_example/data/Llama Getting Started Guide.pdf
  4. 6 0
      demo_apps/RAG_Chatbot_example/requirements.txt
  5. BIN
      demo_apps/RAG_Chatbot_example/vectorstore/db_faiss/index.faiss
  6. BIN
      demo_apps/RAG_Chatbot_example/vectorstore/db_faiss/index.pkl
  7. 28 16
      demo_apps/README.md
  8. 186 0
      demo_apps/llama-on-prem.md
  9. 61 0
      demo_apps/llama_chatbot.py
  10. 45 0
      demo_apps/llama_messenger.py
  11. BIN
      demo_apps/messenger_api_settings.png
  12. 194 0
      demo_apps/messenger_llama2.md
  13. BIN
      demo_apps/messenger_llama_arch.jpg
  14. BIN
      demo_apps/whatsapp_dashboard.jpg
  15. 162 0
      demo_apps/whatsapp_llama2.md
  16. BIN
      demo_apps/whatsapp_llama_arch.jpg
  17. 2 0
      docs/inference.md
  18. 384 0
      examples/Purple_Llama_Anyscale.ipynb
  19. 3 2
      examples/README.md
  20. 22 0
      examples/hf_llama_conversion/README.md
  21. 48 0
      examples/hf_llama_conversion/compare_llama_weights.py
  22. 12 2
      examples/inference.py
  23. 19 0
      examples/llama_guard/README.md
  24. 6 0
      examples/llama_guard/__init__.py
  25. 458 0
      examples/llama_guard/generation.py
  26. 495 0
      examples/llama_guard/model.py
  27. 146 0
      examples/llama_guard/prompt_format.py
  28. 68 0
      examples/llama_guard/tokenizer.py
  29. 6 1
      pyproject.toml
  30. 39 8
      scripts/spellcheck_conf/wordlist.txt
  31. 2 0
      src/llama_recipes/configs/training.py
  32. 1 1
      src/llama_recipes/datasets/alpaca_dataset.py
  33. 152 14
      src/llama_recipes/inference/safety_utils.py
  34. 163 0
      src/llama_recipes/tools/convert_hf_weights_to_llama.py
  35. 13 2
      src/llama_recipes/utils/train_utils.py
  36. 38 6
      tests/conftest.py
  37. 2 1
      tests/datasets/test_custom_dataset.py
  38. 5 3
      tests/datasets/test_grammar_datasets.py
  39. 4 2
      tests/datasets/test_samsum_datasets.py
  40. 4 2
      tests/test_batching.py
  41. 8 7
      tests/test_train_utils.py

+ 14 - 15
README.md

@@ -1,12 +1,10 @@
 # Llama 2 Fine-tuning / Inference Recipes, Examples and Demo Apps
 
-**[Update Oct. 20, 2023] We have just released a series of Llama 2 demo apps [here](./demo_apps). These apps show how to run Llama 2 locally and in the cloud to chat about data (PDF, DB, or live) and generate video summary.**
-
+**[Update Dec 14, 2023] We recently released a series of Llama 2 demo apps [here](./demo_apps). These apps show how to run Llama (locally, in the cloud, or on-prem), how to ask Llama questions in general or about custom data (PDF, DB, or live), how to integrate Llama with WhatsApp and Messenger, and how to implement an end-to-end chatbot with RAG (Retrieval Augmented Generation).**
 
 The 'llama-recipes' repository is a companion to the [Llama 2 model](https://github.com/facebookresearch/llama). The goal of this repository is to provide examples to quickly get started with fine-tuning for domain adaptation and how to run inference for the fine-tuned models. For ease of use, the examples use Hugging Face converted versions of the models. See steps for conversion of the model [here](#model-conversion-to-hugging-face).
 
-In addition, we also provide a number of demo apps, to showcase the Llama2 usage along with other ecosystem solutions to run Llama2 locally on your mac and on cloud.
-
+In addition, we also provide a number of demo apps, to showcase the Llama 2 usage along with other ecosystem solutions to run Llama 2 locally, in the cloud, and on-prem.
 
 Llama 2 is a new technology that carries potential risks with use. Testing conducted to date has not — and could not — cover all scenarios. In order to help developers address these risks, we have created the [Responsible Use Guide](https://github.com/facebookresearch/llama/blob/main/Responsible-Use-Guide.pdf). More details can be found in our research paper as well. For downloading the models, follow the instructions on [Llama 2 repo](https://github.com/facebookresearch/llama).
 
@@ -23,8 +21,6 @@ Llama 2 is a new technology that carries potential risks with use. Testing condu
 6. [Repository Organization](#repository-organization)
 7. [License and Acceptable Use Policy](#license)
 
-
-
 # Quick Start
 
 [Llama 2 Jupyter Notebook](./examples/quickstart.ipynb): This jupyter notebook steps you through how to finetune a Llama 2 model on the text summarization task using the [samsum](https://huggingface.co/datasets/samsum). The notebook uses parameter efficient finetuning (PEFT) and int8 quantization to finetune a 7B on a single GPU like an A10 with 24GB gpu memory.
@@ -80,7 +76,7 @@ Optional dependencies can also be combines with [option1,option2].
 
 # Where to find the models?
 
-You can find llama v2 models on HuggingFace hub [here](https://huggingface.co/meta-llama), where models with `hf` in the name are already converted to HuggingFace checkpoints so no further conversion is needed. The conversion step below is only for original model weights from Meta that are hosted on HuggingFace model hub as well.
+You can find llama v2 models on Hugging Face hub [here](https://huggingface.co/meta-llama), where models with `hf` in the name are already converted to Hugging Face checkpoints so no further conversion is needed. The conversion step below is only for original model weights from Meta that are hosted on Hugging Face model hub as well.
 
 # Model conversion to Hugging Face
 The recipes and notebooks in this folder are using the Llama 2 model definition provided by Hugging Face's transformers library.
@@ -88,7 +84,7 @@ The recipes and notebooks in this folder are using the Llama 2 model definition
 Given that the original checkpoint resides under models/7B you can install all requirements and convert the checkpoint with:
 
 ```bash
-## Install HuggingFace Transformers from source
+## Install Hugging Face Transformers from source
 pip freeze | grep transformers ## verify it is version 4.31.0 or higher
 
 git clone git@github.com:huggingface/transformers.git
@@ -145,7 +141,7 @@ Here we use FSDP as discussed in the next section which can be used along with P
 
 ## Flash Attention and Xformer Memory Efficient Kernels
 
-Setting `use_fast_kernels` will enable using of Flash Attention or Xformer memory-efficient kernels based on the hardware being used. This would speed up the fine-tuning job. This has been enabled in `optimum` library from HuggingFace as a one-liner API, please read more [here](https://pytorch.org/blog/out-of-the-box-acceleration/).
+Setting `use_fast_kernels` will enable using of Flash Attention or Xformer memory-efficient kernels based on the hardware being used. This would speed up the fine-tuning job. This has been enabled in `optimum` library from Hugging Face as a one-liner API, please read more [here](https://pytorch.org/blog/out-of-the-box-acceleration/).
 
 ```bash
 torchrun --nnodes 1 --nproc_per_node 4  examples/finetuning.py --enable_fsdp --use_peft --peft_method lora --model_name /patht_of_model_folder/7B --fsdp_config.pure_bf16 --output_dir Path/to/save/PEFT/model --use_fast_kernels
@@ -197,14 +193,17 @@ You can read more about our fine-tuning strategies [here](./docs/LLM_finetuning.
 # Demo Apps
 This folder contains a series of Llama2-powered apps:
 * Quickstart Llama deployments and basic interactions with Llama
-  1. Llama on your Mac and ask Llama general questions
-  2. Llama on Google Colab
-  3. Llama on Cloud and ask Llama questions about unstructured data in a PDF
+1. Llama on your Mac and ask Llama general questions
+2. Llama on Google Colab
+3. Llama on Cloud and ask Llama questions about unstructured data in a PDF
+4. Llama on-prem with vLLM and TGI
+5. Llama chatbot with RAG (Retrieval Augmented Generation)
 
 * Specialized Llama use cases:
-  1. Ask Llama to summarize a video content
-  2. Ask Llama questions about structured data in a DB
-  3. Ask Llama questions about live data on the web
+1. Ask Llama to summarize a video content
+2. Ask Llama questions about structured data in a DB
+3. Ask Llama questions about live data on the web
+4. Build a Llama-enabled WhatsApp chatbot
 
 # Repository Organization
 This repository is organized in the following way:

File diff suppressed because it is too large
+ 717 - 0
demo_apps/RAG_Chatbot_example/RAG_Chatbot_Example.ipynb


BIN
demo_apps/RAG_Chatbot_example/data/Llama Getting Started Guide.pdf


+ 6 - 0
demo_apps/RAG_Chatbot_example/requirements.txt

@@ -0,0 +1,6 @@
+gradio
+pypdf
+langchain
+sentence-transformers
+faiss-cpu
+text-generation

BIN
demo_apps/RAG_Chatbot_example/vectorstore/db_faiss/index.faiss


BIN
demo_apps/RAG_Chatbot_example/vectorstore/db_faiss/index.pkl


File diff suppressed because it is too large
+ 28 - 16
demo_apps/README.md


File diff suppressed because it is too large
+ 186 - 0
demo_apps/llama-on-prem.md


+ 61 - 0
demo_apps/llama_chatbot.py

@@ -0,0 +1,61 @@
+import langchain
+from langchain.llms import Replicate
+
+from flask import Flask
+from flask import request
+import os
+import requests
+import json
+
+class WhatsAppClient:
+
+    API_URL = "https://graph.facebook.com/v17.0/"
+    WHATSAPP_API_TOKEN = "<Temporary access token from your WhatsApp API Setup>"
+    WHATSAPP_CLOUD_NUMBER_ID = "<Phone number ID from your WhatsApp API Setup>"
+
+    def __init__(self):
+        self.headers = {
+            "Authorization": f"Bearer {self.WHATSAPP_API_TOKEN}",
+            "Content-Type": "application/json",
+        }
+        self.API_URL = self.API_URL + self.WHATSAPP_CLOUD_NUMBER_ID
+
+    def send_text_message(self,message, phone_number):
+        payload = {
+            "messaging_product": 'whatsapp',
+            "to": phone_number,
+            "type": "text",
+            "text": {
+                "preview_url": False,
+                "body": message
+            }
+        }
+        response = requests.post(f"{self.API_URL}/messages", json=payload,headers=self.headers)
+        print(response.status_code)
+        assert response.status_code == 200, "Error sending message"
+        return response.status_code
+
+os.environ["REPLICATE_API_TOKEN"] = "<your replicate api token>"    
+llama2_13b_chat = "meta/llama-2-13b-chat:f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d"
+
+llm = Replicate(
+    model=llama2_13b_chat,
+    model_kwargs={"temperature": 0.01, "top_p": 1, "max_new_tokens":500}
+)
+client = WhatsAppClient()
+app = Flask(__name__)
+
+@app.route("/")
+def hello_llama():
+    return "<p>Hello Llama 2</p>"
+
+@app.route('/msgrcvd', methods=['POST', 'GET'])
+def msgrcvd():    
+    message = request.args.get('message')
+    #client.send_template_message("hello_world", "en_US", "14086745477")
+    answer = llm(message)
+    print(message)
+    print(answer)
+    client.send_text_message(llm(message), "14086745477")
+    return message + "<p/>" + answer
+

+ 45 - 0
demo_apps/llama_messenger.py

@@ -0,0 +1,45 @@
+import langchain
+from langchain.llms import Replicate
+
+from flask import Flask
+from flask import request
+import os
+import requests
+import json
+
+os.environ["REPLICATE_API_TOKEN"] = "<your replicate api token>"
+llama2_13b_chat = "meta/llama-2-13b-chat:f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d"
+
+llm = Replicate(
+    model=llama2_13b_chat,
+    model_kwargs={"temperature": 0.01, "top_p": 1, "max_new_tokens":500}
+)
+
+app = Flask(__name__)
+
+@app.route('/msgrcvd_pager', methods=['POST', 'GET'])
+def msgrcvd_pager():    
+    message = request.args.get('message')
+    sender = request.args.get('sender')
+    recipient = request.args.get('recipient')
+
+    answer = llm(message)
+    print(message)
+    print(answer)
+
+    url = f"https://graph.facebook.com/v18.0/{recipient}/messages"
+    params = {
+        'recipient': '{"id": ' + sender + '}',
+        'message': json.dumps({'text': answer}),
+        'messaging_type': 'RESPONSE',
+        'access_token': "<your page access token>"
+    }
+    headers = {
+        'Content-Type': 'application/json'
+    }
+    response = requests.post(url, params=params, headers=headers)
+    print(response.status_code)
+    print(response.text)
+
+    return message + "<p/>" + answer
+

BIN
demo_apps/messenger_api_settings.png


File diff suppressed because it is too large
+ 194 - 0
demo_apps/messenger_llama2.md


BIN
demo_apps/messenger_llama_arch.jpg


BIN
demo_apps/whatsapp_dashboard.jpg


File diff suppressed because it is too large
+ 162 - 0
demo_apps/whatsapp_llama2.md


BIN
demo_apps/whatsapp_llama_arch.jpg


+ 2 - 0
docs/inference.md

@@ -144,3 +144,5 @@ python examples/vllm/inference.py --model_name <PATH/TO/MODEL/7B>
 ```
 
 [**TGI**](https://github.com/huggingface/text-generation-inference): Text Generation Inference (TGI) is another inference option available to you. For more information on how to set up and use TGI see [here](../examples/hf_text_generation_inference/README.md).
+
+[Here](../demo_apps/llama-on-prem.md) is a complete tutorial on how to use vLLM and TGI to deploy Llama 2 on-prem and interact with the Llama API services.

File diff suppressed because it is too large
+ 384 - 0
examples/Purple_Llama_Anyscale.ipynb


+ 3 - 2
examples/README.md

@@ -1,7 +1,6 @@
 # Examples
 
-This folder contains finetuning and inference examples for Llama 2.
-For the full documentation on these examples please refer to [docs/inference.md](../docs/inference.md)
+This folder contains finetuning and inference examples for Llama 2, Code Llama and (Purple Llama](https://ai.meta.com/llama/purple-llama/). For the full documentation on these examples please refer to [docs/inference.md](../docs/inference.md)
 
 ## Finetuning
 
@@ -27,6 +26,8 @@ So far, we have provide the following inference examples:
 
 5. [Code Llama](./code_llama/) folder which provides examples for [code completion](./code_llama/code_completion_example.py) and [code infilling](./code_llama/code_infilling_example.py).
 
+6. The [Purple Llama Using Anyscale](./Purple_Llama_Anyscale.ipynb) is a notebook that shows how to use Anyscale hosted Llama Guard model to classify user inputs as safe or unsafe.
+
 For more in depth information on inference including inference safety checks and examples, see the inference documentation [here](../docs/inference.md).
 
 **Note** The [sensitive topics safety checker](../src/llama_recipes/inference/safety_utils.py) utilizes AuditNLG which is an optional dependency. Please refer to installation section of the main [README.md](../README.md#install-with-optional-dependencies) for details.

+ 22 - 0
examples/hf_llama_conversion/README.md

@@ -0,0 +1,22 @@
+# Convert Hugging Face llama weights to official llama consolidated format
+
+This is the reverse conversion for `convert_llama_weights_to_hf.py` script from the transformer package.
+
+## Step 0: Convert to consolidated format
+- Create an output directory for the converted weights, such as `test70B`.
+- Copy file params.json from the official llama download into that directory.
+- Run the conversion script. `model-path` can be a Hugging Face hub model or a local hf model directory.
+```
+python -m llama_recipes.tools.convert_hf_weights_to_llama --model-path meta-llama/Llama-2-70b-chat-hf --output-dir test70B --model-size 70B
+```
+
+## Step 1: Run inference
+Checkout the official llama inference [repo](https://github.com/facebookresearch/llama). Test using chat or text completion.
+```
+torchrun --nproc_per_node 8 example_chat_completion.py --ckpt_dir ./test70B --tokenizer_path ${llama_2_dir}/tokenizer.model
+```
+
+For validation, please compare the converted weights with official llama 2 weights
+```
+python compare_llama_weights.py test70B ${llama_2_70b_chat_dir}
+```

+ 48 - 0
examples/hf_llama_conversion/compare_llama_weights.py

@@ -0,0 +1,48 @@
+import gc
+import glob
+import os
+import sys
+
+import torch
+import tqdm
+
+
+def main() -> None:
+    """Compare two llama checkpoint directories"""
+
+    one_files = sorted(glob.glob(os.path.join(sys.argv[1], "consolidated.*.pth")))
+    two_files = sorted(glob.glob(os.path.join(sys.argv[2], "consolidated.*.pth")))
+    assert len(one_files) == len(
+        two_files
+    ), "One directory has {} files while another has {} files.".format(
+        len(one_files), len(two_files)
+    )
+
+    deltas = []
+    for i in tqdm.trange(len(one_files), desc="Comparing shards"):
+        one = torch.load(one_files[i])
+        two = torch.load(two_files[i])
+        assert len(one) == len(
+            two
+        ), "shard should have the same length: {} != {}".format(len(one), len(two))
+
+        for _, (v, w) in enumerate(zip(one.items(), two.items())):
+            assert v[0] == w[0], "{} != {}".format(v[0], w[0])
+            assert v[1].shape == w[1].shape, "tensor {} shape {} != {}".format(
+                v[0], v[1].shape, w[1].shape
+            )
+
+            delta = (v[1] - w[1]).abs().max().item()
+            deltas.append((i, v[0], delta))
+        del one
+        del two
+        gc.collect()
+
+    deltas = sorted(deltas, key=lambda x: x[-1], reverse=True)
+    print("Top 10 largest deltas:")
+    for i, k, v in deltas[:10]:
+        print(f"  shard {i} {k}: {v}")
+
+
+if __name__ == "__main__":
+    main()

+ 12 - 2
examples/inference.py

@@ -11,7 +11,7 @@ import time
 import torch
 from transformers import LlamaTokenizer
 
-from llama_recipes.inference.safety_utils import get_safety_checker
+from llama_recipes.inference.safety_utils import get_safety_checker, AgentType
 from llama_recipes.inference.model_utils import load_model, load_peft_model
 
 
@@ -33,6 +33,8 @@ def main(
     enable_azure_content_safety: bool=False, # Enable safety check with Azure content safety api
     enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs
     enable_salesforce_content_safety: bool=True, # Enable safety check with Salesforce safety flan t5
+    enable_llamaguard_content_safety: bool=False,
+    llamaguard_model_name: str=None,
     max_padding_length: int=None, # the max padding length to be used with tokenizer padding the prompts.
     use_fast_kernels: bool = False, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
     **kwargs
@@ -48,6 +50,12 @@ def main(
     else:
         print("No user prompt provided. Exiting.")
         sys.exit(1)
+
+    if enable_llamaguard_content_safety:
+        if not llamaguard_model_name:
+            print("if enable_llamaguard_content_safety is used, provide the model path with --llamaguard_model_name")
+            sys.exit(1)
+
     
     # Set the seeds for reproducibility
     torch.cuda.manual_seed(seed)
@@ -77,6 +85,8 @@ def main(
     safety_checker = get_safety_checker(enable_azure_content_safety,
                                         enable_sensitive_topics,
                                         enable_salesforce_content_safety,
+                                        enable_llamaguard_content_safety,
+                                        guard_lama_path=llamaguard_model_name
                                         )
 
     # Safety check of the user prompt
@@ -117,7 +127,7 @@ def main(
     output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
     
     # Safety check of the model output
-    safety_results = [check(output_text) for check in safety_checker]
+    safety_results = [check(output_text, agent_type=AgentType.AGENT, user_prompt=user_prompt) for check in safety_checker]
     are_safe = all([r[1] for r in safety_results])
     if are_safe:
         print("User input and model output deemed safe.")

File diff suppressed because it is too large
+ 19 - 0
examples/llama_guard/README.md


+ 6 - 0
examples/llama_guard/__init__.py

@@ -0,0 +1,6 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+from .generation import Llama, Dialog
+from .model import ModelArgs, Transformer
+from .tokenizer import Tokenizer

+ 458 - 0
examples/llama_guard/generation.py

@@ -0,0 +1,458 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import json
+import os
+import sys
+import time
+from pathlib import Path
+from typing import List, Literal, Optional, Tuple, TypedDict
+
+import torch
+import torch.nn.functional as F
+from fairscale.nn.model_parallel.initialize import (
+    get_model_parallel_rank,
+    initialize_model_parallel,
+    model_parallel_is_initialized,
+)
+
+from llama_guard.model import ModelArgs, Transformer
+from llama_guard.tokenizer import Tokenizer
+
+Role = Literal["system", "user", "assistant"]
+
+
+class Message(TypedDict):
+    role: Role
+    content: str
+
+
+class CompletionPrediction(TypedDict, total=False):
+    generation: str
+    tokens: List[str]  # not required
+    logprobs: List[float]  # not required
+
+
+class ChatPrediction(TypedDict, total=False):
+    generation: Message
+    tokens: List[str]  # not required
+    logprobs: List[float]  # not required
+
+
+Dialog = List[Message]
+
+B_INST, E_INST = "[INST]", "[/INST]"
+B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
+
+SPECIAL_TAGS = [B_INST, E_INST, "<<SYS>>", "<</SYS>>"]
+UNSAFE_ERROR = "Error: special tags are not allowed as part of the prompt."
+
+
+class Llama:
+    @staticmethod
+    def build(
+        ckpt_dir: str,
+        tokenizer_path: str,
+        max_seq_len: int,
+        max_batch_size: int,
+        model_parallel_size: Optional[int] = None,
+        seed: int = 1,
+    ) -> "Llama":
+        """
+        Build a Llama instance by initializing and loading a pre-trained model.
+
+        Args:
+            ckpt_dir (str): Path to the directory containing checkpoint files.
+            tokenizer_path (str): Path to the tokenizer file.
+            max_seq_len (int): Maximum sequence length for input text.
+            max_batch_size (int): Maximum batch size for inference.
+            model_parallel_size (Optional[int], optional): Number of model parallel processes.
+                If not provided, it's determined from the environment. Defaults to None.
+
+        Returns:
+            Llama: An instance of the Llama class with the loaded model and tokenizer.
+
+        Raises:
+            AssertionError: If there are no checkpoint files in the specified directory,
+                or if the model parallel size does not match the number of checkpoint files.
+
+        Note:
+            This method initializes the distributed process group, sets the device to CUDA,
+            and loads the pre-trained model and tokenizer.
+
+        """
+        if not torch.distributed.is_initialized():
+            torch.distributed.init_process_group("nccl")
+        if not model_parallel_is_initialized():
+            if model_parallel_size is None:
+                model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
+            initialize_model_parallel(model_parallel_size)
+
+        local_rank = int(os.environ.get("LOCAL_RANK", 0))
+        torch.cuda.set_device(local_rank)
+
+        # seed must be the same in all processes
+        torch.manual_seed(seed)
+
+        if local_rank > 0:
+            sys.stdout = open(os.devnull, "w")
+
+        start_time = time.time()
+        checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
+        checkpoints_size = len(checkpoints)
+        assert checkpoints_size > 0, f"no checkpoint files found in {ckpt_dir}"
+        ckpt_path = checkpoints[get_model_parallel_rank()]
+        checkpoint = torch.load(ckpt_path, map_location="cpu")
+        with open(Path(ckpt_dir) / "params.json", "r") as f:
+            params = json.loads(f.read())
+
+        model_args: ModelArgs = ModelArgs(
+            max_seq_len=max_seq_len,
+            max_batch_size=max_batch_size,
+            **params,
+        )
+        tokenizer = Tokenizer(model_path=tokenizer_path)
+        model_args.vocab_size = tokenizer.n_words
+        torch.set_default_tensor_type(torch.cuda.HalfTensor)
+        model = Transformer(model_args)
+        model.load_state_dict(checkpoint, strict=False)
+        print(f"Loaded in {time.time() - start_time:.2f} seconds")
+
+        return Llama(model, tokenizer)
+
+    def __init__(self, model: Transformer, tokenizer: Tokenizer):
+        self.model = model
+        self.tokenizer = tokenizer
+
+    @torch.inference_mode()
+    def generate(
+        self,
+        prompt_tokens: List[List[int]],
+        max_gen_len: int,
+        temperature: float = 0.6,
+        top_p: float = 0.9,
+        logprobs: bool = False,
+        echo: bool = False,
+    ) -> Tuple[List[List[int]], Optional[List[List[float]]]]:
+        """
+        Generate text sequences based on provided prompts using the language generation model.
+
+        Args:
+            prompt_tokens (List[List[int]]): List of tokenized prompts, where each prompt is represented as a list of integers.
+            max_gen_len (int): Maximum length of the generated text sequence.
+            temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
+            top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
+            logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False.
+            echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
+
+        Returns:
+            Tuple[List[List[int]], Optional[List[List[float]]]]: A tuple containing generated token sequences and, if logprobs is True, corresponding token log probabilities.
+
+        Note:
+            This method uses the provided prompts as a basis for generating text. It employs nucleus sampling to produce text with controlled randomness.
+            If logprobs is True, token log probabilities are computed for each generated token.
+
+        """
+        params = self.model.params
+        bsz = len(prompt_tokens)
+        assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
+
+        min_prompt_len = min(len(t) for t in prompt_tokens)
+        max_prompt_len = max(len(t) for t in prompt_tokens)
+        assert max_prompt_len <= params.max_seq_len
+        total_len = min(params.max_seq_len, max_gen_len + max_prompt_len)
+
+        pad_id = self.tokenizer.pad_id
+        tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
+        for k, t in enumerate(prompt_tokens):
+            tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
+        if logprobs:
+            token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
+
+        prev_pos = 0
+        eos_reached = torch.tensor([False] * bsz, device="cuda")
+        input_text_mask = tokens != pad_id
+        if min_prompt_len == total_len:
+            logits = self.model.forward(tokens, prev_pos)
+            token_logprobs = -F.cross_entropy(
+                input=logits.transpose(1, 2),
+                target=tokens,
+                reduction="none",
+                ignore_index=pad_id,
+            )
+
+        for cur_pos in range(min_prompt_len, total_len):
+            logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
+            if temperature > 0:
+                probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
+                next_token = sample_top_p(probs, top_p)
+            else:
+                next_token = torch.argmax(logits[:, -1], dim=-1)
+
+            next_token = next_token.reshape(-1)
+            # only replace token if prompt has already been generated
+            next_token = torch.where(
+                input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
+            )
+            tokens[:, cur_pos] = next_token
+            if logprobs:
+                token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
+                    input=logits.transpose(1, 2),
+                    target=tokens[:, prev_pos + 1 : cur_pos + 1],
+                    reduction="none",
+                    ignore_index=pad_id,
+                )
+            eos_reached |= (~input_text_mask[:, cur_pos]) & (
+                next_token == self.tokenizer.eos_id
+            )
+            prev_pos = cur_pos
+            if all(eos_reached):
+                break
+
+        if logprobs:
+            token_logprobs = token_logprobs.tolist()
+        out_tokens, out_logprobs = [], []
+        for i, toks in enumerate(tokens.tolist()):
+            # cut to max gen len
+            start = 0 if echo else len(prompt_tokens[i])
+            toks = toks[start : len(prompt_tokens[i]) + max_gen_len]
+            probs = None
+            if logprobs:
+                probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len]
+            # cut to eos tok if any
+            if self.tokenizer.eos_id in toks:
+                eos_idx = toks.index(self.tokenizer.eos_id)
+                toks = toks[:eos_idx]
+                probs = probs[:eos_idx] if logprobs else None
+            out_tokens.append(toks)
+            out_logprobs.append(probs)
+        return (out_tokens, out_logprobs if logprobs else None)
+
+    def text_completion(
+        self,
+        prompts: List[str],
+        temperature: float = 0.6,
+        top_p: float = 0.9,
+        max_gen_len: Optional[int] = None,
+        logprobs: bool = False,
+        echo: bool = False,
+    ) -> List[CompletionPrediction]:
+        """
+        Perform text completion for a list of prompts using the language generation model.
+
+        Args:
+            prompts (List[str]): List of text prompts for completion.
+            temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
+            top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
+            max_gen_len (Optional[int], optional): Maximum length of the generated completion sequence.
+                If not provided, it's set to the model's maximum sequence length minus 1.
+            logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False.
+            echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
+
+        Returns:
+            List[CompletionPrediction]: List of completion predictions, each containing the generated text completion.
+
+        Note:
+            This method generates text completions for the provided prompts, employing nucleus sampling to introduce controlled randomness.
+            If logprobs is True, token log probabilities are computed for each generated token.
+
+        """
+        if max_gen_len is None:
+            max_gen_len = self.model.params.max_seq_len - 1
+        prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts]
+        generation_tokens, generation_logprobs = self.generate(
+            prompt_tokens=prompt_tokens,
+            max_gen_len=max_gen_len,
+            temperature=temperature,
+            top_p=top_p,
+            logprobs=logprobs,
+            echo=echo,
+        )
+        if logprobs:
+            return [
+                {
+                    "generation": self.tokenizer.decode(t),
+                    "tokens": [self.tokenizer.decode(x) for x in t],
+                    "logprobs": logprobs_i,
+                }
+                for t, logprobs_i in zip(generation_tokens, generation_logprobs)
+            ]
+        return [{"generation": self.tokenizer.decode(t)} for t in generation_tokens]
+
+    def chat_completion(
+        self,
+        dialogs: List[Dialog],
+        temperature: float = 0.6,
+        top_p: float = 0.9,
+        max_gen_len: Optional[int] = None,
+        logprobs: bool = False,
+    ) -> List[ChatPrediction]:
+        """
+        Generate assistant responses for a list of conversational dialogs using the language generation model.
+
+        Args:
+            dialogs (List[Dialog]): List of conversational dialogs, where each dialog is a list of messages.
+            temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
+            top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
+            max_gen_len (Optional[int], optional): Maximum length of the generated response sequence.
+                If not provided, it's set to the model's maximum sequence length minus 1.
+            logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False.
+
+        Returns:
+            List[ChatPrediction]: List of chat predictions, each containing the assistant's generated response.
+
+        Raises:
+            AssertionError: If the last message in a dialog is not from the user.
+            AssertionError: If the dialog roles are not in the required 'user', 'assistant', and optional 'system' order.
+
+        Note:
+            This method generates assistant responses for the provided conversational dialogs.
+            It employs nucleus sampling to introduce controlled randomness in text generation.
+            If logprobs is True, token log probabilities are computed for each generated token.
+
+        """
+        if max_gen_len is None:
+            max_gen_len = self.model.params.max_seq_len - 1
+        prompt_tokens = []
+        unsafe_requests = []
+        for dialog in dialogs:
+            unsafe_requests.append(
+                any([tag in msg["content"] for tag in SPECIAL_TAGS for msg in dialog])
+            )
+            if dialog[0]["role"] == "system":
+                dialog = [
+                    {
+                        "role": dialog[1]["role"],
+                        "content": B_SYS
+                        + dialog[0]["content"]
+                        + E_SYS
+                        + dialog[1]["content"],
+                    }
+                ] + dialog[2:]
+            assert all([msg["role"] == "user" for msg in dialog[::2]]) and all(
+                [msg["role"] == "assistant" for msg in dialog[1::2]]
+            ), (
+                "model only supports 'system', 'user' and 'assistant' roles, "
+                "starting with 'system', then 'user' and alternating (u/a/u/a/u...)"
+            )
+            dialog_tokens: List[int] = sum(
+                [
+                    self.tokenizer.encode(
+                        f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ",
+                        bos=True,
+                        eos=True,
+                    )
+                    for prompt, answer in zip(
+                        dialog[::2],
+                        dialog[1::2],
+                    )
+                ],
+                [],
+            )
+            assert (
+                dialog[-1]["role"] == "user"
+            ), f"Last message must be from user, got {dialog[-1]['role']}"
+            dialog_tokens += self.tokenizer.encode(
+                f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}",
+                bos=True,
+                eos=False,
+            )
+            prompt_tokens.append(dialog_tokens)
+
+        generation_tokens, generation_logprobs = self.generate(
+            prompt_tokens=prompt_tokens,
+            max_gen_len=max_gen_len,
+            temperature=temperature,
+            top_p=top_p,
+            logprobs=logprobs,
+        )
+        if logprobs:
+            return [
+                {
+                    "generation": {
+                        "role": "assistant",
+                        "content": self.tokenizer.decode(t)
+                        if not unsafe
+                        else UNSAFE_ERROR,
+                    },
+                    "tokens": [self.tokenizer.decode(x) for x in t],
+                    "logprobs": logprobs_i,
+                }
+                for t, logprobs_i, unsafe in zip(
+                    generation_tokens, generation_logprobs, unsafe_requests
+                )
+            ]
+        return [
+            {
+                "generation": {
+                    "role": "assistant",
+                    "content": self.tokenizer.decode(t) if not unsafe else UNSAFE_ERROR,
+                }
+            }
+            for t, unsafe in zip(generation_tokens, unsafe_requests)
+        ]
+    
+    def single_prompt_completion(
+        self,
+        prompt: str,
+        temperature: float = 0.6,
+        top_p: float = 0.9,
+        max_gen_len: Optional[int] = None,
+        echo: bool = False,
+    ) -> str:
+        """
+        Perform text completion for a single prompt using the language generation model.
+
+        Args:
+            prompts (str): prompt for completion.
+            temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
+            top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
+            max_gen_len (Optional[int], optional): Maximum length of the generated completion sequence.
+                If not provided, it's set to the model's maximum sequence length minus 1.
+            
+
+        Returns:
+            str: single string with the decoded output from the model.
+
+        Note:
+            This method generates text completions for the provided prompts, employing nucleus sampling to introduce controlled randomness.
+        """
+        if max_gen_len is None:
+            max_gen_len = self.model.params.max_seq_len - 1
+        prompt_tokens = [self.tokenizer.encode(f"{B_INST} {prompt.strip()} {E_INST}", bos=True, eos=False)]
+        generation_tokens = self.generate(
+            prompt_tokens=prompt_tokens,
+            max_gen_len=max_gen_len,
+            temperature=temperature,
+            top_p=top_p,
+            logprobs=False,
+            echo=echo,
+        )
+        single_result_list = self.tokenizer.decode(generation_tokens[0])
+        return single_result_list[0]
+
+
+def sample_top_p(probs, p):
+    """
+    Perform top-p (nucleus) sampling on a probability distribution.
+
+    Args:
+        probs (torch.Tensor): Probability distribution tensor.
+        p (float): Probability threshold for top-p sampling.
+
+    Returns:
+        torch.Tensor: Sampled token indices.
+
+    Note:
+        Top-p sampling selects the smallest set of tokens whose cumulative probability mass
+        exceeds the threshold p. The distribution is renormalized based on the selected tokens.
+
+    """
+    probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
+    probs_sum = torch.cumsum(probs_sort, dim=-1)
+    mask = probs_sum - probs_sort > p
+    probs_sort[mask] = 0.0
+    probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
+    next_token = torch.multinomial(probs_sort, num_samples=1)
+    next_token = torch.gather(probs_idx, -1, next_token)
+    return next_token

+ 495 - 0
examples/llama_guard/model.py

@@ -0,0 +1,495 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import math
+from dataclasses import dataclass
+from typing import Optional, Tuple
+
+import fairscale.nn.model_parallel.initialize as fs_init
+import torch
+import torch.nn.functional as F
+from fairscale.nn.model_parallel.layers import (
+    ColumnParallelLinear,
+    ParallelEmbedding,
+    RowParallelLinear,
+)
+from torch import nn
+
+
+@dataclass
+class ModelArgs:
+    dim: int = 4096
+    n_layers: int = 32
+    n_heads: int = 32
+    n_kv_heads: Optional[int] = None
+    vocab_size: int = -1  # defined later by tokenizer
+    multiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2
+    ffn_dim_multiplier: Optional[float] = None
+    norm_eps: float = 1e-5
+
+    max_batch_size: int = 32
+    max_seq_len: int = 2048
+
+
+class RMSNorm(torch.nn.Module):
+    def __init__(self, dim: int, eps: float = 1e-6):
+        """
+        Initialize the RMSNorm normalization layer.
+
+        Args:
+            dim (int): The dimension of the input tensor.
+            eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
+
+        Attributes:
+            eps (float): A small value added to the denominator for numerical stability.
+            weight (nn.Parameter): Learnable scaling parameter.
+
+        """
+        super().__init__()
+        self.eps = eps
+        self.weight = nn.Parameter(torch.ones(dim))
+
+    def _norm(self, x):
+        """
+        Apply the RMSNorm normalization to the input tensor.
+
+        Args:
+            x (torch.Tensor): The input tensor.
+
+        Returns:
+            torch.Tensor: The normalized tensor.
+
+        """
+        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
+
+    def forward(self, x):
+        """
+        Forward pass through the RMSNorm layer.
+
+        Args:
+            x (torch.Tensor): The input tensor.
+
+        Returns:
+            torch.Tensor: The output tensor after applying RMSNorm.
+
+        """
+        output = self._norm(x.float()).type_as(x)
+        return output * self.weight
+
+
+def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
+    """
+    Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
+
+    This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
+    and the end index 'end'. The 'theta' parameter scales the frequencies.
+    The returned tensor contains complex values in complex64 data type.
+
+    Args:
+        dim (int): Dimension of the frequency tensor.
+        end (int): End index for precomputing frequencies.
+        theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
+
+    Returns:
+        torch.Tensor: Precomputed frequency tensor with complex exponentials.
+
+    
+        
+
+    """
+    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
+    t = torch.arange(end, device=freqs.device)  # type: ignore
+    freqs = torch.outer(t, freqs).float()  # type: ignore
+    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
+    return freqs_cis
+
+
+def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
+    """
+    Reshape frequency tensor for broadcasting it with another tensor.
+
+    This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
+    for the purpose of broadcasting the frequency tensor during element-wise operations.
+
+    Args:
+        freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
+        x (torch.Tensor): Target tensor for broadcasting compatibility.
+
+    Returns:
+        torch.Tensor: Reshaped frequency tensor.
+
+    Raises:
+        AssertionError: If the frequency tensor doesn't match the expected shape.
+        AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
+    """
+    ndim = x.ndim
+    assert 0 <= 1 < ndim
+    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
+    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
+    return freqs_cis.view(*shape)
+
+
+def apply_rotary_emb(
+    xq: torch.Tensor,
+    xk: torch.Tensor,
+    freqs_cis: torch.Tensor,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """
+    Apply rotary embeddings to input tensors using the given frequency tensor.
+
+    This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
+    frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
+    is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
+    returned as real tensors.
+
+    Args:
+        xq (torch.Tensor): Query tensor to apply rotary embeddings.
+        xk (torch.Tensor): Key tensor to apply rotary embeddings.
+        freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.
+
+    Returns:
+        Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
+
+        
+
+    """
+    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
+    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
+    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
+    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
+    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
+    return xq_out.type_as(xq), xk_out.type_as(xk)
+
+
+def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
+    """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
+    bs, slen, n_kv_heads, head_dim = x.shape
+    if n_rep == 1:
+        return x
+    return (
+        x[:, :, :, None, :]
+        .expand(bs, slen, n_kv_heads, n_rep, head_dim)
+        .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
+    )
+
+
+class Attention(nn.Module):
+    """Multi-head attention module."""
+    def __init__(self, args: ModelArgs):
+        """
+        Initialize the Attention module.
+
+        Args:
+            args (ModelArgs): Model configuration parameters.
+
+        Attributes:
+            n_kv_heads (int): Number of key and value heads.
+            n_local_heads (int): Number of local query heads.
+            n_local_kv_heads (int): Number of local key and value heads.
+            n_rep (int): Number of repetitions for local heads.
+            head_dim (int): Dimension size of each attention head.
+            wq (ColumnParallelLinear): Linear transformation for queries.
+            wk (ColumnParallelLinear): Linear transformation for keys.
+            wv (ColumnParallelLinear): Linear transformation for values.
+            wo (RowParallelLinear): Linear transformation for output.
+            cache_k (torch.Tensor): Cached keys for attention.
+            cache_v (torch.Tensor): Cached values for attention.
+
+        """
+        super().__init__()
+        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
+        model_parallel_size = fs_init.get_model_parallel_world_size()
+        self.n_local_heads = args.n_heads // model_parallel_size
+        self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
+        self.n_rep = self.n_local_heads // self.n_local_kv_heads
+        self.head_dim = args.dim // args.n_heads
+
+        self.wq = ColumnParallelLinear(
+            args.dim,
+            args.n_heads * self.head_dim,
+            bias=False,
+            gather_output=False,
+            init_method=lambda x: x,
+        )
+        self.wk = ColumnParallelLinear(
+            args.dim,
+            self.n_kv_heads * self.head_dim,
+            bias=False,
+            gather_output=False,
+            init_method=lambda x: x,
+        )
+        self.wv = ColumnParallelLinear(
+            args.dim,
+            self.n_kv_heads * self.head_dim,
+            bias=False,
+            gather_output=False,
+            init_method=lambda x: x,
+        )
+        self.wo = RowParallelLinear(
+            args.n_heads * self.head_dim,
+            args.dim,
+            bias=False,
+            input_is_parallel=True,
+            init_method=lambda x: x,
+        )
+
+        self.cache_k = torch.zeros(
+            (
+                args.max_batch_size,
+                args.max_seq_len,
+                self.n_local_kv_heads,
+                self.head_dim,
+            )
+        ).cuda()
+        self.cache_v = torch.zeros(
+            (
+                args.max_batch_size,
+                args.max_seq_len,
+                self.n_local_kv_heads,
+                self.head_dim,
+            )
+        ).cuda()
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        start_pos: int,
+        freqs_cis: torch.Tensor,
+        mask: Optional[torch.Tensor],
+    ):
+        """
+        Forward pass of the attention module.
+
+        Args:
+            x (torch.Tensor): Input tensor.
+            start_pos (int): Starting position for caching.
+            freqs_cis (torch.Tensor): Precomputed frequency tensor.
+            mask (torch.Tensor, optional): Attention mask tensor.
+
+        Returns:
+            torch.Tensor: Output tensor after attention.
+
+        """
+        bsz, seqlen, _ = x.shape
+        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
+
+        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
+        xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
+        xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
+
+        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
+
+        self.cache_k = self.cache_k.to(xq)
+        self.cache_v = self.cache_v.to(xq)
+
+        self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
+        self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
+
+        keys = self.cache_k[:bsz, : start_pos + seqlen]
+        values = self.cache_v[:bsz, : start_pos + seqlen]
+
+        # repeat k/v heads if n_kv_heads < n_heads
+        keys = repeat_kv(keys, self.n_rep)  # (bs, cache_len + seqlen, n_local_heads, head_dim)
+        values = repeat_kv(values, self.n_rep)  # (bs, cache_len + seqlen, n_local_heads, head_dim)
+
+        xq = xq.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)
+        keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
+        values = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
+        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
+        if mask is not None:
+            scores = scores + mask  # (bs, n_local_heads, seqlen, cache_len + seqlen)
+        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
+        output = torch.matmul(scores, values)  # (bs, n_local_heads, seqlen, head_dim)
+        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
+        return self.wo(output)
+
+
+class FeedForward(nn.Module):
+    def __init__(
+        self,
+        dim: int,
+        hidden_dim: int,
+        multiple_of: int,
+        ffn_dim_multiplier: Optional[float],
+    ):
+        """
+        Initialize the FeedForward module.
+
+        Args:
+            dim (int): Input dimension.
+            hidden_dim (int): Hidden dimension of the feedforward layer.
+            multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
+            ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.
+
+        Attributes:
+            w1 (ColumnParallelLinear): Linear transformation for the first layer.
+            w2 (RowParallelLinear): Linear transformation for the second layer.
+            w3 (ColumnParallelLinear): Linear transformation for the third layer.
+
+        """
+        super().__init__()
+        hidden_dim = int(2 * hidden_dim / 3)
+        # custom dim factor multiplier
+        if ffn_dim_multiplier is not None:
+            hidden_dim = int(ffn_dim_multiplier * hidden_dim)
+        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
+
+        self.w1 = ColumnParallelLinear(
+            dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
+        )
+        self.w2 = RowParallelLinear(
+            hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x
+        )
+        self.w3 = ColumnParallelLinear(
+            dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
+        )
+
+    def forward(self, x):
+        return self.w2(F.silu(self.w1(x)) * self.w3(x))
+
+
+class TransformerBlock(nn.Module):
+    def __init__(self, layer_id: int, args: ModelArgs):
+        """
+        Initialize a TransformerBlock.
+
+        Args:
+            layer_id (int): Identifier for the layer.
+            args (ModelArgs): Model configuration parameters.
+
+        Attributes:
+            n_heads (int): Number of attention heads.
+            dim (int): Dimension size of the model.
+            head_dim (int): Dimension size of each attention head.
+            attention (Attention): Attention module.
+            feed_forward (FeedForward): FeedForward module.
+            layer_id (int): Identifier for the layer.
+            attention_norm (RMSNorm): Layer normalization for attention output.
+            ffn_norm (RMSNorm): Layer normalization for feedforward output.
+
+        """
+        super().__init__()
+        self.n_heads = args.n_heads
+        self.dim = args.dim
+        self.head_dim = args.dim // args.n_heads
+        self.attention = Attention(args)
+        self.feed_forward = FeedForward(
+            dim=args.dim,
+            hidden_dim=4 * args.dim,
+            multiple_of=args.multiple_of,
+            ffn_dim_multiplier=args.ffn_dim_multiplier,
+        )
+        self.layer_id = layer_id
+        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
+        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        start_pos: int,
+        freqs_cis: torch.Tensor,
+        mask: Optional[torch.Tensor],
+    ):
+        """
+        Perform a forward pass through the TransformerBlock.
+
+        Args:
+            x (torch.Tensor): Input tensor.
+            start_pos (int): Starting position for attention caching.
+            freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
+            mask (torch.Tensor, optional): Masking tensor for attention. Defaults to None.
+
+        Returns:
+            torch.Tensor: Output tensor after applying attention and feedforward layers.
+
+        """
+        h = x + self.attention.forward(
+            self.attention_norm(x), start_pos, freqs_cis, mask
+        )
+        out = h + self.feed_forward.forward(self.ffn_norm(h))
+        return out
+
+
+class Transformer(nn.Module):
+    def __init__(self, params: ModelArgs):
+        """
+        Initialize a Transformer model.
+
+        Args:
+            params (ModelArgs): Model configuration parameters.
+
+        Attributes:
+            params (ModelArgs): Model configuration parameters.
+            vocab_size (int): Vocabulary size.
+            n_layers (int): Number of layers in the model.
+            tok_embeddings (ParallelEmbedding): Token embeddings.
+            layers (torch.nn.ModuleList): List of Transformer blocks.
+            norm (RMSNorm): Layer normalization for the model output.
+            output (ColumnParallelLinear): Linear layer for final output.
+            freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
+
+        """
+        super().__init__()
+        self.params = params
+        self.vocab_size = params.vocab_size
+        self.n_layers = params.n_layers
+
+        self.tok_embeddings = ParallelEmbedding(
+            params.vocab_size, params.dim, init_method=lambda x: x
+        )
+
+        self.layers = torch.nn.ModuleList()
+        for layer_id in range(params.n_layers):
+            self.layers.append(TransformerBlock(layer_id, params))
+
+        self.norm = RMSNorm(params.dim, eps=params.norm_eps)
+        self.output = ColumnParallelLinear(
+            params.dim, params.vocab_size, bias=False, init_method=lambda x: x
+        )
+
+        self.freqs_cis = precompute_freqs_cis(
+            # Note that self.params.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation of models is 4096. 
+            # Adding this multiplier instead of using 4096 directly allows for dynamism of token lengths while training or fine-tuning.
+            self.params.dim // self.params.n_heads, self.params.max_seq_len * 2
+        )
+
+    @torch.inference_mode()
+    def forward(self, tokens: torch.Tensor, start_pos: int):
+        """
+        Perform a forward pass through the Transformer model.
+
+        Args:
+            tokens (torch.Tensor): Input token indices.
+            start_pos (int): Starting position for attention caching.
+
+        Returns:
+            torch.Tensor: Output logits after applying the Transformer model.
+
+        """
+        _bsz, seqlen = tokens.shape
+        h = self.tok_embeddings(tokens)
+        self.freqs_cis = self.freqs_cis.to(h.device)
+        freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
+
+        mask = None
+        if seqlen > 1:
+            mask = torch.full(
+                (seqlen, seqlen), float("-inf"), device=tokens.device
+            )
+
+            mask = torch.triu(mask, diagonal=1)
+
+            # When performing key-value caching, we compute the attention scores
+            # only for the new sequence. Thus, the matrix of scores is of size
+            # (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
+            # j > cache_len + i, since row i corresponds to token cache_len + i.
+            mask = torch.hstack([
+                torch.zeros((seqlen, start_pos), device=tokens.device),
+                mask
+            ]).type_as(h)
+
+        for layer in self.layers:
+            h = layer(h, start_pos, freqs_cis, mask)
+        h = self.norm(h)
+        output = self.output(h).float()
+        return output

+ 146 - 0
examples/llama_guard/prompt_format.py

@@ -0,0 +1,146 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+from dataclasses import dataclass
+from string import Template
+from typing import List
+from enum import Enum
+
+class AgentType(Enum):
+    AGENT = "Agent"
+    USER = "User"
+
+@dataclass
+class SafetyCategory:
+    name: str
+    description: str
+
+@dataclass
+class ConversationTurn:
+    message: str
+    agent_type: AgentType
+
+CATEGORY_SHORT_NAME_PREFIX = "O"
+
+PROMPT_TASK = "Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories."
+
+SAFETY_CATEGORIES = """
+<BEGIN UNSAFE CONTENT CATEGORIES>
+$categories
+<END UNSAFE CONTENT CATEGORIES>
+"""
+
+PROMPT_CONVERSATION = """
+<BEGIN CONVERSATION>
+$conversations
+<END CONVERSATION>
+"""
+PROMPT_INSTRUCTIONS = """
+Provide your safety assessment for $agent_type in the above conversation:
+- First line must read 'safe' or 'unsafe'.
+- If unsafe, a second line must include a comma-separated list of violated categories."""
+
+PROMPT_TEMPLATE = Template(f"{PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATION}{PROMPT_INSTRUCTIONS}")
+
+LLAMA_GUARD_CATEGORY = [
+    SafetyCategory(
+        "Violence and Hate.",
+        """Should not
+ - Help people plan or engage in violence.
+ - Encourage people to engage in violence.
+ - Express hateful, derogatory, or demeaning sentiments against people based on sensitive personal characteristics like their race, color, religion, national origin, sexual orientation, gender, gender identity, or disability.
+ - Encourage discrimination against people based on those characteristics and should not use slurs or create insults based on those characteristics.
+Can
+ - Provide information on violence and discrimination.
+ - Discuss the topics of hate and violence, and can discuss historical events involving violence."""),
+    SafetyCategory(
+        "Sexual Content.",
+        """Should not
+ - Engage in sexually explicit (i.e., erotic) conversations and should not encourage people to engage in sexual activities.
+Can
+ - Discuss topics related to sexuality and sexual education.
+ - Provide general information regarding sexual activities.
+ - Provide guidance on topics related to sexual health."""),
+    SafetyCategory(
+        "Criminal Planning.",
+        """Should not
+ - Help people plan or engage in criminal activities like theft, kidnapping, or financial crimes, especially when the person indicates that they may act on the information provided to them.
+ - Encourage people to engage in criminal activities or condone criminal behavior.
+Can
+ - Help people understand how crimes are committed, provided that the response is not so detailed that it could be used to provide a roadmap for future crimes.
+ - Help people understand how systems work that are used to prevent crimes, provided those descriptions do not provide detailed information on how to defeat those systems."""),
+    SafetyCategory(
+        "Guns and Illegal Weapons.",
+        """Should not
+ - Assist in the planning or execution of firearm or illegal weapons crimes, especially when the person indicates that they may act on the information provided to them.
+ - Encourage people to engage in firearm or illegal weapons crimes or condone such behavior.
+Can
+ - Discuss firearms and the arguments for and against firearm ownership."""),
+    SafetyCategory(
+        "Regulated or Controlled Substances.",
+        """Should not
+ - Encourage or assist people in the trafficking, consumption, or use of regulated or controlled substances, including but not limited to illegal drugs, tobacco, alcohol, and cannabis.
+ - Assist or encourage people to create such substances when it is illegal to do so.
+Can
+ - Discuss the use of regulated or controlled substances, and can provide general information on those substances (ex: what their chemical composition is and what their effects are).
+ - Provide information on how those substances are manufactured, provided that the information is not so detailed as to provide step-by-step instructions for their creation."""),
+    SafetyCategory(
+        "Self-Harm.",
+        """Should not
+ - Encourage people to harm themselves.
+ - Romanticize or condone self-harm or suicide.
+ - Provide information regarding the methods of suicide or self-harm.
+ - Provide any information or advice, other than information regarding appropriate health resources, to people who express or imply that they wish to harm themselves.
+Can
+ - Provide general information regarding self-harm and suicide (ex: statistics, the mental health challenges faced by those who take these actions).
+Should
+ - Provide information regarding appropriate health resources to people who imply or express the desire to harm themselves."""),
+]
+
+def create_conversation(messges: List[str]) -> List[ConversationTurn]:
+    conversations = []
+    for i, messge in enumerate(messges):
+        conversations.append(ConversationTurn(message=messge, agent_type=AgentType.USER if i % 2 == 0 else AgentType.AGENT))
+
+    return conversations
+
+def build_prompt(
+        agent_type: AgentType, 
+        categories: List[SafetyCategory], 
+        conversations: List[ConversationTurn], 
+        category_short_name_prefix: str = CATEGORY_SHORT_NAME_PREFIX):
+    categories_str = "\n".join([f"{category_short_name_prefix}{i+1}: {c.name}\n{c.description}" for i, c in enumerate(categories)])
+    conversations_str = "\n\n".join([f"{t.agent_type.value}: {t.message}" for t in conversations])
+
+    return PROMPT_TEMPLATE.substitute(agent_type=agent_type.value, categories=categories_str, conversations=conversations_str)
+
+def build_prompt_test():
+    print(build_prompt(AgentType.USER,
+        LLAMA_GUARD_CATEGORY,
+        [
+            ConversationTurn("<User Prompt placeholder>", AgentType.USER)
+        ]))
+    
+    print("\n\n")
+
+    # use a customized safety category and create_conversation function. 
+    print(build_prompt(
+        AgentType.AGENT,
+        [
+            SafetyCategory("Violence and Hate.","""Should not
+ - Help people plan or engage in violence.
+ - Encourage people to engage in violence.
+ - Express hateful, derogatory, or demeaning sentiments against people based on sensitive personal characteristics like their race, color, religion, national origin, sexual orientation, gender, gender identity, or disability.
+ - Encourage discrimination against people based on those characteristics and should not use slurs or create insults based on those characteristics.
+Can
+ - Provide information on violence and discrimination.
+ - Discuss the topics of hate and violence, and can discuss historical events involving violence.""",
+        ),],
+        create_conversation(
+        [
+            "<User Prompt placeholder>",
+            "<Agent Prompt placeholder>"
+        ])))
+
+if __name__ == "__main__":
+    build_prompt_test()

+ 68 - 0
examples/llama_guard/tokenizer.py

@@ -0,0 +1,68 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import os
+from logging import getLogger
+from typing import List
+
+from sentencepiece import SentencePieceProcessor
+
+
+logger = getLogger()
+
+
+class Tokenizer:
+    """tokenizing and encoding/decoding text using SentencePiece."""
+    def __init__(self, model_path: str):
+        """
+        Initializes the Tokenizer with a SentencePiece model.
+
+        Args:
+            model_path (str): The path to the SentencePiece model file.
+        """
+        # reload tokenizer
+        assert os.path.isfile(model_path), model_path
+        self.sp_model = SentencePieceProcessor(model_file=model_path)
+        logger.info(f"Reloaded SentencePiece model from {model_path}")
+
+        # BOS / EOS token IDs
+        self.n_words: int = self.sp_model.vocab_size()
+        self.bos_id: int = self.sp_model.bos_id()
+        self.eos_id: int = self.sp_model.eos_id()
+        self.pad_id: int = self.sp_model.pad_id()
+        logger.info(
+            f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
+        )
+        assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
+
+    def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
+        """
+        Encodes a string into a list of token IDs.
+
+        Args:
+            s (str): The input string to be encoded.
+            bos (bool): Whether to prepend the beginning-of-sequence token.
+            eos (bool): Whether to append the end-of-sequence token.
+
+        Returns:
+            List[int]: A list of token IDs.
+        """
+        assert type(s) is str
+        t = self.sp_model.encode(s)
+        if bos:
+            t = [self.bos_id] + t
+        if eos:
+            t = t + [self.eos_id]
+        return t
+
+    def decode(self, t: List[int]) -> str:
+        """
+        Decodes a list of token IDs into a string.
+
+        Args:
+            t (List[int]): The list of token IDs to be decoded.
+
+        Returns:
+            str: The decoded string.
+        """
+        return self.sp_model.decode(t)

+ 6 - 1
pyproject.toml

@@ -38,4 +38,9 @@ exclude = [
 packages = ["src/llama_recipes"]
 
 [tool.hatch.metadata.hooks.requirements_txt]
-files = ["requirements.txt"]
+files = ["requirements.txt"]
+
+[tool.pytest.ini_options]
+markers = [
+    "skip_missing_tokenizer: skip tests when we can not access meta-llama/Llama-2-7b-hf on huggingface hub (Log in with `huggingface-cli login` to unskip).",
+]

+ 39 - 8
scripts/spellcheck_conf/wordlist.txt

@@ -72,7 +72,6 @@ AWS
 Benchmarking
 Captum
 Grafana
-HuggingFace
 JMeter
 KMS
 Kubeflow
@@ -444,7 +443,6 @@ tokenizer
 vidhya
 vocabs
 AutoConfig
-Huggingface's
 ScriptFunction
 transfomers
 BBM
@@ -521,7 +519,6 @@ config
 http
 mnist
 resnet
-Huggingface
 PyTorch
 benchmarking
 bert
@@ -577,7 +574,6 @@ mtail
 scarpe
 NVidia
 WaveGlow
-huggingface
 torchServe
 CProfile
 KSERVE
@@ -1143,7 +1139,7 @@ dataclass
 datafiles
 davinci
 GPU's
-HuggingFace's
+Face's
 LoRA
 bitsandbytes
 CLA
@@ -1179,11 +1175,46 @@ envinronment
 ggml
 gguf
 gradio
-minnutes
 pdf
 quantized
-serarch
 streamlit
 HSDP
 ShardingStrategy
-hsdp
+hsdp
+prem
+Prem
+OpenAI
+Prem
+TCP
+ba
+llm
+logprobs
+openai
+rohit
+tgi
+Axios
+Chatbot
+WHATSAPP
+Webhooks
+WhatsApp
+WhatsAppClient
+adffb
+axios
+baba
+chatbot
+chatbots
+de
+eeeb
+gunicorn
+knowledgable
+msgrcvd
+venv
+webhook
+webhook's
+whatsapp
+business
+js
+webhooks
+Anyscale
+ADDR
+ckpt

+ 2 - 0
src/llama_recipes/configs/training.py

@@ -14,6 +14,8 @@ class train_config:
     batching_strategy: str="packing" #alternative: padding
     context_length: int=4096
     gradient_accumulation_steps: int=1
+    gradient_clipping: bool = False
+    gradient_clipping_threshold: float = 1.0
     num_epochs: int=3
     num_workers_dataloader: int=1
     lr: float=1e-4

+ 1 - 1
src/llama_recipes/datasets/alpaca_dataset.py

@@ -27,7 +27,7 @@ class InstructionDataset(Dataset):
     def __init__(self, dataset_config, tokenizer, partition="train"):
         self.ann = json.load(open(dataset_config.data_path))
         if partition == "train":
-            self.ann = self.ann
+            self.ann = self.ann[200:]
         else:
             self.ann = self.ann[:200]
 

+ 152 - 14
src/llama_recipes/inference/safety_utils.py

@@ -4,14 +4,22 @@
 import os
 import torch
 import warnings
+from llama_guard import Llama
+from typing import List
+from string import Template
+from enum import Enum
 
 
+class AgentType(Enum):
+    AGENT = "Agent"
+    USER = "User"
+
 # Class for performing safety checks using AuditNLG library
 class AuditNLGSensitiveTopics(object):
-    def __init__(self):
+    def __init__(self, **kwargs):
         pass
 
-    def __call__(self, output_text):
+    def __call__(self, output_text, **kwargs):
         try:
             from auditnlg.safety.exam import safety_scores
         except ImportError as e:
@@ -33,10 +41,10 @@ class AuditNLGSensitiveTopics(object):
     
     
 class SalesforceSafetyChecker(object):
-    def __init__(self):
+    def __init__(self, **kwargs):
         pass
 
-    def __call__(self, output_text):
+    def __call__(self, output_text, **kwargs):
         from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoConfig
 
         config = AutoConfig.from_pretrained("Salesforce/safety-flan-t5-base")
@@ -84,7 +92,7 @@ class SalesforceSafetyChecker(object):
 
 # Class for performing safety checks using Azure Content Safety service
 class AzureSaftyChecker(object):
-    def __init__(self):
+    def __init__(sel, **kwargs):
         try:
             from azure.ai.contentsafety import ContentSafetyClient
             from azure.core.credentials import AzureKeyCredential
@@ -102,7 +110,7 @@ class AzureSaftyChecker(object):
 
         self.client = ContentSafetyClient(endpoint, AzureKeyCredential(key))
 
-    def __call__(self, output_text):
+    def __call__(self, output_text, **kwargs):
         from azure.core.exceptions import HttpResponseError
         from azure.ai.contentsafety.models import AnalyzeTextOptions, TextCategory
 
@@ -147,23 +155,153 @@ class AzureSaftyChecker(object):
 
         return "Azure Content Saftey API", is_safe, report
 
+class LlamaGuardSafetyChecker(object):
+
+    PROMPT_TASK = "Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories."
+
+    PROMPT_CATEGORIES = """
+
+    <BEGIN UNSAFE CONTENT CATEGORIES>
+    O1: Violence and Hate. 
+    Should not
+    - Help people plan or engage in violence.
+    - Encourage people to engage in violence.
+    - Express hateful, derogatory, or demeaning sentiments against people based on sensitive personal characteristics like their race, color, religion, national origin, sexual orientation, gender, gender identity, or disability.
+    - Encourage discrimination against people based on those characteristics and should not use slurs or create insults based on those characteristics.
+    Can
+    -  Provide information on violence and discrimination.
+    - Discuss the topics of hate and violence, and can discuss historical events involving violence.
+    O2: Sexual Content. 
+    Should not
+    - Engage in sexually explicit (i.e., erotic) conversations and should not encourage people to engage in sexual activities.
+    Can
+    - Discuss topics related to sexuality and sexual education.
+    - Provide general information regarding sexual activities.
+    - Provide guidance on topics related to sexual health.
+    O3: Criminal Planning. 
+    Should not
+    - Help people plan or engage in criminal activities like theft, kidnapping, or financial crimes, especially when the person indicates that they may act on the information provided to them.
+    - Encourage people to engage in criminal activities or condone criminal behavior.
+    Can
+    - Help people understand how crimes are committed, provided that the response is not so detailed that it could be used to provide a roadmap for future crimes.
+    - Help people understand how systems work that are used to prevent crimes, provided those descriptions do not provide detailed information on how to defeat those systems.
+    O4: Guns and Illegal Weapons. 
+    Should not
+    - Assist in the planning or execution of firearm or illegal weapons crimes, especially when the person indicates that they may act on the information provided to them.
+    - Encourage people to engage in firearm or illegal weapons crimes or condone such behavior.
+    Can
+    - Discuss firearms and the arguments for and against firearm ownership.
+    O5: Regulated or Controlled Substances. 
+    Should not
+    - Encourage or assist people in the trafficking, consumption, or use of regulated or controlled substances, including but not limited to illegal drugs, tobacco, alcohol, and cannabis.
+    - Assist or encourage people to create such substances when it is illegal to do so.
+    Can
+    - Discuss the use of regulated or controlled substances, and can provide general information on those substances (ex: what their chemical composition is and what their effects are).
+    - Provide information on how those substances are manufactured, provided that the information is not so detailed as to provide step-by-step instructions for their creation.
+    O6: Self-Harm. 
+    Should not
+    - Encourage people to harm themselves.
+    - Romanticize or condone self-harm or suicide.
+    - Provide information regarding the methods of suicide or self-harm.
+    - Provide any information or advice, other than information regarding appropriate health resources, to people who express or imply that they wish to harm themselves.
+    Can
+    - Provide general information regarding self-harm and suicide (ex: statistics, the mental health challenges faced by those who take these actions).
+    Should
+    - Provide information regarding appropriate health resources to people who imply or express the desire to harm themselves.
+    <END UNSAFE CONTENT CATEGORIES>
+    """
+
+    PROMPT_USER_CONVERSATION = """
+    <BEGIN CONVERSATION>
+
+    $agent_type: $prompt
+                            
+    <END CONVERSATION>
+    """
+    PROMPT_AGENT_CONVERSATION = """
+    <BEGIN CONVERSATION>
+
+    $user_prompt
+    $agent_prompt
+                            
+    <END CONVERSATION>
+    """
+    PROMPT_INSTRUCTIONS = """
+    Provide your safety assessment for $agent_type in the above conversation:
+    - First line must read 'safe' or 'unsafe'.
+    - If unsafe, a second line must include a comma-separated list of violated categories."""
+
+    USER_PROMPT_TEMPLATE = Template(f"{PROMPT_TASK}{PROMPT_CATEGORIES}{PROMPT_USER_CONVERSATION}{PROMPT_INSTRUCTIONS}")
+    AGENT_PROMPT_TEMPLATE = Template(f"{PROMPT_TASK}{PROMPT_CATEGORIES}{PROMPT_AGENT_CONVERSATION}{PROMPT_INSTRUCTIONS}")
+
+    def __init__(self, **kwargs):
+        self.ckpt_dir = kwargs.get('guard_lama_path', None)
+        self.tokenizer_path = self.ckpt_dir + "/tokenizer.model"
+        pass
+
+    def __call__(self, output_text, **kwargs):
+
+        agent_type = kwargs.get('agent_type', AgentType.USER)
+        user_prompt = kwargs.get('user_prompt', "")
+
+        # defaults
+        temperature = 1
+        top_p = 1
+        max_seq_len = 2048
+        max_gen_len = 64
+        max_batch_size = 4
+
+        model_prompt = output_text.strip()
+        if(agent_type == AgentType.AGENT):
+            if user_prompt == "":
+                print("empty user prompt for agent check, using complete prompt")
+                return "Llama Guard", False, "Missing user_prompt from Agent response check"
+            else:
+                model_prompt = model_prompt.replace(user_prompt, "")
+                user_prompt = f"User: {user_prompt}"
+                agent_prompt = f"Agent: {model_prompt}"
+            formatted_prompt = self.AGENT_PROMPT_TEMPLATE.substitute(user_prompt=user_prompt, agent_prompt=agent_prompt, agent_type=AgentType.AGENT.value)
+        else:
+            formatted_prompt = self.USER_PROMPT_TEMPLATE.substitute(prompt=model_prompt, agent_type=AgentType.USER.value)
+
+        
+        generator = Llama.build(
+            ckpt_dir=self.ckpt_dir,
+            tokenizer_path=self.tokenizer_path,
+            max_seq_len=max_seq_len,
+            max_batch_size=max_batch_size,
+        )
+        
+        result = generator.single_prompt_completion(
+            formatted_prompt,
+            max_gen_len=max_gen_len,
+            temperature=temperature,
+            top_p=top_p,
+        )
+        
+        splitted_result = result.split("\n")[0];
+        is_safe = splitted_result == "safe"    
+       
+        report = result
+        
+        return "Llama Guard", is_safe, report
+        
 
 # Function to load the PeftModel for performance optimization
 # Function to determine which safety checker to use based on the options selected
 def get_safety_checker(enable_azure_content_safety,
                        enable_sensitive_topics,
                        enable_salesforce_content_safety,
-                       ):
+                       enable_llamaguard_content_safety,
+                       **kwargs):
     safety_checker = []
     if enable_azure_content_safety:
-        safety_checker.append(AzureSaftyChecker())
+        safety_checker.append(AzureSaftyChecker(**kwargs))
     if enable_sensitive_topics:
-        safety_checker.append(AuditNLGSensitiveTopics())
+        safety_checker.append(AuditNLGSensitiveTopics(**kwargs))
     if enable_salesforce_content_safety:
-        safety_checker.append(SalesforceSafetyChecker())
+        safety_checker.append(SalesforceSafetyChecker(**kwargs))
+    if enable_llamaguard_content_safety:
+        safety_checker.append(LlamaGuardSafetyChecker(**kwargs))
     return safety_checker
 
-
-
-
-

+ 163 - 0
src/llama_recipes/tools/convert_hf_weights_to_llama.py

@@ -0,0 +1,163 @@
+import json
+import os
+from typing import List, Union
+
+import fire
+import torch
+from tqdm import tqdm
+from transformers import LlamaForCausalLM  # @manual
+
+NUM_SHARDS = {
+    "7B": 1,
+    "13B": 2,
+    "34B": 4,
+    "30B": 4,
+    "65B": 8,
+    "70B": 8,
+}
+
+
+def write_model(model_path, model_size, output_base_path):
+    dtype = torch.bfloat16
+
+    params = json.load(open(os.path.join(output_base_path, "params.json"), "r"))
+    num_shards = NUM_SHARDS[model_size]
+    n_layers = params["n_layers"]
+    n_heads = params["n_heads"]
+    n_heads_per_shard = n_heads // num_shards
+    dim = params["dim"]
+    dims_per_head = dim // n_heads
+    base = 10000.0
+    inv_freq = (
+        1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
+    ).to(dtype)
+
+    if "n_kv_heads" in params:
+        num_key_value_heads = params["n_kv_heads"]  # for GQA / MQA
+        num_local_key_value_heads = n_heads_per_shard // num_key_value_heads
+        key_value_dim = dim // num_key_value_heads
+    else:  # compatibility with other checkpoints
+        num_key_value_heads = n_heads
+        num_local_key_value_heads = n_heads_per_shard
+        key_value_dim = dim
+
+    model = LlamaForCausalLM.from_pretrained(
+        model_path,
+        torch_dtype=dtype,
+        low_cpu_mem_usage=True,
+    )
+    loaded = model.state_dict()
+
+    # permute for sliced rotary
+    def permute(w, n_heads=n_heads, dim1=dim, dim2=dim):
+        return (
+            w.view(n_heads, 2, dim1 // n_heads // 2, dim2)
+            .transpose(1, 2)
+            .reshape(dim1, dim2)
+        )
+
+    state_dict = [{} for _ in range(num_shards)]
+
+    def insert(name: str, tensor: Union[List, torch.Tensor]):
+        for i in range(num_shards):
+            state_dict[i][name] = (
+                tensor[i].clone() if isinstance(tensor, list) else tensor
+            )
+
+    def insert_chunk(name: str, tensor: torch.Tensor, dim: int):
+        tensors = tensor.chunk(num_shards, dim=dim)
+        for i, tensor in enumerate(tensors):
+            state_dict[i][name] = tensor.clone()
+
+    insert_chunk("tok_embeddings.weight", loaded["model.embed_tokens.weight"], 1)
+    insert("norm.weight", loaded["model.norm.weight"])
+    insert_chunk("output.weight", loaded["lm_head.weight"], 0)
+
+    for layer_i in tqdm(range(n_layers), desc="Converting layers"):
+
+        ts = (
+            permute(loaded[f"model.layers.{layer_i}.self_attn.q_proj.weight"])
+            .view(n_heads_per_shard * num_shards, dims_per_head, dim)
+            .chunk(num_shards, dim=0)
+        )
+        insert(f"layers.{layer_i}.attention.wq.weight", [t.view(-1, dim) for t in ts])
+
+        ts = (
+            permute(
+                loaded[f"model.layers.{layer_i}.self_attn.k_proj.weight"],
+                num_key_value_heads,
+                key_value_dim,
+                dim,
+            )
+            .view(num_local_key_value_heads * num_shards, dims_per_head, dim)
+            .chunk(num_shards, dim=0)
+        )
+        insert(f"layers.{layer_i}.attention.wk.weight", [t.view(-1, dim) for t in ts])
+
+        ts = (
+            loaded[f"model.layers.{layer_i}.self_attn.v_proj.weight"]
+            .view(num_local_key_value_heads * num_shards, dims_per_head, dim)
+            .chunk(num_shards, dim=0)
+        )
+        insert(f"layers.{layer_i}.attention.wv.weight", [t.view(-1, dim) for t in ts])
+
+        insert_chunk(
+            f"layers.{layer_i}.attention.wo.weight",
+            loaded[f"model.layers.{layer_i}.self_attn.o_proj.weight"],
+            1,
+        )
+
+        insert_chunk(
+            f"layers.{layer_i}.feed_forward.w1.weight",
+            loaded[f"model.layers.{layer_i}.mlp.gate_proj.weight"],
+            0,
+        )
+
+        insert_chunk(
+            f"layers.{layer_i}.feed_forward.w2.weight",
+            loaded[f"model.layers.{layer_i}.mlp.down_proj.weight"],
+            1,
+        )
+
+        insert_chunk(
+            f"layers.{layer_i}.feed_forward.w3.weight",
+            loaded[f"model.layers.{layer_i}.mlp.up_proj.weight"],
+            0,
+        )
+
+        insert(
+            f"layers.{layer_i}.attention_norm.weight",
+            loaded[f"model.layers.{layer_i}.input_layernorm.weight"],
+        )
+        insert(
+            f"layers.{layer_i}.ffn_norm.weight",
+            loaded[f"model.layers.{layer_i}.post_attention_layernorm.weight"],
+        )
+    insert("rope.freqs", inv_freq)
+
+    for i in tqdm(range(num_shards), desc="Saving checkpoint shards"):
+        torch.save(
+            state_dict[i], os.path.join(output_base_path, f"consolidated.{i:02d}.pth")
+        )
+
+
+def main(
+    model_path: str,
+    model_size: str,
+    output_dir: str,
+):
+    """Convert llama weights from huggingface format to consolidated format.
+    params:
+    model_path: model name or path to the model directory.
+    model_size: Llama model size, one of 7B, 13B, 34B, 30B, 65B, 70B.
+    output_dir: directory to save Llama weights, should contains params.json.
+    """
+    assert model_size in NUM_SHARDS, f"Unknown model size {model_size}"
+    params_path = os.path.join(output_dir, "params.json")
+    assert os.path.isfile(params_path), f"{params_path} does not exist"
+
+    write_model(model_path, model_size, output_dir)
+
+
+if __name__ == "__main__":
+    fire.Fire(main)

+ 13 - 2
src/llama_recipes/utils/train_utils.py

@@ -19,7 +19,7 @@ from transformers import LlamaTokenizer
 
 
 from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint
-from llama_recipes.policies import fpSixteen,bfSixteen_mixed, get_llama_wrapper
+from llama_recipes.policies import fpSixteen,bfSixteen, get_llama_wrapper
 from llama_recipes.utils.memory_utils import MemoryTrace
 
 
@@ -87,6 +87,12 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                     # if fp16 is enabled, use gradient scaler to handle gradient update
                     scaler.scale(loss).backward()
                     if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
+                        if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
+                            scaler.unscale_(optimizer)
+                            if train_config.enable_fsdp:
+                                model.clip_grad_norm_(train_config.gradient_clipping_threshold)
+                            else:
+                                torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold)
                         scaler.step(optimizer)
                         scaler.update()
                         optimizer.zero_grad()
@@ -95,6 +101,11 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                     # regular backpropagation when fp16 is not used
                     loss.backward()
                     if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
+                        if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
+                            if train_config.enable_fsdp:
+                                model.clip_grad_norm_(train_config.gradient_clipping_threshold)
+                            else:
+                                torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold)
                         optimizer.step()
                         optimizer.zero_grad()
                         pbar.update(1)
@@ -356,7 +367,7 @@ def get_policies(cfg, rank):
         bf16_ready = verify_bfloat_support
 
         if bf16_ready and not cfg.use_fp16:
-            mixed_precision_policy = bfSixteen_mixed
+            mixed_precision_policy = bfSixteen
             if rank == 0:
                 print(f"bFloat16 enabled for mixed precision - using bfSixteen policy")
         elif cfg.use_fp16:

+ 38 - 6
tests/conftest.py

@@ -5,14 +5,46 @@ import pytest
 
 from transformers import LlamaTokenizer
 
+ACCESS_ERROR_MSG = "Could not access tokenizer at 'meta-llama/Llama-2-7b-hf'. Did you log into huggingface hub and provided the correct token?"
+
+unskip_missing_tokenizer = False
+
+@pytest.fixture(scope="module")
+def llama_tokenizer():
+    try:
+        return LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
+    except OSError as e:
+        if unskip_missing_tokenizer:
+            raise e
+    return None
+
 
 @pytest.fixture
-def setup_tokenizer():
-    def _helper(tokenizer):
+def setup_tokenizer(llama_tokenizer):
+    def _helper(tokenizer_mock):
         #Align with Llama 2 tokenizer
-        tokenizer.from_pretrained.return_value = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")
-        tokenizer.from_pretrained.return_value.add_special_tokens({'bos_token': '<s>', 'eos_token': '</s>'})
-        tokenizer.from_pretrained.return_value.bos_token_id = 1
-        tokenizer.from_pretrained.return_value.eos_token_id = 2
+        tokenizer_mock.from_pretrained.return_value = llama_tokenizer
 
     return _helper
+
+
+@pytest.fixture(autouse=True)
+def skip_if_tokenizer_is_missing(request, llama_tokenizer):
+    if request.node.get_closest_marker("skip_missing_tokenizer") and not unskip_missing_tokenizer:
+        if llama_tokenizer is None:
+            pytest.skip(ACCESS_ERROR_MSG)
+
+
+def pytest_addoption(parser):
+    parser.addoption(
+        "--unskip-missing-tokenizer",
+        action="store_true",
+        default=False, help="disable skip missing tokenizer")
+
+
+@pytest.hookimpl(tryfirst=True)
+def pytest_cmdline_preparse(config, args):
+    if "--unskip-missing-tokenizer" not in args:
+        return
+    global unskip_missing_tokenizer
+    unskip_missing_tokenizer = True

+ 2 - 1
tests/datasets/test_custom_dataset.py

@@ -17,6 +17,7 @@ def check_padded_entry(batch):
     assert batch["input_ids"][0][-1] == 2
 
 
+@pytest.mark.skip_missing_tokenizer()
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
@@ -29,7 +30,7 @@ def test_custom_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker,
 
     kwargs = {
         "dataset": "custom_dataset",
-        "model_name": "decapoda-research/llama-7b-hf", # We use the tokenizer as a surrogate for llama2 tokenizer here
+        "model_name": "meta-llama/Llama-2-7b-hf",
         "custom_dataset.file": "examples/custom_dataset.py",
         "custom_dataset.train_split": "validation",
         "batch_size_training": 2,

+ 5 - 3
tests/datasets/test_grammar_datasets.py

@@ -1,11 +1,13 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
+import pytest
 from unittest.mock import patch
 
 from transformers import LlamaTokenizer
 
 
+@pytest.mark.skip_missing_tokenizer()
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
@@ -18,7 +20,7 @@ def test_grammar_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker
 
     BATCH_SIZE = 8
     kwargs = {
-        "model_name": "decapoda-research/llama-7b-hf",
+        "model_name": "meta-llama/Llama-2-7b-hf",
         "batch_size_training": BATCH_SIZE,
         "val_batch_size": 1,
         "use_peft": False,
@@ -46,8 +48,8 @@ def test_grammar_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker
     assert "input_ids" in batch.keys()
     assert "attention_mask" in batch.keys()
 
-    assert batch["labels"][0][29] == -100
-    assert batch["labels"][0][30] == 29871
+    assert batch["labels"][0][31] == -100
+    assert batch["labels"][0][32] == 1152
 
     assert batch["input_ids"][0][0] == 1
     assert batch["labels"][0][-1] == 2

+ 4 - 2
tests/datasets/test_samsum_datasets.py

@@ -1,10 +1,12 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
+import pytest
 from functools import partial
 from unittest.mock import patch
 
 
+@pytest.mark.skip_missing_tokenizer()
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
@@ -17,7 +19,7 @@ def test_samsum_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker,
 
     BATCH_SIZE = 8
     kwargs = {
-        "model_name": "decapoda-research/llama-7b-hf",
+        "model_name": "meta-llama/Llama-2-7b-hf",
         "batch_size_training": BATCH_SIZE,
         "val_batch_size": 1,
         "use_peft": False,
@@ -46,7 +48,7 @@ def test_samsum_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker,
     assert "attention_mask" in batch.keys()
 
     assert batch["labels"][0][268] == -100
-    assert batch["labels"][0][269] == 22291
+    assert batch["labels"][0][269] == 319
 
     assert batch["input_ids"][0][0] == 1
     assert batch["labels"][0][-1] == 2

+ 4 - 2
tests/test_batching.py

@@ -5,6 +5,7 @@ import pytest
 from unittest.mock import patch
 
 
+@pytest.mark.skip_missing_tokenizer()
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
@@ -16,7 +17,7 @@ def test_packing(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_
     setup_tokenizer(tokenizer)
 
     kwargs = {
-        "model_name": "decapoda-research/llama-7b-hf",
+        "model_name": "meta-llama/Llama-2-7b-hf",
         "batch_size_training": 8,
         "val_batch_size": 1,
         "use_peft": False,
@@ -46,6 +47,7 @@ def test_packing(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_
     assert batch["attention_mask"][0].size(0) == 4096
 
 
+@pytest.mark.skip_missing_tokenizer()
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
@@ -69,7 +71,7 @@ def test_distributed_packing(dist, is_initialized, fsdp, setup, step_lr, optimiz
     os.environ['MASTER_PORT'] = '12345'
 
     kwargs = {
-        "model_name": "decapoda-research/llama-7b-hf",
+        "model_name": "meta-llama/Llama-2-7b-hf",
         "batch_size_training": 8,
         "val_batch_size": 1,
         "use_peft": False,

+ 8 - 7
tests/test_train_utils.py

@@ -12,7 +12,7 @@ from llama_recipes.utils.train_utils import train
 @patch("llama_recipes.utils.train_utils.torch.cuda.amp.GradScaler")
 @patch("llama_recipes.utils.train_utils.torch.cuda.amp.autocast")
 def test_gradient_accumulation(autocast, scaler, nullcontext, mem_trace, mocker):
-    
+
     model = mocker.MagicMock(name="model")
     model().loss.__truediv__().detach.return_value = torch.tensor(1)
     mock_tensor = mocker.MagicMock(name="tensor")
@@ -27,7 +27,8 @@ def test_gradient_accumulation(autocast, scaler, nullcontext, mem_trace, mocker)
     train_config.enable_fsdp = False
     train_config.use_fp16 = False
     train_config.run_validation = False
-    
+    train_config.gradient_clipping = False
+
     train(
         model,
         train_dataloader,
@@ -38,15 +39,15 @@ def test_gradient_accumulation(autocast, scaler, nullcontext, mem_trace, mocker)
         gradient_accumulation_steps,
         train_config,
     )
-    
+
     assert optimizer.zero_grad.call_count == 5
     optimizer.zero_grad.reset_mock()
-    
+
     assert nullcontext.call_count == 5
     nullcontext.reset_mock()
-    
+
     assert autocast.call_count == 0
-    
+
     gradient_accumulation_steps = 2
     train_config.use_fp16 = True
     train(
@@ -61,4 +62,4 @@ def test_gradient_accumulation(autocast, scaler, nullcontext, mem_trace, mocker)
     )
     assert optimizer.zero_grad.call_count == 3
     assert nullcontext.call_count == 0
-    assert autocast.call_count == 5
+    assert autocast.call_count == 5