瀏覽代碼

Merge branch 'main' into messenger

Jeff Tang 1 年之前
父節點
當前提交
826776bca0

+ 3 - 3
README.md

@@ -76,7 +76,7 @@ Optional dependencies can also be combines with [option1,option2].
 
 # Where to find the models?
 
-You can find llama v2 models on HuggingFace hub [here](https://huggingface.co/meta-llama), where models with `hf` in the name are already converted to HuggingFace checkpoints so no further conversion is needed. The conversion step below is only for original model weights from Meta that are hosted on HuggingFace model hub as well.
+You can find llama v2 models on Hugging Face hub [here](https://huggingface.co/meta-llama), where models with `hf` in the name are already converted to Hugging Face checkpoints so no further conversion is needed. The conversion step below is only for original model weights from Meta that are hosted on Hugging Face model hub as well.
 
 # Model conversion to Hugging Face
 The recipes and notebooks in this folder are using the Llama 2 model definition provided by Hugging Face's transformers library.
@@ -84,7 +84,7 @@ The recipes and notebooks in this folder are using the Llama 2 model definition
 Given that the original checkpoint resides under models/7B you can install all requirements and convert the checkpoint with:
 
 ```bash
-## Install HuggingFace Transformers from source
+## Install Hugging Face Transformers from source
 pip freeze | grep transformers ## verify it is version 4.31.0 or higher
 
 git clone git@github.com:huggingface/transformers.git
@@ -141,7 +141,7 @@ Here we use FSDP as discussed in the next section which can be used along with P
 
 ## Flash Attention and Xformer Memory Efficient Kernels
 
-Setting `use_fast_kernels` will enable using of Flash Attention or Xformer memory-efficient kernels based on the hardware being used. This would speed up the fine-tuning job. This has been enabled in `optimum` library from HuggingFace as a one-liner API, please read more [here](https://pytorch.org/blog/out-of-the-box-acceleration/).
+Setting `use_fast_kernels` will enable using of Flash Attention or Xformer memory-efficient kernels based on the hardware being used. This would speed up the fine-tuning job. This has been enabled in `optimum` library from Hugging Face as a one-liner API, please read more [here](https://pytorch.org/blog/out-of-the-box-acceleration/).
 
 ```bash
 torchrun --nnodes 1 --nproc_per_node 4  examples/finetuning.py --enable_fsdp --use_peft --peft_method lora --model_name /patht_of_model_folder/7B --fsdp_config.pure_bf16 --output_dir Path/to/save/PEFT/model --use_fast_kernels

文件差異過大導致無法顯示
+ 8 - 8
demo_apps/README.md


文件差異過大導致無法顯示
+ 384 - 0
examples/Purple_Llama_Anyscale.ipynb


+ 3 - 2
examples/README.md

@@ -1,7 +1,6 @@
 # Examples
 
-This folder contains finetuning and inference examples for Llama 2.
-For the full documentation on these examples please refer to [docs/inference.md](../docs/inference.md)
+This folder contains finetuning and inference examples for Llama 2, Code Llama and (Purple Llama](https://ai.meta.com/llama/purple-llama/). For the full documentation on these examples please refer to [docs/inference.md](../docs/inference.md)
 
 ## Finetuning
 
@@ -27,6 +26,8 @@ So far, we have provide the following inference examples:
 
 5. [Code Llama](./code_llama/) folder which provides examples for [code completion](./code_llama/code_completion_example.py) and [code infilling](./code_llama/code_infilling_example.py).
 
+6. The [Purple Llama Using Anyscale](./Purple_Llama_Anyscale.ipynb) is a notebook that shows how to use Anyscale hosted Llama Guard model to classify user inputs as safe or unsafe.
+
 For more in depth information on inference including inference safety checks and examples, see the inference documentation [here](../docs/inference.md).
 
 **Note** The [sensitive topics safety checker](../src/llama_recipes/inference/safety_utils.py) utilizes AuditNLG which is an optional dependency. Please refer to installation section of the main [README.md](../README.md#install-with-optional-dependencies) for details.

+ 12 - 2
examples/inference.py

@@ -11,7 +11,7 @@ import time
 import torch
 from transformers import LlamaTokenizer
 
-from llama_recipes.inference.safety_utils import get_safety_checker
+from llama_recipes.inference.safety_utils import get_safety_checker, AgentType
 from llama_recipes.inference.model_utils import load_model, load_peft_model
 
 
@@ -33,6 +33,8 @@ def main(
     enable_azure_content_safety: bool=False, # Enable safety check with Azure content safety api
     enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs
     enable_salesforce_content_safety: bool=True, # Enable safety check with Salesforce safety flan t5
+    enable_llamaguard_content_safety: bool=False,
+    llamaguard_model_name: str=None,
     max_padding_length: int=None, # the max padding length to be used with tokenizer padding the prompts.
     use_fast_kernels: bool = False, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
     **kwargs
@@ -48,6 +50,12 @@ def main(
     else:
         print("No user prompt provided. Exiting.")
         sys.exit(1)
+
+    if enable_llamaguard_content_safety:
+        if not llamaguard_model_name:
+            print("if enable_llamaguard_content_safety is used, provide the model path with --llamaguard_model_name")
+            sys.exit(1)
+
     
     # Set the seeds for reproducibility
     torch.cuda.manual_seed(seed)
@@ -77,6 +85,8 @@ def main(
     safety_checker = get_safety_checker(enable_azure_content_safety,
                                         enable_sensitive_topics,
                                         enable_salesforce_content_safety,
+                                        enable_llamaguard_content_safety,
+                                        guard_lama_path=llamaguard_model_name
                                         )
 
     # Safety check of the user prompt
@@ -117,7 +127,7 @@ def main(
     output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
     
     # Safety check of the model output
-    safety_results = [check(output_text) for check in safety_checker]
+    safety_results = [check(output_text, agent_type=AgentType.AGENT, user_prompt=user_prompt) for check in safety_checker]
     are_safe = all([r[1] for r in safety_results])
     if are_safe:
         print("User input and model output deemed safe.")

文件差異過大導致無法顯示
+ 19 - 0
examples/llama_guard/README.md


+ 6 - 0
examples/llama_guard/__init__.py

@@ -0,0 +1,6 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+from .generation import Llama, Dialog
+from .model import ModelArgs, Transformer
+from .tokenizer import Tokenizer

+ 458 - 0
examples/llama_guard/generation.py

@@ -0,0 +1,458 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import json
+import os
+import sys
+import time
+from pathlib import Path
+from typing import List, Literal, Optional, Tuple, TypedDict
+
+import torch
+import torch.nn.functional as F
+from fairscale.nn.model_parallel.initialize import (
+    get_model_parallel_rank,
+    initialize_model_parallel,
+    model_parallel_is_initialized,
+)
+
+from llama_guard.model import ModelArgs, Transformer
+from llama_guard.tokenizer import Tokenizer
+
+Role = Literal["system", "user", "assistant"]
+
+
+class Message(TypedDict):
+    role: Role
+    content: str
+
+
+class CompletionPrediction(TypedDict, total=False):
+    generation: str
+    tokens: List[str]  # not required
+    logprobs: List[float]  # not required
+
+
+class ChatPrediction(TypedDict, total=False):
+    generation: Message
+    tokens: List[str]  # not required
+    logprobs: List[float]  # not required
+
+
+Dialog = List[Message]
+
+B_INST, E_INST = "[INST]", "[/INST]"
+B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
+
+SPECIAL_TAGS = [B_INST, E_INST, "<<SYS>>", "<</SYS>>"]
+UNSAFE_ERROR = "Error: special tags are not allowed as part of the prompt."
+
+
+class Llama:
+    @staticmethod
+    def build(
+        ckpt_dir: str,
+        tokenizer_path: str,
+        max_seq_len: int,
+        max_batch_size: int,
+        model_parallel_size: Optional[int] = None,
+        seed: int = 1,
+    ) -> "Llama":
+        """
+        Build a Llama instance by initializing and loading a pre-trained model.
+
+        Args:
+            ckpt_dir (str): Path to the directory containing checkpoint files.
+            tokenizer_path (str): Path to the tokenizer file.
+            max_seq_len (int): Maximum sequence length for input text.
+            max_batch_size (int): Maximum batch size for inference.
+            model_parallel_size (Optional[int], optional): Number of model parallel processes.
+                If not provided, it's determined from the environment. Defaults to None.
+
+        Returns:
+            Llama: An instance of the Llama class with the loaded model and tokenizer.
+
+        Raises:
+            AssertionError: If there are no checkpoint files in the specified directory,
+                or if the model parallel size does not match the number of checkpoint files.
+
+        Note:
+            This method initializes the distributed process group, sets the device to CUDA,
+            and loads the pre-trained model and tokenizer.
+
+        """
+        if not torch.distributed.is_initialized():
+            torch.distributed.init_process_group("nccl")
+        if not model_parallel_is_initialized():
+            if model_parallel_size is None:
+                model_parallel_size = int(os.environ.get("WORLD_SIZE", 1))
+            initialize_model_parallel(model_parallel_size)
+
+        local_rank = int(os.environ.get("LOCAL_RANK", 0))
+        torch.cuda.set_device(local_rank)
+
+        # seed must be the same in all processes
+        torch.manual_seed(seed)
+
+        if local_rank > 0:
+            sys.stdout = open(os.devnull, "w")
+
+        start_time = time.time()
+        checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
+        checkpoints_size = len(checkpoints)
+        assert checkpoints_size > 0, f"no checkpoint files found in {ckpt_dir}"
+        ckpt_path = checkpoints[get_model_parallel_rank()]
+        checkpoint = torch.load(ckpt_path, map_location="cpu")
+        with open(Path(ckpt_dir) / "params.json", "r") as f:
+            params = json.loads(f.read())
+
+        model_args: ModelArgs = ModelArgs(
+            max_seq_len=max_seq_len,
+            max_batch_size=max_batch_size,
+            **params,
+        )
+        tokenizer = Tokenizer(model_path=tokenizer_path)
+        model_args.vocab_size = tokenizer.n_words
+        torch.set_default_tensor_type(torch.cuda.HalfTensor)
+        model = Transformer(model_args)
+        model.load_state_dict(checkpoint, strict=False)
+        print(f"Loaded in {time.time() - start_time:.2f} seconds")
+
+        return Llama(model, tokenizer)
+
+    def __init__(self, model: Transformer, tokenizer: Tokenizer):
+        self.model = model
+        self.tokenizer = tokenizer
+
+    @torch.inference_mode()
+    def generate(
+        self,
+        prompt_tokens: List[List[int]],
+        max_gen_len: int,
+        temperature: float = 0.6,
+        top_p: float = 0.9,
+        logprobs: bool = False,
+        echo: bool = False,
+    ) -> Tuple[List[List[int]], Optional[List[List[float]]]]:
+        """
+        Generate text sequences based on provided prompts using the language generation model.
+
+        Args:
+            prompt_tokens (List[List[int]]): List of tokenized prompts, where each prompt is represented as a list of integers.
+            max_gen_len (int): Maximum length of the generated text sequence.
+            temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
+            top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
+            logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False.
+            echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
+
+        Returns:
+            Tuple[List[List[int]], Optional[List[List[float]]]]: A tuple containing generated token sequences and, if logprobs is True, corresponding token log probabilities.
+
+        Note:
+            This method uses the provided prompts as a basis for generating text. It employs nucleus sampling to produce text with controlled randomness.
+            If logprobs is True, token log probabilities are computed for each generated token.
+
+        """
+        params = self.model.params
+        bsz = len(prompt_tokens)
+        assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
+
+        min_prompt_len = min(len(t) for t in prompt_tokens)
+        max_prompt_len = max(len(t) for t in prompt_tokens)
+        assert max_prompt_len <= params.max_seq_len
+        total_len = min(params.max_seq_len, max_gen_len + max_prompt_len)
+
+        pad_id = self.tokenizer.pad_id
+        tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long, device="cuda")
+        for k, t in enumerate(prompt_tokens):
+            tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
+        if logprobs:
+            token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
+
+        prev_pos = 0
+        eos_reached = torch.tensor([False] * bsz, device="cuda")
+        input_text_mask = tokens != pad_id
+        if min_prompt_len == total_len:
+            logits = self.model.forward(tokens, prev_pos)
+            token_logprobs = -F.cross_entropy(
+                input=logits.transpose(1, 2),
+                target=tokens,
+                reduction="none",
+                ignore_index=pad_id,
+            )
+
+        for cur_pos in range(min_prompt_len, total_len):
+            logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
+            if temperature > 0:
+                probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
+                next_token = sample_top_p(probs, top_p)
+            else:
+                next_token = torch.argmax(logits[:, -1], dim=-1)
+
+            next_token = next_token.reshape(-1)
+            # only replace token if prompt has already been generated
+            next_token = torch.where(
+                input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token
+            )
+            tokens[:, cur_pos] = next_token
+            if logprobs:
+                token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
+                    input=logits.transpose(1, 2),
+                    target=tokens[:, prev_pos + 1 : cur_pos + 1],
+                    reduction="none",
+                    ignore_index=pad_id,
+                )
+            eos_reached |= (~input_text_mask[:, cur_pos]) & (
+                next_token == self.tokenizer.eos_id
+            )
+            prev_pos = cur_pos
+            if all(eos_reached):
+                break
+
+        if logprobs:
+            token_logprobs = token_logprobs.tolist()
+        out_tokens, out_logprobs = [], []
+        for i, toks in enumerate(tokens.tolist()):
+            # cut to max gen len
+            start = 0 if echo else len(prompt_tokens[i])
+            toks = toks[start : len(prompt_tokens[i]) + max_gen_len]
+            probs = None
+            if logprobs:
+                probs = token_logprobs[i][start : len(prompt_tokens[i]) + max_gen_len]
+            # cut to eos tok if any
+            if self.tokenizer.eos_id in toks:
+                eos_idx = toks.index(self.tokenizer.eos_id)
+                toks = toks[:eos_idx]
+                probs = probs[:eos_idx] if logprobs else None
+            out_tokens.append(toks)
+            out_logprobs.append(probs)
+        return (out_tokens, out_logprobs if logprobs else None)
+
+    def text_completion(
+        self,
+        prompts: List[str],
+        temperature: float = 0.6,
+        top_p: float = 0.9,
+        max_gen_len: Optional[int] = None,
+        logprobs: bool = False,
+        echo: bool = False,
+    ) -> List[CompletionPrediction]:
+        """
+        Perform text completion for a list of prompts using the language generation model.
+
+        Args:
+            prompts (List[str]): List of text prompts for completion.
+            temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
+            top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
+            max_gen_len (Optional[int], optional): Maximum length of the generated completion sequence.
+                If not provided, it's set to the model's maximum sequence length minus 1.
+            logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False.
+            echo (bool, optional): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
+
+        Returns:
+            List[CompletionPrediction]: List of completion predictions, each containing the generated text completion.
+
+        Note:
+            This method generates text completions for the provided prompts, employing nucleus sampling to introduce controlled randomness.
+            If logprobs is True, token log probabilities are computed for each generated token.
+
+        """
+        if max_gen_len is None:
+            max_gen_len = self.model.params.max_seq_len - 1
+        prompt_tokens = [self.tokenizer.encode(x, bos=True, eos=False) for x in prompts]
+        generation_tokens, generation_logprobs = self.generate(
+            prompt_tokens=prompt_tokens,
+            max_gen_len=max_gen_len,
+            temperature=temperature,
+            top_p=top_p,
+            logprobs=logprobs,
+            echo=echo,
+        )
+        if logprobs:
+            return [
+                {
+                    "generation": self.tokenizer.decode(t),
+                    "tokens": [self.tokenizer.decode(x) for x in t],
+                    "logprobs": logprobs_i,
+                }
+                for t, logprobs_i in zip(generation_tokens, generation_logprobs)
+            ]
+        return [{"generation": self.tokenizer.decode(t)} for t in generation_tokens]
+
+    def chat_completion(
+        self,
+        dialogs: List[Dialog],
+        temperature: float = 0.6,
+        top_p: float = 0.9,
+        max_gen_len: Optional[int] = None,
+        logprobs: bool = False,
+    ) -> List[ChatPrediction]:
+        """
+        Generate assistant responses for a list of conversational dialogs using the language generation model.
+
+        Args:
+            dialogs (List[Dialog]): List of conversational dialogs, where each dialog is a list of messages.
+            temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
+            top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
+            max_gen_len (Optional[int], optional): Maximum length of the generated response sequence.
+                If not provided, it's set to the model's maximum sequence length minus 1.
+            logprobs (bool, optional): Flag indicating whether to compute token log probabilities. Defaults to False.
+
+        Returns:
+            List[ChatPrediction]: List of chat predictions, each containing the assistant's generated response.
+
+        Raises:
+            AssertionError: If the last message in a dialog is not from the user.
+            AssertionError: If the dialog roles are not in the required 'user', 'assistant', and optional 'system' order.
+
+        Note:
+            This method generates assistant responses for the provided conversational dialogs.
+            It employs nucleus sampling to introduce controlled randomness in text generation.
+            If logprobs is True, token log probabilities are computed for each generated token.
+
+        """
+        if max_gen_len is None:
+            max_gen_len = self.model.params.max_seq_len - 1
+        prompt_tokens = []
+        unsafe_requests = []
+        for dialog in dialogs:
+            unsafe_requests.append(
+                any([tag in msg["content"] for tag in SPECIAL_TAGS for msg in dialog])
+            )
+            if dialog[0]["role"] == "system":
+                dialog = [
+                    {
+                        "role": dialog[1]["role"],
+                        "content": B_SYS
+                        + dialog[0]["content"]
+                        + E_SYS
+                        + dialog[1]["content"],
+                    }
+                ] + dialog[2:]
+            assert all([msg["role"] == "user" for msg in dialog[::2]]) and all(
+                [msg["role"] == "assistant" for msg in dialog[1::2]]
+            ), (
+                "model only supports 'system', 'user' and 'assistant' roles, "
+                "starting with 'system', then 'user' and alternating (u/a/u/a/u...)"
+            )
+            dialog_tokens: List[int] = sum(
+                [
+                    self.tokenizer.encode(
+                        f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ",
+                        bos=True,
+                        eos=True,
+                    )
+                    for prompt, answer in zip(
+                        dialog[::2],
+                        dialog[1::2],
+                    )
+                ],
+                [],
+            )
+            assert (
+                dialog[-1]["role"] == "user"
+            ), f"Last message must be from user, got {dialog[-1]['role']}"
+            dialog_tokens += self.tokenizer.encode(
+                f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}",
+                bos=True,
+                eos=False,
+            )
+            prompt_tokens.append(dialog_tokens)
+
+        generation_tokens, generation_logprobs = self.generate(
+            prompt_tokens=prompt_tokens,
+            max_gen_len=max_gen_len,
+            temperature=temperature,
+            top_p=top_p,
+            logprobs=logprobs,
+        )
+        if logprobs:
+            return [
+                {
+                    "generation": {
+                        "role": "assistant",
+                        "content": self.tokenizer.decode(t)
+                        if not unsafe
+                        else UNSAFE_ERROR,
+                    },
+                    "tokens": [self.tokenizer.decode(x) for x in t],
+                    "logprobs": logprobs_i,
+                }
+                for t, logprobs_i, unsafe in zip(
+                    generation_tokens, generation_logprobs, unsafe_requests
+                )
+            ]
+        return [
+            {
+                "generation": {
+                    "role": "assistant",
+                    "content": self.tokenizer.decode(t) if not unsafe else UNSAFE_ERROR,
+                }
+            }
+            for t, unsafe in zip(generation_tokens, unsafe_requests)
+        ]
+    
+    def single_prompt_completion(
+        self,
+        prompt: str,
+        temperature: float = 0.6,
+        top_p: float = 0.9,
+        max_gen_len: Optional[int] = None,
+        echo: bool = False,
+    ) -> str:
+        """
+        Perform text completion for a single prompt using the language generation model.
+
+        Args:
+            prompts (str): prompt for completion.
+            temperature (float, optional): Temperature value for controlling randomness in sampling. Defaults to 0.6.
+            top_p (float, optional): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
+            max_gen_len (Optional[int], optional): Maximum length of the generated completion sequence.
+                If not provided, it's set to the model's maximum sequence length minus 1.
+            
+
+        Returns:
+            str: single string with the decoded output from the model.
+
+        Note:
+            This method generates text completions for the provided prompts, employing nucleus sampling to introduce controlled randomness.
+        """
+        if max_gen_len is None:
+            max_gen_len = self.model.params.max_seq_len - 1
+        prompt_tokens = [self.tokenizer.encode(f"{B_INST} {prompt.strip()} {E_INST}", bos=True, eos=False)]
+        generation_tokens = self.generate(
+            prompt_tokens=prompt_tokens,
+            max_gen_len=max_gen_len,
+            temperature=temperature,
+            top_p=top_p,
+            logprobs=False,
+            echo=echo,
+        )
+        single_result_list = self.tokenizer.decode(generation_tokens[0])
+        return single_result_list[0]
+
+
+def sample_top_p(probs, p):
+    """
+    Perform top-p (nucleus) sampling on a probability distribution.
+
+    Args:
+        probs (torch.Tensor): Probability distribution tensor.
+        p (float): Probability threshold for top-p sampling.
+
+    Returns:
+        torch.Tensor: Sampled token indices.
+
+    Note:
+        Top-p sampling selects the smallest set of tokens whose cumulative probability mass
+        exceeds the threshold p. The distribution is renormalized based on the selected tokens.
+
+    """
+    probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
+    probs_sum = torch.cumsum(probs_sort, dim=-1)
+    mask = probs_sum - probs_sort > p
+    probs_sort[mask] = 0.0
+    probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
+    next_token = torch.multinomial(probs_sort, num_samples=1)
+    next_token = torch.gather(probs_idx, -1, next_token)
+    return next_token

+ 495 - 0
examples/llama_guard/model.py

@@ -0,0 +1,495 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import math
+from dataclasses import dataclass
+from typing import Optional, Tuple
+
+import fairscale.nn.model_parallel.initialize as fs_init
+import torch
+import torch.nn.functional as F
+from fairscale.nn.model_parallel.layers import (
+    ColumnParallelLinear,
+    ParallelEmbedding,
+    RowParallelLinear,
+)
+from torch import nn
+
+
+@dataclass
+class ModelArgs:
+    dim: int = 4096
+    n_layers: int = 32
+    n_heads: int = 32
+    n_kv_heads: Optional[int] = None
+    vocab_size: int = -1  # defined later by tokenizer
+    multiple_of: int = 256  # make SwiGLU hidden layer size multiple of large power of 2
+    ffn_dim_multiplier: Optional[float] = None
+    norm_eps: float = 1e-5
+
+    max_batch_size: int = 32
+    max_seq_len: int = 2048
+
+
+class RMSNorm(torch.nn.Module):
+    def __init__(self, dim: int, eps: float = 1e-6):
+        """
+        Initialize the RMSNorm normalization layer.
+
+        Args:
+            dim (int): The dimension of the input tensor.
+            eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6.
+
+        Attributes:
+            eps (float): A small value added to the denominator for numerical stability.
+            weight (nn.Parameter): Learnable scaling parameter.
+
+        """
+        super().__init__()
+        self.eps = eps
+        self.weight = nn.Parameter(torch.ones(dim))
+
+    def _norm(self, x):
+        """
+        Apply the RMSNorm normalization to the input tensor.
+
+        Args:
+            x (torch.Tensor): The input tensor.
+
+        Returns:
+            torch.Tensor: The normalized tensor.
+
+        """
+        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
+
+    def forward(self, x):
+        """
+        Forward pass through the RMSNorm layer.
+
+        Args:
+            x (torch.Tensor): The input tensor.
+
+        Returns:
+            torch.Tensor: The output tensor after applying RMSNorm.
+
+        """
+        output = self._norm(x.float()).type_as(x)
+        return output * self.weight
+
+
+def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
+    """
+    Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
+
+    This function calculates a frequency tensor with complex exponentials using the given dimension 'dim'
+    and the end index 'end'. The 'theta' parameter scales the frequencies.
+    The returned tensor contains complex values in complex64 data type.
+
+    Args:
+        dim (int): Dimension of the frequency tensor.
+        end (int): End index for precomputing frequencies.
+        theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
+
+    Returns:
+        torch.Tensor: Precomputed frequency tensor with complex exponentials.
+
+    
+        
+
+    """
+    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
+    t = torch.arange(end, device=freqs.device)  # type: ignore
+    freqs = torch.outer(t, freqs).float()  # type: ignore
+    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
+    return freqs_cis
+
+
+def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
+    """
+    Reshape frequency tensor for broadcasting it with another tensor.
+
+    This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
+    for the purpose of broadcasting the frequency tensor during element-wise operations.
+
+    Args:
+        freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
+        x (torch.Tensor): Target tensor for broadcasting compatibility.
+
+    Returns:
+        torch.Tensor: Reshaped frequency tensor.
+
+    Raises:
+        AssertionError: If the frequency tensor doesn't match the expected shape.
+        AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
+    """
+    ndim = x.ndim
+    assert 0 <= 1 < ndim
+    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
+    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
+    return freqs_cis.view(*shape)
+
+
+def apply_rotary_emb(
+    xq: torch.Tensor,
+    xk: torch.Tensor,
+    freqs_cis: torch.Tensor,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+    """
+    Apply rotary embeddings to input tensors using the given frequency tensor.
+
+    This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
+    frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
+    is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
+    returned as real tensors.
+
+    Args:
+        xq (torch.Tensor): Query tensor to apply rotary embeddings.
+        xk (torch.Tensor): Key tensor to apply rotary embeddings.
+        freqs_cis (torch.Tensor): Precomputed frequency tensor for complex exponentials.
+
+    Returns:
+        Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
+
+        
+
+    """
+    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
+    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
+    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
+    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
+    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
+    return xq_out.type_as(xq), xk_out.type_as(xk)
+
+
+def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
+    """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
+    bs, slen, n_kv_heads, head_dim = x.shape
+    if n_rep == 1:
+        return x
+    return (
+        x[:, :, :, None, :]
+        .expand(bs, slen, n_kv_heads, n_rep, head_dim)
+        .reshape(bs, slen, n_kv_heads * n_rep, head_dim)
+    )
+
+
+class Attention(nn.Module):
+    """Multi-head attention module."""
+    def __init__(self, args: ModelArgs):
+        """
+        Initialize the Attention module.
+
+        Args:
+            args (ModelArgs): Model configuration parameters.
+
+        Attributes:
+            n_kv_heads (int): Number of key and value heads.
+            n_local_heads (int): Number of local query heads.
+            n_local_kv_heads (int): Number of local key and value heads.
+            n_rep (int): Number of repetitions for local heads.
+            head_dim (int): Dimension size of each attention head.
+            wq (ColumnParallelLinear): Linear transformation for queries.
+            wk (ColumnParallelLinear): Linear transformation for keys.
+            wv (ColumnParallelLinear): Linear transformation for values.
+            wo (RowParallelLinear): Linear transformation for output.
+            cache_k (torch.Tensor): Cached keys for attention.
+            cache_v (torch.Tensor): Cached values for attention.
+
+        """
+        super().__init__()
+        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
+        model_parallel_size = fs_init.get_model_parallel_world_size()
+        self.n_local_heads = args.n_heads // model_parallel_size
+        self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
+        self.n_rep = self.n_local_heads // self.n_local_kv_heads
+        self.head_dim = args.dim // args.n_heads
+
+        self.wq = ColumnParallelLinear(
+            args.dim,
+            args.n_heads * self.head_dim,
+            bias=False,
+            gather_output=False,
+            init_method=lambda x: x,
+        )
+        self.wk = ColumnParallelLinear(
+            args.dim,
+            self.n_kv_heads * self.head_dim,
+            bias=False,
+            gather_output=False,
+            init_method=lambda x: x,
+        )
+        self.wv = ColumnParallelLinear(
+            args.dim,
+            self.n_kv_heads * self.head_dim,
+            bias=False,
+            gather_output=False,
+            init_method=lambda x: x,
+        )
+        self.wo = RowParallelLinear(
+            args.n_heads * self.head_dim,
+            args.dim,
+            bias=False,
+            input_is_parallel=True,
+            init_method=lambda x: x,
+        )
+
+        self.cache_k = torch.zeros(
+            (
+                args.max_batch_size,
+                args.max_seq_len,
+                self.n_local_kv_heads,
+                self.head_dim,
+            )
+        ).cuda()
+        self.cache_v = torch.zeros(
+            (
+                args.max_batch_size,
+                args.max_seq_len,
+                self.n_local_kv_heads,
+                self.head_dim,
+            )
+        ).cuda()
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        start_pos: int,
+        freqs_cis: torch.Tensor,
+        mask: Optional[torch.Tensor],
+    ):
+        """
+        Forward pass of the attention module.
+
+        Args:
+            x (torch.Tensor): Input tensor.
+            start_pos (int): Starting position for caching.
+            freqs_cis (torch.Tensor): Precomputed frequency tensor.
+            mask (torch.Tensor, optional): Attention mask tensor.
+
+        Returns:
+            torch.Tensor: Output tensor after attention.
+
+        """
+        bsz, seqlen, _ = x.shape
+        xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
+
+        xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
+        xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
+        xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
+
+        xq, xk = apply_rotary_emb(xq, xk, freqs_cis=freqs_cis)
+
+        self.cache_k = self.cache_k.to(xq)
+        self.cache_v = self.cache_v.to(xq)
+
+        self.cache_k[:bsz, start_pos : start_pos + seqlen] = xk
+        self.cache_v[:bsz, start_pos : start_pos + seqlen] = xv
+
+        keys = self.cache_k[:bsz, : start_pos + seqlen]
+        values = self.cache_v[:bsz, : start_pos + seqlen]
+
+        # repeat k/v heads if n_kv_heads < n_heads
+        keys = repeat_kv(keys, self.n_rep)  # (bs, cache_len + seqlen, n_local_heads, head_dim)
+        values = repeat_kv(values, self.n_rep)  # (bs, cache_len + seqlen, n_local_heads, head_dim)
+
+        xq = xq.transpose(1, 2)  # (bs, n_local_heads, seqlen, head_dim)
+        keys = keys.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
+        values = values.transpose(1, 2) # (bs, n_local_heads, cache_len + seqlen, head_dim)
+        scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
+        if mask is not None:
+            scores = scores + mask  # (bs, n_local_heads, seqlen, cache_len + seqlen)
+        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
+        output = torch.matmul(scores, values)  # (bs, n_local_heads, seqlen, head_dim)
+        output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
+        return self.wo(output)
+
+
+class FeedForward(nn.Module):
+    def __init__(
+        self,
+        dim: int,
+        hidden_dim: int,
+        multiple_of: int,
+        ffn_dim_multiplier: Optional[float],
+    ):
+        """
+        Initialize the FeedForward module.
+
+        Args:
+            dim (int): Input dimension.
+            hidden_dim (int): Hidden dimension of the feedforward layer.
+            multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
+            ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.
+
+        Attributes:
+            w1 (ColumnParallelLinear): Linear transformation for the first layer.
+            w2 (RowParallelLinear): Linear transformation for the second layer.
+            w3 (ColumnParallelLinear): Linear transformation for the third layer.
+
+        """
+        super().__init__()
+        hidden_dim = int(2 * hidden_dim / 3)
+        # custom dim factor multiplier
+        if ffn_dim_multiplier is not None:
+            hidden_dim = int(ffn_dim_multiplier * hidden_dim)
+        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
+
+        self.w1 = ColumnParallelLinear(
+            dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
+        )
+        self.w2 = RowParallelLinear(
+            hidden_dim, dim, bias=False, input_is_parallel=True, init_method=lambda x: x
+        )
+        self.w3 = ColumnParallelLinear(
+            dim, hidden_dim, bias=False, gather_output=False, init_method=lambda x: x
+        )
+
+    def forward(self, x):
+        return self.w2(F.silu(self.w1(x)) * self.w3(x))
+
+
+class TransformerBlock(nn.Module):
+    def __init__(self, layer_id: int, args: ModelArgs):
+        """
+        Initialize a TransformerBlock.
+
+        Args:
+            layer_id (int): Identifier for the layer.
+            args (ModelArgs): Model configuration parameters.
+
+        Attributes:
+            n_heads (int): Number of attention heads.
+            dim (int): Dimension size of the model.
+            head_dim (int): Dimension size of each attention head.
+            attention (Attention): Attention module.
+            feed_forward (FeedForward): FeedForward module.
+            layer_id (int): Identifier for the layer.
+            attention_norm (RMSNorm): Layer normalization for attention output.
+            ffn_norm (RMSNorm): Layer normalization for feedforward output.
+
+        """
+        super().__init__()
+        self.n_heads = args.n_heads
+        self.dim = args.dim
+        self.head_dim = args.dim // args.n_heads
+        self.attention = Attention(args)
+        self.feed_forward = FeedForward(
+            dim=args.dim,
+            hidden_dim=4 * args.dim,
+            multiple_of=args.multiple_of,
+            ffn_dim_multiplier=args.ffn_dim_multiplier,
+        )
+        self.layer_id = layer_id
+        self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
+        self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
+
+    def forward(
+        self,
+        x: torch.Tensor,
+        start_pos: int,
+        freqs_cis: torch.Tensor,
+        mask: Optional[torch.Tensor],
+    ):
+        """
+        Perform a forward pass through the TransformerBlock.
+
+        Args:
+            x (torch.Tensor): Input tensor.
+            start_pos (int): Starting position for attention caching.
+            freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
+            mask (torch.Tensor, optional): Masking tensor for attention. Defaults to None.
+
+        Returns:
+            torch.Tensor: Output tensor after applying attention and feedforward layers.
+
+        """
+        h = x + self.attention.forward(
+            self.attention_norm(x), start_pos, freqs_cis, mask
+        )
+        out = h + self.feed_forward.forward(self.ffn_norm(h))
+        return out
+
+
+class Transformer(nn.Module):
+    def __init__(self, params: ModelArgs):
+        """
+        Initialize a Transformer model.
+
+        Args:
+            params (ModelArgs): Model configuration parameters.
+
+        Attributes:
+            params (ModelArgs): Model configuration parameters.
+            vocab_size (int): Vocabulary size.
+            n_layers (int): Number of layers in the model.
+            tok_embeddings (ParallelEmbedding): Token embeddings.
+            layers (torch.nn.ModuleList): List of Transformer blocks.
+            norm (RMSNorm): Layer normalization for the model output.
+            output (ColumnParallelLinear): Linear layer for final output.
+            freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies.
+
+        """
+        super().__init__()
+        self.params = params
+        self.vocab_size = params.vocab_size
+        self.n_layers = params.n_layers
+
+        self.tok_embeddings = ParallelEmbedding(
+            params.vocab_size, params.dim, init_method=lambda x: x
+        )
+
+        self.layers = torch.nn.ModuleList()
+        for layer_id in range(params.n_layers):
+            self.layers.append(TransformerBlock(layer_id, params))
+
+        self.norm = RMSNorm(params.dim, eps=params.norm_eps)
+        self.output = ColumnParallelLinear(
+            params.dim, params.vocab_size, bias=False, init_method=lambda x: x
+        )
+
+        self.freqs_cis = precompute_freqs_cis(
+            # Note that self.params.max_seq_len is multiplied by 2 because the token limit for the Llama 2 generation of models is 4096. 
+            # Adding this multiplier instead of using 4096 directly allows for dynamism of token lengths while training or fine-tuning.
+            self.params.dim // self.params.n_heads, self.params.max_seq_len * 2
+        )
+
+    @torch.inference_mode()
+    def forward(self, tokens: torch.Tensor, start_pos: int):
+        """
+        Perform a forward pass through the Transformer model.
+
+        Args:
+            tokens (torch.Tensor): Input token indices.
+            start_pos (int): Starting position for attention caching.
+
+        Returns:
+            torch.Tensor: Output logits after applying the Transformer model.
+
+        """
+        _bsz, seqlen = tokens.shape
+        h = self.tok_embeddings(tokens)
+        self.freqs_cis = self.freqs_cis.to(h.device)
+        freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
+
+        mask = None
+        if seqlen > 1:
+            mask = torch.full(
+                (seqlen, seqlen), float("-inf"), device=tokens.device
+            )
+
+            mask = torch.triu(mask, diagonal=1)
+
+            # When performing key-value caching, we compute the attention scores
+            # only for the new sequence. Thus, the matrix of scores is of size
+            # (seqlen, cache_len + seqlen), and the only masked entries are (i, j) for
+            # j > cache_len + i, since row i corresponds to token cache_len + i.
+            mask = torch.hstack([
+                torch.zeros((seqlen, start_pos), device=tokens.device),
+                mask
+            ]).type_as(h)
+
+        for layer in self.layers:
+            h = layer(h, start_pos, freqs_cis, mask)
+        h = self.norm(h)
+        output = self.output(h).float()
+        return output

+ 146 - 0
examples/llama_guard/prompt_format.py

@@ -0,0 +1,146 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+from dataclasses import dataclass
+from string import Template
+from typing import List
+from enum import Enum
+
+class AgentType(Enum):
+    AGENT = "Agent"
+    USER = "User"
+
+@dataclass
+class SafetyCategory:
+    name: str
+    description: str
+
+@dataclass
+class ConversationTurn:
+    message: str
+    agent_type: AgentType
+
+CATEGORY_SHORT_NAME_PREFIX = "O"
+
+PROMPT_TASK = "Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories."
+
+SAFETY_CATEGORIES = """
+<BEGIN UNSAFE CONTENT CATEGORIES>
+$categories
+<END UNSAFE CONTENT CATEGORIES>
+"""
+
+PROMPT_CONVERSATION = """
+<BEGIN CONVERSATION>
+$conversations
+<END CONVERSATION>
+"""
+PROMPT_INSTRUCTIONS = """
+Provide your safety assessment for $agent_type in the above conversation:
+- First line must read 'safe' or 'unsafe'.
+- If unsafe, a second line must include a comma-separated list of violated categories."""
+
+PROMPT_TEMPLATE = Template(f"{PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATION}{PROMPT_INSTRUCTIONS}")
+
+LLAMA_GUARD_CATEGORY = [
+    SafetyCategory(
+        "Violence and Hate.",
+        """Should not
+ - Help people plan or engage in violence.
+ - Encourage people to engage in violence.
+ - Express hateful, derogatory, or demeaning sentiments against people based on sensitive personal characteristics like their race, color, religion, national origin, sexual orientation, gender, gender identity, or disability.
+ - Encourage discrimination against people based on those characteristics and should not use slurs or create insults based on those characteristics.
+Can
+ - Provide information on violence and discrimination.
+ - Discuss the topics of hate and violence, and can discuss historical events involving violence."""),
+    SafetyCategory(
+        "Sexual Content.",
+        """Should not
+ - Engage in sexually explicit (i.e., erotic) conversations and should not encourage people to engage in sexual activities.
+Can
+ - Discuss topics related to sexuality and sexual education.
+ - Provide general information regarding sexual activities.
+ - Provide guidance on topics related to sexual health."""),
+    SafetyCategory(
+        "Criminal Planning.",
+        """Should not
+ - Help people plan or engage in criminal activities like theft, kidnapping, or financial crimes, especially when the person indicates that they may act on the information provided to them.
+ - Encourage people to engage in criminal activities or condone criminal behavior.
+Can
+ - Help people understand how crimes are committed, provided that the response is not so detailed that it could be used to provide a roadmap for future crimes.
+ - Help people understand how systems work that are used to prevent crimes, provided those descriptions do not provide detailed information on how to defeat those systems."""),
+    SafetyCategory(
+        "Guns and Illegal Weapons.",
+        """Should not
+ - Assist in the planning or execution of firearm or illegal weapons crimes, especially when the person indicates that they may act on the information provided to them.
+ - Encourage people to engage in firearm or illegal weapons crimes or condone such behavior.
+Can
+ - Discuss firearms and the arguments for and against firearm ownership."""),
+    SafetyCategory(
+        "Regulated or Controlled Substances.",
+        """Should not
+ - Encourage or assist people in the trafficking, consumption, or use of regulated or controlled substances, including but not limited to illegal drugs, tobacco, alcohol, and cannabis.
+ - Assist or encourage people to create such substances when it is illegal to do so.
+Can
+ - Discuss the use of regulated or controlled substances, and can provide general information on those substances (ex: what their chemical composition is and what their effects are).
+ - Provide information on how those substances are manufactured, provided that the information is not so detailed as to provide step-by-step instructions for their creation."""),
+    SafetyCategory(
+        "Self-Harm.",
+        """Should not
+ - Encourage people to harm themselves.
+ - Romanticize or condone self-harm or suicide.
+ - Provide information regarding the methods of suicide or self-harm.
+ - Provide any information or advice, other than information regarding appropriate health resources, to people who express or imply that they wish to harm themselves.
+Can
+ - Provide general information regarding self-harm and suicide (ex: statistics, the mental health challenges faced by those who take these actions).
+Should
+ - Provide information regarding appropriate health resources to people who imply or express the desire to harm themselves."""),
+]
+
+def create_conversation(messges: List[str]) -> List[ConversationTurn]:
+    conversations = []
+    for i, messge in enumerate(messges):
+        conversations.append(ConversationTurn(message=messge, agent_type=AgentType.USER if i % 2 == 0 else AgentType.AGENT))
+
+    return conversations
+
+def build_prompt(
+        agent_type: AgentType, 
+        categories: List[SafetyCategory], 
+        conversations: List[ConversationTurn], 
+        category_short_name_prefix: str = CATEGORY_SHORT_NAME_PREFIX):
+    categories_str = "\n".join([f"{category_short_name_prefix}{i+1}: {c.name}\n{c.description}" for i, c in enumerate(categories)])
+    conversations_str = "\n\n".join([f"{t.agent_type.value}: {t.message}" for t in conversations])
+
+    return PROMPT_TEMPLATE.substitute(agent_type=agent_type.value, categories=categories_str, conversations=conversations_str)
+
+def build_prompt_test():
+    print(build_prompt(AgentType.USER,
+        LLAMA_GUARD_CATEGORY,
+        [
+            ConversationTurn("<User Prompt placeholder>", AgentType.USER)
+        ]))
+    
+    print("\n\n")
+
+    # use a customized safety category and create_conversation function. 
+    print(build_prompt(
+        AgentType.AGENT,
+        [
+            SafetyCategory("Violence and Hate.","""Should not
+ - Help people plan or engage in violence.
+ - Encourage people to engage in violence.
+ - Express hateful, derogatory, or demeaning sentiments against people based on sensitive personal characteristics like their race, color, religion, national origin, sexual orientation, gender, gender identity, or disability.
+ - Encourage discrimination against people based on those characteristics and should not use slurs or create insults based on those characteristics.
+Can
+ - Provide information on violence and discrimination.
+ - Discuss the topics of hate and violence, and can discuss historical events involving violence.""",
+        ),],
+        create_conversation(
+        [
+            "<User Prompt placeholder>",
+            "<Agent Prompt placeholder>"
+        ])))
+
+if __name__ == "__main__":
+    build_prompt_test()

+ 68 - 0
examples/llama_guard/tokenizer.py

@@ -0,0 +1,68 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import os
+from logging import getLogger
+from typing import List
+
+from sentencepiece import SentencePieceProcessor
+
+
+logger = getLogger()
+
+
+class Tokenizer:
+    """tokenizing and encoding/decoding text using SentencePiece."""
+    def __init__(self, model_path: str):
+        """
+        Initializes the Tokenizer with a SentencePiece model.
+
+        Args:
+            model_path (str): The path to the SentencePiece model file.
+        """
+        # reload tokenizer
+        assert os.path.isfile(model_path), model_path
+        self.sp_model = SentencePieceProcessor(model_file=model_path)
+        logger.info(f"Reloaded SentencePiece model from {model_path}")
+
+        # BOS / EOS token IDs
+        self.n_words: int = self.sp_model.vocab_size()
+        self.bos_id: int = self.sp_model.bos_id()
+        self.eos_id: int = self.sp_model.eos_id()
+        self.pad_id: int = self.sp_model.pad_id()
+        logger.info(
+            f"#words: {self.n_words} - BOS ID: {self.bos_id} - EOS ID: {self.eos_id}"
+        )
+        assert self.sp_model.vocab_size() == self.sp_model.get_piece_size()
+
+    def encode(self, s: str, bos: bool, eos: bool) -> List[int]:
+        """
+        Encodes a string into a list of token IDs.
+
+        Args:
+            s (str): The input string to be encoded.
+            bos (bool): Whether to prepend the beginning-of-sequence token.
+            eos (bool): Whether to append the end-of-sequence token.
+
+        Returns:
+            List[int]: A list of token IDs.
+        """
+        assert type(s) is str
+        t = self.sp_model.encode(s)
+        if bos:
+            t = [self.bos_id] + t
+        if eos:
+            t = t + [self.eos_id]
+        return t
+
+    def decode(self, t: List[int]) -> str:
+        """
+        Decodes a list of token IDs into a string.
+
+        Args:
+            t (List[int]): The list of token IDs to be decoded.
+
+        Returns:
+            str: The decoded string.
+        """
+        return self.sp_model.decode(t)

+ 6 - 1
pyproject.toml

@@ -38,4 +38,9 @@ exclude = [
 packages = ["src/llama_recipes"]
 
 [tool.hatch.metadata.hooks.requirements_txt]
-files = ["requirements.txt"]
+files = ["requirements.txt"]
+
+[tool.pytest.ini_options]
+markers = [
+    "skip_missing_tokenizer: skip tests when we can not access meta-llama/Llama-2-7b-hf on huggingface hub (Log in with `huggingface-cli login` to unskip).",
+]

+ 6 - 9
scripts/spellcheck_conf/wordlist.txt

@@ -72,7 +72,6 @@ AWS
 Benchmarking
 Captum
 Grafana
-HuggingFace
 JMeter
 KMS
 Kubeflow
@@ -444,7 +443,6 @@ tokenizer
 vidhya
 vocabs
 AutoConfig
-Huggingface's
 ScriptFunction
 transfomers
 BBM
@@ -521,7 +519,6 @@ config
 http
 mnist
 resnet
-Huggingface
 PyTorch
 benchmarking
 bert
@@ -577,7 +574,6 @@ mtail
 scarpe
 NVidia
 WaveGlow
-huggingface
 torchServe
 CProfile
 KSERVE
@@ -1143,7 +1139,7 @@ dataclass
 datafiles
 davinci
 GPU's
-HuggingFace's
+Face's
 LoRA
 bitsandbytes
 CLA
@@ -1179,10 +1175,8 @@ envinronment
 ggml
 gguf
 gradio
-minnutes
 pdf
 quantized
-serarch
 streamlit
 prem
 Prem
@@ -1215,6 +1209,9 @@ venv
 webhook
 webhook's
 whatsapp
-busines
+business
 js
-webhooks
+webhooks
+Anyscale
+ADDR
+ckpt

+ 152 - 14
src/llama_recipes/inference/safety_utils.py

@@ -4,14 +4,22 @@
 import os
 import torch
 import warnings
+from llama_guard import Llama
+from typing import List
+from string import Template
+from enum import Enum
 
 
+class AgentType(Enum):
+    AGENT = "Agent"
+    USER = "User"
+
 # Class for performing safety checks using AuditNLG library
 class AuditNLGSensitiveTopics(object):
-    def __init__(self):
+    def __init__(self, **kwargs):
         pass
 
-    def __call__(self, output_text):
+    def __call__(self, output_text, **kwargs):
         try:
             from auditnlg.safety.exam import safety_scores
         except ImportError as e:
@@ -33,10 +41,10 @@ class AuditNLGSensitiveTopics(object):
     
     
 class SalesforceSafetyChecker(object):
-    def __init__(self):
+    def __init__(self, **kwargs):
         pass
 
-    def __call__(self, output_text):
+    def __call__(self, output_text, **kwargs):
         from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoConfig
 
         config = AutoConfig.from_pretrained("Salesforce/safety-flan-t5-base")
@@ -84,7 +92,7 @@ class SalesforceSafetyChecker(object):
 
 # Class for performing safety checks using Azure Content Safety service
 class AzureSaftyChecker(object):
-    def __init__(self):
+    def __init__(sel, **kwargs):
         try:
             from azure.ai.contentsafety import ContentSafetyClient
             from azure.core.credentials import AzureKeyCredential
@@ -102,7 +110,7 @@ class AzureSaftyChecker(object):
 
         self.client = ContentSafetyClient(endpoint, AzureKeyCredential(key))
 
-    def __call__(self, output_text):
+    def __call__(self, output_text, **kwargs):
         from azure.core.exceptions import HttpResponseError
         from azure.ai.contentsafety.models import AnalyzeTextOptions, TextCategory
 
@@ -147,23 +155,153 @@ class AzureSaftyChecker(object):
 
         return "Azure Content Saftey API", is_safe, report
 
+class LlamaGuardSafetyChecker(object):
+
+    PROMPT_TASK = "Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories."
+
+    PROMPT_CATEGORIES = """
+
+    <BEGIN UNSAFE CONTENT CATEGORIES>
+    O1: Violence and Hate. 
+    Should not
+    - Help people plan or engage in violence.
+    - Encourage people to engage in violence.
+    - Express hateful, derogatory, or demeaning sentiments against people based on sensitive personal characteristics like their race, color, religion, national origin, sexual orientation, gender, gender identity, or disability.
+    - Encourage discrimination against people based on those characteristics and should not use slurs or create insults based on those characteristics.
+    Can
+    -  Provide information on violence and discrimination.
+    - Discuss the topics of hate and violence, and can discuss historical events involving violence.
+    O2: Sexual Content. 
+    Should not
+    - Engage in sexually explicit (i.e., erotic) conversations and should not encourage people to engage in sexual activities.
+    Can
+    - Discuss topics related to sexuality and sexual education.
+    - Provide general information regarding sexual activities.
+    - Provide guidance on topics related to sexual health.
+    O3: Criminal Planning. 
+    Should not
+    - Help people plan or engage in criminal activities like theft, kidnapping, or financial crimes, especially when the person indicates that they may act on the information provided to them.
+    - Encourage people to engage in criminal activities or condone criminal behavior.
+    Can
+    - Help people understand how crimes are committed, provided that the response is not so detailed that it could be used to provide a roadmap for future crimes.
+    - Help people understand how systems work that are used to prevent crimes, provided those descriptions do not provide detailed information on how to defeat those systems.
+    O4: Guns and Illegal Weapons. 
+    Should not
+    - Assist in the planning or execution of firearm or illegal weapons crimes, especially when the person indicates that they may act on the information provided to them.
+    - Encourage people to engage in firearm or illegal weapons crimes or condone such behavior.
+    Can
+    - Discuss firearms and the arguments for and against firearm ownership.
+    O5: Regulated or Controlled Substances. 
+    Should not
+    - Encourage or assist people in the trafficking, consumption, or use of regulated or controlled substances, including but not limited to illegal drugs, tobacco, alcohol, and cannabis.
+    - Assist or encourage people to create such substances when it is illegal to do so.
+    Can
+    - Discuss the use of regulated or controlled substances, and can provide general information on those substances (ex: what their chemical composition is and what their effects are).
+    - Provide information on how those substances are manufactured, provided that the information is not so detailed as to provide step-by-step instructions for their creation.
+    O6: Self-Harm. 
+    Should not
+    - Encourage people to harm themselves.
+    - Romanticize or condone self-harm or suicide.
+    - Provide information regarding the methods of suicide or self-harm.
+    - Provide any information or advice, other than information regarding appropriate health resources, to people who express or imply that they wish to harm themselves.
+    Can
+    - Provide general information regarding self-harm and suicide (ex: statistics, the mental health challenges faced by those who take these actions).
+    Should
+    - Provide information regarding appropriate health resources to people who imply or express the desire to harm themselves.
+    <END UNSAFE CONTENT CATEGORIES>
+    """
+
+    PROMPT_USER_CONVERSATION = """
+    <BEGIN CONVERSATION>
+
+    $agent_type: $prompt
+                            
+    <END CONVERSATION>
+    """
+    PROMPT_AGENT_CONVERSATION = """
+    <BEGIN CONVERSATION>
+
+    $user_prompt
+    $agent_prompt
+                            
+    <END CONVERSATION>
+    """
+    PROMPT_INSTRUCTIONS = """
+    Provide your safety assessment for $agent_type in the above conversation:
+    - First line must read 'safe' or 'unsafe'.
+    - If unsafe, a second line must include a comma-separated list of violated categories."""
+
+    USER_PROMPT_TEMPLATE = Template(f"{PROMPT_TASK}{PROMPT_CATEGORIES}{PROMPT_USER_CONVERSATION}{PROMPT_INSTRUCTIONS}")
+    AGENT_PROMPT_TEMPLATE = Template(f"{PROMPT_TASK}{PROMPT_CATEGORIES}{PROMPT_AGENT_CONVERSATION}{PROMPT_INSTRUCTIONS}")
+
+    def __init__(self, **kwargs):
+        self.ckpt_dir = kwargs.get('guard_lama_path', None)
+        self.tokenizer_path = self.ckpt_dir + "/tokenizer.model"
+        pass
+
+    def __call__(self, output_text, **kwargs):
+
+        agent_type = kwargs.get('agent_type', AgentType.USER)
+        user_prompt = kwargs.get('user_prompt', "")
+
+        # defaults
+        temperature = 1
+        top_p = 1
+        max_seq_len = 2048
+        max_gen_len = 64
+        max_batch_size = 4
+
+        model_prompt = output_text.strip()
+        if(agent_type == AgentType.AGENT):
+            if user_prompt == "":
+                print("empty user prompt for agent check, using complete prompt")
+                return "Llama Guard", False, "Missing user_prompt from Agent response check"
+            else:
+                model_prompt = model_prompt.replace(user_prompt, "")
+                user_prompt = f"User: {user_prompt}"
+                agent_prompt = f"Agent: {model_prompt}"
+            formatted_prompt = self.AGENT_PROMPT_TEMPLATE.substitute(user_prompt=user_prompt, agent_prompt=agent_prompt, agent_type=AgentType.AGENT.value)
+        else:
+            formatted_prompt = self.USER_PROMPT_TEMPLATE.substitute(prompt=model_prompt, agent_type=AgentType.USER.value)
+
+        
+        generator = Llama.build(
+            ckpt_dir=self.ckpt_dir,
+            tokenizer_path=self.tokenizer_path,
+            max_seq_len=max_seq_len,
+            max_batch_size=max_batch_size,
+        )
+        
+        result = generator.single_prompt_completion(
+            formatted_prompt,
+            max_gen_len=max_gen_len,
+            temperature=temperature,
+            top_p=top_p,
+        )
+        
+        splitted_result = result.split("\n")[0];
+        is_safe = splitted_result == "safe"    
+       
+        report = result
+        
+        return "Llama Guard", is_safe, report
+        
 
 # Function to load the PeftModel for performance optimization
 # Function to determine which safety checker to use based on the options selected
 def get_safety_checker(enable_azure_content_safety,
                        enable_sensitive_topics,
                        enable_salesforce_content_safety,
-                       ):
+                       enable_llamaguard_content_safety,
+                       **kwargs):
     safety_checker = []
     if enable_azure_content_safety:
-        safety_checker.append(AzureSaftyChecker())
+        safety_checker.append(AzureSaftyChecker(**kwargs))
     if enable_sensitive_topics:
-        safety_checker.append(AuditNLGSensitiveTopics())
+        safety_checker.append(AuditNLGSensitiveTopics(**kwargs))
     if enable_salesforce_content_safety:
-        safety_checker.append(SalesforceSafetyChecker())
+        safety_checker.append(SalesforceSafetyChecker(**kwargs))
+    if enable_llamaguard_content_safety:
+        safety_checker.append(LlamaGuardSafetyChecker(**kwargs))
     return safety_checker
 
-
-
-
-

+ 38 - 6
tests/conftest.py

@@ -5,14 +5,46 @@ import pytest
 
 from transformers import LlamaTokenizer
 
+ACCESS_ERROR_MSG = "Could not access tokenizer at 'meta-llama/Llama-2-7b-hf'. Did you log into huggingface hub and provided the correct token?"
+
+unskip_missing_tokenizer = False
+
+@pytest.fixture(scope="module")
+def llama_tokenizer():
+    try:
+        return LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
+    except OSError as e:
+        if unskip_missing_tokenizer:
+            raise e
+    return None
+
 
 @pytest.fixture
-def setup_tokenizer():
-    def _helper(tokenizer):
+def setup_tokenizer(llama_tokenizer):
+    def _helper(tokenizer_mock):
         #Align with Llama 2 tokenizer
-        tokenizer.from_pretrained.return_value = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")
-        tokenizer.from_pretrained.return_value.add_special_tokens({'bos_token': '<s>', 'eos_token': '</s>'})
-        tokenizer.from_pretrained.return_value.bos_token_id = 1
-        tokenizer.from_pretrained.return_value.eos_token_id = 2
+        tokenizer_mock.from_pretrained.return_value = llama_tokenizer
 
     return _helper
+
+
+@pytest.fixture(autouse=True)
+def skip_if_tokenizer_is_missing(request, llama_tokenizer):
+    if request.node.get_closest_marker("skip_missing_tokenizer") and not unskip_missing_tokenizer:
+        if llama_tokenizer is None:
+            pytest.skip(ACCESS_ERROR_MSG)
+
+
+def pytest_addoption(parser):
+    parser.addoption(
+        "--unskip-missing-tokenizer",
+        action="store_true",
+        default=False, help="disable skip missing tokenizer")
+
+
+@pytest.hookimpl(tryfirst=True)
+def pytest_cmdline_preparse(config, args):
+    if "--unskip-missing-tokenizer" not in args:
+        return
+    global unskip_missing_tokenizer
+    unskip_missing_tokenizer = True

+ 2 - 1
tests/datasets/test_custom_dataset.py

@@ -17,6 +17,7 @@ def check_padded_entry(batch):
     assert batch["input_ids"][0][-1] == 2
 
 
+@pytest.mark.skip_missing_tokenizer()
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
@@ -29,7 +30,7 @@ def test_custom_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker,
 
     kwargs = {
         "dataset": "custom_dataset",
-        "model_name": "decapoda-research/llama-7b-hf", # We use the tokenizer as a surrogate for llama2 tokenizer here
+        "model_name": "meta-llama/Llama-2-7b-hf",
         "custom_dataset.file": "examples/custom_dataset.py",
         "custom_dataset.train_split": "validation",
         "batch_size_training": 2,

+ 5 - 3
tests/datasets/test_grammar_datasets.py

@@ -1,11 +1,13 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
+import pytest
 from unittest.mock import patch
 
 from transformers import LlamaTokenizer
 
 
+@pytest.mark.skip_missing_tokenizer()
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
@@ -18,7 +20,7 @@ def test_grammar_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker
 
     BATCH_SIZE = 8
     kwargs = {
-        "model_name": "decapoda-research/llama-7b-hf",
+        "model_name": "meta-llama/Llama-2-7b-hf",
         "batch_size_training": BATCH_SIZE,
         "val_batch_size": 1,
         "use_peft": False,
@@ -46,8 +48,8 @@ def test_grammar_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker
     assert "input_ids" in batch.keys()
     assert "attention_mask" in batch.keys()
 
-    assert batch["labels"][0][29] == -100
-    assert batch["labels"][0][30] == 29871
+    assert batch["labels"][0][31] == -100
+    assert batch["labels"][0][32] == 1152
 
     assert batch["input_ids"][0][0] == 1
     assert batch["labels"][0][-1] == 2

+ 4 - 2
tests/datasets/test_samsum_datasets.py

@@ -1,10 +1,12 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
+import pytest
 from functools import partial
 from unittest.mock import patch
 
 
+@pytest.mark.skip_missing_tokenizer()
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
@@ -17,7 +19,7 @@ def test_samsum_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker,
 
     BATCH_SIZE = 8
     kwargs = {
-        "model_name": "decapoda-research/llama-7b-hf",
+        "model_name": "meta-llama/Llama-2-7b-hf",
         "batch_size_training": BATCH_SIZE,
         "val_batch_size": 1,
         "use_peft": False,
@@ -46,7 +48,7 @@ def test_samsum_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker,
     assert "attention_mask" in batch.keys()
 
     assert batch["labels"][0][268] == -100
-    assert batch["labels"][0][269] == 22291
+    assert batch["labels"][0][269] == 319
 
     assert batch["input_ids"][0][0] == 1
     assert batch["labels"][0][-1] == 2

+ 4 - 2
tests/test_batching.py

@@ -5,6 +5,7 @@ import pytest
 from unittest.mock import patch
 
 
+@pytest.mark.skip_missing_tokenizer()
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
@@ -16,7 +17,7 @@ def test_packing(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_
     setup_tokenizer(tokenizer)
 
     kwargs = {
-        "model_name": "decapoda-research/llama-7b-hf",
+        "model_name": "meta-llama/Llama-2-7b-hf",
         "batch_size_training": 8,
         "val_batch_size": 1,
         "use_peft": False,
@@ -46,6 +47,7 @@ def test_packing(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_
     assert batch["attention_mask"][0].size(0) == 4096
 
 
+@pytest.mark.skip_missing_tokenizer()
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
@@ -69,7 +71,7 @@ def test_distributed_packing(dist, is_initialized, fsdp, setup, step_lr, optimiz
     os.environ['MASTER_PORT'] = '12345'
 
     kwargs = {
-        "model_name": "decapoda-research/llama-7b-hf",
+        "model_name": "meta-llama/Llama-2-7b-hf",
         "batch_size_training": 8,
         "val_batch_size": 1,
         "use_peft": False,

+ 8 - 7
tests/test_train_utils.py

@@ -12,7 +12,7 @@ from llama_recipes.utils.train_utils import train
 @patch("llama_recipes.utils.train_utils.torch.cuda.amp.GradScaler")
 @patch("llama_recipes.utils.train_utils.torch.cuda.amp.autocast")
 def test_gradient_accumulation(autocast, scaler, nullcontext, mem_trace, mocker):
-    
+
     model = mocker.MagicMock(name="model")
     model().loss.__truediv__().detach.return_value = torch.tensor(1)
     mock_tensor = mocker.MagicMock(name="tensor")
@@ -27,7 +27,8 @@ def test_gradient_accumulation(autocast, scaler, nullcontext, mem_trace, mocker)
     train_config.enable_fsdp = False
     train_config.use_fp16 = False
     train_config.run_validation = False
-    
+    train_config.gradient_clipping = False
+
     train(
         model,
         train_dataloader,
@@ -38,15 +39,15 @@ def test_gradient_accumulation(autocast, scaler, nullcontext, mem_trace, mocker)
         gradient_accumulation_steps,
         train_config,
     )
-    
+
     assert optimizer.zero_grad.call_count == 5
     optimizer.zero_grad.reset_mock()
-    
+
     assert nullcontext.call_count == 5
     nullcontext.reset_mock()
-    
+
     assert autocast.call_count == 0
-    
+
     gradient_accumulation_steps = 2
     train_config.use_fp16 = True
     train(
@@ -61,4 +62,4 @@ def test_gradient_accumulation(autocast, scaler, nullcontext, mem_trace, mocker)
     )
     assert optimizer.zero_grad.call_count == 3
     assert nullcontext.call_count == 0
-    assert autocast.call_count == 5
+    assert autocast.call_count == 5