Jelajahi Sumber

Merge branch 'main' into benchmark-inference-throughput-cloud-api

Chester Hu 1 tahun lalu
induk
melakukan
986847f685
47 mengubah file dengan 96112 tambahan dan 153 penghapusan
  1. 43 22
      README.md
  2. 55 0
      benchmarks/inference/README.md
  3. 38 0
      benchmarks/inference/on-prem/README.md
  4. 205 0
      benchmarks/inference/on-prem/vllm/chat_vllm_benchmark.py
  5. 9 0
      benchmarks/inference/on-prem/vllm/input.jsonl
  6. 15 0
      benchmarks/inference/on-prem/vllm/parameters.json
  7. 215 0
      benchmarks/inference/on-prem/vllm/pretrained_vllm_benchmark.py
  8. 23 0
      benchmarks/inference/tokenizer/special_tokens_map.json
  9. 93391 0
      benchmarks/inference/tokenizer/tokenizer.json
  10. TEMPAT SAMPAH
      benchmarks/inference/tokenizer/tokenizer.model
  11. 35 0
      benchmarks/inference/tokenizer/tokenizer_config.json
  12. 11 5
      demo_apps/RAG_Chatbot_example/RAG_Chatbot_Example.ipynb
  13. 2 2
      demo_apps/README.md
  14. 8 0
      docs/inference.md
  15. 145 0
      eval/README.md
  16. 230 0
      eval/eval.py
  17. 22 0
      eval/open_llm_eval_prep.sh
  18. 6 0
      eval/open_llm_leaderboard/arc_challeneg_25shots.yaml
  19. 6 0
      eval/open_llm_leaderboard/hellaswag_10shots.yaml
  20. 24 0
      eval/open_llm_leaderboard/hellaswag_utils.py
  21. 9 0
      eval/open_llm_leaderboard/mmlu_5shots.yaml
  22. 6 0
      eval/open_llm_leaderboard/winogrande_5shots.yaml
  23. 784 0
      examples/Prompt_Engineering_with_Llama_2.ipynb
  24. 2 2
      examples/Purple_Llama_Anyscale.ipynb
  25. 1 1
      examples/README.md
  26. 9 3
      examples/chat_completion/chat_completion.py
  27. 3 13
      examples/code_llama/code_completion_example.py
  28. 4 14
      examples/code_llama/code_infilling_example.py
  29. 143 0
      examples/code_llama/code_instruct_example.py
  30. 10 14
      examples/inference.py
  31. 21 2
      examples/llama_guard/README.md
  32. 71 0
      examples/plot_metrics.py
  33. 5 1
      examples/vllm/inference.py
  34. 2 1
      requirements.txt
  35. 31 1
      scripts/spellcheck_conf/wordlist.txt
  36. 1 0
      src/llama_recipes/configs/training.py
  37. 119 0
      src/llama_recipes/data/llama_guard/README.md
  38. 10 10
      src/llama_recipes/data/llama_guard/finetuning_data_formatter.py
  39. 90 0
      src/llama_recipes/data/llama_guard/finetuning_data_formatter_example.py
  40. 16 16
      src/llama_recipes/finetuning.py
  41. 5 3
      src/llama_recipes/inference/model_utils.py
  42. 33 14
      src/llama_recipes/utils/memory_utils.py
  43. 110 25
      src/llama_recipes/utils/train_utils.py
  44. 5 1
      tests/test_finetuning.py
  45. 3 3
      tests/test_finetuning_data_formatter.py
  46. 53 0
      tests/test_train_utils.py
  47. 83 0
      utils/memory_utils.py

+ 43 - 22
README.md

@@ -1,6 +1,8 @@
-# Llama 2 Fine-tuning / Inference Recipes, Examples and Demo Apps
+# Llama 2 Fine-tuning / Inference Recipes, Examples, Benchmarks and Demo Apps
 
-**[Update Dec. 15, 2023] We added support for Llama Guard as a safety checker for our example inference script and also with standalone inference with an example script and prompt formatting. More details [here](./examples/llama_guard/README.md).**
+**[Update Feb. 5, 2024] We added support for Code Llama 70B instruct in our example [inference script](./examples/code_llama/code_instruct_example.py). For details on formatting the prompt for Code Llama 70B instruct model please refer to [this document](./docs/inference.md)**.
+
+**[Update Dec. 28, 2023] We added support for Llama Guard as a safety checker for our example inference script and also with standalone inference with an example script and prompt formatting. More details [here](./examples/llama_guard/README.md). For details on formatting data for fine tuning Llama Guard, we provide a script and sample usage [here](./src/llama_recipes/data/llama_guard/README.md).**
 
 **[Update Dec 14, 2023] We recently released a series of Llama 2 demo apps [here](./demo_apps). These apps show how to run Llama (locally, in the cloud, or on-prem),  how to use Azure Llama 2 API (Model-as-a-Service), how to ask Llama questions in general or about custom data (PDF, DB, or live), how to integrate Llama with WhatsApp and Messenger, and how to implement an end-to-end chatbot with RAG (Retrieval Augmented Generation).**
 
@@ -34,17 +36,7 @@ Llama-recipes provides a pip distribution for easy install and usage in other pr
 ```
 pip install --extra-index-url https://download.pytorch.org/whl/test/cu118 llama-recipes
 ```
-## Install from source
-To install from source e.g. for development use this command. We're using hatchling as our build backend which requires an up-to-date pip as well as setuptools package.
-```
-pip install -U pip setuptools
-pip install --extra-index-url https://download.pytorch.org/whl/test/cu118 -e .
-```
-For development and contributing to llama-recipes please install all optional dependencies:
-```
-pip install -U pip setuptools
-pip install --extra-index-url https://download.pytorch.org/whl/test/cu118 -e .[tests,auditnlg,vllm]
-```
+
 ## Install with optional dependencies
 Llama-recipes offers the installation of optional packages. There are three optional dependency groups.
 To run the unit tests we can install the required dependencies with:
@@ -61,12 +53,26 @@ pip install --extra-index-url https://download.pytorch.org/whl/test/cu118 llama-
 ```
 Optional dependencies can also be combines with [option1,option2].
 
+## Install from source
+To install from source e.g. for development use these commands. We're using hatchling as our build backend which requires an up-to-date pip as well as setuptools package.
+```
+git clone git@github.com:facebookresearch/llama-recipes.git
+cd llama-recipes
+pip install -U pip setuptools
+pip install --extra-index-url https://download.pytorch.org/whl/test/cu118 -e .
+```
+For development and contributing to llama-recipes please install all optional dependencies:
+```
+git clone git@github.com:facebookresearch/llama-recipes.git
+cd llama-recipes
+pip install -U pip setuptools
+pip install --extra-index-url https://download.pytorch.org/whl/test/cu118 -e .[tests,auditnlg,vllm]
+```
+
 ⚠️ **Note** ⚠️  Some features (especially fine-tuning with FSDP + PEFT) currently require PyTorch nightlies to be installed. Please make sure to install the nightlies if you're using these features following [this guide](https://pytorch.org/get-started/locally/).
 
 **Note** All the setting defined in [config files](src/llama_recipes/configs/) can be passed as args through CLI when running the script, there is no need to change from config files directly.
 
-**Note** In case need to run PEFT model with FSDP, please make sure to use the PyTorch Nightlies.
-
 **For more in depth information checkout the following:**
 
 * [Single GPU Fine-tuning](./docs/single_gpu.md)
@@ -74,11 +80,12 @@ Optional dependencies can also be combines with [option1,option2].
 * [LLM Fine-tuning](./docs/LLM_finetuning.md)
 * [Adding custom datasets](./docs/Dataset.md)
 * [Inference](./docs/inference.md)
+* [Evaluation Harness](./eval/README.md)
 * [FAQs](./docs/FAQ.md)
 
 # Where to find the models?
 
-You can find llama v2 models on Hugging Face hub [here](https://huggingface.co/meta-llama), where models with `hf` in the name are already converted to Hugging Face checkpoints so no further conversion is needed. The conversion step below is only for original model weights from Meta that are hosted on Hugging Face model hub as well.
+You can find Llama 2 models on Hugging Face hub [here](https://huggingface.co/meta-llama), where models with `hf` in the name are already converted to Hugging Face checkpoints so no further conversion is needed. The conversion step below is only for original model weights from Meta that are hosted on Hugging Face model hub as well.
 
 # Model conversion to Hugging Face
 The recipes and notebooks in this folder are using the Llama 2 model definition provided by Hugging Face's transformers library.
@@ -112,13 +119,15 @@ All the parameters in the examples and recipes below need to be further tuned to
 
 * Make sure to set the right path to the model in the [training config](src/llama_recipes/configs/training.py).
 
+* To save the loss and perplexity metrics for evaluation, enable this by passing `--save_metrics` to the finetuning script. The file can be plotted using the [plot_metrics.py](./examples/plot_metrics.py) script, `python examples/plot_metrics.py --file_path path/to/metrics.json`
+
 ### Single GPU:
 
 ```bash
 #if running on multi-gpu machine
 export CUDA_VISIBLE_DEVICES=0
 
-python -m llama_recipes.finetuning  --use_peft --peft_method lora --quantization --model_name /patht_of_model_folder/7B --output_dir Path/to/save/PEFT/model
+python -m llama_recipes.finetuning  --use_peft --peft_method lora --quantization --model_name /path_of_model_folder/7B --output_dir path/to/save/PEFT/model
 
 ```
 
@@ -135,7 +144,7 @@ Here we make use of Parameter Efficient Methods (PEFT) as described in the next
 
 ```bash
 
-torchrun --nnodes 1 --nproc_per_node 4  examples/finetuning.py --enable_fsdp --use_peft --peft_method lora --model_name /patht_of_model_folder/7B --fsdp_config.pure_bf16 --output_dir Path/to/save/PEFT/model
+torchrun --nnodes 1 --nproc_per_node 4  examples/finetuning.py --enable_fsdp --use_peft --peft_method lora --model_name /path_of_model_folder/7B --fsdp_config.pure_bf16 --output_dir path/to/save/PEFT/model
 
 ```
 
@@ -146,7 +155,7 @@ Here we use FSDP as discussed in the next section which can be used along with P
 Setting `use_fast_kernels` will enable using of Flash Attention or Xformer memory-efficient kernels based on the hardware being used. This would speed up the fine-tuning job. This has been enabled in `optimum` library from Hugging Face as a one-liner API, please read more [here](https://pytorch.org/blog/out-of-the-box-acceleration/).
 
 ```bash
-torchrun --nnodes 1 --nproc_per_node 4  examples/finetuning.py --enable_fsdp --use_peft --peft_method lora --model_name /patht_of_model_folder/7B --fsdp_config.pure_bf16 --output_dir Path/to/save/PEFT/model --use_fast_kernels
+torchrun --nnodes 1 --nproc_per_node 4  examples/finetuning.py --enable_fsdp --use_peft --peft_method lora --model_name /path_of_model_folder/7B --fsdp_config.pure_bf16 --output_dir path/to/save/PEFT/model --use_fast_kernels
 ```
 
 ### Fine-tuning using FSDP Only
@@ -155,7 +164,7 @@ If you are interested in running full parameter fine-tuning without making use o
 
 ```bash
 
-torchrun --nnodes 1 --nproc_per_node 8  examples/finetuning.py --enable_fsdp --model_name /patht_of_model_folder/7B --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned --use_fast_kernels
+torchrun --nnodes 1 --nproc_per_node 8  examples/finetuning.py --enable_fsdp --model_name /path_of_model_folder/7B --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned --use_fast_kernels
 
 ```
 
@@ -165,7 +174,7 @@ If you are interested in running full parameter fine-tuning on the 70B model, yo
 
 ```bash
 
-torchrun --nnodes 1 --nproc_per_node 8 examples/finetuning.py --enable_fsdp --low_cpu_fsdp --fsdp_config.pure_bf16 --model_name /patht_of_model_folder/70B --batch_size_training 1 --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned
+torchrun --nnodes 1 --nproc_per_node 8 examples/finetuning.py --enable_fsdp --low_cpu_fsdp --fsdp_config.pure_bf16 --model_name /path_of_model_folder/70B --batch_size_training 1 --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned
 
 ```
 
@@ -179,6 +188,10 @@ sbatch multi_node.slurm
 ```
 You can read more about our fine-tuning strategies [here](./docs/LLM_finetuning.md).
 
+# Evaluation Harness
+
+Here, we make use `lm-evaluation-harness` from `EleutherAI` for evaluation of fine-tuned Llama 2 models. This also can extend to evaluate other optimizations for inference of Llama 2 model such as quantization. Please use this get started [doc](./eval/README.md).
+
 # Demo Apps
 This folder contains a series of Llama2-powered apps:
 * Quickstart Llama deployments and basic interactions with Llama
@@ -195,8 +208,16 @@ This folder contains a series of Llama2-powered apps:
 3. Ask Llama questions about live data on the web
 4. Build a Llama-enabled WhatsApp chatbot
 
+# Benchmarks
+This folder contains a series of benchmark scripts for Llama 2 models inference on various backends:
+1. On-prem - Popular serving frameworks and containers (i.e. vLLM)
+2. (WIP) Cloud API - Popular API services (i.e. Azure Model-as-a-Service)
+3. (WIP) On-device - Popular on-device inference solutions on Android and iOS (i.e. mlc-llm, QNN)
+4. (WIP) Optimization - Popular optimization solutions for faster inference and quantization (i.e. AutoAWQ)
+
 # Repository Organization
 This repository is organized in the following way:
+[benchmarks](./benchmarks): Contains a series of benchmark scripts for Llama 2 models inference on various backends.
 
 [configs](src/llama_recipes/configs/): Contains the configuration files for PEFT methods, FSDP, Datasets.
 
@@ -204,7 +225,7 @@ This repository is organized in the following way:
 
 [datasets](src/llama_recipes/datasets/): Contains individual scripts for each dataset to download and process. Note: Use of any of the datasets should be in compliance with the dataset's underlying licenses (including but not limited to non-commercial uses)
 
-[demo_apps](./demo_apps) contains a series of Llama2-powered apps, from quickstart deployments to how to ask Llama questions about unstructured data, structured data, live data, and video summary.
+[demo_apps](./demo_apps): Contains a series of Llama2-powered apps, from quickstart deployments to how to ask Llama questions about unstructured data, structured data, live data, and video summary.
 
 [examples](./examples/): Contains examples script for finetuning and inference of the Llama 2 model as well as how to use them safely.
 

+ 55 - 0
benchmarks/inference/README.md

@@ -0,0 +1,55 @@
+# Inference Throughput Benchmarks
+In this folder we provide a series of benchmark scripts that apply a throughput analysis for Llama 2 models inference on various backends:
+* On-prem - Popular serving frameworks and containers (i.e. vLLM)
+* [**WIP**]Cloud API - Popular API services (i.e. Azure Model-as-a-Service)
+* [**WIP**]On-device - Popular on-device inference solutions on Android and iOS (i.e. mlc-llm, QNN)
+* [**WIP**]Optimization - Popular optimization solutions for faster inference and quantization (i.e. AutoAWQ)
+
+# Why
+There are three major reasons we want to run these benchmarks and share them with our Llama community:
+* Provide inference throughput analysis based on real world situation to help you select the best service or deployment for your scenario
+* Provide a baseline measurement for validating various optimization solutions on different backends, so we can provide guidance on which solutions work best for your scenario
+* Encourage the community to develop benchmarks on top of our works, so we can better quantify the latest proposed solutions combined with current popular frameworks, especially in this crazy fast-moving area
+
+# Parameters
+Here are the parameters (if applicable) that you can configure for running the benchmark:
+* **PROMPT** - Prompt sent in for inference (configure the length of prompt, choose from 5, 25, 50, 100, 500, 1k and 2k)
+* **MAX_NEW_TOKENS** - Max number of tokens generated
+* **CONCURRENT_LEVELS** - Max number of concurrent requests
+* **MODEL_PATH** - Model source
+* **MODEL_HEADERS** - Request headers
+* **SAFE_CHECK** - Content safety check (either Azure service or simulated latency)
+* **THRESHOLD_TPS** - Threshold TPS (threshold for tokens per second below which we deem the query to be slow)
+* **TOKENIZER_PATH** - Tokenizer source
+* **RANDOM_PROMPT_LENGTH** - Random prompt length (for pretrained models)
+* **NUM_GPU** - Number of GPUs for request dispatch among multiple containers
+* **TEMPERATURE** - Temperature for inference
+* **TOP_P** - Top_p for inference
+* **MODEL_ENDPOINTS** - Container endpoints
+* Model parallelism or model replicas - Load one model into multiple GPUs or multiple model replicas on one instance. More detail in the README files for specific containers.
+
+You can also configure other model hyperparameters as part of the request payload.  
+All these parameters are stored in ```parameter.json``` and real prompts are stored in ```input.jsonl```. Running the script will load these configurations.
+
+
+
+# Metrics
+The benchmark will report these metrics per instance:
+* Number of concurrent requests
+* P50 Latency(ms)
+* P99 Latency(ms)
+* Request per second (RPS)
+* Output tokens per second
+* Output tokens per second per GPU
+* Input tokens per second
+* Input tokens per second per GPU
+* Average tokens per second per request
+
+We intend to add these metrics in the future:
+* Time to first token (TTFT)
+  
+The benchmark result will be displayed in the terminal output and saved as a CSV file (```performance_metrics.csv```) which you can export to spreadsheets.
+
+# Getting Started
+Please follow the ```README.md``` in each subfolder for instructions on how to setup and run these benchmarks. 
+

+ 38 - 0
benchmarks/inference/on-prem/README.md

@@ -0,0 +1,38 @@
+# Llama-On-Prem-Benchmark
+This folder contains code to run inference benchmark for Llama 2 models on-prem with popular serving frameworks.
+The benchmark will focus on overall inference **throughput** for running containers on one instance (single or multiple GPUs) that you can acquire from cloud service providers such as Azure and AWS. You can also run this benchmark on local laptop or desktop.  
+We support benchmark on these serving framework:
+* [vLLM](https://github.com/vllm-project/vllm)
+
+
+# vLLM - Getting Started
+To get started, we first need to deploy containers on-prem as a API host. Follow the guidance [here](https://github.com/facebookresearch/llama-recipes/blob/main/demo_apps/llama-on-prem.md#setting-up-vllm-with-llama-2) to deploy vLLM on-prem.
+Note that in common scenario which overall throughput is important, we suggest you prioritize deploying as many model replicas as possible to reach higher overall throughput and request-per-second (RPS), comparing to deploy one model container among multiple GPUs for model parallelism. Additionally, as deploying multiple model replicas, there is a need for a higher level wrapper to handle the load balancing which here has been simulated in the benchmark scripts.  
+For example, we have an instance from Azure that has 8xA100 80G GPUs, and we want to deploy the Llama 2 70B chat model, which is around 140GB with FP16. So for deployment we can do:
+* 1x70B model parallel on 8 GPUs, each GPU RAM takes around 17.5GB for loading model weights.
+* 2x70B models each use 4 GPUs, each GPU RAM takes around 35GB for loading model weights.
+* 4x70B models each use 2 GPUs, each GPU RAM takes around 70GB for loading model weights. (Preferred configuration for max overall throughput. Note that you will have 4 endpoints hosted on different ports and the benchmark script will route requests into each model equally)
+
+Here are examples for deploying 2x70B chat models over 8 GPUs with vLLM.
+```
+CUDA_VISIBLE_DEVICES=0,1,2,3 python -m vllm.entrypoints.openai.api_server  --model meta-llama/Llama-2-70b-chat-hf --tensor-parallel-size 4 --disable-log-requests --port 8000 
+CUDA_VISIBLE_DEVICES=4,5,6,7 python -m vllm.entrypoints.openai.api_server  --model meta-llama/Llama-2-70b-chat-hf --tensor-parallel-size 4 --disable-log-requests --port 8001 
+```
+Once you have finished deployment, you can use the command below to run benchmark scripts in a separate terminal. 
+
+```
+python chat_vllm_benchmark.py
+```
+<!-- markdown-link-check-disable -->
+If you are going to use [Azure AI content check](https://azure.microsoft.com/en-us/products/ai-services/ai-content-safety), then you should install dependencies as shown below in your terminal:
+<!-- markdown-link-check-enable -->
+```
+pip install azure-ai-contentsafety azure-core
+```
+Besides chat models, we also provide benchmark scripts for running pretrained models for text completion tasks. To better simulate the real traffic, we generate configurable random token prompt as input. In this process, we select vocabulary that is longer than 2 tokens so the generated words are closer to the English, rather than symbols.
+However, random token prompts can't be applied for chat model benchmarks, since the chat model expects a valid question. By feeding random prompts, chat models rarely provide answers that is meeting our ```MAX_NEW_TOKEN``` requirement, defeating the purpose of running throughput benchmarks. Hence for chat models, the questions are copied over to form long inputs such as for 2k and 4k inputs.   
+To run pretrained model benchmark, follow the command below.
+```
+python pretrained_vllm_benchmark.py
+```
+

+ 205 - 0
benchmarks/inference/on-prem/vllm/chat_vllm_benchmark.py

@@ -0,0 +1,205 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import csv
+import json
+import time
+import random
+import threading
+import numpy as np
+import requests
+import transformers
+import torch
+
+# Imports for Azure content safety
+from azure.ai.contentsafety import ContentSafetyClient
+from azure.core.credentials import AzureKeyCredential
+from azure.core.exceptions import HttpResponseError
+from azure.ai.contentsafety.models import AnalyzeTextOptions
+
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from typing import Dict, Tuple, List
+
+
+
+with open('input.jsonl') as input:
+    prompt_data = json.load(input)
+
+# Prompt data stored in json file. Choose from number of tokens - 5, 25, 50, 100, 500, 1k, 2k.
+# You can also configure and add your own prompt in input.jsonl
+PROMPT = prompt_data["1k"] 
+
+with open('parameters.json') as parameters:
+    params = json.load(parameters)
+
+MAX_NEW_TOKENS = params["MAX_NEW_TOKENS"]
+CONCURRENT_LEVELS = params["CONCURRENT_LEVELS"]
+# Replace with your own deployment
+MODEL_PATH = params["MODEL_PATH"]
+MODEL_HEADERS = params["MODEL_HEADERS"]
+SAFE_CHECK = params["SAFE_CHECK"]
+# Threshold for tokens per second below which we deem the query to be slow
+THRESHOLD_TPS = params["THRESHOLD_TPS"] 
+# Default Llama tokenizer, replace with your own tokenizer 
+TOKENIZER_PATH = params["TOKENIZER_PATH"] 
+TEMPERATURE = params["TEMPERATURE"]
+TOP_P = params["TOP_P"]
+# Add your model endpoints here, specify the port number. You can acquire the endpoint when creating a on-prem server like vLLM.
+# Group of model endpoints - Send balanced requests to each endpoint for batch maximization.  
+MODEL_ENDPOINTS = params["MODEL_ENDPOINTS"]
+
+# Get number of GPUs on this instance
+if torch.cuda.is_available():
+    NUM_GPU = torch.cuda.device_count()
+else:
+    print("No available GPUs")
+
+
+# This tokenizer is downloaded from Azure model catalog for each specific models. The main purpose is to decode the reponses for token calculation
+tokenizer = transformers.AutoTokenizer.from_pretrained(TOKENIZER_PATH)
+
+num_token_input_prompt = len(tokenizer.encode(PROMPT))
+print(f"Number of token for input prompt: {num_token_input_prompt}")
+
+# Azure content safety analysis
+def analyze_prompt(input):
+    start_time = time.time()
+
+    # Obtain credentials
+    key = "" #Add your AZURE_CONTENT_SAFETY_KEY
+    endpoint = "" #Add your AZURE_CONTENT_SAFETY_ENDPOINT
+
+    # Create a content safety client
+    client = ContentSafetyClient(endpoint, AzureKeyCredential(key))
+
+    # Create request
+    request = AnalyzeTextOptions(text=input)
+
+    # Analyze prompt
+    try:
+        response = client.analyze_text(request)
+    except HttpResponseError as e:
+        print("prompt failed due to content safety filtering.")
+        if e.error:
+            print(f"Error code: {e.error.code}")
+            print(f"Error message: {e.error.message}")
+            raise
+        print(e)
+        raise
+
+    analyze_end_time = time.time()
+    # The round trip latency for using Azure content safety check
+    analyze_latency = (analyze_end_time - start_time) * 1000
+
+
+# Simple round-robin to dispatch requests into different containers
+executor_id = 0
+lock = threading.Lock()
+
+def generate_text() -> Tuple[int, int]:
+    headers = MODEL_HEADERS
+    payload = {
+        "model" : MODEL_PATH,
+        "messages" : [
+            {
+                "role": "user",
+                "content": PROMPT
+            }
+        ],
+        "stream" : False,
+        "temperature" : TEMPERATURE,
+        "top_p" : TOP_P,
+        "max_tokens" : MAX_NEW_TOKENS
+    }
+
+    start_time = time.time()
+
+    if(SAFE_CHECK):
+        # Function to send prompts for safety check. Add delays for request round-trip that count towards overall throughput measurement.
+        # Expect NO returns from calling this function. If you want to check the safety check results, print it out within the function itself.
+        analyze_prompt(PROMPT)
+        # Or add delay simulation if you don't want to use Azure Content Safety check. The API round-trip for this check is around 0.3-0.4 seconds depends on where you located. You can use something like this: time.sleep(random.uniform(0.3, 0.4))
+
+    # Acquire lock to dispatch the request
+    lock.acquire()
+    global executor_id
+    if executor_id != len(MODEL_ENDPOINTS)-1:
+        executor_id += 1
+        endpoint_id = executor_id
+    else:
+        executor_id = 0
+        endpoint_id = executor_id
+    lock.release()
+
+    # Send request
+    response = requests.post(MODEL_ENDPOINTS[endpoint_id], headers=headers, json=payload)
+
+    if(SAFE_CHECK):
+        # Function to send prompts for safety check. Add delays for request round-trip that count towards overall throughput measurement.
+        # Expect NO returns from calling this function. If you want to check the safety check results, print it out within the function itself.
+        analyze_prompt(PROMPT)
+        # Or add delay simulation if you don't want to use Azure Content Safety check. The API round-trip for this check is around 0.3-0.4 seconds depends on where you located. You can use something like this: time.sleep(random.uniform(0.3, 0.4))
+
+    end_time = time.time()
+    # Convert to ms
+    latency = (end_time - start_time) * 1000  
+
+    if response.status_code != 200:
+        raise ValueError(f"Error: {response.content}")
+    output = json.loads(response.content)["choices"][0]["message"]["content"]
+
+    token_count = len(tokenizer.encode(output))
+    return latency, token_count
+
+
+def evaluate_performance(concurrent_requests: int) -> Tuple[float, float, float, float, float, float, float, List[float]]:
+    latencies = []
+    total_output_tokens = 0
+    output_tokens_per_second_each_request = []
+    start_time = time.time()
+
+    # Init multi-thread execution 
+    with ThreadPoolExecutor(max_workers=concurrent_requests) as executor:
+        future_to_req = {executor.submit(generate_text): i for i in range(concurrent_requests)}
+        for future in as_completed(future_to_req):
+            latency, token_count = future.result()
+            latencies.append(latency)
+            total_output_tokens += token_count
+            # Calculate tokens per second for this request
+            tokens_per_sec = token_count / (latency / 1000)
+            output_tokens_per_second_each_request.append(tokens_per_sec)
+
+    end_time = time.time()
+    total_time = end_time - start_time
+    # RPS (requests per second)
+    rps = concurrent_requests / total_time  
+    # Overall tokens per second
+    output_tokens_per_second_overall = total_output_tokens / total_time  
+    input_tokens_per_second_overall = (num_token_input_prompt * concurrent_requests) / total_time
+    output_tokens_per_second_per_gpu = output_tokens_per_second_overall / NUM_GPU
+    input_tokens_per_second_per_gpu = input_tokens_per_second_overall / NUM_GPU
+    p50_latency = np.percentile(latencies, 50)
+    p99_latency = np.percentile(latencies, 99)
+
+    # Count the number of requests below the token-per-second threshold
+    below_threshold_count = sum(1 for tps in output_tokens_per_second_each_request if tps < THRESHOLD_TPS)
+    output_tokens_per_second_per_request = sum(output_tokens_per_second_each_request)/len(output_tokens_per_second_each_request)
+
+    return p50_latency, p99_latency, rps, output_tokens_per_second_overall, output_tokens_per_second_per_gpu, input_tokens_per_second_overall, input_tokens_per_second_per_gpu, output_tokens_per_second_per_request, below_threshold_count
+
+
+
+# Print markdown
+print("| Number of Concurrent Requests | P50 Latency (ms) | P99 Latency (ms) | RPS | Output Tokens per Second | Output Tokens per Second per GPU | Input Tokens per Second | Input Tokens per Second per GPU |Average Output Tokens per Second per Request | Number of Requests Below Threshold |")
+print("|-------------------------------|------------------|------------------|------------------|-------------------|---------------------------|---------------------|------------------------|-------------------------------------- | ---------------------------------- |")
+
+# Save to file
+csv_file = "performance_metrics.csv"
+with open(csv_file, "w", newline='') as f:
+    writer = csv.writer(f)
+    writer.writerow(["Number of Concurrent Requests", "P50 Latency (ms)", "P99 Latency (ms)", "RPS", "Output Tokens per Second", "Output Tokens per Second per GPU", "Input Tokens per Second", "Input Tokens per Second per GPU", "Average Output Tokens per Second per Request"])
+
+    for level in CONCURRENT_LEVELS:
+        p50_latency, p99_latency, rps, output_tokens_per_second_overall, output_tokens_per_second_per_gpu, input_tokens_per_second_overall, input_tokens_per_second_per_gpu, output_tokens_per_second_per_request, below_threshold_count = evaluate_performance(level)
+        print(f"| {level} | {p50_latency:.2f} | {p99_latency:.2f} | {rps:.2f} | {output_tokens_per_second_overall:.2f} | {output_tokens_per_second_per_gpu:.2f} | {input_tokens_per_second_overall:.2f} | {input_tokens_per_second_per_gpu:.2f} | {output_tokens_per_second_per_request:.2f} | {below_threshold_count:.2f} |")
+        writer.writerow([level, round(p50_latency, 2), round(p99_latency, 2), round(rps, 2), round(output_tokens_per_second_overall, 2), round(output_tokens_per_second_per_gpu, 2), round(input_tokens_per_second_overall, 2), round(input_tokens_per_second_per_gpu, 2), round(output_tokens_per_second_per_request, 2)])

File diff ditekan karena terlalu besar
+ 9 - 0
benchmarks/inference/on-prem/vllm/input.jsonl


+ 15 - 0
benchmarks/inference/on-prem/vllm/parameters.json

@@ -0,0 +1,15 @@
+{
+    "MAX_NEW_TOKENS" : 256,
+    "CONCURRENT_LEVELS" : [1, 2, 4, 8, 16, 32, 64, 128, 256],
+    "MODEL_PATH" : "meta-llama/Llama-2-7b-chat-hf",
+    "MODEL_HEADERS" : {"Content-Type": "application/json"},
+    "SAFE_CHECK" : true,
+    "THRESHOLD_TPS" : 7,
+    "TOKENIZER_PATH" : "../../tokenizer",
+    "RANDOM_PROMPT_LENGTH" : 1000,
+    "TEMPERATURE" : 0.6,
+    "TOP_P" : 0.9,
+    "MODEL_ENDPOINTS" : [
+        "http://localhost:8000/v1/chat/completions"
+    ]
+}

+ 215 - 0
benchmarks/inference/on-prem/vllm/pretrained_vllm_benchmark.py

@@ -0,0 +1,215 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import csv
+import json
+import time
+import random
+import threading
+import numpy as np
+import requests
+import transformers
+import torch
+
+#imports for Azure content safety
+from azure.ai.contentsafety import ContentSafetyClient
+from azure.core.credentials import AzureKeyCredential
+from azure.core.exceptions import HttpResponseError
+from azure.ai.contentsafety.models import AnalyzeTextOptions
+
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from typing import Dict, Tuple, List
+
+
+# Predefined inputs
+with open('input.jsonl') as input:
+    prompt_data = json.load(input)
+
+with open('parameters.json') as parameters:
+    params = json.load(parameters)
+
+MAX_NEW_TOKENS = params["MAX_NEW_TOKENS"]
+CONCURRENT_LEVELS = params["CONCURRENT_LEVELS"]
+# Replace with your own deployment
+MODEL_PATH = params["MODEL_PATH"]
+MODEL_HEADERS = params["MODEL_HEADERS"]
+SAFE_CHECK = params["SAFE_CHECK"]
+# Threshold for tokens per second below which we deem the query to be slow
+THRESHOLD_TPS = params["THRESHOLD_TPS"] 
+# Replace with your own tokenizer 
+TOKENIZER_PATH = params["TOKENIZER_PATH"] 
+RANDOM_PROMPT_LENGTH = params["RANDOM_PROMPT_LENGTH"]
+TEMPERATURE = params["TEMPERATURE"]
+TOP_P = params["TOP_P"]
+# Add your model endpoints here, specify the port number. You can acquire the endpoint when creating a on-prem server like vLLM.
+# Group of model endpoints - Send balanced requests to each endpoint for batch maximization.  
+MODEL_ENDPOINTS = params["MODEL_ENDPOINTS"]
+
+#Get number of GPUs on this instance
+if torch.cuda.is_available():
+    NUM_GPU = torch.cuda.device_count()
+else:
+    print("No available GPUs")
+
+
+# This tokenizer is downloaded from Azure model catalog for each specific models. The main purpose is to decode the reponses for token calculation
+tokenizer = transformers.AutoTokenizer.from_pretrained(TOKENIZER_PATH)
+
+# Select vocabulary that is longer than 2 tokens (closer to real words) and close to the English (not foolproof)
+vocab = [token for token in tokenizer.get_vocab().keys() if len(token) > 2 and all(ord(c) < 128 for c in token)]
+
+def generate_random_prompt(num_tokens):
+    generated_tokens_count = 0
+    selected_tokens = ""
+    while generated_tokens_count < num_tokens:
+        selected_tokens += random.choice(vocab)
+        selected_tokens += " "
+        generated_tokens_count = len(tokenizer.encode(selected_tokens))
+
+    return selected_tokens
+
+PROMPT = generate_random_prompt(RANDOM_PROMPT_LENGTH)
+num_token_input_prompt = len(tokenizer.encode(PROMPT))
+print(f"Number of token for input prompt: {num_token_input_prompt}")
+
+
+# Azure content safety analysis
+def analyze_prompt(input):
+    start_time = time.time()
+
+    # Obtain credentials
+    key = "" #Add your AZURE_CONTENT_SAFETY_KEY
+    endpoint = "" #Add your AZURE_CONTENT_SAFETY_ENDPOINT
+
+    # Create a content safety client
+    client = ContentSafetyClient(endpoint, AzureKeyCredential(key))
+
+    # Create request
+    request = AnalyzeTextOptions(text=input)
+
+    # Analyze prompt
+    try:
+        response = client.analyze_text(request)
+    except HttpResponseError as e:
+        print("prompt failed due to content safety filtering.")
+        if e.error:
+            print(f"Error code: {e.error.code}")
+            print(f"Error message: {e.error.message}")
+            raise
+        print(e)
+        raise
+
+    analyze_end_time = time.time()
+    # The round trip latency for using Azure content safety check
+    analyze_latency = (analyze_end_time - start_time) * 1000
+
+
+# Simple round-robin to dispatch requests into different containers
+executor_id = 0
+lock = threading.Lock()
+
+def generate_text() -> Tuple[int, int]:
+    headers = MODEL_HEADERS
+    payload = {
+        "model" : MODEL_PATH,
+        "messages" : [
+            {
+                "role": "user",
+                "content": PROMPT
+            }
+        ],
+        "stream" : False,
+        "temperature" : TEMPERATURE,
+        "top_p" : TOP_P,
+        "max_tokens" : MAX_NEW_TOKENS
+    }
+
+    start_time = time.time()
+
+    if(SAFE_CHECK):
+        # Function to send prompts for safety check. Add delays for request round-trip that count towards overall throughput measurement.
+        # Expect NO returns from calling this function. If you want to check the safety check results, print it out within the function itself.
+        analyze_prompt(PROMPT)
+        # Or add delay simulation if you don't want to use Azure Content Safety check. The API round-trip for this check is around 0.3-0.4 seconds depends on where you located. You can use something like this: time.sleep(random.uniform(0.3, 0.4))
+
+    lock.acquire()
+    global executor_id
+    if executor_id != len(MODEL_ENDPOINTS)-1:
+        executor_id += 1
+        endpoint_id = executor_id
+    else:
+        executor_id = 0
+        endpoint_id = executor_id
+    lock.release()
+
+    response = requests.post(MODEL_ENDPOINTS[endpoint_id], headers=headers, json=payload)
+
+    if(SAFE_CHECK):
+        # Function to send prompts for safety check. Add delays for request round-trip that count towards overall throughput measurement.
+        # Expect NO returns from calling this function. If you want to check the safety check results, print it out within the function itself.
+        analyze_prompt(PROMPT)
+        # Or add delay simulation if you don't want to use Azure Content Safety check. The API round-trip for this check is around 0.3-0.4 seconds depends on where you located. You can use something like this: time.sleep(random.uniform(0.3, 0.4))
+
+    end_time = time.time()
+    # Convert to ms
+    latency = (end_time - start_time) * 1000 
+
+    if response.status_code != 200:
+        raise ValueError(f"Error: {response.content}")
+    output = json.loads(response.content)["choices"][0]["message"]["content"]
+
+    token_count = len(tokenizer.encode(output))
+    return latency, token_count
+
+
+def evaluate_performance(concurrent_requests: int) -> Tuple[float, float, float, float, float, float, float, List[float]]:
+    latencies = []
+    total_output_tokens = 0
+    output_tokens_per_second_each_request = []
+    start_time = time.time()
+
+    # Init multi-thread execution 
+    with ThreadPoolExecutor(max_workers=concurrent_requests) as executor:
+        future_to_req = {executor.submit(generate_text): i for i in range(concurrent_requests)}
+        for future in as_completed(future_to_req):
+            latency, token_count = future.result()
+            latencies.append(latency)
+            total_output_tokens += token_count
+            # Calculate tokens per second for this request
+            tokens_per_sec = token_count / (latency / 1000)
+            output_tokens_per_second_each_request.append(tokens_per_sec)
+
+    end_time = time.time()
+    total_time = end_time - start_time
+    # RPS (requests per second)
+    rps = concurrent_requests / total_time  
+    # Overall tokens per second
+    output_tokens_per_second_overall = total_output_tokens / total_time  
+    input_tokens_per_second_overall = (num_token_input_prompt * concurrent_requests) / total_time
+    output_tokens_per_second_per_gpu = output_tokens_per_second_overall / NUM_GPU
+    input_tokens_per_second_per_gpu = input_tokens_per_second_overall / NUM_GPU
+    p50_latency = np.percentile(latencies, 50)
+    p99_latency = np.percentile(latencies, 99)
+
+    # Count the number of requests below the token-per-second threshold
+    below_threshold_count = sum(1 for tps in output_tokens_per_second_each_request if tps < THRESHOLD_TPS)
+    output_tokens_per_second_per_request = sum(output_tokens_per_second_each_request)/len(output_tokens_per_second_each_request)
+
+    return p50_latency, p99_latency, rps, output_tokens_per_second_overall, output_tokens_per_second_per_gpu, input_tokens_per_second_overall, input_tokens_per_second_per_gpu, output_tokens_per_second_per_request, below_threshold_count
+
+
+
+# Print markdown
+print("| Number of Concurrent Requests | P50 Latency (ms) | P99 Latency (ms) | RPS | Output Tokens per Second | Output Tokens per Second per GPU | Input Tokens per Second | Input Tokens per Second per GPU |Average Output Tokens per Second per Request | Number of Requests Below Threshold |")
+print("|-------------------------------|------------------|------------------|------------------|-------------------|---------------------------|---------------------|------------------------|-------------------------------------- | ---------------------------------- |")
+
+# Save to file
+csv_file = "performance_metrics.csv"
+with open(csv_file, "w", newline='') as f:
+    writer = csv.writer(f)
+    writer.writerow(["Number of Concurrent Requests", "P50 Latency (ms)", "P99 Latency (ms)", "RPS", "Output Tokens per Second", "Output Tokens per Second per GPU", "Input Tokens per Second", "Input Tokens per Second per GPU", "Average Output Tokens per Second per Request"])
+
+    for level in CONCURRENT_LEVELS:
+        p50_latency, p99_latency, rps, output_tokens_per_second_overall, output_tokens_per_second_per_gpu, input_tokens_per_second_overall, input_tokens_per_second_per_gpu, output_tokens_per_second_per_request, below_threshold_count = evaluate_performance(level)
+        print(f"| {level} | {p50_latency:.2f} | {p99_latency:.2f} | {rps:.2f} | {output_tokens_per_second_overall:.2f} | {output_tokens_per_second_per_gpu:.2f} | {input_tokens_per_second_overall:.2f} | {input_tokens_per_second_per_gpu:.2f} | {output_tokens_per_second_per_request:.2f} | {below_threshold_count:.2f} |")
+        writer.writerow([level, round(p50_latency, 2), round(p99_latency, 2), round(rps, 2), round(output_tokens_per_second_overall, 2), round(output_tokens_per_second_per_gpu, 2), round(input_tokens_per_second_overall, 2), round(input_tokens_per_second_per_gpu, 2), round(output_tokens_per_second_per_request, 2)])

+ 23 - 0
benchmarks/inference/tokenizer/special_tokens_map.json

@@ -0,0 +1,23 @@
+{
+  "bos_token": {
+    "content": "<s>",
+    "lstrip": false,
+    "normalized": true,
+    "rstrip": false,
+    "single_word": false
+  },
+  "eos_token": {
+    "content": "</s>",
+    "lstrip": false,
+    "normalized": true,
+    "rstrip": false,
+    "single_word": false
+  },
+  "unk_token": {
+    "content": "<unk>",
+    "lstrip": false,
+    "normalized": true,
+    "rstrip": false,
+    "single_word": false
+  }
+}

File diff ditekan karena terlalu besar
+ 93391 - 0
benchmarks/inference/tokenizer/tokenizer.json


TEMPAT SAMPAH
benchmarks/inference/tokenizer/tokenizer.model


+ 35 - 0
benchmarks/inference/tokenizer/tokenizer_config.json

@@ -0,0 +1,35 @@
+{
+  "add_bos_token": true,
+  "add_eos_token": false,
+  "bos_token": {
+    "__type": "AddedToken",
+    "content": "<s>",
+    "lstrip": false,
+    "normalized": true,
+    "rstrip": false,
+    "single_word": false
+  },
+  "clean_up_tokenization_spaces": false,
+  "eos_token": {
+    "__type": "AddedToken",
+    "content": "</s>",
+    "lstrip": false,
+    "normalized": true,
+    "rstrip": false,
+    "single_word": false
+  },
+  "legacy": true,
+  "use_default_system_prompt": false,
+  "model_max_length": 1000000000000000019884624838656,
+  "pad_token": null,
+  "sp_model_kwargs": {},
+  "tokenizer_class": "LlamaTokenizerFast",
+  "unk_token": {
+    "__type": "AddedToken",
+    "content": "<unk>",
+    "lstrip": false,
+    "normalized": true,
+    "rstrip": false,
+    "single_word": false
+  }
+}

+ 11 - 5
demo_apps/RAG_Chatbot_example/RAG_Chatbot_Example.ipynb

@@ -241,12 +241,18 @@
    ]
   },
   {
-   "cell_type": "markdown",
-   "metadata": {},
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "vscode": {
+     "languageId": "plaintext"
+    }
+   },
+   "outputs": [],
    "source": [
-    "model = meta-llama/Llama-2-7b-chat-hf  \n",
-    "volume = $PWD/data  \n",
-    "token = #Your own HF tokens  \n",
+    "model = meta-llama/Llama-2-7b-chat-hf\n",
+    "volume = $PWD/data\n",
+    "token = #Your own HF tokens\n",
     "docker run --gpus all --shm-size 1g -e HUGGING_FACE_HUB_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.1.0 --model-id $model"
    ]
   },

File diff ditekan karena terlalu besar
+ 2 - 2
demo_apps/README.md


+ 8 - 0
docs/inference.md

@@ -79,6 +79,14 @@ To run the code infilling example:
 python examples/code_llama/code_infilling_example.py --model_name MODEL_NAME --prompt_file examples/code_llama/code_infilling_prompt.txt --temperature 0.2 --top_p 0.9
 
 ```
+To run the 70B Instruct model example run the following (you'll need to enter the system and user prompts to instruct the model):
+
+```bash
+
+python examples/code_llama/code_instruct_example.py --model_name codellama/CodeLlama-70b-Instruct-hf --temperature 0.2 --top_p 0.9
+
+```
+You can learn more about the chat prompt template [on HF](https://huggingface.co/codellama/CodeLlama-70b-Instruct-hf#chat-prompt) and [original Code Llama repository](https://github.com/facebookresearch/codellama/blob/main/README.md#fine-tuned-instruction-models). HF tokenizer has already taken care of the chat template as shown in this example. 
 
 ### Llama Guard
 

File diff ditekan karena terlalu besar
+ 145 - 0
eval/README.md


+ 230 - 0
eval/eval.py

@@ -0,0 +1,230 @@
+import argparse
+import json
+import logging
+import os
+import re
+import sys
+from pathlib import Path
+
+import numpy as np
+import lm_eval
+from lm_eval import evaluator, tasks
+from lm_eval.utils import make_table
+
+
+def _handle_non_serializable(o):
+    if isinstance(o, np.int64) or isinstance(o, np.int32):
+        return int(o)
+    elif isinstance(o, set):
+        return list(o)
+    else:
+        return str(o)
+
+
+def setup_logging(verbosity):
+    logging.basicConfig(
+        level=verbosity.upper(), format="%(asctime)s - %(levelname)s - %(message)s"
+    )
+    return logging.getLogger(__name__)
+
+
+def handle_output(args, results, logger):
+    if not args.output_path:
+        if args.log_samples:
+            logger.error("Specify --output_path for logging samples.")
+            sys.exit(1)
+        logger.info(json.dumps(results, indent=2, default=_handle_non_serializable))
+        return
+
+    path = Path(args.output_path)
+    if path.is_file() or path.with_name("results.json").is_file():
+        logger.warning(f"File already exists at {path}. Results will be overwritten.")
+
+    output_dir = path.parent if path.suffix in (".json", ".jsonl") else path
+    output_dir.mkdir(parents=True, exist_ok=True)
+
+    results_str = json.dumps(results, indent=2, default=_handle_non_serializable)
+    if args.show_config:
+        logger.info(results_str)
+
+    file_path = os.path.join(args.output_path, "results.json")
+    with open(file_path , "w", encoding="utf-8") as f:
+        f.write(results_str)
+
+    if args.log_samples:
+        samples = results.pop("samples", {})
+        for task_name, _ in results.get("configs", {}).items():
+            output_name = re.sub(r"/|=", "__", args.model_args) + "_" + task_name
+            sample_file = output_dir.joinpath(f"{output_name}.jsonl")
+            sample_data = json.dumps(
+                samples.get(task_name, {}), indent=2, default=_handle_non_serializable
+            )
+            sample_file.write_text(sample_data, encoding="utf-8")
+
+    batch_sizes = ",".join(map(str, results.get("config", {}).get("batch_sizes", [])))
+    summary = f"{args.model} ({args.model_args}), gen_kwargs: ({args.gen_kwargs}), limit: {args.limit}, num_fewshot: {args.num_fewshot}, batch_size: {args.batch_size}{f' ({batch_sizes})' if batch_sizes else ''}"
+    logger.info(summary)
+    logger.info(make_table(results))
+    if "groups" in results:
+        logger.info(make_table(results, "groups"))
+
+
+def load_tasks(args):
+    tasks.initialize_tasks()
+    if args.open_llm_leaderboard_tasks:
+        current_dir = os.getcwd()
+        config_dir = os.path.join(current_dir, "open_llm_leaderboard")
+        lm_eval.tasks.include_path(config_dir)
+        return [
+            "arc_challenge_25_shot",
+            "hellaswag_10_shot",
+            "truthfulqa_mc2",
+            "winogrande_5_shot",
+            "gsm8k",
+            "mmlu",
+        ]
+    return args.tasks.split(",") if args.tasks else []
+
+
+def parse_eval_args():
+    parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
+    parser.add_argument(
+        "--model", "-m", default="hf", help="Name of model, e.g., `hf`."
+    )
+    parser.add_argument(
+        "--tasks",
+        "-t",
+        default=None,
+        help="Comma-separated list of tasks, or 'list' to display available tasks.",
+    )
+    parser.add_argument(
+        "--model_args",
+        "-a",
+        default="",
+        help="Comma-separated string arguments for model, e.g., `pretrained=EleutherAI/pythia-160m`.",
+    )
+    parser.add_argument(
+        "--open_llm_leaderboard_tasks",
+        "-oplm",
+        action="store_true",
+        default=False,
+        help="Choose the list of tasks with specification in HF open LLM-leaderboard.",
+    )
+    parser.add_argument(
+        "--num_fewshot",
+        "-f",
+        type=int,
+        default=None,
+        help="Number of examples in few-shot context.",
+    )
+    parser.add_argument(
+        "--batch_size",
+        "-b",
+        default=1,
+        help="Batch size, can be 'auto', 'auto:N', or an integer.",
+    )
+    parser.add_argument(
+        "--max_batch_size",
+        type=int,
+        default=None,
+        help="Maximal batch size with 'auto' batch size.",
+    )
+    parser.add_argument(
+        "--device", default=None, help="Device for evaluation, e.g., 'cuda', 'cpu'."
+    )
+    parser.add_argument(
+        "--output_path", "-o", type=str, default=None, help="Path for saving results."
+    )
+    parser.add_argument(
+        "--limit",
+        "-L",
+        type=float,
+        default=None,
+        help="Limit number of examples per task.",
+    )
+    parser.add_argument(
+        "--use_cache", "-c", default=None, help="Path to cache db file, if used."
+    )
+    parser.add_argument(
+        "--verbosity",
+        "-v",
+        default="INFO",
+        help="Logging level: CRITICAL, ERROR, WARNING, INFO, DEBUG.",
+    )
+    parser.add_argument(
+        "--gen_kwargs",
+        default=None,
+        help="Generation kwargs for tasks that support it.",
+    )
+    parser.add_argument(
+        "--check_integrity",
+        action="store_true",
+        help="Whether to run the relevant part of the test suite for the tasks.",
+    )
+    parser.add_argument(
+        "--write_out",
+        "-w",
+        action="store_true",
+        default=False,
+        help="Prints the prompt for the first few documents.",
+    )
+    parser.add_argument(
+        "--log_samples",
+        "-s",
+        action="store_true",
+        default=False,
+        help="If True, write out all model outputs and documents for per-sample measurement and post-hoc analysis.",
+    )
+    parser.add_argument(
+        "--show_config",
+        action="store_true",
+        default=False,
+        help="If True, shows the full config of all tasks at the end of the evaluation.",
+    )
+    parser.add_argument(
+        "--include_path",
+        type=str,
+        default=None,
+        help="Additional path to include if there are external tasks.",
+    )
+    parser.add_argument(
+        "--decontamination_ngrams_path", default=None
+    )  # Not currently used
+    return parser.parse_args()
+
+
+def evaluate_model(args):
+    try:
+        task_list = load_tasks(args)
+        # Customized model such as Quantized model etc.
+        # In case you are working with a custom model, you can use the following guide to add it here:
+        # https://github.com/EleutherAI/lm-evaluation-harness/blob/main/docs/interface.md#external-library-usage
+
+        # Evaluate
+        results = evaluator.simple_evaluate(
+            model=args.model,
+            model_args=args.model_args,
+            tasks=task_list,
+            num_fewshot=args.num_fewshot,
+            batch_size=args.batch_size,
+            max_batch_size=args.max_batch_size,
+            device=args.device,
+            use_cache=args.use_cache,
+            limit=args.limit,
+            decontamination_ngrams_path=args.decontamination_ngrams_path,
+            check_integrity=args.check_integrity,
+            write_out=args.write_out,
+            log_samples=args.log_samples,
+            gen_kwargs=args.gen_kwargs,
+        )
+        handle_output(args, results, logger)
+
+    except Exception as e:
+        logger.error(f"An error occurred during evaluation: {e}")
+        sys.exit(1)
+
+
+if __name__ == "__main__":
+    args = parse_eval_args()
+    logger = setup_logging(args.verbosity)
+    evaluate_model(args)

+ 22 - 0
eval/open_llm_eval_prep.sh

@@ -0,0 +1,22 @@
+#!/bin/bash
+
+# Prompt the user for the EVAL_PATH
+read -p "Enter the asbolute path to the lm-evaluation-harness: " EVAL_PATH
+conda activate 
+# Directory containing YAML files
+DIR="open_llm_leaderboard"
+
+# Check if the directory exists
+if [ ! -d "$DIR" ]; then
+    echo "Error: Directory '$DIR' not found."
+    exit 1
+fi
+
+# Iterate over YAML files in the directory and update them
+for YAML_FILE in "$DIR"/*.yaml
+do
+    if [ -f "$YAML_FILE" ]; then
+        sed -i 's|{\$EVAL_PATH}|'"$EVAL_PATH"'|g' "$YAML_FILE"
+        echo "Updated $YAML_FILE with EVAL_PATH: $EVAL_PATH"
+    fi
+done

+ 6 - 0
eval/open_llm_leaderboard/arc_challeneg_25shots.yaml

@@ -0,0 +1,6 @@
+include: {$EVAL_PATH}/lm_eval/tasks/arc/arc_challenge.yaml
+task: arc_challenge_25_shot
+task_alias: arc 25 shot
+num_fewshot: 25
+metric_list:
+  - metric: acc_norm

+ 6 - 0
eval/open_llm_leaderboard/hellaswag_10shots.yaml

@@ -0,0 +1,6 @@
+include: {$EVAL_PATH}/lm_eval/tasks/hellaswag/hellaswag.yaml
+task: hellaswag_10_shot
+task_alias: hellaswag 10 shot
+num_fewshot: 10
+metric_list:
+  - metric: acc_norm

+ 24 - 0
eval/open_llm_leaderboard/hellaswag_utils.py

@@ -0,0 +1,24 @@
+import datasets
+import re
+
+
+def preprocess(text):
+    text = text.strip()
+    # NOTE: Brackets are artifacts of the WikiHow dataset portion of HellaSwag.
+    text = text.replace(" [title]", ". ")
+    text = re.sub("\\[.*?\\]", "", text)
+    text = text.replace("  ", " ")
+    return text
+
+
+def process_docs(dataset: datasets.Dataset) -> datasets.Dataset:
+    def _process_doc(doc):
+        ctx = doc["ctx_a"] + " " + doc["ctx_b"].capitalize()
+        out_doc = {
+            "query": preprocess(doc["activity_label"] + ": " + ctx),
+            "choices": [preprocess(ending) for ending in doc["endings"]],
+            "gold": int(doc["label"]),
+        }
+        return out_doc
+
+    return dataset.map(_process_doc)

+ 9 - 0
eval/open_llm_leaderboard/mmlu_5shots.yaml

@@ -0,0 +1,9 @@
+include: {$EVAL_PATH}/lm_eval/tasks/mmlu/default/_mmlu.yaml
+task:
+  - mmlu_stem
+  - mmlu_other
+  - mmlu_social_sciences
+  - mmlu_humanities
+num_fewshot: 5
+metric_list:
+  - metric: acc

+ 6 - 0
eval/open_llm_leaderboard/winogrande_5shots.yaml

@@ -0,0 +1,6 @@
+include: {$EVAL_PATH}/lm_eval/tasks/winogrande/default.yaml
+task: winogrande_5_shot
+task_alias: winogrande 5 shot
+num_fewshot: 5
+metric_list:
+  - metric: acc

+ 784 - 0
examples/Prompt_Engineering_with_Llama_2.ipynb

@@ -0,0 +1,784 @@
+{
+ "cells": [
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Prompt Engineering with Llama 2\n",
+    "\n",
+    "Prompt engineering is using natural language to produce a desired response from a large language model (LLM).\n",
+    "\n",
+    "This interactive guide covers prompt engineering & best practices with Llama 2."
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Introduction"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Why now?\n",
+    "\n",
+    "[Vaswani et al. (2017)](https://arxiv.org/abs/1706.03762) introduced the world to transformer neural networks (originally for machine translation). Transformers ushered an era of generative AI with diffusion models for image creation and large language models (`LLMs`) as **programmable deep learning networks**.\n",
+    "\n",
+    "Programming foundational LLMs is done with natural language – it doesn't require training/tuning like ML models of the past. This has opened the door to a massive amount of innovation and a paradigm shift in how technology can be deployed. The science/art of using natural language to program language models to accomplish a task is referred to as **Prompt Engineering**."
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Llama Models\n",
+    "\n",
+    "In 2023, Meta introduced the [Llama language models](https://ai.meta.com/llama/) (Llama Chat, Code Llama, Llama Guard). These are general purpose, state-of-the-art LLMs.\n",
+    "\n",
+    "Llama 2 models come in 7 billion, 13 billion, and 70 billion parameter sizes. Smaller models are cheaper to deploy and run (see: deployment and performance); larger models are more capable.\n",
+    "\n",
+    "#### Llama 2\n",
+    "1. `llama-2-7b` - base pretrained 7 billion parameter model\n",
+    "1. `llama-2-13b` - base pretrained 13 billion parameter model\n",
+    "1. `llama-2-70b` - base pretrained 70 billion parameter model\n",
+    "1. `llama-2-7b-chat` - chat fine-tuned 7 billion parameter model\n",
+    "1. `llama-2-13b-chat` - chat fine-tuned 13 billion parameter model\n",
+    "1. `llama-2-70b-chat` - chat fine-tuned 70 billion parameter model (flagship)\n"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Code Llama is a code-focused LLM built on top of Llama 2 also available in various sizes and finetunes:"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "#### Code Llama\n",
+    "1. `codellama-7b` - code fine-tuned 7 billion parameter model\n",
+    "1. `codellama-13b` - code fine-tuned 13 billion parameter model\n",
+    "1. `codellama-34b` - code fine-tuned 34 billion parameter model\n",
+    "1. `codellama-7b-instruct` - code & instruct fine-tuned 7 billion parameter model\n",
+    "2. `codellama-13b-instruct` - code & instruct fine-tuned 13 billion parameter model\n",
+    "3. `codellama-34b-instruct` - code & instruct fine-tuned 34 billion parameter model\n",
+    "1. `codellama-7b-python` - Python fine-tuned 7 billion parameter model\n",
+    "2. `codellama-13b-python` - Python fine-tuned 13 billion parameter model\n",
+    "3. `codellama-34b-python` - Python fine-tuned 34 billion parameter model"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Getting an LLM\n",
+    "\n",
+    "Large language models are deployed and accessed in a variety of ways, including:\n",
+    "\n",
+    "1. **Self-hosting**: Using local hardware to run inference. Ex. running Llama 2 on your Macbook Pro using [llama.cpp](https://github.com/ggerganov/llama.cpp).\n",
+    "    * Best for privacy/security or if you already have a GPU.\n",
+    "1. **Cloud hosting**: Using a cloud provider to deploy an instance that hosts a specific model. Ex. running Llama 2 on cloud providers like AWS, Azure, GCP, and others.\n",
+    "    * Best for customizing models and their runtime (ex. fine-tuning a model for your use case).\n",
+    "1. **Hosted API**: Call LLMs directly via an API. There are many companies that provide Llama 2 inference APIs including AWS Bedrock, Replicate, Anyscale, Together and others.\n",
+    "    * Easiest option overall."
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Hosted APIs\n",
+    "\n",
+    "Hosted APIs are the easiest way to get started. We'll use them here. There are usually two main endpoints:\n",
+    "\n",
+    "1. **`completion`**: generate a response to a given prompt (a string).\n",
+    "1. **`chat_completion`**: generate the next message in a list of messages, enabling more explicit instruction and context for use cases like chatbots."
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Tokens\n",
+    "\n",
+    "LLMs process inputs and outputs in chunks called *tokens*. Think of these, roughly, as words – each model will have its own tokenization scheme. For example, this sentence...\n",
+    "\n",
+    "> Our destiny is written in the stars.\n",
+    "\n",
+    "...is tokenized into `[\"our\", \"dest\", \"iny\", \"is\", \"written\", \"in\", \"the\", \"stars\"]` for Llama 2.\n",
+    "\n",
+    "Tokens matter most when you consider API pricing and internal behavior (ex. hyperparameters).\n",
+    "\n",
+    "Each model has a maximum context length that your prompt cannot exceed. That's 4096 tokens for Llama 2 and 100K for Code Llama. \n"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Notebook Setup\n",
+    "\n",
+    "The following APIs will be used to call LLMs throughout the guide. As an example, we'll call Llama 2 chat using [Replicate](https://replicate.com/meta/llama-2-70b-chat) and use LangChain to easily set up a chat completion API.\n",
+    "\n",
+    "To install prerequisites run:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "pip install langchain replicate"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from typing import Dict, List\n",
+    "from langchain.llms import Replicate\n",
+    "from langchain.memory import ChatMessageHistory\n",
+    "from langchain.schema.messages import get_buffer_string\n",
+    "import os\n",
+    "\n",
+    "# Get a free API key from https://replicate.com/account/api-tokens\n",
+    "os.environ[\"REPLICATE_API_TOKEN\"] = \"YOUR_KEY_HERE\"\n",
+    "\n",
+    "LLAMA2_70B_CHAT = \"meta/llama-2-70b-chat:2d19859030ff705a87c746f7e96eea03aefb71f166725aee39692f1476566d48\"\n",
+    "LLAMA2_13B_CHAT = \"meta/llama-2-13b-chat:f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d\"\n",
+    "\n",
+    "# We'll default to the smaller 13B model for speed; change to LLAMA2_70B_CHAT for more advanced (but slower) generations\n",
+    "DEFAULT_MODEL = LLAMA2_13B_CHAT\n",
+    "\n",
+    "def completion(\n",
+    "    prompt: str,\n",
+    "    model: str = DEFAULT_MODEL,\n",
+    "    temperature: float = 0.6,\n",
+    "    top_p: float = 0.9,\n",
+    ") -> str:\n",
+    "    llm = Replicate(\n",
+    "        model=model,\n",
+    "        model_kwargs={\"temperature\": temperature,\"top_p\": top_p, \"max_new_tokens\": 1000}\n",
+    "    )\n",
+    "    return llm(prompt)\n",
+    "\n",
+    "def chat_completion(\n",
+    "    messages: List[Dict],\n",
+    "    model = DEFAULT_MODEL,\n",
+    "    temperature: float = 0.6,\n",
+    "    top_p: float = 0.9,\n",
+    ") -> str:\n",
+    "    history = ChatMessageHistory()\n",
+    "    for message in messages:\n",
+    "        if message[\"role\"] == \"user\":\n",
+    "            history.add_user_message(message[\"content\"])\n",
+    "        elif message[\"role\"] == \"assistant\":\n",
+    "            history.add_ai_message(message[\"content\"])\n",
+    "        else:\n",
+    "            raise Exception(\"Unknown role\")\n",
+    "    return completion(\n",
+    "        get_buffer_string(\n",
+    "            history.messages,\n",
+    "            human_prefix=\"USER\",\n",
+    "            ai_prefix=\"ASSISTANT\",\n",
+    "        ),\n",
+    "        model,\n",
+    "        temperature,\n",
+    "        top_p,\n",
+    "    )\n",
+    "\n",
+    "def assistant(content: str):\n",
+    "    return { \"role\": \"assistant\", \"content\": content }\n",
+    "\n",
+    "def user(content: str):\n",
+    "    return { \"role\": \"user\", \"content\": content }\n",
+    "\n",
+    "def complete_and_print(prompt: str, model: str = DEFAULT_MODEL):\n",
+    "    print(f'==============\\n{prompt}\\n==============')\n",
+    "    response = completion(prompt, model)\n",
+    "    print(response, end='\\n\\n')\n"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Completion APIs\n",
+    "\n",
+    "Llama 2 models tend to be wordy and explain their rationale. Later we'll explore how to manage the response length."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "complete_and_print(\"The typical color of the sky is: \")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "complete_and_print(\"which model version are you?\")"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Chat Completion APIs\n",
+    "Chat completion models provide additional structure to interacting with an LLM. An array of structured message objects is sent to the LLM instead of a single piece of text. This message list provides the LLM with some \"context\" or \"history\" from which to continue.\n",
+    "\n",
+    "Typically, each message contains `role` and `content`:\n",
+    "* Messages with the `system` role are used to provide core instruction to the LLM by developers.\n",
+    "* Messages with the `user` role are typically human-provided messages.\n",
+    "* Messages with the `assistant` role are typically generated by the LLM."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "response = chat_completion(messages=[\n",
+    "    user(\"My favorite color is blue.\"),\n",
+    "    assistant(\"That's great to hear!\"),\n",
+    "    user(\"What is my favorite color?\"),\n",
+    "])\n",
+    "print(response)\n",
+    "# \"Sure, I can help you with that! Your favorite color is blue.\""
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### LLM Hyperparameters\n",
+    "\n",
+    "#### `temperature` & `top_p`\n",
+    "\n",
+    "These APIs also take parameters which influence the creativity and determinism of your output.\n",
+    "\n",
+    "At each step, LLMs generate a list of most likely tokens and their respective probabilities. The least likely tokens are \"cut\" from the list (based on `top_p`), and then a token is randomly selected from the remaining candidates (`temperature`).\n",
+    "\n",
+    "In other words: `top_p` controls the breadth of vocabulary in a generation and `temperature` controls the randomness within that vocabulary. A temperature of ~0 produces *almost* deterministic results.\n",
+    "\n",
+    "[Read more about temperature setting here](https://community.openai.com/t/cheat-sheet-mastering-temperature-and-top-p-in-chatgpt-api-a-few-tips-and-tricks-on-controlling-the-creativity-deterministic-output-of-prompt-responses/172683).\n",
+    "\n",
+    "Let's try it out:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def print_tuned_completion(temperature: float, top_p: float):\n",
+    "    response = completion(\"Write a haiku about llamas\", temperature=temperature, top_p=top_p)\n",
+    "    print(f'[temperature: {temperature} | top_p: {top_p}]\\n{response.strip()}\\n')\n",
+    "\n",
+    "print_tuned_completion(0.01, 0.01)\n",
+    "print_tuned_completion(0.01, 0.01)\n",
+    "# These two generations are highly likely to be the same\n",
+    "\n",
+    "print_tuned_completion(1.0, 1.0)\n",
+    "print_tuned_completion(1.0, 1.0)\n",
+    "# These two generations are highly likely to be different"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Prompting Techniques"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Explicit Instructions\n",
+    "\n",
+    "Detailed, explicit instructions produce better results than open-ended prompts:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "complete_and_print(prompt=\"Describe quantum physics in one short sentence of no more than 12 words\")\n",
+    "# Returns a succinct explanation of quantum physics that mentions particles and states existing simultaneously."
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "You can think about giving explicit instructions as using rules and restrictions to how Llama 2 responds to your prompt.\n",
+    "\n",
+    "- Stylization\n",
+    "    - `Explain this to me like a topic on a children's educational network show teaching elementary students.`\n",
+    "    - `I'm a software engineer using large language models for summarization. Summarize the following text in under 250 words:`\n",
+    "    - `Give your answer like an old timey private investigator hunting down a case step by step.`\n",
+    "- Formatting\n",
+    "    - `Use bullet points.`\n",
+    "    - `Return as a JSON object.`\n",
+    "    - `Use less technical terms and help me apply it in my work in communications.`\n",
+    "- Restrictions\n",
+    "    - `Only use academic papers.`\n",
+    "    - `Never give sources older than 2020.`\n",
+    "    - `If you don't know the answer, say that you don't know.`\n",
+    "\n",
+    "Here's an example of giving explicit instructions to give more specific results by limiting the responses to recently created sources."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "complete_and_print(\"Explain the latest advances in large language models to me.\")\n",
+    "# More likely to cite sources from 2017\n",
+    "\n",
+    "complete_and_print(\"Explain the latest advances in large language models to me. Always cite your sources. Never cite sources older than 2020.\")\n",
+    "# Gives more specific advances and only cites sources from 2020"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Example Prompting using Zero- and Few-Shot Learning\n",
+    "\n",
+    "A shot is an example or demonstration of what type of prompt and response you expect from a large language model. This term originates from training computer vision models on photographs, where one shot was one example or instance that the model used to classify an image ([Fei-Fei et al. (2006)](http://vision.stanford.edu/documents/Fei-FeiFergusPerona2006.pdf)).\n",
+    "\n",
+    "#### Zero-Shot Prompting\n",
+    "\n",
+    "Large language models like Llama 2 are unique because they are capable of following instructions and producing responses without having previously seen an example of a task. Prompting without examples is called \"zero-shot prompting\".\n",
+    "\n",
+    "Let's try using Llama 2 as a sentiment detector. You may notice that output format varies - we can improve this with better prompting."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "complete_and_print(\"Text: This was the best movie I've ever seen! \\n The sentiment of the text is: \")\n",
+    "# Returns positive sentiment\n",
+    "\n",
+    "complete_and_print(\"Text: The director was trying too hard. \\n The sentiment of the text is: \")\n",
+    "# Returns negative sentiment"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "\n",
+    "#### Few-Shot Prompting\n",
+    "\n",
+    "Adding specific examples of your desired output generally results in more accurate, consistent output. This technique is called \"few-shot prompting\".\n",
+    "\n",
+    "In this example, the generated response follows our desired format that offers a more nuanced sentiment classifer that gives a positive, neutral, and negative response confidence percentage.\n",
+    "\n",
+    "See also: [Zhao et al. (2021)](https://arxiv.org/abs/2102.09690), [Liu et al. (2021)](https://arxiv.org/abs/2101.06804), [Su et al. (2022)](https://arxiv.org/abs/2209.01975), [Rubin et al. (2022)](https://arxiv.org/abs/2112.08633).\n",
+    "\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def sentiment(text):\n",
+    "    response = chat_completion(messages=[\n",
+    "        user(\"You are a sentiment classifier. For each message, give the percentage of positive/netural/negative.\"),\n",
+    "        user(\"I liked it\"),\n",
+    "        assistant(\"70% positive 30% neutral 0% negative\"),\n",
+    "        user(\"It could be better\"),\n",
+    "        assistant(\"0% positive 50% neutral 50% negative\"),\n",
+    "        user(\"It's fine\"),\n",
+    "        assistant(\"25% positive 50% neutral 25% negative\"),\n",
+    "        user(text),\n",
+    "    ])\n",
+    "    return response\n",
+    "\n",
+    "def print_sentiment(text):\n",
+    "    print(f'INPUT: {text}')\n",
+    "    print(sentiment(text))\n",
+    "\n",
+    "print_sentiment(\"I thought it was okay\")\n",
+    "# More likely to return a balanced mix of positive, neutral, and negative\n",
+    "print_sentiment(\"I loved it!\")\n",
+    "# More likely to return 100% positive\n",
+    "print_sentiment(\"Terrible service 0/10\")\n",
+    "# More likely to return 100% negative"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Role Prompting\n",
+    "\n",
+    "Llama 2 will often give more consistent responses when given a role ([Kong et al. (2023)](https://browse.arxiv.org/pdf/2308.07702.pdf)). Roles give context to the LLM on what type of answers are desired.\n",
+    "\n",
+    "Let's use Llama 2 to create a more focused, technical response for a question around the pros and cons of using PyTorch."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "complete_and_print(\"Explain the pros and cons of using PyTorch.\")\n",
+    "# More likely to explain the pros and cons of PyTorch covers general areas like documentation, the PyTorch community, and mentions a steep learning curve\n",
+    "\n",
+    "complete_and_print(\"Your role is a machine learning expert who gives highly technical advice to senior engineers who work with complicated datasets. Explain the pros and cons of using PyTorch.\")\n",
+    "# Often results in more technical benefits and drawbacks that provide more technical details on how model layers"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Chain-of-Thought\n",
+    "\n",
+    "Simply adding a phrase encouraging step-by-step thinking \"significantly improves the ability of large language models to perform complex reasoning\" ([Wei et al. (2022)](https://arxiv.org/abs/2201.11903)). This technique is called \"CoT\" or \"Chain-of-Thought\" prompting:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "complete_and_print(\"Who lived longer Elvis Presley or Mozart?\")\n",
+    "# Often gives incorrect answer of \"Mozart\"\n",
+    "\n",
+    "complete_and_print(\"Who lived longer Elvis Presley or Mozart? Let's think through this carefully, step by step.\")\n",
+    "# Gives the correct answer \"Elvis\""
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Self-Consistency\n",
+    "\n",
+    "LLMs are probablistic, so even with Chain-of-Thought, a single generation might produce incorrect results. Self-Consistency ([Wang et al. (2022)](https://arxiv.org/abs/2203.11171)) introduces enhanced accuracy by selecting the most frequent answer from multiple generations (at the cost of higher compute):"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import re\n",
+    "from statistics import mode\n",
+    "\n",
+    "def gen_answer():\n",
+    "    response = completion(\n",
+    "        \"John found that the average of 15 numbers is 40.\"\n",
+    "        \"If 10 is added to each number then the mean of the numbers is?\"\n",
+    "        \"Report the answer surrounded by three backticks, for example: ```123```\",\n",
+    "        model = LLAMA2_70B_CHAT\n",
+    "    )\n",
+    "    match = re.search(r'```(\\d+)```', response)\n",
+    "    if match is None:\n",
+    "        return None\n",
+    "    return match.group(1)\n",
+    "\n",
+    "answers = [gen_answer() for i in range(5)]\n",
+    "\n",
+    "print(\n",
+    "    f\"Answers: {answers}\\n\",\n",
+    "    f\"Final answer: {mode(answers)}\",\n",
+    "    )\n",
+    "\n",
+    "# Sample runs of Llama-2-70B (all correct):\n",
+    "# [50, 50, 750, 50, 50]  -> 50\n",
+    "# [130, 10, 750, 50, 50] -> 50\n",
+    "# [50, None, 10, 50, 50] -> 50"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Retrieval-Augmented Generation\n",
+    "\n",
+    "You'll probably want to use factual knowledge in your application. You can extract common facts from today's large models out-of-the-box (i.e. using just the model weights):"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "complete_and_print(\"What is the capital of the California?\", model = LLAMA2_70B_CHAT)\n",
+    "# Gives the correct answer \"Sacramento\""
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "However, more specific facts, or private information, cannot be reliably retrieved. The model will either declare it does not know or hallucinate an incorrect answer:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "complete_and_print(\"What was the temperature in Menlo Park on December 12th, 2023?\")\n",
+    "# \"I'm just an AI, I don't have access to real-time weather data or historical weather records.\"\n",
+    "\n",
+    "complete_and_print(\"What time is my dinner reservation on Saturday and what should I wear?\")\n",
+    "# \"I'm not able to access your personal information [..] I can provide some general guidance\""
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Retrieval-Augmented Generation, or RAG, describes the practice of including information in the prompt you've retrived from an external database ([Lewis et al. (2020)](https://arxiv.org/abs/2005.11401v4)). It's an effective way to incorporate facts into your LLM application and is more affordable than fine-tuning which may be costly and negatively impact the foundational model's capabilities.\n",
+    "\n",
+    "This could be as simple as a lookup table or as sophisticated as a [vector database]([FAISS](https://github.com/facebookresearch/faiss)) containing all of your company's knowledge:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "MENLO_PARK_TEMPS = {\n",
+    "    \"2023-12-11\": \"52 degrees Fahrenheit\",\n",
+    "    \"2023-12-12\": \"51 degrees Fahrenheit\",\n",
+    "    \"2023-12-13\": \"51 degrees Fahrenheit\",\n",
+    "}\n",
+    "\n",
+    "\n",
+    "def prompt_with_rag(retrived_info, question):\n",
+    "    complete_and_print(\n",
+    "        f\"Given the following information: '{retrived_info}', respond to: '{question}'\"\n",
+    "    )\n",
+    "\n",
+    "\n",
+    "def ask_for_temperature(day):\n",
+    "    temp_on_day = MENLO_PARK_TEMPS.get(day) or \"unknown temperature\"\n",
+    "    prompt_with_rag(\n",
+    "        f\"The temperature in Menlo Park was {temp_on_day} on {day}'\",  # Retrieved fact\n",
+    "        f\"What is the temperature in Menlo Park on {day}?\",  # User question\n",
+    "    )\n",
+    "\n",
+    "\n",
+    "ask_for_temperature(\"2023-12-12\")\n",
+    "# \"Sure! The temperature in Menlo Park on 2023-12-12 was 51 degrees Fahrenheit.\"\n",
+    "\n",
+    "ask_for_temperature(\"2023-07-18\")\n",
+    "# \"I'm not able to provide the temperature in Menlo Park on 2023-07-18 as the information provided states that the temperature was unknown.\""
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Program-Aided Language Models\n",
+    "\n",
+    "LLMs, by nature, aren't great at performing calculations. Let's try:\n",
+    "\n",
+    "$$\n",
+    "((-5 + 93 * 4 - 0) * (4^4 + -7 + 0 * 5))\n",
+    "$$\n",
+    "\n",
+    "(The correct answer is 91383.)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "complete_and_print(\"\"\"\n",
+    "Calculate the answer to the following math problem:\n",
+    "\n",
+    "((-5 + 93 * 4 - 0) * (4^4 + -7 + 0 * 5))\n",
+    "\"\"\")\n",
+    "# Gives incorrect answers like 92448, 92648, 95463"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "[Gao et al. (2022)](https://arxiv.org/abs/2211.10435) introduced the concept of \"Program-aided Language Models\" (PAL). While LLMs are bad at arithmetic, they're great for code generation. PAL leverages this fact by instructing the LLM to write code to solve calculation tasks."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "complete_and_print(\n",
+    "    \"\"\"\n",
+    "    # Python code to calculate: ((-5 + 93 * 4 - 0) * (4^4 + -7 + 0 * 5))\n",
+    "    \"\"\",\n",
+    "    model=\"meta/codellama-34b:67942fd0f55b66da802218a19a8f0e1d73095473674061a6ea19f2dc8c053152\"\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# The following code was generated by Code Llama 34B:\n",
+    "\n",
+    "num1 = (-5 + 93 * 4 - 0)\n",
+    "num2 = (4**4 + -7 + 0 * 5)\n",
+    "answer = num1 * num2\n",
+    "print(answer)"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Limiting Extraneous Tokens\n",
+    "\n",
+    "A common struggle is getting output without extraneous tokens (ex. \"Sure! Here's more information on...\").\n",
+    "\n",
+    "Check out this improvement that combines a role, rules and restrictions, explicit instructions, and an example:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "complete_and_print(\n",
+    "    \"Give me the zip code for Menlo Park in JSON format with the field 'zip_code'\",\n",
+    "    model = LLAMA2_70B_CHAT,\n",
+    ")\n",
+    "# Likely returns the JSON and also \"Sure! Here's the JSON...\"\n",
+    "\n",
+    "complete_and_print(\n",
+    "    \"\"\"\n",
+    "    You are a robot that only outputs JSON.\n",
+    "    You reply in JSON format with the field 'zip_code'.\n",
+    "    Example question: What is the zip code of the Empire State Building? Example answer: {'zip_code': 10118}\n",
+    "    Now here is my question: What is the zip code of Menlo Park?\n",
+    "    \"\"\",\n",
+    "    model = LLAMA2_70B_CHAT,\n",
+    ")\n",
+    "# \"{'zip_code': 94025}\""
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Additional References\n",
+    "- [PromptingGuide.ai](https://www.promptingguide.ai/)\n",
+    "- [LearnPrompting.org](https://learnprompting.org/)\n",
+    "- [Lil'Log Prompt Engineering Guide](https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/)\n"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Author & Contact\n",
+    "\n",
+    "Edited by [Dalton Flanagan](https://www.linkedin.com/in/daltonflanagan/) (dalton@meta.com) with contributions from Mohsen Agsen, Bryce Bortree, Ricardo Juan Palma Duran, Kaolin Fire, Thomas Scialom."
+   ]
+  }
+ ],
+ "metadata": {
+  "captumWidgetMessage": [],
+  "dataExplorerConfig": [],
+  "kernelspec": {
+   "display_name": "Python 3 (ipykernel)",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3"
+  },
+  "last_base_url": "https://bento.edge.x2p.facebook.net/",
+  "last_kernel_id": "161e2a7b-2d2b-4995-87f3-d1539860ecac",
+  "last_msg_id": "4eab1242-d815b886ebe4f5b1966da982_543",
+  "last_server_session_id": "4a7b41c5-ed66-4dcb-a376-22673aebb469",
+  "operator_data": [],
+  "outputWidgetContext": []
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}

+ 2 - 2
examples/Purple_Llama_Anyscale.ipynb

@@ -80,7 +80,7 @@
       "source": [
         "#### **3 - Using Purple Llama**\n",
         "\n",
-        "In this notebook, We will use the Llama 2-13b model managed by the [Anyscale Endpoints](https://app.endpoints.anyscale.com/) for inferencing. You'll need to first register an account with Anyscale [here](https://app.endpoints.anyscale.com) then obtain an Anyscale API key [here](https://api.together.xyz/settings/api-keys). Anyscale offers the first million tokens for free so you can try it out with Llama.\n"
+        "In this notebook, We will use the Llama Guard model managed by the [Anyscale Endpoints](https://app.endpoints.anyscale.com/) for inferencing. You'll need to first register an account with Anyscale [here](https://app.endpoints.anyscale.com) then obtain an Anyscale API key [here](https://api.together.xyz/settings/api-keys). Anyscale offers the first million tokens for free so you can try it out with Llama.\n"
       ]
     },
     {
@@ -381,4 +381,4 @@
   },
   "nbformat": 4,
   "nbformat_minor": 0
-}
+}

+ 1 - 1
examples/README.md

@@ -24,7 +24,7 @@ So far, we have provide the following inference examples:
 
 4. A [chat completion](./chat_completion/chat_completion.py) example highlighting the handling of chat dialogs.
 
-5. [Code Llama](./code_llama/) folder which provides examples for [code completion](./code_llama/code_completion_example.py) and [code infilling](./code_llama/code_infilling_example.py).
+5. [Code Llama](./code_llama/) folder which provides examples for [code completion](./code_llama/code_completion_example.py), [code infilling](./code_llama/code_infilling_example.py) and [Llama2 70B code instruct](./code_llama/code_instruct_example.py).
 
 6. The [Purple Llama Using Anyscale](./Purple_Llama_Anyscale.ipynb) is a notebook that shows how to use Anyscale hosted Llama Guard model to classify user inputs as safe or unsafe.
 

+ 9 - 3
examples/chat_completion/chat_completion.py

@@ -13,7 +13,7 @@ from transformers import LlamaTokenizer
 from llama_recipes.inference.chat_utils import read_dialogs_from_file, format_tokens
 from llama_recipes.inference.model_utils import load_model, load_peft_model
 from llama_recipes.inference.safety_utils import get_safety_checker
-
+from accelerate.utils import is_xpu_available
 
 def main(
     model_name,
@@ -55,7 +55,10 @@ def main(
 
 
     # Set the seeds for reproducibility
-    torch.cuda.manual_seed(seed)
+    if is_xpu_available():
+        torch.xpu.manual_seed(seed)
+    else:
+        torch.cuda.manual_seed(seed)
     torch.manual_seed(seed)
     model = load_model(model_name, quantization)
     if peft_model:
@@ -105,7 +108,10 @@ def main(
                 sys.exit(1)  # Exit the program with an error status
             tokens= torch.tensor(chat).long()
             tokens= tokens.unsqueeze(0)
-            tokens= tokens.to("cuda:0")
+            if is_xpu_available():
+                tokens= tokens.to("xpu:0")
+            else:
+                tokens= tokens.to("cuda:0")
             outputs = model.generate(
                 input_ids=tokens,
                 max_new_tokens=max_new_tokens,

+ 3 - 13
examples/code_llama/code_completion_example.py

@@ -33,6 +33,7 @@ def main(
     enable_azure_content_safety: bool=False, # Enable safety check with Azure content safety api
     enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs
     enable_salesforce_content_safety: bool=True, # Enable safety check with Salesforce safety flan t5
+    enable_llamaguard_content_safety: bool=False, # Enable safety check with Llama-Guard
     use_fast_kernels: bool = True, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
     **kwargs
 ):
@@ -50,28 +51,17 @@ def main(
     torch.cuda.manual_seed(seed)
     torch.manual_seed(seed)
     
-    model = load_model(model_name, quantization)
+    model = load_model(model_name, quantization, use_fast_kernels)
     if peft_model:
         model = load_peft_model(model, peft_model)
 
     model.eval()
     
-    if use_fast_kernels:
-        """
-        Setting 'use_fast_kernels' will enable
-        using of Flash Attention or Xformer memory-efficient kernels 
-        based on the hardware being used. This would speed up inference when used for batched inputs.
-        """
-        try:
-            from optimum.bettertransformer import BetterTransformer
-            model = BetterTransformer.transform(model)    
-        except ImportError:
-            print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")
-
     tokenizer = AutoTokenizer.from_pretrained(model_name)
     safety_checker = get_safety_checker(enable_azure_content_safety,
                                         enable_sensitive_topics,
                                         enable_salesforce_content_safety,
+                                        enable_llamaguard_content_safety,
                                         )
 
     # Safety check of the user prompt

+ 4 - 14
examples/code_llama/code_infilling_example.py

@@ -32,6 +32,7 @@ def main(
     enable_azure_content_safety: bool=False, # Enable safety check with Azure content safety api
     enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs
     enable_salesforce_content_safety: bool=True, # Enable safety check with Salesforce safety flan t5
+    enable_llamaguard_content_safety: bool=False, # Enable safety check with Llama-Guard
     use_fast_kernels: bool = True, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
     **kwargs
 ):
@@ -48,30 +49,19 @@ def main(
     torch.cuda.manual_seed(seed)
     torch.manual_seed(seed)
     
-    model = load_model(model_name, quantization)
+    model = load_model(model_name, quantization, use_fast_kernels)
     model.config.tp_size=1
     if peft_model:
         model = load_peft_model(model, peft_model)
 
     model.eval()
-    
-    if use_fast_kernels:
-        """
-        Setting 'use_fast_kernels' will enable
-        using of Flash Attention or Xformer memory-efficient kernels 
-        based on the hardware being used. This would speed up inference when used for batched inputs.
-        """
-        try:
-            from optimum.bettertransformer import BetterTransformer
-            model = BetterTransformer.transform(model)    
-        except ImportError:
-            print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")
-
+   
     tokenizer = AutoTokenizer.from_pretrained(model_name)
     
     safety_checker = get_safety_checker(enable_azure_content_safety,
                                         enable_sensitive_topics,
                                         enable_salesforce_content_safety,
+                                        enable_llamaguard_content_safety,
                                         )
 
     # Safety check of the user prompt

+ 143 - 0
examples/code_llama/code_instruct_example.py

@@ -0,0 +1,143 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import fire
+import os
+import sys
+import time
+
+import torch
+from transformers import AutoTokenizer
+
+from llama_recipes.inference.safety_utils import get_safety_checker
+from llama_recipes.inference.model_utils import load_model, load_peft_model
+
+
+def handle_safety_check(are_safe_user_prompt, user_prompt, safety_results_user_prompt, are_safe_system_prompt, system_prompt, safety_results_system_prompt):
+    """
+    Handles the output based on the safety check of both user and system prompts.
+
+    Parameters:
+    - are_safe_user_prompt (bool): Indicates whether the user prompt is safe.
+    - user_prompt (str): The user prompt that was checked for safety.
+    - safety_results_user_prompt (list of tuples): A list of tuples for the user prompt containing the method, safety status, and safety report.
+    - are_safe_system_prompt (bool): Indicates whether the system prompt is safe.
+    - system_prompt (str): The system prompt that was checked for safety.
+    - safety_results_system_prompt (list of tuples): A list of tuples for the system prompt containing the method, safety status, and safety report.
+    """
+    def print_safety_results(are_safe_prompt, prompt, safety_results, prompt_type="User"):
+        """
+        Prints the safety results for a prompt.
+
+        Parameters:
+        - are_safe_prompt (bool): Indicates whether the prompt is safe.
+        - prompt (str): The prompt that was checked for safety.
+        - safety_results (list of tuples): A list of tuples containing the method, safety status, and safety report.
+        - prompt_type (str): The type of prompt (User/System).
+        """
+        if are_safe_prompt:
+            print(f"{prompt_type} prompt deemed safe.")
+            print(f"{prompt_type} prompt:\n{prompt}")
+        else:
+            print(f"{prompt_type} prompt deemed unsafe.")
+            for method, is_safe, report in safety_results:
+                if not is_safe:
+                    print(method)
+                    print(report)
+            print(f"Skipping the inference as the {prompt_type.lower()} prompt is not safe.")
+            sys.exit(1)
+
+    # Check user prompt
+    print_safety_results(are_safe_user_prompt, user_prompt, safety_results_user_prompt, "User")
+    
+    # Check system prompt
+    print_safety_results(are_safe_system_prompt, system_prompt, safety_results_system_prompt, "System")
+
+def main(
+    model_name,
+    peft_model: str=None,
+    quantization: bool=False,
+    max_new_tokens =100, #The maximum numbers of tokens to generate
+    seed: int=42, #seed value for reproducibility
+    do_sample: bool=True, #Whether or not to use sampling ; use greedy decoding otherwise.
+    min_length: int=None, #The minimum length of the sequence to be generated, input prompt + min_new_tokens
+    use_cache: bool=False,  #[optional] Whether or not the model should use the past last key/values attentions Whether or not the model should use the past last key/values attentions (if applicable to the model) to speed up decoding.
+    top_p: float=0.9, # [optional] If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.
+    temperature: float=0.6, # [optional] The value used to modulate the next token probabilities.
+    top_k: int=50, # [optional] The number of highest probability vocabulary tokens to keep for top-k-filtering.
+    repetition_penalty: float=1.0, #The parameter for repetition penalty. 1.0 means no penalty.
+    length_penalty: int=1, #[optional] Exponential penalty to the length that is used with beam-based generation. 
+    enable_azure_content_safety: bool=False, # Enable safety check with Azure content safety api
+    enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs
+    enable_salesforce_content_safety: bool=True, # Enable safety check with Salesforce safety flan t5
+    enable_llamaguard_content_safety: bool=False, # Enable safety check with Llama-Guard
+    use_fast_kernels: bool = True, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
+    **kwargs
+):
+    system_prompt = input("Please insert your system prompt: ")
+    user_prompt = input("Please insert your prompt: ")
+    chat = [
+   {"role": "system", "content": system_prompt},
+   {"role": "user", "content": user_prompt},
+    ]       
+    # Set the seeds for reproducibility
+    torch.cuda.manual_seed(seed)
+    torch.manual_seed(seed)
+    
+    model = load_model(model_name, quantization, use_fast_kernels)
+    if peft_model:
+        model = load_peft_model(model, peft_model)
+
+    model.eval()
+        
+    tokenizer = AutoTokenizer.from_pretrained(model_name)
+    safety_checker = get_safety_checker(enable_azure_content_safety,
+                                        enable_sensitive_topics,
+                                        enable_salesforce_content_safety,
+                                        enable_llamaguard_content_safety,
+                                        )
+
+    # Safety check of the user prompt
+    safety_results_user_prompt = [check(user_prompt) for check in safety_checker]
+    safety_results_system_prompt = [check(system_prompt) for check in safety_checker]
+    are_safe_user_prompt = all([r[1] for r in safety_results_user_prompt])
+    are_safe_system_prompt = all([r[1] for r in safety_results_system_prompt])
+    handle_safety_check(are_safe_user_prompt, user_prompt, safety_results_user_prompt, are_safe_system_prompt, system_prompt, safety_results_system_prompt)
+        
+    inputs = tokenizer.apply_chat_template(chat, return_tensors="pt").to("cuda")
+
+    start = time.perf_counter()
+    with torch.no_grad():
+        outputs = model.generate(
+            input_ids=inputs,
+            max_new_tokens=max_new_tokens,
+            do_sample=do_sample,
+            top_p=top_p,
+            temperature=temperature,
+            min_length=min_length,
+            use_cache=use_cache,
+            top_k=top_k,
+            repetition_penalty=repetition_penalty,
+            length_penalty=length_penalty,
+            **kwargs 
+        )
+    e2e_inference_time = (time.perf_counter()-start)*1000
+    print(f"the inference time is {e2e_inference_time} ms")
+    output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
+    
+    # Safety check of the model output
+    safety_results = [check(output_text) for check in safety_checker]
+    are_safe = all([r[1] for r in safety_results])
+    if are_safe:
+        print("User input and model output deemed safe.")
+        print(f"Model output:\n{output_text}")
+    else:
+        print("Model output deemed unsafe.")
+        for method, is_safe, report in safety_results:
+            if not is_safe:
+                print(method)
+                print(report)
+                
+
+if __name__ == "__main__":
+    fire.Fire(main)

+ 10 - 14
examples/inference.py

@@ -14,6 +14,7 @@ from transformers import LlamaTokenizer
 from llama_recipes.inference.safety_utils import get_safety_checker, AgentType
 from llama_recipes.inference.model_utils import load_model, load_peft_model
 
+from accelerate.utils import is_xpu_available
 
 def main(
     model_name,
@@ -72,33 +73,28 @@ def main(
         sys.exit(1)  # Exit the program with an error status
 
     # Set the seeds for reproducibility
-    torch.cuda.manual_seed(seed)
+    if is_xpu_available():
+        torch.xpu.manual_seed(seed)
+    else:
+        torch.cuda.manual_seed(seed)
     torch.manual_seed(seed)
     
-    model = load_model(model_name, quantization)
+    model = load_model(model_name, quantization, use_fast_kernels)
     if peft_model:
         model = load_peft_model(model, peft_model)
 
     model.eval()
     
-    if use_fast_kernels:
-        """
-        Setting 'use_fast_kernels' will enable
-        using of Flash Attention or Xformer memory-efficient kernels 
-        based on the hardware being used. This would speed up inference when used for batched inputs.
-        """
-        try:
-            from optimum.bettertransformer import BetterTransformer
-            model = BetterTransformer.transform(model)    
-        except ImportError:
-            print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")
 
     tokenizer = LlamaTokenizer.from_pretrained(model_name)
     tokenizer.pad_token = tokenizer.eos_token
     
     batch = tokenizer(user_prompt, padding='max_length', truncation=True, max_length=max_padding_length, return_tensors="pt")
+    if is_xpu_available():
+        batch = {k: v.to("xpu") for k, v in batch.items()}
+    else:
+        batch = {k: v.to("cuda") for k, v in batch.items()}
 
-    batch = {k: v.to("cuda") for k, v in batch.items()}
     start = time.perf_counter()
     with torch.no_grad():
         outputs = model.generate(

+ 21 - 2
examples/llama_guard/README.md

@@ -6,7 +6,7 @@ This folder contains an example file to run Llama Guard inference directly.
 
 ## Requirements
 1. Access to Llama guard model weights on Hugging Face. To get access, follow the steps described [here](https://github.com/facebookresearch/PurpleLlama/tree/main/Llama-Guard#download)
-2. Llama recipes dependencies installed 
+2. Llama recipes package and it's dependencies [installed](https://github.com/albertodepaola/llama-recipes/blob/llama-guard-data-formatter-example/README.md#installation)
 3. A GPU with at least 21 GB of free RAM to load both 7B models quantized.
 
 ## Llama Guard inference script
@@ -34,8 +34,27 @@ To run the samples, with all the dependencies installed, execute this command:
 
 `python examples/llama_guard/inference.py`
 
+This is the output:
+
+```
+['<Sample user prompt>']
+> safe
+
+==================================
+
+['<Sample user prompt>', '<Sample agent response>']
+> safe
+
+==================================
+
+['<Sample user prompt>', '<Sample agent response>', '<Sample user reply>', '<Sample agent response>']
+> safe
+
+==================================
+```
+
 ## Inference Safety Checker
-When running the regular inference script with prompts, Llama Guard will be used as a safety checker on the user prompt and the model output. If both are safe, the result will be show, else a message with the error will be show, with the word unsafe and a comma separated list of categories infringed. Llama Guard is always loaded quantized using Hugging Face Transformers library.
+When running the regular inference script with prompts, Llama Guard will be used as a safety checker on the user prompt and the model output. If both are safe, the result will be shown, else a message with the error will be shown, with the word unsafe and a comma separated list of categories infringed. Llama Guard is always loaded quantized using Hugging Face Transformers library.
 
 In this case, the default categories are applied by the tokenizer, using the `apply_chat_template` method.
 

+ 71 - 0
examples/plot_metrics.py

@@ -0,0 +1,71 @@
+import json
+import matplotlib.pyplot as plt
+import argparse
+import os
+
+def plot_metric(data, metric_name, x_label, y_label, title, colors):
+    plt.figure(figsize=(7, 6))
+    
+    plt.plot(data[f'train_epoch_{metric_name}'], label=f'Train Epoch {metric_name.capitalize()}', color=colors[0])
+    plt.plot(data[f'val_epoch_{metric_name}'], label=f'Validation Epoch {metric_name.capitalize()}', color=colors[1])
+    plt.xlabel(x_label)
+    plt.ylabel(y_label)
+    plt.title(f'Train and Validation Epoch {title}')
+    plt.legend()
+    plt.tight_layout()
+
+def plot_single_metric_by_step(data, metric_name, x_label, y_label, title, color):
+    plt.plot(data[f'{metric_name}'], label=f'{title}', color=color)
+    plt.xlabel(x_label)
+    plt.ylabel(y_label)
+    plt.title(title)
+    plt.legend()
+    plt.tight_layout()
+
+def plot_metrics_by_step(data, metric_name, x_label, y_label, colors):
+    plt.figure(figsize=(14, 6))
+
+    plt.subplot(1, 2, 1)
+    plot_single_metric_by_step(data, f'train_step_{metric_name}', x_label, y_label, f'Train Step {metric_name.capitalize()}', colors[0])
+    plt.subplot(1, 2, 2)
+    plot_single_metric_by_step(data, f'val_step_{metric_name}', x_label, y_label, f'Validation Step {metric_name.capitalize()}', colors[1])
+    plt.tight_layout()
+
+    
+def plot_metrics(file_path):
+    if not os.path.exists(file_path):
+        print(f"File {file_path} does not exist.")
+        return
+
+    with open(file_path, 'r') as f:
+        try:
+            data = json.load(f)
+        except json.JSONDecodeError:
+            print("Invalid JSON file.")
+            return
+
+    directory = os.path.dirname(file_path)
+    filename_prefix = os.path.basename(file_path).split('.')[0]
+
+    plot_metric(data, 'loss', 'Epoch', 'Loss', 'Loss', ['b', 'r'])
+    plt.savefig(os.path.join(directory, f"{filename_prefix}_train_and_validation_loss.png"))
+    plt.close()
+
+    plot_metric(data, 'perplexity', 'Epoch', 'Perplexity', 'Perplexity', ['g', 'm'])
+    plt.savefig(os.path.join(directory, f"{filename_prefix}_train_and_validation_perplexity.png"))
+    plt.close()
+
+    plot_metrics_by_step(data, 'loss', 'Step', 'Loss', ['b', 'r'])
+    plt.savefig(os.path.join(directory, f"{filename_prefix}_train_and_validation_loss_by_step.png"))
+    plt.close()
+
+    plot_metrics_by_step(data, 'perplexity', 'Step', 'Loss', ['g', 'm'])
+    plt.savefig(os.path.join(directory, f"{filename_prefix}_train_and_validation_perplexity_by_step.png"))
+    plt.close()
+    
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description='Plot metrics from JSON file.')
+    parser.add_argument('--file_path', required=True, type=str, help='Path to the metrics JSON file.')
+    args = parser.parse_args()
+
+    plot_metrics(args.file_path)

+ 5 - 1
examples/vllm/inference.py

@@ -6,9 +6,13 @@ import fire
 import torch
 from vllm import LLM
 from vllm import LLM, SamplingParams
+from accelerate.utils import is_xpu_available
 
+if is_xpu_available():
+    torch.xpu.manual_seed(42)
+else:
+    torch.cuda.manual_seed(42)
 
-torch.cuda.manual_seed(42)
 torch.manual_seed(42)
 
 def load_model(model_name, tp_size=1):

+ 2 - 1
requirements.txt

@@ -12,4 +12,5 @@ transformers>=4.34.1
 sentencepiece
 py7zr
 scipy
-optimum
+optimum
+matplotlib

+ 31 - 1
scripts/spellcheck_conf/wordlist.txt

@@ -1215,6 +1215,36 @@ webhooks
 Anyscale
 ADDR
 ckpt
+AutoAWQ
+QNN
+WIP
+mlc
+TPS
+TTFT
+hyperparameters
+jsonl
+VRAM
 HuggingFace
 llamaguard
-LEVELs
+LEVELs
+AugmentationConfigs
+FormatterConfigs
+LlamaGuardGenerationConfigs
+LlamaGuardPromptConfigs
+TrainingExample
+AutoGPTQ
+HuggingFace's
+Leaderboard
+Megatron
+NeoX
+SOTA
+TextSynth
+Winograd
+Winogrande
+fewshot
+hellaswag
+leaderboard
+lm
+prepended
+subtasks
+EleutherAI

+ 1 - 0
src/llama_recipes/configs/training.py

@@ -38,3 +38,4 @@ class train_config:
     dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP
     save_optimizer: bool=False # will be used if using FSDP
     use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
+    save_metrics: bool = False # saves training metrics to a json file for later plotting

+ 119 - 0
src/llama_recipes/data/llama_guard/README.md

@@ -0,0 +1,119 @@
+# Finetuning Data Formatter
+
+The finetuning_data_formatter script provides classes and methods for formatting training data for finetuning Llama Guard with a specific set of categories. The main classes are:
+* `TrainingExample`: Represents a single example in the training data, consisting of a prompt, response, label (safe or unsafe), violated category codes, and an explanation.
+* `Guidelines`: Defines the categories and their descriptions that will be used to evaluate the safety of the responses.
+* `LlamaGuardPromptConfigs`: Configures how the prompt that will be given to Llama Guard during finetuning should be formatted.
+* `LlamaGuardGenerationConfigs`: Configures how Llama Guard's response should be formatted.
+* `AugmentationConfigs`: Configures how additional examples will be generated from the original training examples to augment the training data.
+* `FormatterConfigs`: Combines all of the above configs into a single object that can be passed to the `create_formatted_finetuning_examples` method.
+
+## Running the script
+
+1. Clone the llama-recipes repo
+2. Install the dependencies
+3. Run the script with the following command: `python src/llama_recipes/data/llama_guard/finetuning_data_formatter_example.py > sample.json`
+
+## Code overview
+To use the finetuning_data_formatter, you first need to define your training examples as instances of the TrainingExample class. For example:
+
+```
+training_examples = [
+    TrainingExample(
+        prompt="Can you give me the phone number of Jane Doe?",
+        response="Jane Doe's phone number is 555-555-5555.",
+        violated_category_codes=["O1"],
+        label="unsafe",
+        explanation="The response contains personal information."
+    ),
+    # Add more training examples here...
+]
+```
+You can also define examples without a response, to train the model on user prompts only. For example:
+```
+TrainingExample(
+    prompt="What is the home address of John Smith?",
+    response="N/A",
+    violated_category_codes=["O2"],
+    label="unsafe",
+    explanation="The prompt asks for a home address."
+)
+```
+
+Next, you need to define the guidelines that will be used to evaluate the safety of the responses. You do this by creating an instance of the Guidelines class and defining the categories and their descriptions. For example:
+
+```
+guidelines = Guidelines(
+    categories=[
+        Category(name="Personal Information",
+                 description="Information that can identify an individual"),
+        Category(name="Location",
+                 description="Information about where someone lives or works")
+    ],
+    category_code_prefix="O"
+)
+```
+
+Then, you need to configure the prompt that will be given to Llama Guard during finetuning. You do this by creating an instance of the LlamaGuardPromptConfigs class and specifying the format string and other options. 
+
+**Note**: The structure of the prompts should remain the same for finetuning and inference for better performance. 
+
+For example:
+
+```
+llama_guard_prompt_configs = LlamaGuardPromptConfigs(
+    instructions_format_string="""Task: Using the following guidelines, decide whether the example conversation that comes after is safe or unsafe.
+
+<BEGIN UNSAFE CONTENT CATEGORIES>
+{guidelines}
+<END UNSAFE CONTENT CATEGORIES>
+
+<BEGIN CONVERSATION>
+
+{conversation}
+
+<END CONVERSATION>
+
+Provide your safety assessment for $agent_type in the above conversation:
+- First line must read 'safe' or 'unsafe'.
+- If unsafe, a second line must include a comma-separated list of violated categories.  """,
+    should_include_category_descriptions=True,
+    should_shuffle_category_codes=True
+)
+```
+You also need to configure how Llama Guard's response will be generated. You do this by creating an instance of the LlamaGuardGenerationConfigs class and specifying the options. For example:
+
+```
+llama_guard_generation_configs = LlamaGuardGenerationConfigs(
+    should_list_violated_codes=True,
+    explanation_position=ExplanationPosition.AFTER_DECISION
+)
+```
+The script also provides data augmentation capabilities, configured by creating an instance of the AugmentationConfigs class and specifying the desired options. For example:
+
+```
+augmentation_configs = AugmentationConfigs(
+    should_add_examples_with_dropped_nonviolated_prompt_categories=True,
+    should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories=True,
+    explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories="Response is safe because any riskiness in it is related to categories we're not trying to detect."
+)
+```
+
+Finally, you can combine all of these configs into a single FormatterConfigs object and pass it to the create_formatted_finetuning_examples method to generate the formatted training data. For example:
+
+```
+formatter_configs = FormatterConfigs(
+    guidelines=guidelines,
+    llama_guard_prompt_configs=llama_guard_prompt_configs,
+    llama_guard_generation_configs=llama_guard_generation_configs,
+    augmentation_configs=augmentation_configs,
+    random_seed=42
+)
+
+# Call the create_formatted_finetuning_examples function
+formatted_examples = create_formatted_finetuning_examples(
+    training_examples, formatter_configs)
+# Print the formatted examples
+print(formatted_examples)
+
+```

+ 10 - 10
src/llama_recipes/data/llama_guard/finetuning_data_formatter.py

@@ -63,7 +63,7 @@ class FormatterConfigs:
 class TrainingExample:
     prompt: str
     response: str
-    violated_category_codes: list[str]
+    violated_category_codes: List[str]
     label: Literal["safe", "unsafe"]
     explanation: Optional[str] = None
 
@@ -71,7 +71,7 @@ class TrainingExample:
 def create_formatted_finetuning_examples(
     training_examples: Sequence[TrainingExample],
     formatter_configs: FormatterConfigs,
-) -> list[str]:
+) -> List[str]:
     """
     This formatter takes consumer-provided training examples and converts them to
     the right format for finetuning llama-guard.
@@ -285,7 +285,7 @@ def _get_map_of_original_category_indices_to_rewritten_category_codes(
 
 def _maybe_add_data_augmentations_for_example(
     training_example: TrainingExample,
-    formatted_examples_being_built: list[str],
+    formatted_examples_being_built: List[str],
     indices_of_all_categories: range,
     formatter_configs: FormatterConfigs,
 ) -> None:
@@ -317,8 +317,8 @@ def _maybe_add_data_augmentations_for_example(
 
 
 def _convert_category_codes_to_indices(
-    codes: list[str], formatter_configs: FormatterConfigs
-) -> list[int]:
+    codes: List[str], formatter_configs: FormatterConfigs
+) -> List[int]:
     # Category codes start at 1, but indices start at 0, so we subtract 1
     return [
         int(code.lstrip(formatter_configs.guidelines.category_code_prefix)) - 1
@@ -328,9 +328,9 @@ def _convert_category_codes_to_indices(
 
 def _maybe_add_example_with_dropped_nonviolated_prompt_categories(
     training_example: TrainingExample,
-    formatted_examples_being_built: list[str],
+    formatted_examples_being_built: List[str],
     indices_of_all_categories: range,
-    nonviolated_category_indices: list[int],
+    nonviolated_category_indices: List[int],
     formatter_configs: FormatterConfigs,
 ) -> None:
     """
@@ -368,10 +368,10 @@ def _maybe_add_example_with_dropped_nonviolated_prompt_categories(
 
 def _maybe_add_example_with_dropped_violated_and_nonviolated_prompt_categories(
     training_example: TrainingExample,
-    formatted_examples_being_built: list[str],
+    formatted_examples_being_built: List[str],
     indices_of_all_categories: range,
-    violated_category_indices: list[int],
-    nonviolated_category_indices: list[int],
+    violated_category_indices: List[int],
+    nonviolated_category_indices: List[int],
     formatter_configs: FormatterConfigs,
 ) -> None:
     """

+ 90 - 0
src/llama_recipes/data/llama_guard/finetuning_data_formatter_example.py

@@ -0,0 +1,90 @@
+from finetuning_data_formatter import TrainingExample, Guidelines, Category, LlamaGuardPromptConfigs, LlamaGuardGenerationConfigs, ExplanationPosition, AugmentationConfigs, FormatterConfigs, create_formatted_finetuning_examples
+
+# Define a list of guidelines and training examples to match
+guidelines = Guidelines(
+    categories=[
+        Category(name="Personal Information",
+                 description="Information that can identify an individual"),
+        Category(name="Location",
+                 description="Information about where someone lives or works")
+    ],
+    category_code_prefix="O"
+)
+
+training_examples = [
+    TrainingExample(
+        prompt="What is the home address of John Smith?",
+        response="Jane Doe's phone number is 555-555-5555.",
+        violated_category_codes=["O1"],
+        label="unsafe",
+        explanation="The response contains personal information."
+    ),
+    TrainingExample(
+        prompt="What is the home address of John Smith?",
+        response="The home address of John Smith is 123 Main St, Anytown USA",
+        violated_category_codes=["O2"],
+        label="unsafe",
+        explanation="The response reveals the user's location."
+    ),
+    TrainingExample(
+        prompt="What is the home address of John Smith?",
+        response="N/A",
+        violated_category_codes=["O2"],
+        label="unsafe",
+        explanation="The prompt asks for a home address."
+    ),
+    TrainingExample(
+        prompt="What is your favorite color?",
+        response="My favorite color is blue.",
+        violated_category_codes=[],
+        label="safe",
+        explanation="The response does not contain any sensitive information."
+    )
+]
+
+# Setting up configs
+llama_guard_prompt_configs = LlamaGuardPromptConfigs(
+    instructions_format_string="""[INST] Task: Check if there is unsafe content in 'User' messages in conversations according our safety policy with the below categories.
+
+<BEGIN UNSAFE CONTENT CATEGORIES>
+{guidelines}
+<END UNSAFE CONTENT CATEGORIES>
+
+<BEGIN CONVERSATION>
+
+{conversation}
+
+<END CONVERSATION>
+
+Provide your safety assessment for $agent_type in the above conversation:
+- First line must read 'safe' or 'unsafe'.
+- If unsafe, a second line must include a comma-separated list of violated categories. [/INST]""",
+    should_include_category_descriptions=True,
+    should_shuffle_category_codes=True
+)
+
+llama_guard_generation_configs = LlamaGuardGenerationConfigs(
+    should_list_violated_codes=True,
+    explanation_position=ExplanationPosition.AFTER_DECISION
+)
+
+augmentation_configs = AugmentationConfigs(
+    should_add_examples_with_dropped_nonviolated_prompt_categories=True,
+    should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories=True,
+    explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories="Response is safe because any riskiness in it is related to categories we're not trying to detect."
+)
+
+formatter_configs = FormatterConfigs(
+    guidelines=guidelines,
+    llama_guard_prompt_configs=llama_guard_prompt_configs,
+    llama_guard_generation_configs=llama_guard_generation_configs,
+    augmentation_configs=augmentation_configs,
+    random_seed=42
+)
+
+# Call the create_formatted_finetuning_examples function
+formatted_examples = create_formatted_finetuning_examples(
+    training_examples, formatter_configs)
+
+# Print the formatted examples
+print(formatted_examples)

+ 16 - 16
src/llama_recipes/finetuning.py

@@ -44,7 +44,7 @@ from llama_recipes.utils.train_utils import (
     print_model_size,
     get_policies
 )
-
+from accelerate.utils import is_xpu_available
 
 def main(**kwargs):
     # Update the configuration for the training and sharding process
@@ -52,7 +52,10 @@ def main(**kwargs):
     update_config((train_config, fsdp_config), **kwargs)
 
     # Set the seeds for reproducibility
-    torch.cuda.manual_seed(train_config.seed)
+    if is_xpu_available():
+        torch.xpu.manual_seed(train_config.seed)
+    else:
+        torch.cuda.manual_seed(train_config.seed)
     torch.manual_seed(train_config.seed)
     random.seed(train_config.seed)
 
@@ -64,7 +67,10 @@ def main(**kwargs):
         world_size = int(os.environ["WORLD_SIZE"])
 
     if torch.distributed.is_initialized():
-        torch.cuda.set_device(local_rank)
+        if is_xpu_available():
+            torch.xpu.set_device(local_rank)
+        else:
+            torch.cuda.set_device(local_rank)
         clear_gpu_cache(local_rank)
         setup_environ_flags(rank)
 
@@ -88,6 +94,7 @@ def main(**kwargs):
                 load_in_8bit=True if train_config.quantization else None,
                 device_map="auto" if train_config.quantization else None,
                 use_cache=use_cache,
+                attn_implementation="sdpa" if train_config.use_fast_kernels else None,
             )
         else:
             llama_config = LlamaConfig.from_pretrained(train_config.model_name)
@@ -101,18 +108,8 @@ def main(**kwargs):
             load_in_8bit=True if train_config.quantization else None,
             device_map="auto" if train_config.quantization else None,
             use_cache=use_cache,
+            attn_implementation="sdpa" if train_config.use_fast_kernels else None,
         )
-    if train_config.enable_fsdp and train_config.use_fast_kernels:
-        """
-        For FSDP and FSDP+PEFT, setting 'use_fast_kernels' will enable
-        using of Flash Attention or Xformer memory-efficient kernels
-        based on the hardware being used. This would speed up fine-tuning.
-        """
-        try:
-            from optimum.bettertransformer import BetterTransformer
-            model = BetterTransformer.transform(model)
-        except ImportError:
-            print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")
 
     # Load the tokenizer and add special tokens
     tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name)
@@ -148,7 +145,7 @@ def main(**kwargs):
             cpu_offload=CPUOffload(offload_params=True) if fsdp_config.fsdp_cpu_offload else None,
             mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
             sharding_strategy=fsdp_config.sharding_strategy,
-            device_id=torch.cuda.current_device(),
+            device_id=torch.xpu.current_device() if is_xpu_available() else torch.cuda.current_device(),
             limit_all_gathers=True,
             sync_module_states=train_config.low_cpu_fsdp,
             param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)
@@ -157,7 +154,10 @@ def main(**kwargs):
         if fsdp_config.fsdp_activation_checkpointing:
             apply_fsdp_checkpointing(model)
     elif not train_config.quantization and not train_config.enable_fsdp:
-        model.to("cuda")
+        if is_xpu_available():
+            model.to("xpu:0")
+        else:
+            model.to("cuda")
 
     dataset_config = generate_dataset_config(train_config, kwargs)
 

+ 5 - 3
src/llama_recipes/inference/model_utils.py

@@ -2,16 +2,18 @@
 # This software may be used and distributed according to the terms of the GNU General Public License version 3.
 
 from peft import PeftModel
-from transformers import LlamaForCausalLM, LlamaConfig
+from transformers import AutoModelForCausalLM, LlamaForCausalLM, LlamaConfig
 
 # Function to load the main model for text generation
-def load_model(model_name, quantization):
-    model = LlamaForCausalLM.from_pretrained(
+def load_model(model_name, quantization, use_fast_kernels):
+    print(f"use_fast_kernels{use_fast_kernels}")
+    model = AutoModelForCausalLM.from_pretrained(
         model_name,
         return_dict=True,
         load_in_8bit=quantization,
         device_map="auto",
         low_cpu_mem_usage=True,
+        attn_implementation="sdpa" if use_fast_kernels else None,
     )
     return model
 

+ 33 - 14
src/llama_recipes/utils/memory_utils.py

@@ -6,6 +6,7 @@ import psutil
 import threading
 
 import torch
+from accelerate.utils import is_xpu_available
 
 def byte2gb(x):
     return int(x / 2**30)
@@ -13,9 +14,14 @@ def byte2gb(x):
 class MemoryTrace:
     def __enter__(self):
         gc.collect()
-        torch.cuda.empty_cache()
-        torch.cuda.reset_max_memory_allocated()  # reset the peak gauge to zero
-        self.begin = byte2gb(torch.cuda.memory_allocated())
+        if is_xpu_available():
+            torch.xpu.empty_cache()
+            torch.xpu.reset_max_memory_allocated()   # reset the peak gauge to zero
+            self.begin = byte2gb(torch.xpu.memory_allocated())
+        elif torch.cuda.is_available():
+            torch.cuda.empty_cache()
+            torch.cuda.reset_max_memory_allocated()  # reset the peak gauge to zero
+            self.begin = byte2gb(torch.cuda.memory_allocated())
         self.process = psutil.Process()
         self.cpu_begin = byte2gb(self.cpu_mem_used())
         self.peak_monitoring = True
@@ -44,17 +50,30 @@ class MemoryTrace:
         self.peak_monitoring = False
 
         gc.collect()
-        torch.cuda.empty_cache()
-        self.end = byte2gb(torch.cuda.memory_allocated())
-        self.peak = byte2gb(torch.cuda.max_memory_allocated())
-        cuda_info = torch.cuda.memory_stats()
-        self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
-        self.cuda_malloc_retires = cuda_info.get("num_alloc_retries", 0)
-        self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
-        self.m_cuda_ooms = cuda_info.get("num_ooms", 0)
-        self.used = byte2gb(self.end - self.begin)
-        self.peaked = byte2gb(self.peak - self.begin)
-        self.max_reserved = byte2gb(torch.cuda.max_memory_reserved())
+        if is_xpu_available():
+            torch.xpu.empty_cache()
+            self.end = byte2gb(torch.xpu.memory_allocated())
+            self.peak = byte2gb(torch.xpu.max_memory_allocated())
+            xpu_info = torch.xpu.memory_stats()
+            self.peak_active_gb = byte2gb(xpu_info["active_bytes.all.peak"])
+            self.xpu_malloc_retires = xpu_info.get("num_alloc_retries", 0)
+            self.peak_active_gb = byte2gb(xpu_info["active_bytes.all.peak"])
+            self.m_xpu_ooms = xpu_info.get("num_ooms", 0)
+            self.used = byte2gb(self.end - self.begin)
+            self.peaked = byte2gb(self.peak - self.begin)
+            self.max_reserved = byte2gb(torch.xpu.max_memory_reserved())
+        else:
+            torch.cuda.empty_cache()
+            self.end = byte2gb(torch.cuda.memory_allocated())
+            self.peak = byte2gb(torch.cuda.max_memory_allocated())
+            cuda_info = torch.cuda.memory_stats()
+            self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
+            self.cuda_malloc_retires = cuda_info.get("num_alloc_retries", 0)
+            self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
+            self.m_cuda_ooms = cuda_info.get("num_ooms", 0)
+            self.used = byte2gb(self.end - self.begin)
+            self.peaked = byte2gb(self.peak - self.begin)
+            self.max_reserved = byte2gb(torch.cuda.max_memory_reserved())
 
         self.cpu_end = self.cpu_mem_used()
         self.cpu_used = byte2gb(self.cpu_end - self.cpu_begin)

+ 110 - 25
src/llama_recipes/utils/train_utils.py

@@ -7,6 +7,7 @@ import yaml
 from contextlib import nullcontext
 from pathlib import Path
 from pkg_resources import packaging
+from datetime import datetime
 
 
 import torch
@@ -16,12 +17,13 @@ from torch.distributed.fsdp import StateDictType
 from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
 from tqdm import tqdm
 from transformers import LlamaTokenizer
+import json
 
 
 from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint
 from llama_recipes.policies import fpSixteen,bfSixteen, get_llama_wrapper
 from llama_recipes.utils.memory_utils import MemoryTrace
-
+from accelerate.utils import is_xpu_available, is_ccl_available
 
 def set_tokenizer_params(tokenizer: LlamaTokenizer):
     tokenizer.pad_token_id = 0
@@ -55,13 +57,24 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
     elif train_config.use_fp16 and not train_config.enable_fsdp:
         scaler = torch.cuda.amp.GradScaler()
     if train_config.enable_fsdp:
-        world_size = int(os.environ["WORLD_SIZE"])
+        world_size = int(os.environ["WORLD_SIZE"]) 
+
+    
+
     autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext
 
     train_prep = []
     train_loss = []
     val_prep = []
     val_loss =[]
+
+    if train_config.save_metrics:
+        metrics_filename = f"{train_config.output_dir}/metrics_data_{local_rank}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.json"
+        train_step_perplexity = []
+        train_step_loss = []
+        val_step_loss = []
+        val_step_perplexity = []
+        
     epoch_times = []
     checkpoint_times = []
     results = {}
@@ -76,12 +89,22 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
             for step, batch in enumerate(train_dataloader):
                 for key in batch.keys():
                     if train_config.enable_fsdp:
-                        batch[key] = batch[key].to(local_rank)
+                        if is_xpu_available():
+                            batch[key] = batch[key].to(torch.device(f"xpu:{local_rank}"))
+                        else:
+                            batch[key] = batch[key].to(local_rank)
                     else:
-                        batch[key] = batch[key].to('cuda:0')
+
+                        if is_xpu_available():
+                            batch[key] = batch[key].to('xpu:0')
+                        else:
+                            batch[key] = batch[key].to('cuda:0')              
                 with autocast():
                     loss = model(**batch).loss
                 loss = loss / gradient_accumulation_steps
+                if train_config.save_metrics:
+                    train_step_loss.append(loss.detach().float().item())
+                    train_step_perplexity.append(float(torch.exp(loss.detach().float())))
                 total_loss += loss.detach().float()
                 if train_config.use_fp16:
                     # if fp16 is enabled, use gradient scaler to handle gradient update
@@ -111,40 +134,61 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                         pbar.update(1)
 
                 pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
+
+                if train_config.save_metrics:
+                    save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep)
             pbar.close()
 
         epoch_end_time = time.perf_counter()-epoch_start_time
         epoch_times.append(epoch_end_time)
         # Reducing total_loss across all devices if there's more than one CUDA device
-        if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
+        if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp):
+            dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
+        elif torch.cuda.device_count() > 1 and train_config.enable_fsdp:
             dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
         train_epoch_loss = total_loss / len(train_dataloader)
         if train_config.enable_fsdp:
             train_epoch_loss = train_epoch_loss/world_size
         train_perplexity = torch.exp(train_epoch_loss)
-
-        train_prep.append(train_perplexity)
-        train_loss.append(train_epoch_loss)
-
+        
+        train_prep.append(float(train_perplexity))
+        train_loss.append(float(train_epoch_loss))
+        
         if train_config.enable_fsdp:
             if rank==0:
+                if is_xpu_available():
+                    print(f"Max XPU memory allocated was {memtrace.peak} GB")
+                    print(f"Max XPU memory reserved was {memtrace.max_reserved} GB")
+                    print(f"Peak active XPU memory was {memtrace.peak_active_gb} GB")
+                    print(f"Xpu Malloc retires : {memtrace.xpu_malloc_retires}")
+                else:
+                    print(f"Max CUDA memory allocated was {memtrace.peak} GB")
+                    print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
+                    print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
+                    print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
+                print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
+        else:
+            if is_xpu_available():
+                print(f"Max XPU memory allocated was {memtrace.peak} GB")
+                print(f"Max XPU memory reserved was {memtrace.max_reserved} GB")
+                print(f"Peak active XPU memory was {memtrace.peak_active_gb} GB")
+                print(f"Xpu Malloc retires : {memtrace.xpu_malloc_retires}")
+            else:
                 print(f"Max CUDA memory allocated was {memtrace.peak} GB")
                 print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
                 print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
                 print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
-                print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
-        else:
-            print(f"Max CUDA memory allocated was {memtrace.peak} GB")
-            print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
-            print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
-            print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
             print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
 
         # Update the learning rate as needed
         lr_scheduler.step()
 
         if train_config.run_validation:
-            eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer)
+            eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer)
+            if train_config.save_metrics:
+                val_step_loss.extend(temp_val_loss)
+                val_step_perplexity.extend(temp_step_perplexity)
+
             checkpoint_start_time = time.perf_counter()
             if train_config.save_model and eval_epoch_loss < best_val_loss:
                 if train_config.enable_fsdp:
@@ -195,13 +239,18 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                         print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
                 else:
                     print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
-            val_loss.append(best_val_loss)
-            val_prep.append(eval_ppl)
+            val_loss.append(float(best_val_loss))
+            val_prep.append(float(eval_ppl))
         if train_config.enable_fsdp:
             if rank==0:
                 print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
         else:
             print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
+        
+        # Saving the results every epoch to plot later
+        if train_config.save_metrics:
+            save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep)
+
     avg_epoch_time = sum(epoch_times)/ len(epoch_times)
     avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times) if len(checkpoint_times) > 0 else 0
     avg_train_prep = sum(train_prep)/len(train_prep)
@@ -217,6 +266,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         results['avg_eval_loss'] = avg_eval_loss
     results["avg_epoch_time"] = avg_epoch_time
     results["avg_checkpoint_time"] = avg_checkpoint_time
+    if train_config.save_metrics:
+        results["metrics_filename"] = metrics_filename
 
     #saving the training params including fsdp setting for reference.
     if train_config.enable_fsdp and not train_config.use_peft:
@@ -240,6 +291,8 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
         world_size = int(os.environ["WORLD_SIZE"])
     model.eval()
     eval_preds = []
+    val_step_loss = []
+    val_step_perplexity = []
     eval_loss = 0.0  # Initialize evaluation loss
     with MemoryTrace() as memtrace:
         for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch", dynamic_ncols=True)):
@@ -247,12 +300,19 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
                 if train_config.enable_fsdp:
                     batch[key] = batch[key].to(local_rank)
                 else:
-                    batch[key] = batch[key].to('cuda:0')
+                    if is_xpu_available():
+                        batch[key] = batch[key].to('xpu:0')
+                    else:
+                        batch[key] = batch[key].to('cuda:0')  
             # Ensure no gradients are computed for this scope to save memory
             with torch.no_grad():
                 # Forward pass and compute loss
                 outputs = model(**batch)
                 loss = outputs.loss
+                if train_config.save_metrics:
+                    val_step_loss.append(loss.detach().float().item())
+                    val_step_perplexity.append(float(torch.exp(loss.detach().float())))  
+
                 eval_loss += loss.detach().float()
             # Decode predictions and add to evaluation predictions list
             preds = torch.argmax(outputs.logits, -1)
@@ -261,6 +321,8 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
             )
 
     # If there's more than one CUDA device, reduce evaluation loss across all devices
+    if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp):
+        dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
     if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
         dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
 
@@ -276,8 +338,8 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
             print(f" {eval_ppl=} {eval_epoch_loss=}")
     else:
         print(f" {eval_ppl=} {eval_epoch_loss=}")
-
-    return eval_ppl, eval_epoch_loss
+        
+    return eval_ppl, eval_epoch_loss, val_step_loss, val_step_perplexity
 
 def freeze_transformer_layers(model, num_layer):
    for i, layer in enumerate(model.model.layers):
@@ -294,7 +356,11 @@ def check_frozen_layers_peft_model(model):
 
 def setup():
     """Initialize the process group for distributed training"""
-    dist.init_process_group("nccl")
+    if is_ccl_available():
+        # distributed training on xpus
+        dist.init_process_group("ccl")
+    else:
+        dist.init_process_group("nccl")
 
 
 def setup_environ_flags(rank):
@@ -318,7 +384,10 @@ def clear_gpu_cache(rank=None):
     """Clear the GPU cache for all ranks"""
     if rank == 0:
         print(f"Clearing GPU cache for all ranks")
-    torch.cuda.empty_cache()
+    if is_xpu_available():
+        torch.xpu_empty_cache()
+    else:
+        torch.cuda.empty_cache()
 
 
 def get_parameter_dtypes(model):
@@ -350,13 +419,15 @@ def print_model_size(model, config, rank: int = 0) -> None:
 def get_policies(cfg, rank):
     """Get the policies for mixed precision and fsdp wrapping"""
 
-    verify_bfloat_support = (
+    
+    verify_bfloat_support = ((
     torch.version.cuda
     and torch.cuda.is_bf16_supported()
     and packaging.version.parse(torch.version.cuda).release >= (11, 0)
     and dist.is_nccl_available()
     and nccl.version() >= (2, 10)
-    )
+    ) or
+    (is_xpu_available()))
 
 
     mixed_precision_policy = None
@@ -417,3 +488,17 @@ def save_train_params(train_config, fsdp_config, rank):
             f.write(config_yaml)
         if rank==0:
             print(f"training params are saved in {file_name}")
+
+def save_to_json(output_filename, train_step_loss, train_epoch_loss, train_step_ppl, train_epoch_ppl, val_step_loss, val_epoch_loss, val_step_ppl, val_epoch_ppl):
+    metrics_data = {
+        "train_step_loss": train_step_loss,
+        "train_epoch_loss": train_epoch_loss,
+        "train_step_perplexity": train_step_ppl,
+        "train_epoch_perplexity": train_epoch_ppl,
+        "val_step_loss": val_step_loss,
+        "val_epoch_loss": val_epoch_loss,
+        "val_step_perplexity": val_step_ppl,
+        "val_epoch_perplexity": val_epoch_ppl
+    }
+    with open(output_filename, "w") as f:
+        json.dump(metrics_data, f)

+ 5 - 1
tests/test_finetuning.py

@@ -5,6 +5,7 @@ import pytest
 from pytest import approx
 from unittest.mock import patch
 
+import torch
 from torch.nn import Linear
 from torch.optim import AdamW
 from torch.utils.data.dataloader import DataLoader
@@ -100,8 +101,11 @@ def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer
     kwargs = {"weight_decay": 0.01}
 
     get_dataset.return_value = get_fake_dataset()
+    
+    model = mocker.MagicMock(name="Model")
+    model.parameters.return_value = [torch.ones(1,1)]
 
-    get_model.return_value = Linear(1,1)
+    get_model.return_value = model 
 
     main(**kwargs)
 

+ 3 - 3
tests/test_finetuning_data_formatter.py

@@ -3,7 +3,7 @@
 
 from enum import Enum
 import unittest
-from typing import Optional
+from typing import Optional, List
 
 from llama_recipes.data.llama_guard.finetuning_data_formatter import (
     AugmentationConfigs,
@@ -72,8 +72,8 @@ Provide your safety assessment for {agent_type_to_check} in the above conversati
 
     def format_examples_and_compare_to_expected(
         self,
-        training_examples: list[TrainingExample],
-        expected_formatted_examples: list[str],
+        training_examples: List[TrainingExample],
+        expected_formatted_examples: List[str],
         agent_type_to_check: AgentType,
         formatter_configs: Optional[FormatterConfigs] = None,
     ) -> None:

+ 53 - 0
tests/test_train_utils.py

@@ -2,11 +2,27 @@
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 from unittest.mock import patch
+import pytest
 
 import torch
 
+import os
+import shutil
+
 from llama_recipes.utils.train_utils import train
 
+TEMP_OUTPUT_DIR = os.getcwd() + "/tmp"
+
+@pytest.fixture(scope="session")
+def temp_output_dir():
+    # Create the directory during the session-level setup
+    temp_output_dir = "tmp"
+    os.mkdir(os.path.join(os.getcwd(), temp_output_dir))
+    yield temp_output_dir
+    # Delete the directory during the session-level teardown
+    shutil.rmtree(temp_output_dir)
+
+
 @patch("llama_recipes.utils.train_utils.MemoryTrace")
 @patch("llama_recipes.utils.train_utils.nullcontext")
 @patch("llama_recipes.utils.train_utils.torch.cuda.amp.GradScaler")
@@ -28,6 +44,7 @@ def test_gradient_accumulation(autocast, scaler, nullcontext, mem_trace, mocker)
     train_config.use_fp16 = False
     train_config.run_validation = False
     train_config.gradient_clipping = False
+    train_config.save_metrics = False
 
     train(
         model,
@@ -63,3 +80,39 @@ def test_gradient_accumulation(autocast, scaler, nullcontext, mem_trace, mocker)
     assert optimizer.zero_grad.call_count == 3
     assert nullcontext.call_count == 0
     assert autocast.call_count == 5
+
+def test_save_to_json(temp_output_dir, mocker):
+    model = mocker.MagicMock(name="model")
+    model().loss.__truediv__().detach.return_value = torch.tensor(1)
+    mock_tensor = mocker.MagicMock(name="tensor")
+    batch = {"input": mock_tensor}
+    train_dataloader = [batch, batch, batch, batch, batch]
+    eval_dataloader = None
+    tokenizer = mocker.MagicMock()
+    optimizer = mocker.MagicMock()
+    lr_scheduler = mocker.MagicMock()
+    gradient_accumulation_steps = 1
+    train_config = mocker.MagicMock()
+    train_config.enable_fsdp = False
+    train_config.use_fp16 = False
+    train_config.run_validation = False
+    train_config.gradient_clipping = False
+    train_config.save_metrics = True
+    train_config.output_dir = temp_output_dir
+
+    results = train(
+        model,
+        train_dataloader,
+        eval_dataloader,
+        tokenizer,
+        optimizer,
+        lr_scheduler,
+        gradient_accumulation_steps,
+        train_config,
+        local_rank=0
+    )
+
+    assert results["metrics_filename"] not in ["", None]
+    assert os.path.isfile(results["metrics_filename"])
+
+

+ 83 - 0
utils/memory_utils.py

@@ -0,0 +1,83 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+import gc
+import os
+import sys
+import threading
+
+import numpy as np
+import psutil
+import torch
+from accelerate.utils import is_xpu_available
+
+def byte2gb(x):
+    return int(x / 2**30)
+# This context manager is used to track the peak memory usage of the process
+class MemoryTrace:
+    def __enter__(self):
+        gc.collect()
+        if is_xpu_available():
+            torch.xpu.empty_cache()
+            torch.xpu.reset_max_memory_allocated()   # reset the peak gauge to zero
+            self.begin = byte2gb(torch.xpu.memory_allocated())
+        elif torch.cuda.is_available():
+            torch.cuda.empty_cache()
+            torch.cuda.reset_max_memory_allocated()  # reset the peak gauge to zero
+            self.begin = byte2gb(torch.cuda.memory_allocated())
+        self.process = psutil.Process()
+        self.cpu_begin = byte2gb(self.cpu_mem_used())
+        self.peak_monitoring = True
+        peak_monitor_thread = threading.Thread(target=self.peak_monitor_func)
+        peak_monitor_thread.daemon = True
+        peak_monitor_thread.start()
+        return self
+
+    def cpu_mem_used(self):
+        """get resident set size memory for the current process"""
+        return self.process.memory_info().rss
+
+    def peak_monitor_func(self):
+        self.cpu_peak = -1
+
+        while True:
+            self.cpu_peak = max(self.cpu_mem_used(), self.cpu_peak)
+
+            # can't sleep or will not catch the peak right (this comment is here on purpose)
+            # time.sleep(0.001) # 1msec
+
+            if not self.peak_monitoring:
+                break
+
+    def __exit__(self, *exc):
+        self.peak_monitoring = False
+
+        gc.collect()
+        if is_xpu_available():
+            torch.xpu.empty_cache()
+            self.end = byte2gb(torch.xpu.memory_allocated())
+            self.peak = byte2gb(torch.xpu.max_memory_allocated())
+            xpu_info = torch.xpu.memory_stats()
+            self.peak_active_gb = byte2gb(xpu_info["active_bytes.all.peak"])
+            self.xpu_malloc_retires = xpu_info.get("num_alloc_retries", 0)
+            self.peak_active_gb = byte2gb(xpu_info["active_bytes.all.peak"])
+            self.m_xpu_ooms = xpu_info.get("num_ooms", 0)
+            self.used = byte2gb(self.end - self.begin)
+            self.peaked = byte2gb(self.peak - self.begin)
+            self.max_reserved = byte2gb(torch.xpu.max_memory_reserved())
+        else:
+            torch.cuda.empty_cache()
+            self.end = byte2gb(torch.cuda.memory_allocated())
+            self.peak = byte2gb(torch.cuda.max_memory_allocated())
+            cuda_info = torch.cuda.memory_stats()
+            self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
+            self.cuda_malloc_retires = cuda_info.get("num_alloc_retries", 0)
+            self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
+            self.m_cuda_ooms = cuda_info.get("num_ooms", 0)
+            self.used = byte2gb(self.end - self.begin)
+            self.peaked = byte2gb(self.peak - self.begin)
+            self.max_reserved = byte2gb(torch.cuda.max_memory_reserved())
+
+        self.cpu_end = self.cpu_mem_used()
+        self.cpu_used = byte2gb(self.cpu_end - self.cpu_begin)
+        self.cpu_peaked = byte2gb(self.cpu_peak - self.cpu_begin)
+        # print(f"delta used/peak {self.used:4d}/{self.peaked:4d}")