Selaa lähdekoodia

Merge branch 'main' into flop_counter_gc

Hamid Shojanazeri 1 vuosi sitten
vanhempi
commit
71d137c722
93 muutettua tiedostoa jossa 104222 lisäystä ja 426 poistoa
  1. 11 0
      .vscode/settings.json
  2. 62 26
      README.md
  3. 55 0
      benchmarks/inference/README.md
  4. 38 0
      benchmarks/inference/on-prem/README.md
  5. 205 0
      benchmarks/inference/on-prem/vllm/chat_vllm_benchmark.py
  6. 9 0
      benchmarks/inference/on-prem/vllm/input.jsonl
  7. 15 0
      benchmarks/inference/on-prem/vllm/parameters.json
  8. 215 0
      benchmarks/inference/on-prem/vllm/pretrained_vllm_benchmark.py
  9. 23 0
      benchmarks/inference/tokenizer/special_tokens_map.json
  10. 93391 0
      benchmarks/inference/tokenizer/tokenizer.json
  11. BIN
      benchmarks/inference/tokenizer/tokenizer.model
  12. 35 0
      benchmarks/inference/tokenizer/tokenizer_config.json
  13. 610 0
      demo_apps/Azure_API_example/azure_api_example.ipynb
  14. 433 0
      demo_apps/HelloLlamaCloud.ipynb
  15. 347 0
      demo_apps/HelloLlamaLocal.ipynb
  16. 306 0
      demo_apps/LiveData.ipynb
  17. 120 0
      demo_apps/Llama2_Gradio.ipynb
  18. 723 0
      demo_apps/RAG_Chatbot_example/RAG_Chatbot_Example.ipynb
  19. BIN
      demo_apps/RAG_Chatbot_example/data/Llama Getting Started Guide.pdf
  20. 6 0
      demo_apps/RAG_Chatbot_example/requirements.txt
  21. BIN
      demo_apps/RAG_Chatbot_example/vectorstore/db_faiss/index.faiss
  22. BIN
      demo_apps/RAG_Chatbot_example/vectorstore/db_faiss/index.pkl
  23. 117 0
      demo_apps/README.md
  24. 559 0
      demo_apps/StructuredLlama.ipynb
  25. 698 0
      demo_apps/VideoSummary.ipynb
  26. 38 0
      demo_apps/csv2db.py
  27. 186 0
      demo_apps/llama-on-prem.md
  28. BIN
      demo_apps/llama2-gradio.png
  29. BIN
      demo_apps/llama2-streamlit.png
  30. BIN
      demo_apps/llama2-streamlit2.png
  31. BIN
      demo_apps/llama2.pdf
  32. 61 0
      demo_apps/llama_chatbot.py
  33. 45 0
      demo_apps/llama_messenger.py
  34. BIN
      demo_apps/messenger_api_settings.png
  35. 194 0
      demo_apps/messenger_llama2.md
  36. BIN
      demo_apps/messenger_llama_arch.jpg
  37. 1294 0
      demo_apps/nba.txt
  38. 22 0
      demo_apps/streamlit_llama2.py
  39. 53 0
      demo_apps/txt2csv.py
  40. BIN
      demo_apps/whatsapp_dashboard.jpg
  41. 162 0
      demo_apps/whatsapp_llama2.md
  42. BIN
      demo_apps/whatsapp_llama_arch.jpg
  43. 15 3
      docs/Dataset.md
  44. 21 7
      docs/FAQ.md
  45. 16 2
      docs/inference.md
  46. 784 0
      examples/Prompt_Engineering_with_Llama_2.ipynb
  47. 384 0
      examples/Purple_Llama_Anyscale.ipynb
  48. 6 3
      examples/README.md
  49. 9 3
      examples/chat_completion/chat_completion.py
  50. 23 30
      examples/custom_dataset.py
  51. 22 0
      examples/hf_llama_conversion/README.md
  52. 48 0
      examples/hf_llama_conversion/compare_llama_weights.py
  53. 34 25
      examples/inference.py
  54. 66 0
      examples/llama_guard/README.md
  55. 3 0
      examples/llama_guard/__init__.py
  56. 65 0
      examples/llama_guard/inference.py
  57. 71 0
      examples/plot_metrics.py
  58. 3 6
      examples/quickstart.ipynb
  59. 5 1
      examples/vllm/inference.py
  60. 6 1
      pyproject.toml
  61. 3 2
      requirements.txt
  62. 80 6
      scripts/spellcheck_conf/wordlist.txt
  63. 0 2
      src/llama_recipes/configs/datasets.py
  64. 5 0
      src/llama_recipes/configs/training.py
  65. 2 0
      src/llama_recipes/data/__init__.py
  66. 34 0
      src/llama_recipes/data/concatenator.py
  67. 119 0
      src/llama_recipes/data/llama_guard/README.md
  68. 2 0
      src/llama_recipes/data/llama_guard/__init__.py
  69. 413 0
      src/llama_recipes/data/llama_guard/finetuning_data_formatter.py
  70. 90 0
      src/llama_recipes/data/llama_guard/finetuning_data_formatter_example.py
  71. 57 0
      src/llama_recipes/data/sampler.py
  72. 5 15
      src/llama_recipes/datasets/alpaca_dataset.py
  73. 13 18
      src/llama_recipes/datasets/grammar_dataset/grammar_dataset.py
  74. 19 13
      src/llama_recipes/datasets/samsum_dataset.py
  75. 0 66
      src/llama_recipes/datasets/utils.py
  76. 40 41
      src/llama_recipes/finetuning.py
  77. 149 0
      src/llama_recipes/inference/prompt_format_utils.py
  78. 59 8
      src/llama_recipes/inference/safety_utils.py
  79. 163 0
      src/llama_recipes/tools/convert_hf_weights_to_llama.py
  80. 49 11
      src/llama_recipes/utils/config_utils.py
  81. 6 6
      src/llama_recipes/utils/dataset_utils.py
  82. 33 14
      src/llama_recipes/utils/memory_utils.py
  83. 184 50
      src/llama_recipes/utils/train_utils.py
  84. 50 0
      tests/conftest.py
  85. 42 14
      tests/datasets/test_custom_dataset.py
  86. 56 0
      tests/datasets/test_grammar_datasets.py
  87. 32 14
      tests/datasets/test_samsum_datasets.py
  88. 96 0
      tests/test_batching.py
  89. 84 33
      tests/test_finetuning.py
  90. 483 0
      tests/test_finetuning_data_formatter.py
  91. 86 0
      tests/test_sampler.py
  92. 71 6
      tests/test_train_utils.py
  93. 83 0
      utils/memory_utils.py

+ 11 - 0
.vscode/settings.json

@@ -0,0 +1,11 @@
+{
+    "python.testing.unittestArgs": [
+        "-v",
+        "-s",
+        "./tests",
+        "-p",
+        "test_*.py"
+    ],
+    "python.testing.pytestEnabled": false,
+    "python.testing.unittestEnabled": true
+}

+ 62 - 26
README.md

@@ -1,7 +1,13 @@
-# Llama 2 Fine-tuning / Inference Recipes and Examples
+# Llama 2 Fine-tuning / Inference Recipes, Examples, Benchmarks and Demo Apps
+
+**[Update Dec. 28, 2023] We added support for Llama Guard as a safety checker for our example inference script and also with standalone inference with an example script and prompt formatting. More details [here](./examples/llama_guard/README.md). For details on formatting data for fine tuning Llama Guard, we provide a script and sample usage [here](./src/llama_recipes/data/llama_guard/README.md).**
+
+**[Update Dec 14, 2023] We recently released a series of Llama 2 demo apps [here](./demo_apps). These apps show how to run Llama (locally, in the cloud, or on-prem),  how to use Azure Llama 2 API (Model-as-a-Service), how to ask Llama questions in general or about custom data (PDF, DB, or live), how to integrate Llama with WhatsApp and Messenger, and how to implement an end-to-end chatbot with RAG (Retrieval Augmented Generation).**
 
 The 'llama-recipes' repository is a companion to the [Llama 2 model](https://github.com/facebookresearch/llama). The goal of this repository is to provide examples to quickly get started with fine-tuning for domain adaptation and how to run inference for the fine-tuned models. For ease of use, the examples use Hugging Face converted versions of the models. See steps for conversion of the model [here](#model-conversion-to-hugging-face).
 
+In addition, we also provide a number of demo apps, to showcase the Llama 2 usage along with other ecosystem solutions to run Llama 2 locally, in the cloud, and on-prem.
+
 Llama 2 is a new technology that carries potential risks with use. Testing conducted to date has not — and could not — cover all scenarios. In order to help developers address these risks, we have created the [Responsible Use Guide](https://github.com/facebookresearch/llama/blob/main/Responsible-Use-Guide.pdf). More details can be found in our research paper as well. For downloading the models, follow the instructions on [Llama 2 repo](https://github.com/facebookresearch/llama).
 
 
@@ -13,10 +19,9 @@ Llama 2 is a new technology that carries potential risks with use. Testing condu
     - [Multi GPU One Node](#multiple-gpus-one-node)
     - [Multi GPU Multi Node](#multi-gpu-multi-node)
 4. [Inference](./docs/inference.md)
-5. [Repository Organization](#repository-organization)
-6. [License and Acceptable Use Policy](#license)
-
-
+5. [Demo Apps](#demo-apps)
+6. [Repository Organization](#repository-organization)
+7. [License and Acceptable Use Policy](#license)
 
 # Quick Start
 
@@ -29,17 +34,7 @@ Llama-recipes provides a pip distribution for easy install and usage in other pr
 ```
 pip install --extra-index-url https://download.pytorch.org/whl/test/cu118 llama-recipes
 ```
-## Install from source
-To install from source e.g. for development use this command. We're using hatchling as our build backend which requires an up-to-date pip as well as setuptools package.
-```
-pip install -U pip setuptools
-pip install --extra-index-url https://download.pytorch.org/whl/test/cu118 -e .
-```
-For development and contributing to llama-recipes please install all optional dependencies:
-```
-pip install -U pip setuptools
-pip install --extra-index-url https://download.pytorch.org/whl/test/cu118 -e .[tests,auditnlg,vllm]
-```
+
 ## Install with optional dependencies
 Llama-recipes offers the installation of optional packages. There are three optional dependency groups.
 To run the unit tests we can install the required dependencies with:
@@ -56,12 +51,26 @@ pip install --extra-index-url https://download.pytorch.org/whl/test/cu118 llama-
 ```
 Optional dependencies can also be combines with [option1,option2].
 
+## Install from source
+To install from source e.g. for development use these commands. We're using hatchling as our build backend which requires an up-to-date pip as well as setuptools package.
+```
+git clone git@github.com:facebookresearch/llama-recipes.git
+cd llama-recipes
+pip install -U pip setuptools
+pip install --extra-index-url https://download.pytorch.org/whl/test/cu118 -e .
+```
+For development and contributing to llama-recipes please install all optional dependencies:
+```
+git clone git@github.com:facebookresearch/llama-recipes.git
+cd llama-recipes
+pip install -U pip setuptools
+pip install --extra-index-url https://download.pytorch.org/whl/test/cu118 -e .[tests,auditnlg,vllm]
+```
+
 ⚠️ **Note** ⚠️  Some features (especially fine-tuning with FSDP + PEFT) currently require PyTorch nightlies to be installed. Please make sure to install the nightlies if you're using these features following [this guide](https://pytorch.org/get-started/locally/).
 
 **Note** All the setting defined in [config files](src/llama_recipes/configs/) can be passed as args through CLI when running the script, there is no need to change from config files directly.
 
-**Note** In case need to run PEFT model with FSDP, please make sure to use the PyTorch Nightlies.
-
 **For more in depth information checkout the following:**
 
 * [Single GPU Fine-tuning](./docs/single_gpu.md)
@@ -73,7 +82,7 @@ Optional dependencies can also be combines with [option1,option2].
 
 # Where to find the models?
 
-You can find llama v2 models on HuggingFace hub [here](https://huggingface.co/meta-llama), where models with `hf` in the name are already converted to HuggingFace checkpoints so no further conversion is needed. The conversion step below is only for original model weights from Meta that are hosted on HuggingFace model hub as well.
+You can find Llama 2 models on Hugging Face hub [here](https://huggingface.co/meta-llama), where models with `hf` in the name are already converted to Hugging Face checkpoints so no further conversion is needed. The conversion step below is only for original model weights from Meta that are hosted on Hugging Face model hub as well.
 
 # Model conversion to Hugging Face
 The recipes and notebooks in this folder are using the Llama 2 model definition provided by Hugging Face's transformers library.
@@ -81,7 +90,7 @@ The recipes and notebooks in this folder are using the Llama 2 model definition
 Given that the original checkpoint resides under models/7B you can install all requirements and convert the checkpoint with:
 
 ```bash
-## Install HuggingFace Transformers from source
+## Install Hugging Face Transformers from source
 pip freeze | grep transformers ## verify it is version 4.31.0 or higher
 
 git clone git@github.com:huggingface/transformers.git
@@ -107,13 +116,15 @@ All the parameters in the examples and recipes below need to be further tuned to
 
 * Make sure to set the right path to the model in the [training config](src/llama_recipes/configs/training.py).
 
+* To save the loss and perplexity metrics for evaluation, enable this by passing `--save_metrics` to the finetuning script. The file can be plotted using the [plot_metrics.py](./examples/plot_metrics.py) script, `python examples/plot_metrics.py --file_path path/to/metrics.json`
+
 ### Single GPU:
 
 ```bash
 #if running on multi-gpu machine
 export CUDA_VISIBLE_DEVICES=0
 
-python -m llama_recipes.finetuning  --use_peft --peft_method lora --quantization --model_name /patht_of_model_folder/7B --output_dir Path/to/save/PEFT/model
+python -m llama_recipes.finetuning  --use_peft --peft_method lora --quantization --model_name /path_of_model_folder/7B --output_dir path/to/save/PEFT/model
 
 ```
 
@@ -130,7 +141,7 @@ Here we make use of Parameter Efficient Methods (PEFT) as described in the next
 
 ```bash
 
-torchrun --nnodes 1 --nproc_per_node 4  examples/finetuning.py --enable_fsdp --use_peft --peft_method lora --model_name /patht_of_model_folder/7B --pure_bf16 --output_dir Path/to/save/PEFT/model
+torchrun --nnodes 1 --nproc_per_node 4  examples/finetuning.py --enable_fsdp --use_peft --peft_method lora --model_name /path_of_model_folder/7B --fsdp_config.pure_bf16 --output_dir path/to/save/PEFT/model
 
 ```
 
@@ -138,10 +149,10 @@ Here we use FSDP as discussed in the next section which can be used along with P
 
 ## Flash Attention and Xformer Memory Efficient Kernels
 
-Setting `use_fast_kernels` will enable using of Flash Attention or Xformer memory-efficient kernels based on the hardware being used. This would speed up the fine-tuning job. This has been enabled in `optimum` library from HuggingFace as a one-liner API, please read more [here](https://pytorch.org/blog/out-of-the-box-acceleration/).
+Setting `use_fast_kernels` will enable using of Flash Attention or Xformer memory-efficient kernels based on the hardware being used. This would speed up the fine-tuning job. This has been enabled in `optimum` library from Hugging Face as a one-liner API, please read more [here](https://pytorch.org/blog/out-of-the-box-acceleration/).
 
 ```bash
-torchrun --nnodes 1 --nproc_per_node 4  examples/finetuning.py --enable_fsdp --use_peft --peft_method lora --model_name /patht_of_model_folder/7B --pure_bf16 --output_dir Path/to/save/PEFT/model --use_fast_kernels
+torchrun --nnodes 1 --nproc_per_node 4  examples/finetuning.py --enable_fsdp --use_peft --peft_method lora --model_name /path_of_model_folder/7B --fsdp_config.pure_bf16 --output_dir path/to/save/PEFT/model --use_fast_kernels
 ```
 
 ### Fine-tuning using FSDP Only
@@ -150,7 +161,7 @@ If you are interested in running full parameter fine-tuning without making use o
 
 ```bash
 
-torchrun --nnodes 1 --nproc_per_node 8  examples/finetuning.py --enable_fsdp --model_name /patht_of_model_folder/7B --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned --use_fast_kernels
+torchrun --nnodes 1 --nproc_per_node 8  examples/finetuning.py --enable_fsdp --model_name /path_of_model_folder/7B --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned --use_fast_kernels
 
 ```
 
@@ -160,7 +171,7 @@ If you are interested in running full parameter fine-tuning on the 70B model, yo
 
 ```bash
 
-torchrun --nnodes 1 --nproc_per_node 8 examples/finetuning.py --enable_fsdp --low_cpu_fsdp --pure_bf16 --model_name /patht_of_model_folder/70B --batch_size_training 1 --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned
+torchrun --nnodes 1 --nproc_per_node 8 examples/finetuning.py --enable_fsdp --low_cpu_fsdp --fsdp_config.pure_bf16 --model_name /path_of_model_folder/70B --batch_size_training 1 --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned
 
 ```
 
@@ -174,9 +185,32 @@ sbatch multi_node.slurm
 ```
 You can read more about our fine-tuning strategies [here](./docs/LLM_finetuning.md).
 
+# Demo Apps
+This folder contains a series of Llama2-powered apps:
+* Quickstart Llama deployments and basic interactions with Llama
+1. Llama on your Mac and ask Llama general questions
+2. Llama on Google Colab
+3. Llama on Cloud and ask Llama questions about unstructured data in a PDF
+4. Llama on-prem with vLLM and TGI
+5. Llama chatbot with RAG (Retrieval Augmented Generation)
+6. Azure Llama 2 API (Model-as-a-Service)
+
+* Specialized Llama use cases:
+1. Ask Llama to summarize a video content
+2. Ask Llama questions about structured data in a DB
+3. Ask Llama questions about live data on the web
+4. Build a Llama-enabled WhatsApp chatbot
+
+# Benchmarks
+This folder contains a series of benchmark scripts for Llama 2 models inference on various backends:
+1. On-prem - Popular serving frameworks and containers (i.e. vLLM)
+2. (WIP) Cloud API - Popular API services (i.e. Azure Model-as-a-Service)
+3. (WIP) On-device - Popular on-device inference solutions on Android and iOS (i.e. mlc-llm, QNN)
+4. (WIP) Optimization - Popular optimization solutions for faster inference and quantization (i.e. AutoAWQ)
 
 # Repository Organization
 This repository is organized in the following way:
+[benchmarks](./benchmarks): Contains a series of benchmark scripts for Llama 2 models inference on various backends.
 
 [configs](src/llama_recipes/configs/): Contains the configuration files for PEFT methods, FSDP, Datasets.
 
@@ -184,6 +218,8 @@ This repository is organized in the following way:
 
 [datasets](src/llama_recipes/datasets/): Contains individual scripts for each dataset to download and process. Note: Use of any of the datasets should be in compliance with the dataset's underlying licenses (including but not limited to non-commercial uses)
 
+[demo_apps](./demo_apps): Contains a series of Llama2-powered apps, from quickstart deployments to how to ask Llama questions about unstructured data, structured data, live data, and video summary.
+
 [examples](./examples/): Contains examples script for finetuning and inference of the Llama 2 model as well as how to use them safely.
 
 [inference](src/llama_recipes/inference/): Includes modules for inference for the fine-tuned models.

+ 55 - 0
benchmarks/inference/README.md

@@ -0,0 +1,55 @@
+# Inference Throughput Benchmarks
+In this folder we provide a series of benchmark scripts that apply a throughput analysis for Llama 2 models inference on various backends:
+* On-prem - Popular serving frameworks and containers (i.e. vLLM)
+* [**WIP**]Cloud API - Popular API services (i.e. Azure Model-as-a-Service)
+* [**WIP**]On-device - Popular on-device inference solutions on Android and iOS (i.e. mlc-llm, QNN)
+* [**WIP**]Optimization - Popular optimization solutions for faster inference and quantization (i.e. AutoAWQ)
+
+# Why
+There are three major reasons we want to run these benchmarks and share them with our Llama community:
+* Provide inference throughput analysis based on real world situation to help you select the best service or deployment for your scenario
+* Provide a baseline measurement for validating various optimization solutions on different backends, so we can provide guidance on which solutions work best for your scenario
+* Encourage the community to develop benchmarks on top of our works, so we can better quantify the latest proposed solutions combined with current popular frameworks, especially in this crazy fast-moving area
+
+# Parameters
+Here are the parameters (if applicable) that you can configure for running the benchmark:
+* **PROMPT** - Prompt sent in for inference (configure the length of prompt, choose from 5, 25, 50, 100, 500, 1k and 2k)
+* **MAX_NEW_TOKENS** - Max number of tokens generated
+* **CONCURRENT_LEVELS** - Max number of concurrent requests
+* **MODEL_PATH** - Model source
+* **MODEL_HEADERS** - Request headers
+* **SAFE_CHECK** - Content safety check (either Azure service or simulated latency)
+* **THRESHOLD_TPS** - Threshold TPS (threshold for tokens per second below which we deem the query to be slow)
+* **TOKENIZER_PATH** - Tokenizer source
+* **RANDOM_PROMPT_LENGTH** - Random prompt length (for pretrained models)
+* **NUM_GPU** - Number of GPUs for request dispatch among multiple containers
+* **TEMPERATURE** - Temperature for inference
+* **TOP_P** - Top_p for inference
+* **MODEL_ENDPOINTS** - Container endpoints
+* Model parallelism or model replicas - Load one model into multiple GPUs or multiple model replicas on one instance. More detail in the README files for specific containers.
+
+You can also configure other model hyperparameters as part of the request payload.  
+All these parameters are stored in ```parameter.json``` and real prompts are stored in ```input.jsonl```. Running the script will load these configurations.
+
+
+
+# Metrics
+The benchmark will report these metrics per instance:
+* Number of concurrent requests
+* P50 Latency(ms)
+* P99 Latency(ms)
+* Request per second (RPS)
+* Output tokens per second
+* Output tokens per second per GPU
+* Input tokens per second
+* Input tokens per second per GPU
+* Average tokens per second per request
+
+We intend to add these metrics in the future:
+* Time to first token (TTFT)
+  
+The benchmark result will be displayed in the terminal output and saved as a CSV file (```performance_metrics.csv```) which you can export to spreadsheets.
+
+# Getting Started
+Please follow the ```README.md``` in each subfolder for instructions on how to setup and run these benchmarks. 
+

+ 38 - 0
benchmarks/inference/on-prem/README.md

@@ -0,0 +1,38 @@
+# Llama-On-Prem-Benchmark
+This folder contains code to run inference benchmark for Llama 2 models on-prem with popular serving frameworks.
+The benchmark will focus on overall inference **throughput** for running containers on one instance (single or multiple GPUs) that you can acquire from cloud service providers such as Azure and AWS. You can also run this benchmark on local laptop or desktop.  
+We support benchmark on these serving framework:
+* [vLLM](https://github.com/vllm-project/vllm)
+
+
+# vLLM - Getting Started
+To get started, we first need to deploy containers on-prem as a API host. Follow the guidance [here](https://github.com/facebookresearch/llama-recipes/blob/main/demo_apps/llama-on-prem.md#setting-up-vllm-with-llama-2) to deploy vLLM on-prem.
+Note that in common scenario which overall throughput is important, we suggest you prioritize deploying as many model replicas as possible to reach higher overall throughput and request-per-second (RPS), comparing to deploy one model container among multiple GPUs for model parallelism. Additionally, as deploying multiple model replicas, there is a need for a higher level wrapper to handle the load balancing which here has been simulated in the benchmark scripts.  
+For example, we have an instance from Azure that has 8xA100 80G GPUs, and we want to deploy the Llama 2 70B chat model, which is around 140GB with FP16. So for deployment we can do:
+* 1x70B model parallel on 8 GPUs, each GPU RAM takes around 17.5GB for loading model weights.
+* 2x70B models each use 4 GPUs, each GPU RAM takes around 35GB for loading model weights.
+* 4x70B models each use 2 GPUs, each GPU RAM takes around 70GB for loading model weights. (Preferred configuration for max overall throughput. Note that you will have 4 endpoints hosted on different ports and the benchmark script will route requests into each model equally)
+
+Here are examples for deploying 2x70B chat models over 8 GPUs with vLLM.
+```
+CUDA_VISIBLE_DEVICES=0,1,2,3 python -m vllm.entrypoints.openai.api_server  --model meta-llama/Llama-2-70b-chat-hf --tensor-parallel-size 4 --disable-log-requests --port 8000 
+CUDA_VISIBLE_DEVICES=4,5,6,7 python -m vllm.entrypoints.openai.api_server  --model meta-llama/Llama-2-70b-chat-hf --tensor-parallel-size 4 --disable-log-requests --port 8001 
+```
+Once you have finished deployment, you can use the command below to run benchmark scripts in a separate terminal. 
+
+```
+python chat_vllm_benchmark.py
+```
+<!-- markdown-link-check-disable -->
+If you are going to use [Azure AI content check](https://azure.microsoft.com/en-us/products/ai-services/ai-content-safety), then you should install dependencies as shown below in your terminal:
+<!-- markdown-link-check-enable -->
+```
+pip install azure-ai-contentsafety azure-core
+```
+Besides chat models, we also provide benchmark scripts for running pretrained models for text completion tasks. To better simulate the real traffic, we generate configurable random token prompt as input. In this process, we select vocabulary that is longer than 2 tokens so the generated words are closer to the English, rather than symbols.
+However, random token prompts can't be applied for chat model benchmarks, since the chat model expects a valid question. By feeding random prompts, chat models rarely provide answers that is meeting our ```MAX_NEW_TOKEN``` requirement, defeating the purpose of running throughput benchmarks. Hence for chat models, the questions are copied over to form long inputs such as for 2k and 4k inputs.   
+To run pretrained model benchmark, follow the command below.
+```
+python pretrained_vllm_benchmark.py
+```
+

+ 205 - 0
benchmarks/inference/on-prem/vllm/chat_vllm_benchmark.py

@@ -0,0 +1,205 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import csv
+import json
+import time
+import random
+import threading
+import numpy as np
+import requests
+import transformers
+import torch
+
+# Imports for Azure content safety
+from azure.ai.contentsafety import ContentSafetyClient
+from azure.core.credentials import AzureKeyCredential
+from azure.core.exceptions import HttpResponseError
+from azure.ai.contentsafety.models import AnalyzeTextOptions
+
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from typing import Dict, Tuple, List
+
+
+
+with open('input.jsonl') as input:
+    prompt_data = json.load(input)
+
+# Prompt data stored in json file. Choose from number of tokens - 5, 25, 50, 100, 500, 1k, 2k.
+# You can also configure and add your own prompt in input.jsonl
+PROMPT = prompt_data["1k"] 
+
+with open('parameters.json') as parameters:
+    params = json.load(parameters)
+
+MAX_NEW_TOKENS = params["MAX_NEW_TOKENS"]
+CONCURRENT_LEVELS = params["CONCURRENT_LEVELS"]
+# Replace with your own deployment
+MODEL_PATH = params["MODEL_PATH"]
+MODEL_HEADERS = params["MODEL_HEADERS"]
+SAFE_CHECK = params["SAFE_CHECK"]
+# Threshold for tokens per second below which we deem the query to be slow
+THRESHOLD_TPS = params["THRESHOLD_TPS"] 
+# Default Llama tokenizer, replace with your own tokenizer 
+TOKENIZER_PATH = params["TOKENIZER_PATH"] 
+TEMPERATURE = params["TEMPERATURE"]
+TOP_P = params["TOP_P"]
+# Add your model endpoints here, specify the port number. You can acquire the endpoint when creating a on-prem server like vLLM.
+# Group of model endpoints - Send balanced requests to each endpoint for batch maximization.  
+MODEL_ENDPOINTS = params["MODEL_ENDPOINTS"]
+
+# Get number of GPUs on this instance
+if torch.cuda.is_available():
+    NUM_GPU = torch.cuda.device_count()
+else:
+    print("No available GPUs")
+
+
+# This tokenizer is downloaded from Azure model catalog for each specific models. The main purpose is to decode the reponses for token calculation
+tokenizer = transformers.AutoTokenizer.from_pretrained(TOKENIZER_PATH)
+
+num_token_input_prompt = len(tokenizer.encode(PROMPT))
+print(f"Number of token for input prompt: {num_token_input_prompt}")
+
+# Azure content safety analysis
+def analyze_prompt(input):
+    start_time = time.time()
+
+    # Obtain credentials
+    key = "" #Add your AZURE_CONTENT_SAFETY_KEY
+    endpoint = "" #Add your AZURE_CONTENT_SAFETY_ENDPOINT
+
+    # Create a content safety client
+    client = ContentSafetyClient(endpoint, AzureKeyCredential(key))
+
+    # Create request
+    request = AnalyzeTextOptions(text=input)
+
+    # Analyze prompt
+    try:
+        response = client.analyze_text(request)
+    except HttpResponseError as e:
+        print("prompt failed due to content safety filtering.")
+        if e.error:
+            print(f"Error code: {e.error.code}")
+            print(f"Error message: {e.error.message}")
+            raise
+        print(e)
+        raise
+
+    analyze_end_time = time.time()
+    # The round trip latency for using Azure content safety check
+    analyze_latency = (analyze_end_time - start_time) * 1000
+
+
+# Simple round-robin to dispatch requests into different containers
+executor_id = 0
+lock = threading.Lock()
+
+def generate_text() -> Tuple[int, int]:
+    headers = MODEL_HEADERS
+    payload = {
+        "model" : MODEL_PATH,
+        "messages" : [
+            {
+                "role": "user",
+                "content": PROMPT
+            }
+        ],
+        "stream" : False,
+        "temperature" : TEMPERATURE,
+        "top_p" : TOP_P,
+        "max_tokens" : MAX_NEW_TOKENS
+    }
+
+    start_time = time.time()
+
+    if(SAFE_CHECK):
+        # Function to send prompts for safety check. Add delays for request round-trip that count towards overall throughput measurement.
+        # Expect NO returns from calling this function. If you want to check the safety check results, print it out within the function itself.
+        analyze_prompt(PROMPT)
+        # Or add delay simulation if you don't want to use Azure Content Safety check. The API round-trip for this check is around 0.3-0.4 seconds depends on where you located. You can use something like this: time.sleep(random.uniform(0.3, 0.4))
+
+    # Acquire lock to dispatch the request
+    lock.acquire()
+    global executor_id
+    if executor_id != len(MODEL_ENDPOINTS)-1:
+        executor_id += 1
+        endpoint_id = executor_id
+    else:
+        executor_id = 0
+        endpoint_id = executor_id
+    lock.release()
+
+    # Send request
+    response = requests.post(MODEL_ENDPOINTS[endpoint_id], headers=headers, json=payload)
+
+    if(SAFE_CHECK):
+        # Function to send prompts for safety check. Add delays for request round-trip that count towards overall throughput measurement.
+        # Expect NO returns from calling this function. If you want to check the safety check results, print it out within the function itself.
+        analyze_prompt(PROMPT)
+        # Or add delay simulation if you don't want to use Azure Content Safety check. The API round-trip for this check is around 0.3-0.4 seconds depends on where you located. You can use something like this: time.sleep(random.uniform(0.3, 0.4))
+
+    end_time = time.time()
+    # Convert to ms
+    latency = (end_time - start_time) * 1000  
+
+    if response.status_code != 200:
+        raise ValueError(f"Error: {response.content}")
+    output = json.loads(response.content)["choices"][0]["message"]["content"]
+
+    token_count = len(tokenizer.encode(output))
+    return latency, token_count
+
+
+def evaluate_performance(concurrent_requests: int) -> Tuple[float, float, float, float, float, float, float, List[float]]:
+    latencies = []
+    total_output_tokens = 0
+    output_tokens_per_second_each_request = []
+    start_time = time.time()
+
+    # Init multi-thread execution 
+    with ThreadPoolExecutor(max_workers=concurrent_requests) as executor:
+        future_to_req = {executor.submit(generate_text): i for i in range(concurrent_requests)}
+        for future in as_completed(future_to_req):
+            latency, token_count = future.result()
+            latencies.append(latency)
+            total_output_tokens += token_count
+            # Calculate tokens per second for this request
+            tokens_per_sec = token_count / (latency / 1000)
+            output_tokens_per_second_each_request.append(tokens_per_sec)
+
+    end_time = time.time()
+    total_time = end_time - start_time
+    # RPS (requests per second)
+    rps = concurrent_requests / total_time  
+    # Overall tokens per second
+    output_tokens_per_second_overall = total_output_tokens / total_time  
+    input_tokens_per_second_overall = (num_token_input_prompt * concurrent_requests) / total_time
+    output_tokens_per_second_per_gpu = output_tokens_per_second_overall / NUM_GPU
+    input_tokens_per_second_per_gpu = input_tokens_per_second_overall / NUM_GPU
+    p50_latency = np.percentile(latencies, 50)
+    p99_latency = np.percentile(latencies, 99)
+
+    # Count the number of requests below the token-per-second threshold
+    below_threshold_count = sum(1 for tps in output_tokens_per_second_each_request if tps < THRESHOLD_TPS)
+    output_tokens_per_second_per_request = sum(output_tokens_per_second_each_request)/len(output_tokens_per_second_each_request)
+
+    return p50_latency, p99_latency, rps, output_tokens_per_second_overall, output_tokens_per_second_per_gpu, input_tokens_per_second_overall, input_tokens_per_second_per_gpu, output_tokens_per_second_per_request, below_threshold_count
+
+
+
+# Print markdown
+print("| Number of Concurrent Requests | P50 Latency (ms) | P99 Latency (ms) | RPS | Output Tokens per Second | Output Tokens per Second per GPU | Input Tokens per Second | Input Tokens per Second per GPU |Average Output Tokens per Second per Request | Number of Requests Below Threshold |")
+print("|-------------------------------|------------------|------------------|------------------|-------------------|---------------------------|---------------------|------------------------|-------------------------------------- | ---------------------------------- |")
+
+# Save to file
+csv_file = "performance_metrics.csv"
+with open(csv_file, "w", newline='') as f:
+    writer = csv.writer(f)
+    writer.writerow(["Number of Concurrent Requests", "P50 Latency (ms)", "P99 Latency (ms)", "RPS", "Output Tokens per Second", "Output Tokens per Second per GPU", "Input Tokens per Second", "Input Tokens per Second per GPU", "Average Output Tokens per Second per Request"])
+
+    for level in CONCURRENT_LEVELS:
+        p50_latency, p99_latency, rps, output_tokens_per_second_overall, output_tokens_per_second_per_gpu, input_tokens_per_second_overall, input_tokens_per_second_per_gpu, output_tokens_per_second_per_request, below_threshold_count = evaluate_performance(level)
+        print(f"| {level} | {p50_latency:.2f} | {p99_latency:.2f} | {rps:.2f} | {output_tokens_per_second_overall:.2f} | {output_tokens_per_second_per_gpu:.2f} | {input_tokens_per_second_overall:.2f} | {input_tokens_per_second_per_gpu:.2f} | {output_tokens_per_second_per_request:.2f} | {below_threshold_count:.2f} |")
+        writer.writerow([level, round(p50_latency, 2), round(p99_latency, 2), round(rps, 2), round(output_tokens_per_second_overall, 2), round(output_tokens_per_second_per_gpu, 2), round(input_tokens_per_second_overall, 2), round(input_tokens_per_second_per_gpu, 2), round(output_tokens_per_second_per_request, 2)])

Tiedoston diff-näkymää rajattu, sillä se on liian suuri
+ 9 - 0
benchmarks/inference/on-prem/vllm/input.jsonl


+ 15 - 0
benchmarks/inference/on-prem/vllm/parameters.json

@@ -0,0 +1,15 @@
+{
+    "MAX_NEW_TOKENS" : 256,
+    "CONCURRENT_LEVELS" : [1, 2, 4, 8, 16, 32, 64, 128, 256],
+    "MODEL_PATH" : "meta-llama/Llama-2-7b-chat-hf",
+    "MODEL_HEADERS" : {"Content-Type": "application/json"},
+    "SAFE_CHECK" : true,
+    "THRESHOLD_TPS" : 7,
+    "TOKENIZER_PATH" : "../../tokenizer",
+    "RANDOM_PROMPT_LENGTH" : 1000,
+    "TEMPERATURE" : 0.6,
+    "TOP_P" : 0.9,
+    "MODEL_ENDPOINTS" : [
+        "http://localhost:8000/v1/chat/completions"
+    ]
+}

+ 215 - 0
benchmarks/inference/on-prem/vllm/pretrained_vllm_benchmark.py

@@ -0,0 +1,215 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import csv
+import json
+import time
+import random
+import threading
+import numpy as np
+import requests
+import transformers
+import torch
+
+#imports for Azure content safety
+from azure.ai.contentsafety import ContentSafetyClient
+from azure.core.credentials import AzureKeyCredential
+from azure.core.exceptions import HttpResponseError
+from azure.ai.contentsafety.models import AnalyzeTextOptions
+
+from concurrent.futures import ThreadPoolExecutor, as_completed
+from typing import Dict, Tuple, List
+
+
+# Predefined inputs
+with open('input.jsonl') as input:
+    prompt_data = json.load(input)
+
+with open('parameters.json') as parameters:
+    params = json.load(parameters)
+
+MAX_NEW_TOKENS = params["MAX_NEW_TOKENS"]
+CONCURRENT_LEVELS = params["CONCURRENT_LEVELS"]
+# Replace with your own deployment
+MODEL_PATH = params["MODEL_PATH"]
+MODEL_HEADERS = params["MODEL_HEADERS"]
+SAFE_CHECK = params["SAFE_CHECK"]
+# Threshold for tokens per second below which we deem the query to be slow
+THRESHOLD_TPS = params["THRESHOLD_TPS"] 
+# Replace with your own tokenizer 
+TOKENIZER_PATH = params["TOKENIZER_PATH"] 
+RANDOM_PROMPT_LENGTH = params["RANDOM_PROMPT_LENGTH"]
+TEMPERATURE = params["TEMPERATURE"]
+TOP_P = params["TOP_P"]
+# Add your model endpoints here, specify the port number. You can acquire the endpoint when creating a on-prem server like vLLM.
+# Group of model endpoints - Send balanced requests to each endpoint for batch maximization.  
+MODEL_ENDPOINTS = params["MODEL_ENDPOINTS"]
+
+#Get number of GPUs on this instance
+if torch.cuda.is_available():
+    NUM_GPU = torch.cuda.device_count()
+else:
+    print("No available GPUs")
+
+
+# This tokenizer is downloaded from Azure model catalog for each specific models. The main purpose is to decode the reponses for token calculation
+tokenizer = transformers.AutoTokenizer.from_pretrained(TOKENIZER_PATH)
+
+# Select vocabulary that is longer than 2 tokens (closer to real words) and close to the English (not foolproof)
+vocab = [token for token in tokenizer.get_vocab().keys() if len(token) > 2 and all(ord(c) < 128 for c in token)]
+
+def generate_random_prompt(num_tokens):
+    generated_tokens_count = 0
+    selected_tokens = ""
+    while generated_tokens_count < num_tokens:
+        selected_tokens += random.choice(vocab)
+        selected_tokens += " "
+        generated_tokens_count = len(tokenizer.encode(selected_tokens))
+
+    return selected_tokens
+
+PROMPT = generate_random_prompt(RANDOM_PROMPT_LENGTH)
+num_token_input_prompt = len(tokenizer.encode(PROMPT))
+print(f"Number of token for input prompt: {num_token_input_prompt}")
+
+
+# Azure content safety analysis
+def analyze_prompt(input):
+    start_time = time.time()
+
+    # Obtain credentials
+    key = "" #Add your AZURE_CONTENT_SAFETY_KEY
+    endpoint = "" #Add your AZURE_CONTENT_SAFETY_ENDPOINT
+
+    # Create a content safety client
+    client = ContentSafetyClient(endpoint, AzureKeyCredential(key))
+
+    # Create request
+    request = AnalyzeTextOptions(text=input)
+
+    # Analyze prompt
+    try:
+        response = client.analyze_text(request)
+    except HttpResponseError as e:
+        print("prompt failed due to content safety filtering.")
+        if e.error:
+            print(f"Error code: {e.error.code}")
+            print(f"Error message: {e.error.message}")
+            raise
+        print(e)
+        raise
+
+    analyze_end_time = time.time()
+    # The round trip latency for using Azure content safety check
+    analyze_latency = (analyze_end_time - start_time) * 1000
+
+
+# Simple round-robin to dispatch requests into different containers
+executor_id = 0
+lock = threading.Lock()
+
+def generate_text() -> Tuple[int, int]:
+    headers = MODEL_HEADERS
+    payload = {
+        "model" : MODEL_PATH,
+        "messages" : [
+            {
+                "role": "user",
+                "content": PROMPT
+            }
+        ],
+        "stream" : False,
+        "temperature" : TEMPERATURE,
+        "top_p" : TOP_P,
+        "max_tokens" : MAX_NEW_TOKENS
+    }
+
+    start_time = time.time()
+
+    if(SAFE_CHECK):
+        # Function to send prompts for safety check. Add delays for request round-trip that count towards overall throughput measurement.
+        # Expect NO returns from calling this function. If you want to check the safety check results, print it out within the function itself.
+        analyze_prompt(PROMPT)
+        # Or add delay simulation if you don't want to use Azure Content Safety check. The API round-trip for this check is around 0.3-0.4 seconds depends on where you located. You can use something like this: time.sleep(random.uniform(0.3, 0.4))
+
+    lock.acquire()
+    global executor_id
+    if executor_id != len(MODEL_ENDPOINTS)-1:
+        executor_id += 1
+        endpoint_id = executor_id
+    else:
+        executor_id = 0
+        endpoint_id = executor_id
+    lock.release()
+
+    response = requests.post(MODEL_ENDPOINTS[endpoint_id], headers=headers, json=payload)
+
+    if(SAFE_CHECK):
+        # Function to send prompts for safety check. Add delays for request round-trip that count towards overall throughput measurement.
+        # Expect NO returns from calling this function. If you want to check the safety check results, print it out within the function itself.
+        analyze_prompt(PROMPT)
+        # Or add delay simulation if you don't want to use Azure Content Safety check. The API round-trip for this check is around 0.3-0.4 seconds depends on where you located. You can use something like this: time.sleep(random.uniform(0.3, 0.4))
+
+    end_time = time.time()
+    # Convert to ms
+    latency = (end_time - start_time) * 1000 
+
+    if response.status_code != 200:
+        raise ValueError(f"Error: {response.content}")
+    output = json.loads(response.content)["choices"][0]["message"]["content"]
+
+    token_count = len(tokenizer.encode(output))
+    return latency, token_count
+
+
+def evaluate_performance(concurrent_requests: int) -> Tuple[float, float, float, float, float, float, float, List[float]]:
+    latencies = []
+    total_output_tokens = 0
+    output_tokens_per_second_each_request = []
+    start_time = time.time()
+
+    # Init multi-thread execution 
+    with ThreadPoolExecutor(max_workers=concurrent_requests) as executor:
+        future_to_req = {executor.submit(generate_text): i for i in range(concurrent_requests)}
+        for future in as_completed(future_to_req):
+            latency, token_count = future.result()
+            latencies.append(latency)
+            total_output_tokens += token_count
+            # Calculate tokens per second for this request
+            tokens_per_sec = token_count / (latency / 1000)
+            output_tokens_per_second_each_request.append(tokens_per_sec)
+
+    end_time = time.time()
+    total_time = end_time - start_time
+    # RPS (requests per second)
+    rps = concurrent_requests / total_time  
+    # Overall tokens per second
+    output_tokens_per_second_overall = total_output_tokens / total_time  
+    input_tokens_per_second_overall = (num_token_input_prompt * concurrent_requests) / total_time
+    output_tokens_per_second_per_gpu = output_tokens_per_second_overall / NUM_GPU
+    input_tokens_per_second_per_gpu = input_tokens_per_second_overall / NUM_GPU
+    p50_latency = np.percentile(latencies, 50)
+    p99_latency = np.percentile(latencies, 99)
+
+    # Count the number of requests below the token-per-second threshold
+    below_threshold_count = sum(1 for tps in output_tokens_per_second_each_request if tps < THRESHOLD_TPS)
+    output_tokens_per_second_per_request = sum(output_tokens_per_second_each_request)/len(output_tokens_per_second_each_request)
+
+    return p50_latency, p99_latency, rps, output_tokens_per_second_overall, output_tokens_per_second_per_gpu, input_tokens_per_second_overall, input_tokens_per_second_per_gpu, output_tokens_per_second_per_request, below_threshold_count
+
+
+
+# Print markdown
+print("| Number of Concurrent Requests | P50 Latency (ms) | P99 Latency (ms) | RPS | Output Tokens per Second | Output Tokens per Second per GPU | Input Tokens per Second | Input Tokens per Second per GPU |Average Output Tokens per Second per Request | Number of Requests Below Threshold |")
+print("|-------------------------------|------------------|------------------|------------------|-------------------|---------------------------|---------------------|------------------------|-------------------------------------- | ---------------------------------- |")
+
+# Save to file
+csv_file = "performance_metrics.csv"
+with open(csv_file, "w", newline='') as f:
+    writer = csv.writer(f)
+    writer.writerow(["Number of Concurrent Requests", "P50 Latency (ms)", "P99 Latency (ms)", "RPS", "Output Tokens per Second", "Output Tokens per Second per GPU", "Input Tokens per Second", "Input Tokens per Second per GPU", "Average Output Tokens per Second per Request"])
+
+    for level in CONCURRENT_LEVELS:
+        p50_latency, p99_latency, rps, output_tokens_per_second_overall, output_tokens_per_second_per_gpu, input_tokens_per_second_overall, input_tokens_per_second_per_gpu, output_tokens_per_second_per_request, below_threshold_count = evaluate_performance(level)
+        print(f"| {level} | {p50_latency:.2f} | {p99_latency:.2f} | {rps:.2f} | {output_tokens_per_second_overall:.2f} | {output_tokens_per_second_per_gpu:.2f} | {input_tokens_per_second_overall:.2f} | {input_tokens_per_second_per_gpu:.2f} | {output_tokens_per_second_per_request:.2f} | {below_threshold_count:.2f} |")
+        writer.writerow([level, round(p50_latency, 2), round(p99_latency, 2), round(rps, 2), round(output_tokens_per_second_overall, 2), round(output_tokens_per_second_per_gpu, 2), round(input_tokens_per_second_overall, 2), round(input_tokens_per_second_per_gpu, 2), round(output_tokens_per_second_per_request, 2)])

+ 23 - 0
benchmarks/inference/tokenizer/special_tokens_map.json

@@ -0,0 +1,23 @@
+{
+  "bos_token": {
+    "content": "<s>",
+    "lstrip": false,
+    "normalized": true,
+    "rstrip": false,
+    "single_word": false
+  },
+  "eos_token": {
+    "content": "</s>",
+    "lstrip": false,
+    "normalized": true,
+    "rstrip": false,
+    "single_word": false
+  },
+  "unk_token": {
+    "content": "<unk>",
+    "lstrip": false,
+    "normalized": true,
+    "rstrip": false,
+    "single_word": false
+  }
+}

Tiedoston diff-näkymää rajattu, sillä se on liian suuri
+ 93391 - 0
benchmarks/inference/tokenizer/tokenizer.json


BIN
benchmarks/inference/tokenizer/tokenizer.model


+ 35 - 0
benchmarks/inference/tokenizer/tokenizer_config.json

@@ -0,0 +1,35 @@
+{
+  "add_bos_token": true,
+  "add_eos_token": false,
+  "bos_token": {
+    "__type": "AddedToken",
+    "content": "<s>",
+    "lstrip": false,
+    "normalized": true,
+    "rstrip": false,
+    "single_word": false
+  },
+  "clean_up_tokenization_spaces": false,
+  "eos_token": {
+    "__type": "AddedToken",
+    "content": "</s>",
+    "lstrip": false,
+    "normalized": true,
+    "rstrip": false,
+    "single_word": false
+  },
+  "legacy": true,
+  "use_default_system_prompt": false,
+  "model_max_length": 1000000000000000019884624838656,
+  "pad_token": null,
+  "sp_model_kwargs": {},
+  "tokenizer_class": "LlamaTokenizerFast",
+  "unk_token": {
+    "__type": "AddedToken",
+    "content": "<unk>",
+    "lstrip": false,
+    "normalized": true,
+    "rstrip": false,
+    "single_word": false
+  }
+}

+ 610 - 0
demo_apps/Azure_API_example/azure_api_example.ipynb

@@ -0,0 +1,610 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Use Azure API with Llama 2\n",
+    "\n",
+    "This notebook shows examples of how to use Llama 2 APIs offered by Microsoft Azure. We will cover:  \n",
+    "* HTTP requests API usage for Llama 2 pretrained and chat models in CLI\n",
+    "* HTTP requests API usage for Llama 2 pretrained and chat models in Python\n",
+    "* Plug the APIs into LangChain\n",
+    "* Wire the model with Gradio to build a simple chatbot with memory\n",
+    "\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Prerequisite\n",
+    "\n",
+    "Before we start building with Azure Llama 2 APIs, there are certain steps we need to take to deploy the models:\n",
+    "\n",
+    "* Register for a valid Azure account with subscription [here](https://azure.microsoft.com/en-us/free/search/?ef_id=_k_CjwKCAiA-P-rBhBEEiwAQEXhH5OHAJLhzzcNsuxwpa5c9EJFcuAjeh6EvZw4afirjbWXXWkiZXmU2hoC5GoQAvD_BwE_k_&OCID=AIDcmm5edswduu_SEM__k_CjwKCAiA-P-rBhBEEiwAQEXhH5OHAJLhzzcNsuxwpa5c9EJFcuAjeh6EvZw4afirjbWXXWkiZXmU2hoC5GoQAvD_BwE_k_&gad_source=1&gclid=CjwKCAiA-P-rBhBEEiwAQEXhH5OHAJLhzzcNsuxwpa5c9EJFcuAjeh6EvZw4afirjbWXXWkiZXmU2hoC5GoQAvD_BwE)\n",
+    "* Take a quick look on what is the [Azure AI Studio](https://learn.microsoft.com/en-us/azure/ai-studio/what-is-ai-studio?tabs=home) and navigate to the website from the link in the article\n",
+    "* Follow the demos in the article to create a project and [resource](https://learn.microsoft.com/en-us/azure/azure-resource-manager/management/manage-resource-groups-portal) group, or you can also follow the guide [here](https://learn.microsoft.com/en-us/azure/ai-studio/how-to/deploy-models-llama?tabs=azure-studio)\n",
+    "* Select Llama models from Model catalog\n",
+    "* Deploy with \"Pay-as-you-go\"\n",
+    "\n",
+    "Once deployed successfully, you should be assigned for an API endpoint and a security key for inference.  \n",
+    "\n",
+    "For more information, you should consult Azure's official documentation [here](https://learn.microsoft.com/en-us/azure/ai-studio/how-to/deploy-models-llama?tabs=azure-studio) for model deployment and inference."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## HTTP Requests API Usage in CLI\n",
+    "\n",
+    "### Basics\n",
+    "\n",
+    "For using the REST API, You will need to have an Endpoint url and Authentication Key associated with that endpoint.  \n",
+    "This can be acquired from previous steps.  \n",
+    "\n",
+    "In this text completion example for pre-trained model, we use a simple curl call for illustration. There are three major components:  \n",
+    "\n",
+    "* The `host-url` is your endpoint url with completion schema. \n",
+    "* The `headers` defines the content type as well as your api key. \n",
+    "* The `payload` or `data`, which is your prompt detail and model hyper parameters."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "!curl -X POST -L https://your-endpoint.inference.ai.azure.com/v1/completions -H 'Content-Type: application/json' -H 'Authorization: your-auth-key' -d '{\"prompt\": \"Math is a\", \"max_tokens\": 30, \"temperature\": 0.7}' "
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "For chat completion, the API schema and request payload are slightly different.\n",
+    "\n",
+    "The `host-url` needs to be `/v1/chat/completions` and the request payload to include roles in conversations. Here is a sample payload:  \n",
+    "\n",
+    "```\n",
+    "{ \n",
+    "  \"messages\": [ \n",
+    "    { \n",
+    "      \"content\": \"You are a helpful assistant.\", \n",
+    "      \"role\": \"system\" \n",
+    "},  \n",
+    "    { \n",
+    "      \"content\": \"Hello!\", \n",
+    "      \"role\": \"user\" \n",
+    "    } \n",
+    "  ], \n",
+    "  \"max_tokens\": 50, \n",
+    "} \n",
+    "```\n",
+    "\n",
+    "Here is a sample curl call for chat completion"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "!curl -X POST -L https://your-endpoint.inference.ai.azure.com/v1/chat/completions -H 'Content-Type: application/json' -H 'Authorization: your-auth-key' -d '{\"messages\":[{\"content\":\"You are a helpful assistant.\",\"role\":\"system\"},{\"content\":\"Who wrote the book Innovators dilemma?\",\"role\":\"user\"}], \"max_tokens\": 50}'"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "If you compare the generation result for both text and chat completion API calls, you will notice that:  \n",
+    "\n",
+    "* Text completion returns a list of `choices` for the input prompt, each contains generated text and completion information such as `logprobs`.\n",
+    "* Chat completion returns a list of `choices` each with a `message` object with completion result, matching the `messages` object in the request.  \n",
+    "\n",
+    "\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Streaming\n",
+    "\n",
+    "One fantastic feature the API offers is the streaming capability.  \n",
+    "Streaming allows the generated tokens to be sent as data-only server-sent events whenever they become available.  \n",
+    "This is extremely important for interactive applications such as chatbots, so the user is always engaged.  \n",
+    "\n",
+    "To use streaming, simply set `\"stream\":\"True\"` as part of the request payload.  \n",
+    "In the streaming mode, the REST API response will be different from non-streaming mode.\n",
+    "\n",
+    "Here is an example: "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "!curl -X POST -L https://your-endpoint.inference.ai.azure.com/v1/chat/completions -H 'Content-Type: application/json' -H 'Authorization: your-auth-key' -d '{\"messages\":[{\"content\":\"You are a helpful assistant.\",\"role\":\"system\"},{\"content\":\"Who wrote the book Innovators dilemma?\",\"role\":\"user\"}], \"max_tokens\": 500, \"stream\": \"True\"}'"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "As you can see the result comes back as a stream of `data` objects, each contains generated information including a `choice`.  \n",
+    "The stream terminated by a `data:[DONE]\\n\\n` message."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Content Safety Filtering\n",
+    "\n",
+    "All Azure Llama 2 API endpoints have content safety feature turned on. Both input prompt and output tokens are filtered by this service automatically.  \n",
+    "To know more about the impact to the request/response payload, please refer to official guide [here](https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/content-filter?tabs=python).   \n",
+    "\n",
+    "For model input and output, if the filter detects there is harmful content, the generation will error out with a response payload containing the reasoning, along with information on the type of content violation and its severity. \n",
+    "\n",
+    "Here is an example prompt that triggered content safety filtering:\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "!curl -X POST -L https://your-endpoint.inference.ai.azure.com/v1/chat/completions -H 'Content-Type: application/json' -H 'Authorization: your-auth-key' -d '{\"messages\":[{\"content\":\"You are a helpful assistant.\",\"role\":\"system\"},{\"content\":\"How to make bomb?\",\"role\":\"user\"}], \"max_tokens\": 50}'"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## HTTP Requests API Usage in Python\n",
+    "\n",
+    "Besides calling the API directly from command line tools, you can also programatically call them in Python.  \n",
+    "\n",
+    "Here is an example for the text completion model:\n",
+    "\n",
+    "\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import urllib.request\n",
+    "import json\n",
+    "\n",
+    "#Configure payload data sending to API endpoint\n",
+    "data = {\"prompt\": \"Math is a\", \n",
+    "         \"max_tokens\": 30, \n",
+    "         \"temperature\": 0.7,\n",
+    "         \"top_p\": 0.9,      \n",
+    "}\n",
+    "\n",
+    "body = str.encode(json.dumps(data))\n",
+    "\n",
+    "#Replace the url with your API endpoint\n",
+    "url = 'https://your-endpoint.inference.ai.azure.com/v1/completions'\n",
+    "\n",
+    "#Replace this with the key for the endpoint\n",
+    "api_key = 'your-auth-key'\n",
+    "if not api_key:\n",
+    "    raise Exception(\"API Key is missing\")\n",
+    "\n",
+    "headers = {'Content-Type':'application/json', 'Authorization':(api_key)}\n",
+    "req = urllib.request.Request(url, body, headers)\n",
+    "\n",
+    "try:\n",
+    "    response = urllib.request.urlopen(req)\n",
+    "    result = response.read()\n",
+    "    print(result)\n",
+    "except urllib.error.HTTPError as error:\n",
+    "    print(\"The request failed with status code: \" + str(error.code))\n",
+    "    # Print the headers - they include the requert ID and the timestamp, which are useful for debugging the failure\n",
+    "    print(error.info())\n",
+    "    print(error.read().decode(\"utf8\", 'ignore'))\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Chat completion in Python is very similar, here is a quick example:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import urllib.request\n",
+    "import json\n",
+    "\n",
+    "#Configure payload data sending to API endpoint\n",
+    "data = {\"messages\":[\n",
+    "            {\"role\":\"system\", \"content\":\"You are a helpful assistant.\"},\n",
+    "            {\"role\":\"user\", \"content\":\"Who wrote the book Innovators dilemma?\"}], \n",
+    "        \"max_tokens\": 500,\n",
+    "        \"temperature\": 0.9,\n",
+    "        \"stream\": \"True\",\n",
+    "}\n",
+    "\n",
+    "body = str.encode(json.dumps(data))\n",
+    "\n",
+    "#Replace the url with your API endpoint\n",
+    "url = 'https://your-endpoint.inference.ai.azure.com/v1/chat/completions'\n",
+    "\n",
+    "#Replace this with the key for the endpoint\n",
+    "api_key = 'your-auth-key'\n",
+    "if not api_key:\n",
+    "    raise Exception(\"API Key is missing\")\n",
+    "\n",
+    "headers = {'Content-Type':'application/json', 'Authorization':(api_key)}\n",
+    "\n",
+    "req = urllib.request.Request(url, body, headers)\n",
+    "\n",
+    "try:\n",
+    "    response = urllib.request.urlopen(req)\n",
+    "    result = response.read()\n",
+    "    print(result)\n",
+    "except urllib.error.HTTPError as error:\n",
+    "    print(\"The request failed with status code: \" + str(error.code))\n",
+    "    # Print the headers - they include the requert ID and the timestamp, which are useful for debugging the failure\n",
+    "    print(error.info())\n",
+    "    print(error.read().decode(\"utf8\", 'ignore'))\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "However in this example, the streamed data content returns back as a single payload. It didn't stream as a serial of data events as we wished. To build true streaming capabilities utilizing the API endpoint, we will utilize the [`requests`](https://requests.readthedocs.io/en/latest/) library instead."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Streaming in Python\n",
+    "\n",
+    "`Requests` library is a simple HTTP library for Python built with [`urllib3`](https://github.com/urllib3/urllib3). It automatically maintains the keep-alive and HTTP connection pooling. With the `Session` class, we can easily stream the result from our API calls.  \n",
+    "\n",
+    "Here is a quick example:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import json\n",
+    "import requests\n",
+    "\n",
+    "data = {\"messages\":[\n",
+    "            {\"role\":\"system\", \"content\":\"You are a helpful assistant.\"},\n",
+    "            {\"role\":\"user\", \"content\":\"Who wrote the book Innovators dilemma?\"}],\n",
+    "        \"max_tokens\": 500,\n",
+    "        \"temperature\": 0.9,\n",
+    "        \"stream\": \"True\"\n",
+    "}\n",
+    "\n",
+    "\n",
+    "def post_stream(url):\n",
+    "    s = requests.Session()\n",
+    "    api_key = \"your-auth-key\"\n",
+    "    headers = {'Content-Type':'application/json', 'Authorization':(api_key)}\n",
+    "\n",
+    "    with s.post(url, data=json.dumps(data), headers=headers, stream=True) as resp:\n",
+    "        print(resp.status_code)\n",
+    "        for line in resp.iter_lines():\n",
+    "            if line:\n",
+    "                print(line)\n",
+    "\n",
+    "\n",
+    "url = \"https://your-endpoint.inference.ai.azure.com/v1/chat/completions\"\n",
+    "post_stream(url)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Use Llama 2 API with LangChain\n",
+    "\n",
+    "In this section, we will demonstrate how to use Llama 2 APIs with LangChain, one of the most popular framework to accelerate building your AI product.  \n",
+    "One common solution here is to create your customized LLM instance, so you can add it to various chains to complete different tasks.  \n",
+    "In this example, we will use the `AzureMLOnlineEndpoint` class LangChain provides to build a customized LLM instance. This particular class is designed to take in Azure endpoint and API keys as inputs and wire it with HTTP calls. So the underlying of it is very similar to how we used `urllib.request` library to send RESTful calls in previous examples to the Azure Endpoint.   \n",
+    "\n",
+    "Note Azure is working on a standard solution for LangChain integration in this [PR](https://github.com/langchain-ai/langchain/pull/14560), you should consider migrating to that in the future. \n",
+    "\n",
+    "First, let's install dependencies: \n",
+    "\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "pip install langchain"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Once all dependencies are installed, you can directly create a `llm` instance based on `AzureMLOnlineEndpoint` as follows:  "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from langchain.llms.azureml_endpoint import AzureMLOnlineEndpoint, ContentFormatterBase\n",
+    "from typing import Dict\n",
+    "import json\n",
+    "\n",
+    "\n",
+    "class AzureLlamaAPIContentFormatter(ContentFormatterBase):\n",
+    "#Content formatter for Llama 2 API for Azure MaaS\n",
+    "\n",
+    "    def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:\n",
+    "        #Formats the request according to the chosen api\n",
+    "        prompt = ContentFormatterBase.escape_special_characters(prompt)\n",
+    "        request_payload_dict = {\n",
+    "                \"messages\": [\n",
+    "                    {\"role\":\"system\", \"content\":\"You are a helpful assistant\"},\n",
+    "                    {\"role\":\"user\", \"content\":f\"{prompt}\"}\n",
+    "                    ]               \n",
+    "            }\n",
+    "        #Add model parameters as part of the dict\n",
+    "        request_payload_dict.update(model_kwargs)\n",
+    "        request_payload = json.dumps(request_payload_dict)\n",
+    "        return str.encode(request_payload)\n",
+    "\n",
+    "    def format_response_payload(self, output: bytes) -> str:\n",
+    "        #Formats response\n",
+    "        return json.loads(output)[\"choices\"][0][\"message\"][\"content\"]\n",
+    "\n",
+    "\n",
+    "content_formatter = AzureLlamaAPIContentFormatter()\n",
+    "\n",
+    "llm = AzureMLOnlineEndpoint(\n",
+    "    endpoint_api_key=\"your-auth-key\",\n",
+    "    endpoint_url=\"https://your-endpoint.inference.ai.azure.com/v1/chat/completions\",\n",
+    "    model_kwargs={\"temperature\": 0.6, \"max_tokens\": 512, \"top_p\": 0.9},\n",
+    "    content_formatter=content_formatter,\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "However, you might wonder what is the `content_formatter` in the context when creating the `llm` instance?   \n",
+    "The `content_formatter` parameter is a [handler class](https://python.langchain.com/docs/integrations/llms/azure_ml#content-formatter) for transforming the request and response of an AzureML endpoint to match with required schema. Since there are various models in the Azure model catalog, each of which needs to handle the data accordingly.  \n",
+    "In our case, all current formatters provided by Langchain including `LLamaContentFormatter` don't follow the schema. So we created our own customized formatter called `AzureLlamaAPIContentFormatter` to handle the input and output data.  \n",
+    "\n",
+    "Once you have the `llm` ready, you can simple inference it by:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "print(llm(\"Who wrote the book Innovators dilemma?\"))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Here is an example that you can create a translator chain with the `llm` instance and translate English to French:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from langchain.chains import LLMChain\n",
+    "from langchain.prompts import PromptTemplate\n",
+    "\n",
+    "template = \"\"\"\n",
+    "You are a Translator. Translate the following content from {input_language} to {output_language} and reply with only the translated result.\n",
+    "{input_content}\n",
+    "\"\"\"\n",
+    "\n",
+    "translator_chain = LLMChain(\n",
+    "    llm = llm,\n",
+    "    prompt = PromptTemplate(\n",
+    "            template=template,\n",
+    "            input_variables=[\"input_language\", \"output_language\", \"input_content\"],\n",
+    "        ),\n",
+    ")\n",
+    "\n",
+    "print(translator_chain.run(input_language=\"English\", output_language=\"French\", input_content=\"Who wrote the book Innovators dilemma?\"))\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "At the time of writing this sample notebook, LangChain doesn't support streaming with `AzureMLOnlineEndpoint` for Llama 2. We are working with LangChain and Azure team to implement that."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Build a chatbot with Llama 2 API\n",
+    "\n",
+    "In this section, we will build a simple chatbot using Azure Llama 2 API, LangChain and [Gradio](https://www.gradio.app/)'s `ChatInterface` with memory capability.\n",
+    "\n",
+    "Gradio is a framework to help demo your machine learning model with a web interface. We also have a dedicated Gradio chatbot [example](https://github.com/facebookresearch/llama-recipes/tree/main/demo_apps/RAG_Chatbot_example) built with Llama 2 on-premises with RAG.   \n",
+    "\n",
+    "First, let's install Gradio dependencies.\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "\n",
+    "pip install gradio"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Let's use `AzureMLOnlineEndpoint` class from the previous example.  \n",
+    "In this example, we have three major components:  \n",
+    "1. Chatbot UI hosted as web interface by Gradio. These are the UI logics that render our model predictions.\n",
+    "2. Model itself, which is the core component that ingests prompts and returns an answer back.\n",
+    "3. Memory component, which stores previous conversation context. In this example, we will use [conversation window buffer](https://python.langchain.com/docs/modules/memory/types/buffer_window) which logs context in certain time window in the past. \n",
+    "\n",
+    "All of them are chained together using LangChain."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import gradio as gr\n",
+    "from langchain.chains import ConversationChain\n",
+    "from langchain.prompts import PromptTemplate\n",
+    "from langchain.llms.azureml_endpoint import AzureMLOnlineEndpoint, ContentFormatterBase\n",
+    "from langchain.memory import ConversationBufferWindowMemory\n",
+    "\n",
+    "import langchain\n",
+    "from typing import Dict\n",
+    "import json\n",
+    "\n",
+    "langchain.debug=True\n",
+    "\n",
+    "class AzureLlamaAPIContentFormatter(ContentFormatterBase):\n",
+    "#Content formatter for Llama 2 API for Azure MaaS\n",
+    "\n",
+    "    def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:\n",
+    "        #Formats the request according to the chosen api\n",
+    "        prompt = ContentFormatterBase.escape_special_characters(prompt)\n",
+    "\n",
+    "        #Note how we instructed the model with system prompts. Past conversation can be past as in system prompt as well\n",
+    "        request_payload_dict = {\n",
+    "                \"messages\": [\n",
+    "                    {\"role\":\"system\", \"content\":\"The following is a conversation between a user and you. Answer the user question based on the conversation. Provide your answer only\"},\n",
+    "                    {\"role\":\"user\", \"content\":f\"{prompt}\"}\n",
+    "                    ]               \n",
+    "            }\n",
+    "        request_payload_dict.update(model_kwargs)\n",
+    "        request_payload = json.dumps(request_payload_dict)\n",
+    "        return str.encode(request_payload)\n",
+    "\n",
+    "    def format_response_payload(self, output: bytes) -> str:\n",
+    "        #Formats response\n",
+    "        return json.loads(output)[\"choices\"][0][\"message\"][\"content\"]\n",
+    "\n",
+    "#Create content fomartter\n",
+    "content_formatter = AzureLlamaAPIContentFormatter()\n",
+    "\n",
+    "#Create llm instance\n",
+    "llm = AzureMLOnlineEndpoint(\n",
+    "    endpoint_api_key=\"your-auth-key\",\n",
+    "    endpoint_url=\"https://your-endpoint.inference.ai.azure.com/v1/chat/completions\",\n",
+    "    model_kwargs={\"temperature\": 0.6, \"max_tokens\": 128, \"top_p\": 0.9},\n",
+    "    content_formatter=content_formatter,\n",
+    ")\n",
+    "\n",
+    "#Create memory\n",
+    "memory = ConversationBufferWindowMemory(llm=llm, k=5, memory_key=\"chat_history\", ai_prefix=\"Assistant\", human_prefix=\"User\")\n",
+    "\n",
+    "#Create input prompt template with chat history for chaining\n",
+    "INPUT_TEMPLATE = \"\"\"Current conversation:\n",
+    "{chat_history}\n",
+    "\n",
+    "User question:{input}\"\"\"\n",
+    "\n",
+    "conversation_prompt_template = PromptTemplate(\n",
+    "    input_variables=[\"chat_history\", \"input\"], template=INPUT_TEMPLATE\n",
+    ")\n",
+    "\n",
+    "conversation_chain_with_memory = ConversationChain(\n",
+    "    llm = llm,\n",
+    "    prompt = conversation_prompt_template,\n",
+    "    verbose = True,\n",
+    "    memory = memory,\n",
+    ")\n",
+    "\n",
+    "#Prediction\n",
+    "def predict(message, history):\n",
+    "    history_format = []\n",
+    "    for user, assistant in history:\n",
+    "        history_format.append({\"role\": \"user\", \"content\": user })\n",
+    "        history_format.append({\"role\": \"assistant\", \"content\":assistant})\n",
+    "    history_format.append({\"role\": \"user\", \"content\": message})\n",
+    "    response = conversation_chain_with_memory.run(input=message)\n",
+    "    return response\n",
+    "\n",
+    "#Launch Gradio chatbot interface\n",
+    "gr.ChatInterface(predict).launch()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "After successfully executing the code above, a chat interface should appear as the interactive output or you can open the localhost url in your selected browser window.  \n",
+    "\n",
+    "This concludes our tutorial and examples. Here are some additional reference:  \n",
+    "* [Fine-tune Llama](https://learn.microsoft.com/azure/ai-studio/how-to/fine-tune-model-llama)\n",
+    "* [Plan and manage costs (marketplace)](https://learn.microsoft.com/azure/ai-studio/how-to/costs-plan-manage#monitor-costs-for-models-offered-through-the-azure-marketplace)\n"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.10.10"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}

Tiedoston diff-näkymää rajattu, sillä se on liian suuri
+ 433 - 0
demo_apps/HelloLlamaCloud.ipynb


Tiedoston diff-näkymää rajattu, sillä se on liian suuri
+ 347 - 0
demo_apps/HelloLlamaLocal.ipynb


+ 306 - 0
demo_apps/LiveData.ipynb

@@ -0,0 +1,306 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "id": "30eb1704-8d76-4bc9-9308-93243aeb69cb",
+   "metadata": {},
+   "source": [
+    "## This demo app shows:\n",
+    "* How to use LlamaIndex, an open source library to help you build custom data augmented LLM applications\n",
+    "* How to ask Llama questions about recent live data via the You.com live search API and LlamaIndex\n",
+    "\n",
+    "The LangChain package is used to facilitate the call to Llama2 hosted on Replicate\n",
+    "\n",
+    "**Note** We will be using Replicate to run the examples here. You will need to first sign in with Replicate with your github account, then create a free API token [here](https://replicate.com/account/api-tokens) that you can use for a while. \n",
+    "After the free trial ends, you will need to enter billing info to continue to use Llama2 hosted on Replicate."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "68cf076e",
+   "metadata": {},
+   "source": [
+    "We start by installing the necessary packages:\n",
+    "- [langchain](https://python.langchain.com/docs/get_started/introduction) which provides RAG capabilities\n",
+    "- [llama-index](https://docs.llamaindex.ai/en/stable/) for data augmentation."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "1d0005d6-e928-4d1a-981b-534a40e19e56",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "!pip install llama-index langchain"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "21fe3849",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# use ServiceContext to configure the LLM used and the custom embeddings \n",
+    "from llama_index import ServiceContext\n",
+    "\n",
+    "# VectorStoreIndex is used to index custom data \n",
+    "from llama_index import VectorStoreIndex\n",
+    "\n",
+    "from langchain.llms import Replicate"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "73e8e661",
+   "metadata": {},
+   "source": [
+    "Next we set up the Replicate token."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "d9d76e33",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from getpass import getpass\n",
+    "import os\n",
+    "\n",
+    "REPLICATE_API_TOKEN = getpass()\n",
+    "os.environ[\"REPLICATE_API_TOKEN\"] = REPLICATE_API_TOKEN"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "f8ff812b",
+   "metadata": {},
+   "source": [
+    "In this example we will use the [YOU.com](https://you.com/) search engine to augment the LLM's responses.\n",
+    "To use the You.com Search API, you can email api@you.com to request an API key. "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "75275628-5235-4b55-8033-601c76107528",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "\n",
+    "YOUCOM_API_KEY = getpass()\n",
+    "os.environ[\"YOUCOM_API_KEY\"] = YOUCOM_API_KEY"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "cb210c7c",
+   "metadata": {},
+   "source": [
+    "We then call the Llama 2 model from replicate. \n",
+    "\n",
+    "We will use the llama 2 13b chat model. You can find more Llama 2 models by searching for them on the [Replicate model explore page](https://replicate.com/explore?query=llama).\n",
+    "You can add them here in the format: model_name/version"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "c12fc2cb",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# set llm to be using Llama2 hosted on Replicate\n",
+    "llama2_13b_chat = \"meta/llama-2-13b-chat:f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d\"\n",
+    "\n",
+    "llm = Replicate(\n",
+    "    model=llama2_13b_chat,\n",
+    "    model_kwargs={\"temperature\": 0.01, \"top_p\": 1, \"max_new_tokens\":500}\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "476d72da",
+   "metadata": {},
+   "source": [
+    "Using our api key we set up earlier, we make a request from YOU.com for live data on a particular topic."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "effc9656-b18d-4d24-a80b-6066564a838b",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "\n",
+    "import requests\n",
+    "\n",
+    "query = \"Meta Connect\" # you can try other live data query about sports score, stock market and weather info \n",
+    "headers = {\"X-API-Key\": os.environ[\"YOUCOM_API_KEY\"]}\n",
+    "data = requests.get(\n",
+    "    f\"https://api.ydc-index.io/search?query={query}\",\n",
+    "    headers=headers,\n",
+    ").json()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "8bed3baf-742e-473c-ada1-4459012a8a2c",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# check the query result in JSON\n",
+    "import json\n",
+    "\n",
+    "print(json.dumps(data, indent=2))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "b196e697",
+   "metadata": {},
+   "source": [
+    "We then use the [`JSONLoader`](https://llamahub.ai/l/file-json) to extract the text from the returned data. The `JSONLoader` gives us the ability to load the data into LamaIndex.\n",
+    "In the next cell we show how to load the JSON result with key info stored as \"snippets\".\n",
+    "\n",
+    "However, you can also add the snippets in the query result to documents like below:\n",
+    "```python \n",
+    "from llama_index import Document\n",
+    "snippets = [snippet for hit in data[\"hits\"] for snippet in hit[\"snippets\"]]\n",
+    "documents = [Document(text=s) for s in snippets]\n",
+    "```\n",
+    "This can be handy if you just need to add a list of text strings to doc"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "7c40e73f-ca13-4f4a-a753-e613df3d389e",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# one way to load the JSON result with key info stored as \"snippets\"\n",
+    "from llama_index import download_loader\n",
+    "\n",
+    "JsonDataReader = download_loader(\"JsonDataReader\")\n",
+    "loader = JsonDataReader()\n",
+    "documents = loader.load_data([hit[\"snippets\"] for hit in data[\"hits\"]])\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "8e5e3b4e",
+   "metadata": {},
+   "source": [
+    "With the data set up, we create a vector store for the data and a query engine for it.\n",
+    "\n",
+    "For our embeddings we will use `HuggingFaceEmbeddings` whose default embedding model is sentence-transformers/all-mpnet-base-v2. This model provides a good balance between speed and performance.\n",
+    "To change the default model, call `HuggingFaceEmbeddings(model_name=<another_embedding_model>)`. \n",
+    "\n",
+    "For more info see https://huggingface.co/blog/mteb. "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "a5de3080-2c4b-479c-baba-793b3bee36ed",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# use HuggingFace embeddings \n",
+    "from langchain.embeddings.huggingface import HuggingFaceEmbeddings\n",
+    "from llama_index import LangchainEmbedding\n",
+    "\n",
+    "\n",
+    "embeddings = LangchainEmbedding(HuggingFaceEmbeddings())\n",
+    "print(embeddings)\n",
+    "\n",
+    "# create a ServiceContext instance to use Llama2 and custom embeddings\n",
+    "service_context = ServiceContext.from_defaults(llm=llm, chunk_size=800, chunk_overlap=20, embed_model=embeddings)\n",
+    "\n",
+    "# create vector store index from the documents created above\n",
+    "index = VectorStoreIndex.from_documents(documents, service_context=service_context)\n",
+    "\n",
+    "# create query engine from the index\n",
+    "query_engine = index.as_query_engine(streaming=True)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "2c4ea012",
+   "metadata": {},
+   "source": [
+    "We are now ready to ask Llama 2 a question about the live data using our query engine."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "de91a191-d0f2-498e-88dc-b2b43423e0e5",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# ask Llama2 a summary question about the search result\n",
+    "response = query_engine.query(\"give me a summary\")\n",
+    "response.print_response_stream()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "72814b20-06aa-4da8-b4dd-f0b0d74a2ea0",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# more questions\n",
+    "query_engine.query(\"what products were announced\").print_response_stream()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "a65bc037-a689-476d-b529-0059a27bc949",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "query_engine.query(\"tell me more about Meta AI assistant\").print_response_stream()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "16a56542",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "query_engine.query(\"what are Generative AI stickers\").print_response_stream()"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3 (ipykernel)",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.8.18"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}

+ 120 - 0
demo_apps/Llama2_Gradio.ipynb

@@ -0,0 +1,120 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "id": "47a9adb3",
+   "metadata": {},
+   "source": [
+    "## This demo app shows how to query Llama 2 using the Gradio UI.\n",
+    "\n",
+    "Since we are using Replicate in this example, you will need to replace `<your replicate api token>` with your API token.\n",
+    "\n",
+    "To get the Replicate token: \n",
+    "\n",
+    "- You will need to first sign in with Replicate with your github account\n",
+    "- Then create a free API token [here](https://replicate.com/account/api-tokens) that you can use for a while \n",
+    "\n",
+    "**Note** After the free trial ends, you will need to enter billing info to continue to use Llama2 hosted on Replicate.\n",
+    "\n",
+    "To run this example:\n",
+    "- Set up your Replicate API token and enter it in place of `<your replicate api token>`\n",
+    "- Run the notebook\n",
+    "- Enter your question and click Submit\n",
+    "\n",
+    "In the notebook or a browser with URL http://127.0.0.1:7860 you should see a UI with your answer."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "id": "928041cc",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Init param `input` is deprecated, please use `model_kwargs` instead.\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Running on local URL:  http://127.0.0.1:7860\n",
+      "\n",
+      "To create a public link, set `share=True` in `launch()`.\n"
+     ]
+    },
+    {
+     "data": {
+      "text/html": [
+       "<div><iframe src=\"http://127.0.0.1:7860/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
+      ],
+      "text/plain": [
+       "<IPython.core.display.HTML object>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/plain": []
+     },
+     "execution_count": 1,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "from langchain.schema import AIMessage, HumanMessage\n",
+    "import gradio as gr\n",
+    "from langchain.llms import Replicate\n",
+    "import os\n",
+    "\n",
+    "os.environ[\"REPLICATE_API_TOKEN\"] = \"<your replicate api token>\"\n",
+    "\n",
+    "llama2_13b_chat = \"meta/llama-2-13b-chat:f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d\"\n",
+    "\n",
+    "llm = Replicate(\n",
+    "    model=llama2_13b_chat,\n",
+    "    model_kwargs={\"temperature\": 0.01, \"top_p\": 1, \"max_new_tokens\":500}\n",
+    ")\n",
+    "\n",
+    "\n",
+    "def predict(message, history):\n",
+    "    history_langchain_format = []\n",
+    "    for human, ai in history:\n",
+    "        history_langchain_format.append(HumanMessage(content=human))\n",
+    "        history_langchain_format.append(AIMessage(content=ai))\n",
+    "    history_langchain_format.append(HumanMessage(content=message))\n",
+    "    gpt_response = llm(message) #history_langchain_format)\n",
+    "    return gpt_response#.content\n",
+    "\n",
+    "gr.ChatInterface(predict).launch()"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3 (ipykernel)",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.8.18"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}

Tiedoston diff-näkymää rajattu, sillä se on liian suuri
+ 723 - 0
demo_apps/RAG_Chatbot_example/RAG_Chatbot_Example.ipynb


BIN
demo_apps/RAG_Chatbot_example/data/Llama Getting Started Guide.pdf


+ 6 - 0
demo_apps/RAG_Chatbot_example/requirements.txt

@@ -0,0 +1,6 @@
+gradio
+pypdf
+langchain
+sentence-transformers
+faiss-cpu
+text-generation

BIN
demo_apps/RAG_Chatbot_example/vectorstore/db_faiss/index.faiss


BIN
demo_apps/RAG_Chatbot_example/vectorstore/db_faiss/index.pkl


Tiedoston diff-näkymää rajattu, sillä se on liian suuri
+ 117 - 0
demo_apps/README.md


Tiedoston diff-näkymää rajattu, sillä se on liian suuri
+ 559 - 0
demo_apps/StructuredLlama.ipynb


Tiedoston diff-näkymää rajattu, sillä se on liian suuri
+ 698 - 0
demo_apps/VideoSummary.ipynb


+ 38 - 0
demo_apps/csv2db.py

@@ -0,0 +1,38 @@
+import sqlite3
+import csv
+
+# Define the input CSV file and the SQLite database file
+input_csv = 'nba_roster.csv'
+database_file = 'nba_roster.db'
+
+# Connect to the SQLite database
+conn = sqlite3.connect(database_file)
+cursor = conn.cursor()
+
+# Create a table to store the data
+cursor.execute('''CREATE TABLE IF NOT EXISTS nba_roster (
+                    Team TEXT,
+                    NAME TEXT,
+                    Jersey TEXT,
+                    POS TEXT,
+                    AGE INT,
+                    HT TEXT,
+                    WT TEXT,
+                    COLLEGE TEXT,
+                    SALARY TEXT
+                )''')
+
+# Read data from the CSV file and insert it into the SQLite table
+with open(input_csv, 'r', newline='') as csvfile:
+    csv_reader = csv.reader(csvfile)
+    next(csv_reader)  # Skip the header row
+    
+    for row in csv_reader:
+        cursor.execute('INSERT INTO nba_roster VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', row)
+
+# Commit the changes and close the database connection
+conn.commit()
+conn.close()
+
+print(f'Data from {input_csv} has been successfully imported into {database_file}')
+

Tiedoston diff-näkymää rajattu, sillä se on liian suuri
+ 186 - 0
demo_apps/llama-on-prem.md


BIN
demo_apps/llama2-gradio.png


BIN
demo_apps/llama2-streamlit.png


BIN
demo_apps/llama2-streamlit2.png


BIN
demo_apps/llama2.pdf


+ 61 - 0
demo_apps/llama_chatbot.py

@@ -0,0 +1,61 @@
+import langchain
+from langchain.llms import Replicate
+
+from flask import Flask
+from flask import request
+import os
+import requests
+import json
+
+class WhatsAppClient:
+
+    API_URL = "https://graph.facebook.com/v17.0/"
+    WHATSAPP_API_TOKEN = "<Temporary access token from your WhatsApp API Setup>"
+    WHATSAPP_CLOUD_NUMBER_ID = "<Phone number ID from your WhatsApp API Setup>"
+
+    def __init__(self):
+        self.headers = {
+            "Authorization": f"Bearer {self.WHATSAPP_API_TOKEN}",
+            "Content-Type": "application/json",
+        }
+        self.API_URL = self.API_URL + self.WHATSAPP_CLOUD_NUMBER_ID
+
+    def send_text_message(self,message, phone_number):
+        payload = {
+            "messaging_product": 'whatsapp',
+            "to": phone_number,
+            "type": "text",
+            "text": {
+                "preview_url": False,
+                "body": message
+            }
+        }
+        response = requests.post(f"{self.API_URL}/messages", json=payload,headers=self.headers)
+        print(response.status_code)
+        assert response.status_code == 200, "Error sending message"
+        return response.status_code
+
+os.environ["REPLICATE_API_TOKEN"] = "<your replicate api token>"    
+llama2_13b_chat = "meta/llama-2-13b-chat:f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d"
+
+llm = Replicate(
+    model=llama2_13b_chat,
+    model_kwargs={"temperature": 0.01, "top_p": 1, "max_new_tokens":500}
+)
+client = WhatsAppClient()
+app = Flask(__name__)
+
+@app.route("/")
+def hello_llama():
+    return "<p>Hello Llama 2</p>"
+
+@app.route('/msgrcvd', methods=['POST', 'GET'])
+def msgrcvd():    
+    message = request.args.get('message')
+    #client.send_template_message("hello_world", "en_US", "14086745477")
+    answer = llm(message)
+    print(message)
+    print(answer)
+    client.send_text_message(llm(message), "14086745477")
+    return message + "<p/>" + answer
+

+ 45 - 0
demo_apps/llama_messenger.py

@@ -0,0 +1,45 @@
+import langchain
+from langchain.llms import Replicate
+
+from flask import Flask
+from flask import request
+import os
+import requests
+import json
+
+os.environ["REPLICATE_API_TOKEN"] = "<your replicate api token>"
+llama2_13b_chat = "meta/llama-2-13b-chat:f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d"
+
+llm = Replicate(
+    model=llama2_13b_chat,
+    model_kwargs={"temperature": 0.01, "top_p": 1, "max_new_tokens":500}
+)
+
+app = Flask(__name__)
+
+@app.route('/msgrcvd_pager', methods=['POST', 'GET'])
+def msgrcvd_pager():    
+    message = request.args.get('message')
+    sender = request.args.get('sender')
+    recipient = request.args.get('recipient')
+
+    answer = llm(message)
+    print(message)
+    print(answer)
+
+    url = f"https://graph.facebook.com/v18.0/{recipient}/messages"
+    params = {
+        'recipient': '{"id": ' + sender + '}',
+        'message': json.dumps({'text': answer}),
+        'messaging_type': 'RESPONSE',
+        'access_token': "<your page access token>"
+    }
+    headers = {
+        'Content-Type': 'application/json'
+    }
+    response = requests.post(url, params=params, headers=headers)
+    print(response.status_code)
+    print(response.text)
+
+    return message + "<p/>" + answer
+

BIN
demo_apps/messenger_api_settings.png


Tiedoston diff-näkymää rajattu, sillä se on liian suuri
+ 194 - 0
demo_apps/messenger_llama2.md


BIN
demo_apps/messenger_llama_arch.jpg


Tiedoston diff-näkymää rajattu, sillä se on liian suuri
+ 1294 - 0
demo_apps/nba.txt


+ 22 - 0
demo_apps/streamlit_llama2.py

@@ -0,0 +1,22 @@
+import streamlit as st
+from langchain.llms import Replicate
+import os
+
+st.title("Llama2-powered Streamlit App")
+
+with st.sidebar:
+    os.environ["REPLICATE_API_TOKEN"] = "<your replicate api token>"
+
+def generate_response(input_text):
+    llama2_13b_chat = "meta/llama-2-13b-chat:f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d"
+
+    llm = Replicate(
+        model=llama2_13b_chat,
+        model_kwargs={"temperature": 0.01, "top_p": 1, "max_new_tokens":500}
+    )
+    st.info(llm(input_text))
+
+with st.form("my_form"):
+    text = st.text_area("Enter text:", "What is Generative AI?")
+    submitted = st.form_submit_button("Submit")
+    generate_response(text)

+ 53 - 0
demo_apps/txt2csv.py

@@ -0,0 +1,53 @@
+import csv
+
+# Define the input and output file names
+input_file = 'nba.txt'
+output_file = 'nba_roster.csv'
+
+# Initialize lists to store data
+roster_data = []
+current_team = None
+
+# Open the input file
+with open(input_file, 'r') as file:
+    for line in file:
+        # Remove leading and trailing whitespaces from the line
+        line = line.strip()
+        
+        # Check if the line starts with 'https', skip it
+        if line.startswith('https'):
+            continue
+        
+        # Check if the line contains the team name
+        if 'Roster' in line:
+            current_team = line.split(' Roster ')[0]
+        elif line and "NAME" not in line:  # Skip empty lines and header lines
+            # Split the line using tabs as the delimiter
+            player_info = line.split('\t')
+            
+            # Remove any numbers from the player's name and set Jersey accordingly
+            name = ''.join([c for c in player_info[0] if not c.isdigit()])
+            jersey = ''.join([c for c in player_info[0] if c.isdigit()])
+            
+            # If no number found, set Jersey to "NA"
+            if not jersey:
+                jersey = "NA"
+            
+            # Append the team name, name, and jersey to the player's data
+            player_info = [current_team, name, jersey] + player_info[1:]
+            
+            # Append the player's data to the roster_data list
+            roster_data.append(player_info)
+
+# Write the data to a CSV file
+with open(output_file, 'w', newline='') as csvfile:
+    writer = csv.writer(csvfile)
+    
+    # Write the header row
+    writer.writerow(['Team', 'NAME', 'Jersey', 'POS', 'AGE', 'HT', 'WT', 'COLLEGE', 'SALARY'])
+    
+    # Write the player data
+    writer.writerows(roster_data)
+
+print(f'Conversion completed. Data saved to {output_file}')
+

BIN
demo_apps/whatsapp_dashboard.jpg


Tiedoston diff-näkymää rajattu, sillä se on liian suuri
+ 162 - 0
demo_apps/whatsapp_llama2.md


BIN
demo_apps/whatsapp_llama_arch.jpg


+ 15 - 3
docs/Dataset.md

@@ -7,6 +7,18 @@ The provided fine tuning script allows you to select between three datasets by p
 * [samsum_dataset](https://huggingface.co/datasets/samsum) contains about 16k messenger-like conversations with summaries.
 * [OpenAssistant/oasst1](https://huggingface.co/datasets/OpenAssistant/oasst1/) contains about 88k messages from assistant-style conversations.
 
+## Batching Strategies
+Llama-recipes support two strategies to batch requests together.
+The default setting is `packing` which concatenates the tokenized samples into long sequences filling up the context length of the model.
+This is the most compute efficient variant as it avoids any padding and all sequences have the same length.
+Samples at the boundary of the context length are truncated and the remainder of the cut sequence it used as the start of the next long sequence.
+
+If the amount of training data is small this procedure might introduce a lot of noise into the training data which can hurt the prediction performance of the fine-tune model.
+Therefore, we also support a `padding` strategy which does not introduce the addition noise due to truncated sequences.
+The strategy tries to minimize the efficiency loss by batching samples of similar length together so only minimal padding is necessary.
+
+The batching strategy can be selected though the command line parameter `--batching_strategy [packing]/[padding]`.
+
 ## Using custom datasets
 
 The list of available datasets in llama-recipes is supposed to give users a quick start on training their Llama model.
@@ -23,9 +35,9 @@ def get_custom_dataset(dataset_config, tokenizer, split: str):
 For an example `get_custom_dataset` you can look at the provided datasets in llama_recipes.datasets or [examples/custom_dataset.py](../examples/custom_dataset.py).
 The `dataset_config` in the above signature will be an instance of llama_recipes.configs.dataset.custom_dataset with the modifications made through the command line.
 The split signals wether to return the training or validation dataset.
-The default function name is `get_custom_dataset` but this can be changes as described below.
+The default function name is `get_custom_dataset` but this can be changed as described below.
 
-In order to start a training with the custom dataset we need to set the `--dataset` as well as the `--custom_dataset.file` parameter. 
+In order to start a training with the custom dataset we need to set the `--dataset` as well as the `--custom_dataset.file` parameter.
 ```
 python -m llama_recipes.finetuning --dataset "custom_dataset" --custom_dataset.file "examples/custom_dataset.py" [TRAINING PARAMETERS]
 ```
@@ -35,7 +47,7 @@ python -m llama_recipes.finetuning --dataset "custom_dataset" --custom_dataset.f
 ```
 This will call the function `get_foo` instead of `get_custom_dataset` when retrieving the dataset.
 
-### Adding new dataset 
+### Adding new dataset
 Each dataset has a corresponding configuration (dataclass) in [configs/datasets.py](../src/llama_recipes/configs/datasets.py) which contains the dataset name, training/validation split names, as well as optional parameters like datafiles etc.
 
 Additionally, there is a preprocessing function for each dataset in the [datasets](../src/llama_recipes/datasets) folder.

Tiedoston diff-näkymää rajattu, sillä se on liian suuri
+ 21 - 7
docs/FAQ.md


+ 16 - 2
docs/inference.md

@@ -41,14 +41,14 @@ model.resize_token_embeddings(model.config.vocab_size + 1)
 ```
 Padding would be required for batch inference. In this this [example](../examples/inference.py), batch size = 1 so essentially padding is not required. However,We added the code pointer as an example in case of batch inference.
 
-**Chat completion**
+### Chat completion
 The inference folder also includes a chat completion example, that adds built-in safety features in fine-tuned models to the prompt tokens. To run the example:
 
 ```bash
 python examples/chat_completion/chat_completion.py --model_name "PATH/TO/MODEL/7B/" --prompt_file examples/chat_completion/chats.json  --quantization --use_auditnlg
 
 ```
-**Code Llama**
+### Code Llama
 
 Code llama was recently released with three flavors, base-model that support multiple programming languages, Python fine-tuned model and an instruction fine-tuned and aligned variation of Code Llama, please read more [here](https://ai.meta.com/blog/code-llama-large-language-model-coding/). Also note that the Python fine-tuned model and 34B models are not trained on infilling objective, hence can not be used for infilling use-case.
 
@@ -80,6 +80,18 @@ python examples/code_llama/code_infilling_example.py --model_name MODEL_NAME --p
 
 ```
 
+### Llama Guard
+
+Llama Guard is a new experimental model that provides input and output guardrails for LLM deployments. For more details, please visit the main [repository](https://github.com/facebookresearch/PurpleLlama/tree/main/Llama-Guard).
+
+Find the inference script for Llama Guard [here](../examples/llama_guard/).
+
+**Note** Please find the right model on HF side [here](https://huggingface.co/meta-llama/LlamaGuard-7b). 
+
+Edit [inference.py](../examples/llama_guard/inference.py) to add test prompts for Llama Guard and execute it with this command:
+
+`python examples/llama_guard/inference.py`
+
 ## Flash Attention and Xformer Memory Efficient Kernels
 
 Setting `use_fast_kernels` will enable using of Flash Attention or Xformer memory-efficient kernels based on the hardware being used. This would speed up inference when used for batched inputs. This has been enabled in `optimum` library from HuggingFace as a one-liner API, please read more [here](https://pytorch.org/blog/out-of-the-box-acceleration/).
@@ -144,3 +156,5 @@ python examples/vllm/inference.py --model_name <PATH/TO/MODEL/7B>
 ```
 
 [**TGI**](https://github.com/huggingface/text-generation-inference): Text Generation Inference (TGI) is another inference option available to you. For more information on how to set up and use TGI see [here](../examples/hf_text_generation_inference/README.md).
+
+[Here](../demo_apps/llama-on-prem.md) is a complete tutorial on how to use vLLM and TGI to deploy Llama 2 on-prem and interact with the Llama API services.

+ 784 - 0
examples/Prompt_Engineering_with_Llama_2.ipynb

@@ -0,0 +1,784 @@
+{
+ "cells": [
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Prompt Engineering with Llama 2\n",
+    "\n",
+    "Prompt engineering is using natural language to produce a desired response from a large language model (LLM).\n",
+    "\n",
+    "This interactive guide covers prompt engineering & best practices with Llama 2."
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Introduction"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Why now?\n",
+    "\n",
+    "[Vaswani et al. (2017)](https://arxiv.org/abs/1706.03762) introduced the world to transformer neural networks (originally for machine translation). Transformers ushered an era of generative AI with diffusion models for image creation and large language models (`LLMs`) as **programmable deep learning networks**.\n",
+    "\n",
+    "Programming foundational LLMs is done with natural language – it doesn't require training/tuning like ML models of the past. This has opened the door to a massive amount of innovation and a paradigm shift in how technology can be deployed. The science/art of using natural language to program language models to accomplish a task is referred to as **Prompt Engineering**."
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Llama Models\n",
+    "\n",
+    "In 2023, Meta introduced the [Llama language models](https://ai.meta.com/llama/) (Llama Chat, Code Llama, Llama Guard). These are general purpose, state-of-the-art LLMs.\n",
+    "\n",
+    "Llama 2 models come in 7 billion, 13 billion, and 70 billion parameter sizes. Smaller models are cheaper to deploy and run (see: deployment and performance); larger models are more capable.\n",
+    "\n",
+    "#### Llama 2\n",
+    "1. `llama-2-7b` - base pretrained 7 billion parameter model\n",
+    "1. `llama-2-13b` - base pretrained 13 billion parameter model\n",
+    "1. `llama-2-70b` - base pretrained 70 billion parameter model\n",
+    "1. `llama-2-7b-chat` - chat fine-tuned 7 billion parameter model\n",
+    "1. `llama-2-13b-chat` - chat fine-tuned 13 billion parameter model\n",
+    "1. `llama-2-70b-chat` - chat fine-tuned 70 billion parameter model (flagship)\n"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Code Llama is a code-focused LLM built on top of Llama 2 also available in various sizes and finetunes:"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "#### Code Llama\n",
+    "1. `codellama-7b` - code fine-tuned 7 billion parameter model\n",
+    "1. `codellama-13b` - code fine-tuned 13 billion parameter model\n",
+    "1. `codellama-34b` - code fine-tuned 34 billion parameter model\n",
+    "1. `codellama-7b-instruct` - code & instruct fine-tuned 7 billion parameter model\n",
+    "2. `codellama-13b-instruct` - code & instruct fine-tuned 13 billion parameter model\n",
+    "3. `codellama-34b-instruct` - code & instruct fine-tuned 34 billion parameter model\n",
+    "1. `codellama-7b-python` - Python fine-tuned 7 billion parameter model\n",
+    "2. `codellama-13b-python` - Python fine-tuned 13 billion parameter model\n",
+    "3. `codellama-34b-python` - Python fine-tuned 34 billion parameter model"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Getting an LLM\n",
+    "\n",
+    "Large language models are deployed and accessed in a variety of ways, including:\n",
+    "\n",
+    "1. **Self-hosting**: Using local hardware to run inference. Ex. running Llama 2 on your Macbook Pro using [llama.cpp](https://github.com/ggerganov/llama.cpp).\n",
+    "    * Best for privacy/security or if you already have a GPU.\n",
+    "1. **Cloud hosting**: Using a cloud provider to deploy an instance that hosts a specific model. Ex. running Llama 2 on cloud providers like AWS, Azure, GCP, and others.\n",
+    "    * Best for customizing models and their runtime (ex. fine-tuning a model for your use case).\n",
+    "1. **Hosted API**: Call LLMs directly via an API. There are many companies that provide Llama 2 inference APIs including AWS Bedrock, Replicate, Anyscale, Together and others.\n",
+    "    * Easiest option overall."
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Hosted APIs\n",
+    "\n",
+    "Hosted APIs are the easiest way to get started. We'll use them here. There are usually two main endpoints:\n",
+    "\n",
+    "1. **`completion`**: generate a response to a given prompt (a string).\n",
+    "1. **`chat_completion`**: generate the next message in a list of messages, enabling more explicit instruction and context for use cases like chatbots."
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Tokens\n",
+    "\n",
+    "LLMs process inputs and outputs in chunks called *tokens*. Think of these, roughly, as words – each model will have its own tokenization scheme. For example, this sentence...\n",
+    "\n",
+    "> Our destiny is written in the stars.\n",
+    "\n",
+    "...is tokenized into `[\"our\", \"dest\", \"iny\", \"is\", \"written\", \"in\", \"the\", \"stars\"]` for Llama 2.\n",
+    "\n",
+    "Tokens matter most when you consider API pricing and internal behavior (ex. hyperparameters).\n",
+    "\n",
+    "Each model has a maximum context length that your prompt cannot exceed. That's 4096 tokens for Llama 2 and 100K for Code Llama. \n"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Notebook Setup\n",
+    "\n",
+    "The following APIs will be used to call LLMs throughout the guide. As an example, we'll call Llama 2 chat using [Replicate](https://replicate.com/meta/llama-2-70b-chat) and use LangChain to easily set up a chat completion API.\n",
+    "\n",
+    "To install prerequisites run:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "pip install langchain replicate"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from typing import Dict, List\n",
+    "from langchain.llms import Replicate\n",
+    "from langchain.memory import ChatMessageHistory\n",
+    "from langchain.schema.messages import get_buffer_string\n",
+    "import os\n",
+    "\n",
+    "# Get a free API key from https://replicate.com/account/api-tokens\n",
+    "os.environ[\"REPLICATE_API_TOKEN\"] = \"YOUR_KEY_HERE\"\n",
+    "\n",
+    "LLAMA2_70B_CHAT = \"meta/llama-2-70b-chat:2d19859030ff705a87c746f7e96eea03aefb71f166725aee39692f1476566d48\"\n",
+    "LLAMA2_13B_CHAT = \"meta/llama-2-13b-chat:f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d\"\n",
+    "\n",
+    "# We'll default to the smaller 13B model for speed; change to LLAMA2_70B_CHAT for more advanced (but slower) generations\n",
+    "DEFAULT_MODEL = LLAMA2_13B_CHAT\n",
+    "\n",
+    "def completion(\n",
+    "    prompt: str,\n",
+    "    model: str = DEFAULT_MODEL,\n",
+    "    temperature: float = 0.6,\n",
+    "    top_p: float = 0.9,\n",
+    ") -> str:\n",
+    "    llm = Replicate(\n",
+    "        model=model,\n",
+    "        model_kwargs={\"temperature\": temperature,\"top_p\": top_p, \"max_new_tokens\": 1000}\n",
+    "    )\n",
+    "    return llm(prompt)\n",
+    "\n",
+    "def chat_completion(\n",
+    "    messages: List[Dict],\n",
+    "    model = DEFAULT_MODEL,\n",
+    "    temperature: float = 0.6,\n",
+    "    top_p: float = 0.9,\n",
+    ") -> str:\n",
+    "    history = ChatMessageHistory()\n",
+    "    for message in messages:\n",
+    "        if message[\"role\"] == \"user\":\n",
+    "            history.add_user_message(message[\"content\"])\n",
+    "        elif message[\"role\"] == \"assistant\":\n",
+    "            history.add_ai_message(message[\"content\"])\n",
+    "        else:\n",
+    "            raise Exception(\"Unknown role\")\n",
+    "    return completion(\n",
+    "        get_buffer_string(\n",
+    "            history.messages,\n",
+    "            human_prefix=\"USER\",\n",
+    "            ai_prefix=\"ASSISTANT\",\n",
+    "        ),\n",
+    "        model,\n",
+    "        temperature,\n",
+    "        top_p,\n",
+    "    )\n",
+    "\n",
+    "def assistant(content: str):\n",
+    "    return { \"role\": \"assistant\", \"content\": content }\n",
+    "\n",
+    "def user(content: str):\n",
+    "    return { \"role\": \"user\", \"content\": content }\n",
+    "\n",
+    "def complete_and_print(prompt: str, model: str = DEFAULT_MODEL):\n",
+    "    print(f'==============\\n{prompt}\\n==============')\n",
+    "    response = completion(prompt, model)\n",
+    "    print(response, end='\\n\\n')\n"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Completion APIs\n",
+    "\n",
+    "Llama 2 models tend to be wordy and explain their rationale. Later we'll explore how to manage the response length."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "complete_and_print(\"The typical color of the sky is: \")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "complete_and_print(\"which model version are you?\")"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Chat Completion APIs\n",
+    "Chat completion models provide additional structure to interacting with an LLM. An array of structured message objects is sent to the LLM instead of a single piece of text. This message list provides the LLM with some \"context\" or \"history\" from which to continue.\n",
+    "\n",
+    "Typically, each message contains `role` and `content`:\n",
+    "* Messages with the `system` role are used to provide core instruction to the LLM by developers.\n",
+    "* Messages with the `user` role are typically human-provided messages.\n",
+    "* Messages with the `assistant` role are typically generated by the LLM."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "response = chat_completion(messages=[\n",
+    "    user(\"My favorite color is blue.\"),\n",
+    "    assistant(\"That's great to hear!\"),\n",
+    "    user(\"What is my favorite color?\"),\n",
+    "])\n",
+    "print(response)\n",
+    "# \"Sure, I can help you with that! Your favorite color is blue.\""
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### LLM Hyperparameters\n",
+    "\n",
+    "#### `temperature` & `top_p`\n",
+    "\n",
+    "These APIs also take parameters which influence the creativity and determinism of your output.\n",
+    "\n",
+    "At each step, LLMs generate a list of most likely tokens and their respective probabilities. The least likely tokens are \"cut\" from the list (based on `top_p`), and then a token is randomly selected from the remaining candidates (`temperature`).\n",
+    "\n",
+    "In other words: `top_p` controls the breadth of vocabulary in a generation and `temperature` controls the randomness within that vocabulary. A temperature of ~0 produces *almost* deterministic results.\n",
+    "\n",
+    "[Read more about temperature setting here](https://community.openai.com/t/cheat-sheet-mastering-temperature-and-top-p-in-chatgpt-api-a-few-tips-and-tricks-on-controlling-the-creativity-deterministic-output-of-prompt-responses/172683).\n",
+    "\n",
+    "Let's try it out:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def print_tuned_completion(temperature: float, top_p: float):\n",
+    "    response = completion(\"Write a haiku about llamas\", temperature=temperature, top_p=top_p)\n",
+    "    print(f'[temperature: {temperature} | top_p: {top_p}]\\n{response.strip()}\\n')\n",
+    "\n",
+    "print_tuned_completion(0.01, 0.01)\n",
+    "print_tuned_completion(0.01, 0.01)\n",
+    "# These two generations are highly likely to be the same\n",
+    "\n",
+    "print_tuned_completion(1.0, 1.0)\n",
+    "print_tuned_completion(1.0, 1.0)\n",
+    "# These two generations are highly likely to be different"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Prompting Techniques"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Explicit Instructions\n",
+    "\n",
+    "Detailed, explicit instructions produce better results than open-ended prompts:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "complete_and_print(prompt=\"Describe quantum physics in one short sentence of no more than 12 words\")\n",
+    "# Returns a succinct explanation of quantum physics that mentions particles and states existing simultaneously."
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "You can think about giving explicit instructions as using rules and restrictions to how Llama 2 responds to your prompt.\n",
+    "\n",
+    "- Stylization\n",
+    "    - `Explain this to me like a topic on a children's educational network show teaching elementary students.`\n",
+    "    - `I'm a software engineer using large language models for summarization. Summarize the following text in under 250 words:`\n",
+    "    - `Give your answer like an old timey private investigator hunting down a case step by step.`\n",
+    "- Formatting\n",
+    "    - `Use bullet points.`\n",
+    "    - `Return as a JSON object.`\n",
+    "    - `Use less technical terms and help me apply it in my work in communications.`\n",
+    "- Restrictions\n",
+    "    - `Only use academic papers.`\n",
+    "    - `Never give sources older than 2020.`\n",
+    "    - `If you don't know the answer, say that you don't know.`\n",
+    "\n",
+    "Here's an example of giving explicit instructions to give more specific results by limiting the responses to recently created sources."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "complete_and_print(\"Explain the latest advances in large language models to me.\")\n",
+    "# More likely to cite sources from 2017\n",
+    "\n",
+    "complete_and_print(\"Explain the latest advances in large language models to me. Always cite your sources. Never cite sources older than 2020.\")\n",
+    "# Gives more specific advances and only cites sources from 2020"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Example Prompting using Zero- and Few-Shot Learning\n",
+    "\n",
+    "A shot is an example or demonstration of what type of prompt and response you expect from a large language model. This term originates from training computer vision models on photographs, where one shot was one example or instance that the model used to classify an image ([Fei-Fei et al. (2006)](http://vision.stanford.edu/documents/Fei-FeiFergusPerona2006.pdf)).\n",
+    "\n",
+    "#### Zero-Shot Prompting\n",
+    "\n",
+    "Large language models like Llama 2 are unique because they are capable of following instructions and producing responses without having previously seen an example of a task. Prompting without examples is called \"zero-shot prompting\".\n",
+    "\n",
+    "Let's try using Llama 2 as a sentiment detector. You may notice that output format varies - we can improve this with better prompting."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "complete_and_print(\"Text: This was the best movie I've ever seen! \\n The sentiment of the text is: \")\n",
+    "# Returns positive sentiment\n",
+    "\n",
+    "complete_and_print(\"Text: The director was trying too hard. \\n The sentiment of the text is: \")\n",
+    "# Returns negative sentiment"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "\n",
+    "#### Few-Shot Prompting\n",
+    "\n",
+    "Adding specific examples of your desired output generally results in more accurate, consistent output. This technique is called \"few-shot prompting\".\n",
+    "\n",
+    "In this example, the generated response follows our desired format that offers a more nuanced sentiment classifer that gives a positive, neutral, and negative response confidence percentage.\n",
+    "\n",
+    "See also: [Zhao et al. (2021)](https://arxiv.org/abs/2102.09690), [Liu et al. (2021)](https://arxiv.org/abs/2101.06804), [Su et al. (2022)](https://arxiv.org/abs/2209.01975), [Rubin et al. (2022)](https://arxiv.org/abs/2112.08633).\n",
+    "\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def sentiment(text):\n",
+    "    response = chat_completion(messages=[\n",
+    "        user(\"You are a sentiment classifier. For each message, give the percentage of positive/netural/negative.\"),\n",
+    "        user(\"I liked it\"),\n",
+    "        assistant(\"70% positive 30% neutral 0% negative\"),\n",
+    "        user(\"It could be better\"),\n",
+    "        assistant(\"0% positive 50% neutral 50% negative\"),\n",
+    "        user(\"It's fine\"),\n",
+    "        assistant(\"25% positive 50% neutral 25% negative\"),\n",
+    "        user(text),\n",
+    "    ])\n",
+    "    return response\n",
+    "\n",
+    "def print_sentiment(text):\n",
+    "    print(f'INPUT: {text}')\n",
+    "    print(sentiment(text))\n",
+    "\n",
+    "print_sentiment(\"I thought it was okay\")\n",
+    "# More likely to return a balanced mix of positive, neutral, and negative\n",
+    "print_sentiment(\"I loved it!\")\n",
+    "# More likely to return 100% positive\n",
+    "print_sentiment(\"Terrible service 0/10\")\n",
+    "# More likely to return 100% negative"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Role Prompting\n",
+    "\n",
+    "Llama 2 will often give more consistent responses when given a role ([Kong et al. (2023)](https://browse.arxiv.org/pdf/2308.07702.pdf)). Roles give context to the LLM on what type of answers are desired.\n",
+    "\n",
+    "Let's use Llama 2 to create a more focused, technical response for a question around the pros and cons of using PyTorch."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "complete_and_print(\"Explain the pros and cons of using PyTorch.\")\n",
+    "# More likely to explain the pros and cons of PyTorch covers general areas like documentation, the PyTorch community, and mentions a steep learning curve\n",
+    "\n",
+    "complete_and_print(\"Your role is a machine learning expert who gives highly technical advice to senior engineers who work with complicated datasets. Explain the pros and cons of using PyTorch.\")\n",
+    "# Often results in more technical benefits and drawbacks that provide more technical details on how model layers"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Chain-of-Thought\n",
+    "\n",
+    "Simply adding a phrase encouraging step-by-step thinking \"significantly improves the ability of large language models to perform complex reasoning\" ([Wei et al. (2022)](https://arxiv.org/abs/2201.11903)). This technique is called \"CoT\" or \"Chain-of-Thought\" prompting:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "complete_and_print(\"Who lived longer Elvis Presley or Mozart?\")\n",
+    "# Often gives incorrect answer of \"Mozart\"\n",
+    "\n",
+    "complete_and_print(\"Who lived longer Elvis Presley or Mozart? Let's think through this carefully, step by step.\")\n",
+    "# Gives the correct answer \"Elvis\""
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Self-Consistency\n",
+    "\n",
+    "LLMs are probablistic, so even with Chain-of-Thought, a single generation might produce incorrect results. Self-Consistency ([Wang et al. (2022)](https://arxiv.org/abs/2203.11171)) introduces enhanced accuracy by selecting the most frequent answer from multiple generations (at the cost of higher compute):"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import re\n",
+    "from statistics import mode\n",
+    "\n",
+    "def gen_answer():\n",
+    "    response = completion(\n",
+    "        \"John found that the average of 15 numbers is 40.\"\n",
+    "        \"If 10 is added to each number then the mean of the numbers is?\"\n",
+    "        \"Report the answer surrounded by three backticks, for example: ```123```\",\n",
+    "        model = LLAMA2_70B_CHAT\n",
+    "    )\n",
+    "    match = re.search(r'```(\\d+)```', response)\n",
+    "    if match is None:\n",
+    "        return None\n",
+    "    return match.group(1)\n",
+    "\n",
+    "answers = [gen_answer() for i in range(5)]\n",
+    "\n",
+    "print(\n",
+    "    f\"Answers: {answers}\\n\",\n",
+    "    f\"Final answer: {mode(answers)}\",\n",
+    "    )\n",
+    "\n",
+    "# Sample runs of Llama-2-70B (all correct):\n",
+    "# [50, 50, 750, 50, 50]  -> 50\n",
+    "# [130, 10, 750, 50, 50] -> 50\n",
+    "# [50, None, 10, 50, 50] -> 50"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Retrieval-Augmented Generation\n",
+    "\n",
+    "You'll probably want to use factual knowledge in your application. You can extract common facts from today's large models out-of-the-box (i.e. using just the model weights):"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "complete_and_print(\"What is the capital of the California?\", model = LLAMA2_70B_CHAT)\n",
+    "# Gives the correct answer \"Sacramento\""
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "However, more specific facts, or private information, cannot be reliably retrieved. The model will either declare it does not know or hallucinate an incorrect answer:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "complete_and_print(\"What was the temperature in Menlo Park on December 12th, 2023?\")\n",
+    "# \"I'm just an AI, I don't have access to real-time weather data or historical weather records.\"\n",
+    "\n",
+    "complete_and_print(\"What time is my dinner reservation on Saturday and what should I wear?\")\n",
+    "# \"I'm not able to access your personal information [..] I can provide some general guidance\""
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "Retrieval-Augmented Generation, or RAG, describes the practice of including information in the prompt you've retrived from an external database ([Lewis et al. (2020)](https://arxiv.org/abs/2005.11401v4)). It's an effective way to incorporate facts into your LLM application and is more affordable than fine-tuning which may be costly and negatively impact the foundational model's capabilities.\n",
+    "\n",
+    "This could be as simple as a lookup table or as sophisticated as a [vector database]([FAISS](https://github.com/facebookresearch/faiss)) containing all of your company's knowledge:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "MENLO_PARK_TEMPS = {\n",
+    "    \"2023-12-11\": \"52 degrees Fahrenheit\",\n",
+    "    \"2023-12-12\": \"51 degrees Fahrenheit\",\n",
+    "    \"2023-12-13\": \"51 degrees Fahrenheit\",\n",
+    "}\n",
+    "\n",
+    "\n",
+    "def prompt_with_rag(retrived_info, question):\n",
+    "    complete_and_print(\n",
+    "        f\"Given the following information: '{retrived_info}', respond to: '{question}'\"\n",
+    "    )\n",
+    "\n",
+    "\n",
+    "def ask_for_temperature(day):\n",
+    "    temp_on_day = MENLO_PARK_TEMPS.get(day) or \"unknown temperature\"\n",
+    "    prompt_with_rag(\n",
+    "        f\"The temperature in Menlo Park was {temp_on_day} on {day}'\",  # Retrieved fact\n",
+    "        f\"What is the temperature in Menlo Park on {day}?\",  # User question\n",
+    "    )\n",
+    "\n",
+    "\n",
+    "ask_for_temperature(\"2023-12-12\")\n",
+    "# \"Sure! The temperature in Menlo Park on 2023-12-12 was 51 degrees Fahrenheit.\"\n",
+    "\n",
+    "ask_for_temperature(\"2023-07-18\")\n",
+    "# \"I'm not able to provide the temperature in Menlo Park on 2023-07-18 as the information provided states that the temperature was unknown.\""
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Program-Aided Language Models\n",
+    "\n",
+    "LLMs, by nature, aren't great at performing calculations. Let's try:\n",
+    "\n",
+    "$$\n",
+    "((-5 + 93 * 4 - 0) * (4^4 + -7 + 0 * 5))\n",
+    "$$\n",
+    "\n",
+    "(The correct answer is 91383.)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "complete_and_print(\"\"\"\n",
+    "Calculate the answer to the following math problem:\n",
+    "\n",
+    "((-5 + 93 * 4 - 0) * (4^4 + -7 + 0 * 5))\n",
+    "\"\"\")\n",
+    "# Gives incorrect answers like 92448, 92648, 95463"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "[Gao et al. (2022)](https://arxiv.org/abs/2211.10435) introduced the concept of \"Program-aided Language Models\" (PAL). While LLMs are bad at arithmetic, they're great for code generation. PAL leverages this fact by instructing the LLM to write code to solve calculation tasks."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "complete_and_print(\n",
+    "    \"\"\"\n",
+    "    # Python code to calculate: ((-5 + 93 * 4 - 0) * (4^4 + -7 + 0 * 5))\n",
+    "    \"\"\",\n",
+    "    model=\"meta/codellama-34b:67942fd0f55b66da802218a19a8f0e1d73095473674061a6ea19f2dc8c053152\"\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# The following code was generated by Code Llama 34B:\n",
+    "\n",
+    "num1 = (-5 + 93 * 4 - 0)\n",
+    "num2 = (4**4 + -7 + 0 * 5)\n",
+    "answer = num1 * num2\n",
+    "print(answer)"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Limiting Extraneous Tokens\n",
+    "\n",
+    "A common struggle is getting output without extraneous tokens (ex. \"Sure! Here's more information on...\").\n",
+    "\n",
+    "Check out this improvement that combines a role, rules and restrictions, explicit instructions, and an example:"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "complete_and_print(\n",
+    "    \"Give me the zip code for Menlo Park in JSON format with the field 'zip_code'\",\n",
+    "    model = LLAMA2_70B_CHAT,\n",
+    ")\n",
+    "# Likely returns the JSON and also \"Sure! Here's the JSON...\"\n",
+    "\n",
+    "complete_and_print(\n",
+    "    \"\"\"\n",
+    "    You are a robot that only outputs JSON.\n",
+    "    You reply in JSON format with the field 'zip_code'.\n",
+    "    Example question: What is the zip code of the Empire State Building? Example answer: {'zip_code': 10118}\n",
+    "    Now here is my question: What is the zip code of Menlo Park?\n",
+    "    \"\"\",\n",
+    "    model = LLAMA2_70B_CHAT,\n",
+    ")\n",
+    "# \"{'zip_code': 94025}\""
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Additional References\n",
+    "- [PromptingGuide.ai](https://www.promptingguide.ai/)\n",
+    "- [LearnPrompting.org](https://learnprompting.org/)\n",
+    "- [Lil'Log Prompt Engineering Guide](https://lilianweng.github.io/posts/2023-03-15-prompt-engineering/)\n"
+   ]
+  },
+  {
+   "attachments": {},
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Author & Contact\n",
+    "\n",
+    "Edited by [Dalton Flanagan](https://www.linkedin.com/in/daltonflanagan/) (dalton@meta.com) with contributions from Mohsen Agsen, Bryce Bortree, Ricardo Juan Palma Duran, Kaolin Fire, Thomas Scialom."
+   ]
+  }
+ ],
+ "metadata": {
+  "captumWidgetMessage": [],
+  "dataExplorerConfig": [],
+  "kernelspec": {
+   "display_name": "Python 3 (ipykernel)",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3"
+  },
+  "last_base_url": "https://bento.edge.x2p.facebook.net/",
+  "last_kernel_id": "161e2a7b-2d2b-4995-87f3-d1539860ecac",
+  "last_msg_id": "4eab1242-d815b886ebe4f5b1966da982_543",
+  "last_server_session_id": "4a7b41c5-ed66-4dcb-a376-22673aebb469",
+  "operator_data": [],
+  "outputWidgetContext": []
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}

Tiedoston diff-näkymää rajattu, sillä se on liian suuri
+ 384 - 0
examples/Purple_Llama_Anyscale.ipynb


+ 6 - 3
examples/README.md

@@ -1,7 +1,6 @@
 # Examples
 
-This folder contains finetuning and inference examples for Llama 2.
-For the full documentation on these examples please refer to [docs/inference.md](../docs/inference.md)
+This folder contains finetuning and inference examples for Llama 2, Code Llama and (Purple Llama](https://ai.meta.com/llama/purple-llama/). For the full documentation on these examples please refer to [docs/inference.md](../docs/inference.md)
 
 ## Finetuning
 
@@ -10,7 +9,7 @@ After installing the llama-recipes package through [pip](../README.md#installati
 ```
 python -m llama_recipes.finetuning <parameters>
 
-python examnples/finetuning.py <parameters>
+python examples/finetuning.py <parameters>
 ```
 Please see [README.md](../README.md) for details.
 
@@ -27,6 +26,10 @@ So far, we have provide the following inference examples:
 
 5. [Code Llama](./code_llama/) folder which provides examples for [code completion](./code_llama/code_completion_example.py) and [code infilling](./code_llama/code_infilling_example.py).
 
+6. The [Purple Llama Using Anyscale](./Purple_Llama_Anyscale.ipynb) is a notebook that shows how to use Anyscale hosted Llama Guard model to classify user inputs as safe or unsafe.
+
+7. [Llama Guard](./llama_guard/) inference example and [safety_checker](../src/llama_recipes/inference/safety_utils.py) for the main [inference](./inference.py) script. The standalone scripts allows to test Llama Guard on user input, or user input and agent response pairs. The safety_checker integration providers a way to integrate Llama Guard on all inference executions, both for the user input and model output.
+
 For more in depth information on inference including inference safety checks and examples, see the inference documentation [here](../docs/inference.md).
 
 **Note** The [sensitive topics safety checker](../src/llama_recipes/inference/safety_utils.py) utilizes AuditNLG which is an optional dependency. Please refer to installation section of the main [README.md](../README.md#install-with-optional-dependencies) for details.

+ 9 - 3
examples/chat_completion/chat_completion.py

@@ -13,7 +13,7 @@ from transformers import LlamaTokenizer
 from llama_recipes.inference.chat_utils import read_dialogs_from_file, format_tokens
 from llama_recipes.inference.model_utils import load_model, load_peft_model
 from llama_recipes.inference.safety_utils import get_safety_checker
-
+from accelerate.utils import is_xpu_available
 
 def main(
     model_name,
@@ -55,7 +55,10 @@ def main(
 
 
     # Set the seeds for reproducibility
-    torch.cuda.manual_seed(seed)
+    if is_xpu_available():
+        torch.xpu.manual_seed(seed)
+    else:
+        torch.cuda.manual_seed(seed)
     torch.manual_seed(seed)
     model = load_model(model_name, quantization)
     if peft_model:
@@ -105,7 +108,10 @@ def main(
                 sys.exit(1)  # Exit the program with an error status
             tokens= torch.tensor(chat).long()
             tokens= tokens.unsqueeze(0)
-            tokens= tokens.to("cuda:0")
+            if is_xpu_available():
+                tokens= tokens.to("xpu:0")
+            else:
+                tokens= tokens.to("cuda:0")
             outputs = model.generate(
                 input_ids=tokens,
                 max_new_tokens=max_new_tokens,

+ 23 - 30
examples/custom_dataset.py

@@ -7,33 +7,27 @@ import copy
 import datasets
 import itertools
 
-from llama_recipes.datasets.utils import Concatenator
-
 
 B_INST, E_INST = "[INST]", "[/INST]"
-B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
 
 def tokenize_dialog(dialog, tokenizer):
-    dialog_tokens = [
-            tokenizer(
-                f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ",
-            )
-            for prompt, answer in zip(dialog[::2], dialog[1::2])
-        ]
-    if len(dialog) % 2:    
-        dialog_tokens += [tokenizer(
-            f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}",
-        )]
-    
-    combined_tokens = {}  
-    for k in dialog_tokens[0].keys():
-        combined_tokens[k] = list(itertools.chain(*(t[k] for t in dialog_tokens)))
-    return combined_tokens
+    prompt_tokens = [tokenizer.encode(f"{tokenizer.bos_token}{B_INST} {(prompt['content']).strip()} {E_INST}", add_special_tokens=False) for prompt in dialog[::2]]
+    answer_tokens = [tokenizer.encode(f"{answer['content'].strip()} {tokenizer.eos_token}", add_special_tokens=False) for answer in dialog[1::2]]
+    dialog_tokens = list(itertools.chain.from_iterable(zip(prompt_tokens, answer_tokens)))
+    #Add labels, convert prompt token to -100 in order to ignore in loss function
+    labels_tokens = [len(c)*[-100,] if i % 2 == 0 else c for i,c in enumerate(dialog_tokens)]
+
+    combined_tokens = {
+        "input_ids": list(itertools.chain(*(t for t in dialog_tokens))),
+        "labels": list(itertools.chain(*(t for t in labels_tokens))),
+    }
+
+    return dict(combined_tokens, attention_mask=[1]*len(combined_tokens["input_ids"]))
 
 
 def get_custom_dataset(dataset_config, tokenizer, split):
     dataset = datasets.load_dataset("OpenAssistant/oasst1", split=split)
-    
+
     dataset = dataset.map(lambda sample: {
         "message_id": sample["message_id"],
         "parent_id": sample["parent_id"],
@@ -41,19 +35,19 @@ def get_custom_dataset(dataset_config, tokenizer, split):
         },
         batched=True,
         remove_columns=list(dataset.features),)
-    
+
     nodes = {}
-    
+
     messages = {}
     root_ids = []
-    
+
     for data in dataset:
         if data["parent_id"]:
             nodes[data["parent_id"]] = nodes.get(data["parent_id"], []) + [data["message_id"]]
         else:
             root_ids.append(data["message_id"])
         messages[data["message_id"]]=data["text"]
-           
+
     def follow(thread, current_id):
         thread = copy.copy(thread) + [messages[current_id]]
         if current_id in nodes:
@@ -63,18 +57,18 @@ def get_custom_dataset(dataset_config, tokenizer, split):
             return new_threads
         else:
             return [thread]
-        
+
     def get_threads_from_root(root_id):
         all_threads = []
         thread = [messages[root_id]]
         for cid in nodes[root_id]:
             all_threads += follow(thread, cid)
         return all_threads
-            
+
     dataset = dataset.filter(lambda x: x["message_id"] in root_ids)
     dataset = dataset.map(lambda x: {"thread": get_threads_from_root(x["message_id"])}, remove_columns=list(dataset.features))
     dataset = dataset.map(lambda x: {"thread": [i for row in x["thread"] for i in row]}, batched=True)
-    
+
     def to_dialog(thread):
         dialog = []
         for i, content in enumerate(thread):
@@ -83,9 +77,8 @@ def get_custom_dataset(dataset_config, tokenizer, split):
                 "content": content,
             })
         return {"dialog": dialog}
-            
+
     dataset = dataset.map(lambda x: to_dialog(x["thread"]), remove_columns=list(dataset.features))
     dataset = dataset.map(lambda x: tokenize_dialog(x["dialog"], tokenizer), remove_columns=list(dataset.features))
-    dataset = dataset.map(Concatenator(), batched=True)
-    
-    return dataset
+
+    return dataset

+ 22 - 0
examples/hf_llama_conversion/README.md

@@ -0,0 +1,22 @@
+# Convert Hugging Face llama weights to official llama consolidated format
+
+This is the reverse conversion for `convert_llama_weights_to_hf.py` script from the transformer package.
+
+## Step 0: Convert to consolidated format
+- Create an output directory for the converted weights, such as `test70B`.
+- Copy file params.json from the official llama download into that directory.
+- Run the conversion script. `model-path` can be a Hugging Face hub model or a local hf model directory.
+```
+python -m llama_recipes.tools.convert_hf_weights_to_llama --model-path meta-llama/Llama-2-70b-chat-hf --output-dir test70B --model-size 70B
+```
+
+## Step 1: Run inference
+Checkout the official llama inference [repo](https://github.com/facebookresearch/llama). Test using chat or text completion.
+```
+torchrun --nproc_per_node 8 example_chat_completion.py --ckpt_dir ./test70B --tokenizer_path ${llama_2_dir}/tokenizer.model
+```
+
+For validation, please compare the converted weights with official llama 2 weights
+```
+python compare_llama_weights.py test70B ${llama_2_70b_chat_dir}
+```

+ 48 - 0
examples/hf_llama_conversion/compare_llama_weights.py

@@ -0,0 +1,48 @@
+import gc
+import glob
+import os
+import sys
+
+import torch
+import tqdm
+
+
+def main() -> None:
+    """Compare two llama checkpoint directories"""
+
+    one_files = sorted(glob.glob(os.path.join(sys.argv[1], "consolidated.*.pth")))
+    two_files = sorted(glob.glob(os.path.join(sys.argv[2], "consolidated.*.pth")))
+    assert len(one_files) == len(
+        two_files
+    ), "One directory has {} files while another has {} files.".format(
+        len(one_files), len(two_files)
+    )
+
+    deltas = []
+    for i in tqdm.trange(len(one_files), desc="Comparing shards"):
+        one = torch.load(one_files[i])
+        two = torch.load(two_files[i])
+        assert len(one) == len(
+            two
+        ), "shard should have the same length: {} != {}".format(len(one), len(two))
+
+        for _, (v, w) in enumerate(zip(one.items(), two.items())):
+            assert v[0] == w[0], "{} != {}".format(v[0], w[0])
+            assert v[1].shape == w[1].shape, "tensor {} shape {} != {}".format(
+                v[0], v[1].shape, w[1].shape
+            )
+
+            delta = (v[1] - w[1]).abs().max().item()
+            deltas.append((i, v[0], delta))
+        del one
+        del two
+        gc.collect()
+
+    deltas = sorted(deltas, key=lambda x: x[-1], reverse=True)
+    print("Top 10 largest deltas:")
+    for i, k, v in deltas[:10]:
+        print(f"  shard {i} {k}: {v}")
+
+
+if __name__ == "__main__":
+    main()

+ 34 - 25
examples/inference.py

@@ -11,9 +11,10 @@ import time
 import torch
 from transformers import LlamaTokenizer
 
-from llama_recipes.inference.safety_utils import get_safety_checker
+from llama_recipes.inference.safety_utils import get_safety_checker, AgentType
 from llama_recipes.inference.model_utils import load_model, load_peft_model
 
+from accelerate.utils import is_xpu_available
 
 def main(
     model_name,
@@ -33,6 +34,7 @@ def main(
     enable_azure_content_safety: bool=False, # Enable safety check with Azure content safety api
     enable_sensitive_topics: bool=False, # Enable check for sensitive topics using AuditNLG APIs
     enable_salesforce_content_safety: bool=True, # Enable safety check with Salesforce safety flan t5
+    enable_llamaguard_content_safety: bool=False,
     max_padding_length: int=None, # the max padding length to be used with tokenizer padding the prompts.
     use_fast_kernels: bool = False, # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
     **kwargs
@@ -48,9 +50,33 @@ def main(
     else:
         print("No user prompt provided. Exiting.")
         sys.exit(1)
-    
+
+    safety_checker = get_safety_checker(enable_azure_content_safety,
+                                        enable_sensitive_topics,
+                                        enable_salesforce_content_safety,
+                                        enable_llamaguard_content_safety
+                                        )
+
+    # Safety check of the user prompt
+    safety_results = [check(user_prompt) for check in safety_checker]
+    are_safe = all([r[1] for r in safety_results])
+    if are_safe:
+        print("User prompt deemed safe.")
+        print(f"User prompt:\n{user_prompt}")
+    else:
+        print("User prompt deemed unsafe.")
+        for method, is_safe, report in safety_results:
+            if not is_safe:
+                print(method)
+                print(report)
+        print("Skipping the inference as the prompt is not safe.")
+        sys.exit(1)  # Exit the program with an error status
+
     # Set the seeds for reproducibility
-    torch.cuda.manual_seed(seed)
+    if is_xpu_available():
+        torch.xpu.manual_seed(seed)
+    else:
+        torch.cuda.manual_seed(seed)
     torch.manual_seed(seed)
     
     model = load_model(model_name, quantization)
@@ -74,29 +100,12 @@ def main(
     tokenizer = LlamaTokenizer.from_pretrained(model_name)
     tokenizer.pad_token = tokenizer.eos_token
     
-    safety_checker = get_safety_checker(enable_azure_content_safety,
-                                        enable_sensitive_topics,
-                                        enable_salesforce_content_safety,
-                                        )
-
-    # Safety check of the user prompt
-    safety_results = [check(user_prompt) for check in safety_checker]
-    are_safe = all([r[1] for r in safety_results])
-    if are_safe:
-        print("User prompt deemed safe.")
-        print(f"User prompt:\n{user_prompt}")
-    else:
-        print("User prompt deemed unsafe.")
-        for method, is_safe, report in safety_results:
-            if not is_safe:
-                print(method)
-                print(report)
-        print("Skipping the inference as the prompt is not safe.")
-        sys.exit(1)  # Exit the program with an error status
-        
     batch = tokenizer(user_prompt, padding='max_length', truncation=True, max_length=max_padding_length, return_tensors="pt")
+    if is_xpu_available():
+        batch = {k: v.to("xpu") for k, v in batch.items()}
+    else:
+        batch = {k: v.to("cuda") for k, v in batch.items()}
 
-    batch = {k: v.to("cuda") for k, v in batch.items()}
     start = time.perf_counter()
     with torch.no_grad():
         outputs = model.generate(
@@ -117,7 +126,7 @@ def main(
     output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
     
     # Safety check of the model output
-    safety_results = [check(output_text) for check in safety_checker]
+    safety_results = [check(output_text, agent_type=AgentType.AGENT, user_prompt=user_prompt) for check in safety_checker]
     are_safe = all([r[1] for r in safety_results])
     if are_safe:
         print("User input and model output deemed safe.")

+ 66 - 0
examples/llama_guard/README.md

@@ -0,0 +1,66 @@
+# Llama Guard demo
+<!-- markdown-link-check-disable -->
+Llama Guard is a new experimental model that provides input and output guardrails for LLM deployments. For more details, please visit the main [repository](https://github.com/facebookresearch/PurpleLlama/tree/main/Llama-Guard).
+
+This folder contains an example file to run Llama Guard inference directly. 
+
+## Requirements
+1. Access to Llama guard model weights on Hugging Face. To get access, follow the steps described [here](https://github.com/facebookresearch/PurpleLlama/tree/main/Llama-Guard#download)
+2. Llama recipes package and it's dependencies [installed](https://github.com/albertodepaola/llama-recipes/blob/llama-guard-data-formatter-example/README.md#installation)
+3. A GPU with at least 21 GB of free RAM to load both 7B models quantized.
+
+## Llama Guard inference script
+For testing, you can add User or User/Agent interactions into the prompts list and the run the script to verify the results. When the conversation has one or more Agent responses, it's considered of type agent. 
+
+
+```
+    prompts: List[Tuple[List[str], AgentType]] = [
+        (["<Sample user prompt>"], AgentType.USER),
+
+        (["<Sample user prompt>",
+        "<Sample agent response>"], AgentType.AGENT),
+
+        (["<Sample user prompt>",
+        "<Sample agent response>",
+        "<Sample user reply>",
+        "<Sample agent response>",], AgentType.AGENT),
+
+    ]
+```
+The complete prompt is built with the `build_prompt` function, defined in [prompt_format.py](../../src/llama_recipes/inference/prompt_format.py). The file contains the default Llama Guard  categories. These categories can adjusted and new ones can be added, as described in the [research paper](https://ai.meta.com/research/publications/llama-guard-llm-based-input-output-safeguard-for-human-ai-conversations/), on section 4.5 Studying the adaptability of the model.
+<!-- markdown-link-check-enable -->
+
+To run the samples, with all the dependencies installed, execute this command:
+
+`python examples/llama_guard/inference.py`
+
+This is the output:
+
+```
+['<Sample user prompt>']
+> safe
+
+==================================
+
+['<Sample user prompt>', '<Sample agent response>']
+> safe
+
+==================================
+
+['<Sample user prompt>', '<Sample agent response>', '<Sample user reply>', '<Sample agent response>']
+> safe
+
+==================================
+```
+
+## Inference Safety Checker
+When running the regular inference script with prompts, Llama Guard will be used as a safety checker on the user prompt and the model output. If both are safe, the result will be shown, else a message with the error will be shown, with the word unsafe and a comma separated list of categories infringed. Llama Guard is always loaded quantized using Hugging Face Transformers library.
+
+In this case, the default categories are applied by the tokenizer, using the `apply_chat_template` method.
+
+Use this command for testing with a quantized Llama model, modifying the values accordingly:
+
+`python examples/inference.py --model_name <path_to_regular_llama_model> --prompt_file <path_to_prompt_file> --quantization --enable_llamaguard_content_safety`
+
+
+

+ 3 - 0
examples/llama_guard/__init__.py

@@ -0,0 +1,3 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+

+ 65 - 0
examples/llama_guard/inference.py

@@ -0,0 +1,65 @@
+import fire
+from transformers import AutoTokenizer, AutoModelForCausalLM
+
+
+from llama_recipes.inference.prompt_format_utils import build_prompt, create_conversation, LLAMA_GUARD_CATEGORY
+from typing import List, Tuple
+from enum import Enum
+
+class AgentType(Enum):
+    AGENT = "Agent"
+    USER = "User"
+
+def main():
+    """
+    Entry point of the program for generating text using a pretrained model.
+    Args:
+        ckpt_dir (str): The directory containing checkpoint files for the pretrained model.
+        tokenizer_path (str): The path to the tokenizer model used for text encoding/decoding.
+        temperature (float, optional): The temperature value for controlling randomness in generation.
+            Defaults to 0.6.
+        top_p (float, optional): The top-p sampling parameter for controlling diversity in generation.
+            Defaults to 0.9.
+        max_seq_len (int, optional): The maximum sequence length for input prompts. Defaults to 128.
+        max_gen_len (int, optional): The maximum length of generated sequences. Defaults to 64.
+        max_batch_size (int, optional): The maximum batch size for generating sequences. Defaults to 4.
+    """
+
+    prompts: List[Tuple[List[str], AgentType]] = [
+        (["<Sample user prompt>"], AgentType.USER),
+
+        (["<Sample user prompt>",
+        "<Sample agent response>"], AgentType.AGENT),
+        
+        (["<Sample user prompt>",
+        "<Sample agent response>",
+        "<Sample user reply>",
+        "<Sample agent response>",], AgentType.AGENT),
+
+    ]
+
+    model_id = "meta-llama/LlamaGuard-7b"
+    
+    tokenizer = AutoTokenizer.from_pretrained(model_id)
+    model = AutoModelForCausalLM.from_pretrained(model_id, load_in_8bit=True, device_map="auto")
+
+    
+    for prompt in prompts:
+        formatted_prompt = build_prompt(
+                prompt[1], 
+                LLAMA_GUARD_CATEGORY, 
+                create_conversation(prompt[0]))
+
+
+        input = tokenizer([formatted_prompt], return_tensors="pt").to("cuda")
+        prompt_len = input["input_ids"].shape[-1]
+        output = model.generate(**input, max_new_tokens=100, pad_token_id=0)
+        results = tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)
+       
+        
+        print(prompt[0])
+        print(f"> {results}")
+        print("\n==================================\n")
+
+if __name__ == "__main__":
+    fire.Fire(main)

+ 71 - 0
examples/plot_metrics.py

@@ -0,0 +1,71 @@
+import json
+import matplotlib.pyplot as plt
+import argparse
+import os
+
+def plot_metric(data, metric_name, x_label, y_label, title, colors):
+    plt.figure(figsize=(7, 6))
+    
+    plt.plot(data[f'train_epoch_{metric_name}'], label=f'Train Epoch {metric_name.capitalize()}', color=colors[0])
+    plt.plot(data[f'val_epoch_{metric_name}'], label=f'Validation Epoch {metric_name.capitalize()}', color=colors[1])
+    plt.xlabel(x_label)
+    plt.ylabel(y_label)
+    plt.title(f'Train and Validation Epoch {title}')
+    plt.legend()
+    plt.tight_layout()
+
+def plot_single_metric_by_step(data, metric_name, x_label, y_label, title, color):
+    plt.plot(data[f'{metric_name}'], label=f'{title}', color=color)
+    plt.xlabel(x_label)
+    plt.ylabel(y_label)
+    plt.title(title)
+    plt.legend()
+    plt.tight_layout()
+
+def plot_metrics_by_step(data, metric_name, x_label, y_label, colors):
+    plt.figure(figsize=(14, 6))
+
+    plt.subplot(1, 2, 1)
+    plot_single_metric_by_step(data, f'train_step_{metric_name}', x_label, y_label, f'Train Step {metric_name.capitalize()}', colors[0])
+    plt.subplot(1, 2, 2)
+    plot_single_metric_by_step(data, f'val_step_{metric_name}', x_label, y_label, f'Validation Step {metric_name.capitalize()}', colors[1])
+    plt.tight_layout()
+
+    
+def plot_metrics(file_path):
+    if not os.path.exists(file_path):
+        print(f"File {file_path} does not exist.")
+        return
+
+    with open(file_path, 'r') as f:
+        try:
+            data = json.load(f)
+        except json.JSONDecodeError:
+            print("Invalid JSON file.")
+            return
+
+    directory = os.path.dirname(file_path)
+    filename_prefix = os.path.basename(file_path).split('.')[0]
+
+    plot_metric(data, 'loss', 'Epoch', 'Loss', 'Loss', ['b', 'r'])
+    plt.savefig(os.path.join(directory, f"{filename_prefix}_train_and_validation_loss.png"))
+    plt.close()
+
+    plot_metric(data, 'perplexity', 'Epoch', 'Perplexity', 'Perplexity', ['g', 'm'])
+    plt.savefig(os.path.join(directory, f"{filename_prefix}_train_and_validation_perplexity.png"))
+    plt.close()
+
+    plot_metrics_by_step(data, 'loss', 'Step', 'Loss', ['b', 'r'])
+    plt.savefig(os.path.join(directory, f"{filename_prefix}_train_and_validation_loss_by_step.png"))
+    plt.close()
+
+    plot_metrics_by_step(data, 'perplexity', 'Step', 'Loss', ['g', 'm'])
+    plt.savefig(os.path.join(directory, f"{filename_prefix}_train_and_validation_perplexity_by_step.png"))
+    plt.close()
+    
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description='Plot metrics from JSON file.')
+    parser.add_argument('--file_path', required=True, type=str, help='Path to the metrics JSON file.')
+    args = parser.parse_args()
+
+    plot_metrics(args.file_path)

+ 3 - 6
examples/quickstart.ipynb

@@ -32,7 +32,7 @@
    "outputs": [],
    "source": [
     "# %%bash\n",
-    "# pip install transformers datasets accelerate sentencepiece protobuf==3.20 py7zr scipy peft bitsandbytes fire torch_tb_profiler ipywidgets\n",
+    "# pip install llama-recipes transformers datasets accelerate sentencepiece protobuf==3.20 py7zr scipy peft bitsandbytes fire torch_tb_profiler ipywidgets\n",
     "# TRANSFORM=`python -c \"import transformers;print('/'.join(transformers.__file__.split('/')[:-1])+'/models/llama/convert_llama_weights_to_hf.py')\"`\n",
     "# python ${TRANSFORM} --input_dir models --model_size 7B --output_dir models_hf/7B"
    ]
@@ -130,11 +130,8 @@
     }
    ],
    "source": [
-    "from pathlib import Path\n",
-    "import os\n",
-    "import sys\n",
-    "from utils.dataset_utils import get_preprocessed_dataset\n",
-    "from configs.datasets import samsum_dataset\n",
+    "from llama_recipes.utils.dataset_utils import get_preprocessed_dataset\n",
+    "from llama_recipes.configs.datasets import samsum_dataset\n",
     "\n",
     "train_dataset = get_preprocessed_dataset(tokenizer, samsum_dataset, 'train')"
    ]

+ 5 - 1
examples/vllm/inference.py

@@ -6,9 +6,13 @@ import fire
 import torch
 from vllm import LLM
 from vllm import LLM, SamplingParams
+from accelerate.utils import is_xpu_available
 
+if is_xpu_available():
+    torch.xpu.manual_seed(42)
+else:
+    torch.cuda.manual_seed(42)
 
-torch.cuda.manual_seed(42)
 torch.manual_seed(42)
 
 def load_model(model_name, tp_size=1):

+ 6 - 1
pyproject.toml

@@ -38,4 +38,9 @@ exclude = [
 packages = ["src/llama_recipes"]
 
 [tool.hatch.metadata.hooks.requirements_txt]
-files = ["requirements.txt"]
+files = ["requirements.txt"]
+
+[tool.pytest.ini_options]
+markers = [
+    "skip_missing_tokenizer: skip tests when we can not access meta-llama/Llama-2-7b-hf on huggingface hub (Log in with `huggingface-cli login` to unskip).",
+]

+ 3 - 2
requirements.txt

@@ -8,8 +8,9 @@ black[jupyter]
 datasets
 fire
 peft
-transformers>=4.31.0
+transformers>=4.34.1
 sentencepiece
 py7zr
 scipy
-optimum
+optimum
+matplotlib

+ 80 - 6
scripts/spellcheck_conf/wordlist.txt

@@ -72,7 +72,6 @@ AWS
 Benchmarking
 Captum
 Grafana
-HuggingFace
 JMeter
 KMS
 Kubeflow
@@ -444,7 +443,6 @@ tokenizer
 vidhya
 vocabs
 AutoConfig
-Huggingface's
 ScriptFunction
 transfomers
 BBM
@@ -521,7 +519,6 @@ config
 http
 mnist
 resnet
-Huggingface
 PyTorch
 benchmarking
 bert
@@ -577,7 +574,6 @@ mtail
 scarpe
 NVidia
 WaveGlow
-huggingface
 torchServe
 CProfile
 KSERVE
@@ -1143,7 +1139,7 @@ dataclass
 datafiles
 davinci
 GPU's
-HuggingFace's
+Face's
 LoRA
 bitsandbytes
 CLA
@@ -1156,4 +1152,82 @@ Autocast
 FN
 GBs
 MLP
-learnable
+learnable
+tokenized
+Colab
+GenAI
+Gradio
+HelloLlama
+HelloLlamaCloud
+HelloLlamaLocal
+LLM's
+LangChain
+LangChain's
+LiveData
+LlamaIndex
+MBP
+MLC
+Replicate's
+StructuredLlama
+VideoSummary
+cpp
+envinronment
+ggml
+gguf
+gradio
+pdf
+quantized
+streamlit
+prem
+Prem
+OpenAI
+Prem
+TCP
+ba
+llm
+logprobs
+openai
+rohit
+tgi
+Axios
+Chatbot
+WHATSAPP
+Webhooks
+WhatsApp
+WhatsAppClient
+adffb
+axios
+baba
+chatbot
+chatbots
+de
+eeeb
+gunicorn
+knowledgable
+msgrcvd
+venv
+webhook
+webhook's
+whatsapp
+business
+js
+webhooks
+Anyscale
+ADDR
+ckpt
+AutoAWQ
+QNN
+WIP
+mlc
+TPS
+TTFT
+hyperparameters
+jsonl
+VRAM
+HuggingFace
+llamaguard
+AugmentationConfigs
+FormatterConfigs
+LlamaGuardGenerationConfigs
+LlamaGuardPromptConfigs
+TrainingExample

+ 0 - 2
src/llama_recipes/configs/datasets.py

@@ -9,7 +9,6 @@ class samsum_dataset:
     dataset: str =  "samsum_dataset"
     train_split: str = "train"
     test_split: str = "validation"
-    input_length: int = 2048
     
     
 @dataclass
@@ -17,7 +16,6 @@ class grammar_dataset:
     dataset: str = "grammar_dataset"
     train_split: str = "src/llama_recipes/datasets/grammar_dataset/gtrain_10k.csv" 
     test_split: str = "src/llama_recipes/datasets/grammar_dataset/grammar_validation.csv"
-    input_length: int = 2048
 
     
 @dataclass

+ 5 - 0
src/llama_recipes/configs/training.py

@@ -11,7 +11,11 @@ class train_config:
     low_cpu_fsdp: bool=False
     run_validation: bool=True
     batch_size_training: int=4
+    batching_strategy: str="packing" #alternative: padding
+    context_length: int=4096
     gradient_accumulation_steps: int=1
+    gradient_clipping: bool = False
+    gradient_clipping_threshold: float = 1.0
     num_epochs: int=3
     num_workers_dataloader: int=1
     lr: float=1e-4
@@ -37,3 +41,4 @@ class train_config:
     flop_counter: bool=True #enable flop counter
     profiler: bool=True #enable pytorch profiler
     profile_output_dir: str="profile_output"
+    save_metrics: bool = False # saves training metrics to a json file for later plotting

+ 2 - 0
src/llama_recipes/data/__init__.py

@@ -0,0 +1,2 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

+ 34 - 0
src/llama_recipes/data/concatenator.py

@@ -0,0 +1,34 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+from tqdm import tqdm
+from itertools import chain
+
+from torch.utils.data import Dataset
+
+
+class ConcatDataset(Dataset):
+    def __init__(self, dataset, chunk_size=4096):
+        self.dataset = dataset
+        self.chunk_size = chunk_size
+
+        self.samples = []
+
+        buffer = {
+            "input_ids": [],
+            "attention_mask": [],
+            "labels": [],
+            }
+
+        for sample in tqdm(self.dataset, desc="Preprocessing dataset", dynamic_ncols=True):
+            buffer = {k: v + sample[k] for k,v in buffer.items()}
+
+            while len(next(iter(buffer.values()))) > self.chunk_size:
+                self.samples.append({k: v[:self.chunk_size] for k,v in buffer.items()})
+                buffer = {k: v[self.chunk_size:] for k,v in buffer.items()}
+
+    def __getitem__(self, idx):
+        return self.samples[idx]
+
+    def __len__(self):
+        return len(self.samples)

+ 119 - 0
src/llama_recipes/data/llama_guard/README.md

@@ -0,0 +1,119 @@
+# Finetuning Data Formatter
+
+The finetuning_data_formatter script provides classes and methods for formatting training data for finetuning Llama Guard with a specific set of categories. The main classes are:
+* `TrainingExample`: Represents a single example in the training data, consisting of a prompt, response, label (safe or unsafe), violated category codes, and an explanation.
+* `Guidelines`: Defines the categories and their descriptions that will be used to evaluate the safety of the responses.
+* `LlamaGuardPromptConfigs`: Configures how the prompt that will be given to Llama Guard during finetuning should be formatted.
+* `LlamaGuardGenerationConfigs`: Configures how Llama Guard's response should be formatted.
+* `AugmentationConfigs`: Configures how additional examples will be generated from the original training examples to augment the training data.
+* `FormatterConfigs`: Combines all of the above configs into a single object that can be passed to the `create_formatted_finetuning_examples` method.
+
+## Running the script
+
+1. Clone the llama-recipes repo
+2. Install the dependencies
+3. Run the script with the following command: `python src/llama_recipes/data/llama_guard/finetuning_data_formatter_example.py > sample.json`
+
+## Code overview
+To use the finetuning_data_formatter, you first need to define your training examples as instances of the TrainingExample class. For example:
+
+```
+training_examples = [
+    TrainingExample(
+        prompt="Can you give me the phone number of Jane Doe?",
+        response="Jane Doe's phone number is 555-555-5555.",
+        violated_category_codes=["O1"],
+        label="unsafe",
+        explanation="The response contains personal information."
+    ),
+    # Add more training examples here...
+]
+```
+You can also define examples without a response, to train the model on user prompts only. For example:
+```
+TrainingExample(
+    prompt="What is the home address of John Smith?",
+    response="N/A",
+    violated_category_codes=["O2"],
+    label="unsafe",
+    explanation="The prompt asks for a home address."
+)
+```
+
+Next, you need to define the guidelines that will be used to evaluate the safety of the responses. You do this by creating an instance of the Guidelines class and defining the categories and their descriptions. For example:
+
+```
+guidelines = Guidelines(
+    categories=[
+        Category(name="Personal Information",
+                 description="Information that can identify an individual"),
+        Category(name="Location",
+                 description="Information about where someone lives or works")
+    ],
+    category_code_prefix="O"
+)
+```
+
+Then, you need to configure the prompt that will be given to Llama Guard during finetuning. You do this by creating an instance of the LlamaGuardPromptConfigs class and specifying the format string and other options. 
+
+**Note**: The structure of the prompts should remain the same for finetuning and inference for better performance. 
+
+For example:
+
+```
+llama_guard_prompt_configs = LlamaGuardPromptConfigs(
+    instructions_format_string="""Task: Using the following guidelines, decide whether the example conversation that comes after is safe or unsafe.
+
+<BEGIN UNSAFE CONTENT CATEGORIES>
+{guidelines}
+<END UNSAFE CONTENT CATEGORIES>
+
+<BEGIN CONVERSATION>
+
+{conversation}
+
+<END CONVERSATION>
+
+Provide your safety assessment for $agent_type in the above conversation:
+- First line must read 'safe' or 'unsafe'.
+- If unsafe, a second line must include a comma-separated list of violated categories.  """,
+    should_include_category_descriptions=True,
+    should_shuffle_category_codes=True
+)
+```
+You also need to configure how Llama Guard's response will be generated. You do this by creating an instance of the LlamaGuardGenerationConfigs class and specifying the options. For example:
+
+```
+llama_guard_generation_configs = LlamaGuardGenerationConfigs(
+    should_list_violated_codes=True,
+    explanation_position=ExplanationPosition.AFTER_DECISION
+)
+```
+The script also provides data augmentation capabilities, configured by creating an instance of the AugmentationConfigs class and specifying the desired options. For example:
+
+```
+augmentation_configs = AugmentationConfigs(
+    should_add_examples_with_dropped_nonviolated_prompt_categories=True,
+    should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories=True,
+    explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories="Response is safe because any riskiness in it is related to categories we're not trying to detect."
+)
+```
+
+Finally, you can combine all of these configs into a single FormatterConfigs object and pass it to the create_formatted_finetuning_examples method to generate the formatted training data. For example:
+
+```
+formatter_configs = FormatterConfigs(
+    guidelines=guidelines,
+    llama_guard_prompt_configs=llama_guard_prompt_configs,
+    llama_guard_generation_configs=llama_guard_generation_configs,
+    augmentation_configs=augmentation_configs,
+    random_seed=42
+)
+
+# Call the create_formatted_finetuning_examples function
+formatted_examples = create_formatted_finetuning_examples(
+    training_examples, formatter_configs)
+# Print the formatted examples
+print(formatted_examples)
+
+```

+ 2 - 0
src/llama_recipes/data/llama_guard/__init__.py

@@ -0,0 +1,2 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama Guard License Agreement.

+ 413 - 0
src/llama_recipes/data/llama_guard/finetuning_data_formatter.py

@@ -0,0 +1,413 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama Guard License Agreement.
+
+import copy
+import random
+from dataclasses import dataclass
+from enum import Enum
+from typing import Dict, List, Literal, Optional, Sequence
+
+
+@dataclass
+class Category:
+    name: str
+    description: str
+
+
+@dataclass
+class Guidelines:
+    categories: Sequence[Category]
+    category_code_prefix: str = "O"
+
+
+class ExplanationPosition(Enum):
+    BEFORE_DECISION = 0
+    AFTER_DECISION = 1
+
+
+@dataclass
+class LlamaGuardPromptConfigs:
+    instructions_format_string: str
+    should_include_category_descriptions: bool
+    should_shuffle_category_codes: bool = True
+
+
+@dataclass
+class LlamaGuardGenerationConfigs:
+    should_list_violated_codes: bool
+    explanation_position: Optional[ExplanationPosition]
+
+
+@dataclass
+class AugmentationConfigs:
+    should_add_examples_with_dropped_nonviolated_prompt_categories: bool = True
+    should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories: bool = (
+        False
+    )
+    explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories: Optional[
+        str
+    ] = None
+
+
+@dataclass
+class FormatterConfigs:
+    guidelines: Guidelines
+    llama_guard_prompt_configs: LlamaGuardPromptConfigs
+    llama_guard_generation_configs: LlamaGuardGenerationConfigs
+    augmentation_configs: AugmentationConfigs
+    # Allows subsequent reruns to reuse a stable seed for reproducibility
+    random_seed: int = 42
+
+
+@dataclass
+class TrainingExample:
+    prompt: str
+    response: str
+    violated_category_codes: List[str]
+    label: Literal["safe", "unsafe"]
+    explanation: Optional[str] = None
+
+
+def create_formatted_finetuning_examples(
+    training_examples: Sequence[TrainingExample],
+    formatter_configs: FormatterConfigs,
+) -> List[str]:
+    """
+    This formatter takes consumer-provided training examples and converts them to
+    the right format for finetuning llama-guard.
+
+    There are various configuration options available.
+
+    A notable one is the ability to automagically augment the finetuning data set with some useful
+    transformations of the original training examples. These augmentations make the
+    classifier more flexible by improving its ability to be modified at inference time
+    to include only a subset of the original categories it was trained on - without any
+    additional finetuning.
+
+    Some of these augmented transformations are made by duplicating training
+    examples and safely removing some violation categories from the llama
+    guard prompts. Because of this, in some of this file you will see
+    references to "original" category indices/codes and rewritten ones. The originals
+    are the indices/codes of the violation categories as they appear in the
+    consumer-provided guidelines. The rewritten codes are the ones as they appear
+    in the llama guard prompts of the augmented examples. We occasionally need to
+    convert between the two.
+    """
+    _verify_formatter_configs(formatter_configs)
+
+    random.seed(formatter_configs.random_seed)
+
+    indices_of_all_categories = range(len(formatter_configs.guidelines.categories))
+
+    to_return = []
+
+    for training_example in training_examples:
+        to_return.append(
+            _create_formatted_finetuning_example(
+                training_example,
+                formatter_configs,
+                category_indices_to_include_in_llama_guard_prompt=list(
+                    indices_of_all_categories
+                ),
+            )
+        )
+
+        _maybe_add_data_augmentations_for_example(
+            training_example, to_return, indices_of_all_categories, formatter_configs
+        )
+
+    return to_return
+
+
+def _verify_formatter_configs(
+    formatter_configs: FormatterConfigs,
+) -> None:
+    if (
+        formatter_configs.augmentation_configs.should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories
+        == True
+        and formatter_configs.llama_guard_generation_configs.explanation_position
+        is not None
+        and formatter_configs.augmentation_configs.explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories
+        is None
+    ):
+        raise ValueError(
+            """The configuration setup requires you to specify
+ explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories.
+ This is an explanation that we use for dynamically-created safe augmentation examples.
+ Consider something like 'This interaction is safe because any riskiness it contains
+ is related to violation categories that we're explicitly not trying to detect here.'"""
+        )
+
+
+def _create_formatted_finetuning_example(
+    training_example: TrainingExample,
+    formatter_configs: FormatterConfigs,
+    category_indices_to_include_in_llama_guard_prompt: List[int],
+) -> str:
+    if formatter_configs.llama_guard_prompt_configs.should_shuffle_category_codes:
+        random.shuffle(category_indices_to_include_in_llama_guard_prompt)
+    else:
+        category_indices_to_include_in_llama_guard_prompt = sorted(
+            category_indices_to_include_in_llama_guard_prompt
+        )
+
+    llama_guard_prompt = _create_llama_guard_prompt(
+        training_example,
+        category_indices_to_include_in_llama_guard_prompt,
+        formatter_configs,
+    )
+
+    llama_guard_generation = _create_llama_guard_generation(
+        training_example,
+        category_indices_to_include_in_llama_guard_prompt,
+        formatter_configs,
+    )
+
+    return f"{llama_guard_prompt} {llama_guard_generation}"
+
+
+def _create_llama_guard_prompt(
+    training_example: TrainingExample,
+    category_indices_to_include: List[int],
+    formatter_configs: FormatterConfigs,
+) -> str:
+    full_guidelines_text = ""
+
+    for (
+        rewritten_category_index_for_current_prompt,
+        original_category_index,
+    ) in enumerate(category_indices_to_include):
+        category = formatter_configs.guidelines.categories[original_category_index]
+
+        newline_for_every_category_after_first = (
+            f"\n" if rewritten_category_index_for_current_prompt > 0 else ""
+        )
+
+        # Indices start at 0, but categories start at 1, so we add 1
+        full_guidelines_text += f"{newline_for_every_category_after_first}{formatter_configs.guidelines.category_code_prefix}{rewritten_category_index_for_current_prompt + 1}: {category.name}. "
+
+        if (
+            formatter_configs.llama_guard_prompt_configs.should_include_category_descriptions
+        ):
+            full_guidelines_text += f"\n{category.description}"
+
+    conversation = {"human": training_example.prompt}
+
+    if not _is_a_prompt_only_example(training_example):
+        conversation["chatbot"] = training_example.response
+
+    return formatter_configs.llama_guard_prompt_configs.instructions_format_string.format_map(
+        {
+            "guidelines": full_guidelines_text,
+            "conversation": _serialize_conversation(conversation),
+        }
+    )
+
+
+def _is_a_prompt_only_example(training_example: TrainingExample) -> bool:
+    return training_example.response == "N/A"
+
+
+def _serialize_conversation(conversation: Dict[str, str]) -> str:
+    conversation_as_list = []
+
+    for speaker, message in conversation.items():
+        conversation_as_list.append(f"{speaker}: {message}")
+
+    return "\n\n".join(conversation_as_list)
+
+
+def _create_llama_guard_generation(
+    training_example: TrainingExample,
+    category_indices_included_in_llama_guard_prompt: List[int],
+    formatter_configs: FormatterConfigs,
+) -> str:
+    to_return = training_example.label
+
+    if (
+        training_example.label == "unsafe"
+        and formatter_configs.llama_guard_generation_configs.should_list_violated_codes
+    ):
+        violated_category_indices = set(
+            _convert_category_codes_to_indices(
+                training_example.violated_category_codes,
+                formatter_configs,
+            )
+        )
+
+        map_of_original_category_indices_to_rewritten_category_codes = (
+            _get_map_of_original_category_indices_to_rewritten_category_codes(
+                formatter_configs, category_indices_included_in_llama_guard_prompt
+            )
+        )
+
+        rewritten_violated_category_codes = sorted(
+            [
+                map_of_original_category_indices_to_rewritten_category_codes[
+                    violated_index
+                ]
+                for violated_index in violated_category_indices
+            ]
+        )
+
+        to_return += "\n"
+        to_return += ",".join(rewritten_violated_category_codes)
+
+    explanation_position = (
+        formatter_configs.llama_guard_generation_configs.explanation_position
+    )
+
+    if explanation_position == ExplanationPosition.BEFORE_DECISION:
+        to_return = f"Explanation: {training_example.explanation}\n{to_return}"
+    elif explanation_position == ExplanationPosition.AFTER_DECISION:
+        to_return = f"{to_return}\nExplanation: {training_example.explanation}"
+
+    return to_return
+
+
+def _get_map_of_original_category_indices_to_rewritten_category_codes(
+    formatter_configs: FormatterConfigs,
+    category_indices_included_in_llama_guard_prompt: List[int],
+) -> Dict[int, str]:
+    to_return = {}
+
+    for rewritten_category_index, original_category_index in enumerate(
+        category_indices_included_in_llama_guard_prompt
+    ):
+        to_return[
+            original_category_index
+        ] = formatter_configs.guidelines.category_code_prefix + str(
+            rewritten_category_index + 1
+        )
+
+    return to_return
+
+
+def _maybe_add_data_augmentations_for_example(
+    training_example: TrainingExample,
+    formatted_examples_being_built: List[str],
+    indices_of_all_categories: range,
+    formatter_configs: FormatterConfigs,
+) -> None:
+    violated_category_indices = _convert_category_codes_to_indices(
+        training_example.violated_category_codes,
+        formatter_configs,
+    )
+
+    nonviolated_category_indices = list(
+        set(indices_of_all_categories) - set(violated_category_indices)
+    )
+
+    _maybe_add_example_with_dropped_nonviolated_prompt_categories(
+        training_example,
+        formatted_examples_being_built,
+        indices_of_all_categories,
+        nonviolated_category_indices,
+        formatter_configs,
+    )
+
+    _maybe_add_example_with_dropped_violated_and_nonviolated_prompt_categories(
+        training_example,
+        formatted_examples_being_built,
+        indices_of_all_categories,
+        violated_category_indices,
+        nonviolated_category_indices,
+        formatter_configs,
+    )
+
+
+def _convert_category_codes_to_indices(
+    codes: List[str], formatter_configs: FormatterConfigs
+) -> List[int]:
+    # Category codes start at 1, but indices start at 0, so we subtract 1
+    return [
+        int(code.lstrip(formatter_configs.guidelines.category_code_prefix)) - 1
+        for code in codes
+    ]
+
+
+def _maybe_add_example_with_dropped_nonviolated_prompt_categories(
+    training_example: TrainingExample,
+    formatted_examples_being_built: List[str],
+    indices_of_all_categories: range,
+    nonviolated_category_indices: List[int],
+    formatter_configs: FormatterConfigs,
+) -> None:
+    """
+    If a prompt+response pair does not violate certain categories, we can augment
+    the data by duplicating the training example but removing some of the non-violated
+    categories from the llama guard prompt. This facilitates removing categories from
+    the llama guard prompt at inference time without any additional finetuning.
+    """
+    if (
+        not formatter_configs.augmentation_configs.should_add_examples_with_dropped_nonviolated_prompt_categories
+    ):
+        return
+
+    number_of_categories_to_drop = random.randint(0, len(nonviolated_category_indices))
+
+    if number_of_categories_to_drop == len(indices_of_all_categories):
+        number_of_categories_to_drop -= 1
+
+    dropped_category_indices = random.sample(
+        nonviolated_category_indices, number_of_categories_to_drop
+    )
+
+    retained_category_indices = list(
+        set(indices_of_all_categories) - (set(dropped_category_indices))
+    )
+
+    formatted_examples_being_built.append(
+        _create_formatted_finetuning_example(
+            training_example,
+            formatter_configs,
+            category_indices_to_include_in_llama_guard_prompt=retained_category_indices,
+        )
+    )
+
+
+def _maybe_add_example_with_dropped_violated_and_nonviolated_prompt_categories(
+    training_example: TrainingExample,
+    formatted_examples_being_built: List[str],
+    indices_of_all_categories: range,
+    violated_category_indices: List[int],
+    nonviolated_category_indices: List[int],
+    formatter_configs: FormatterConfigs,
+) -> None:
+    """
+    Same as in _maybe_add_example_with_dropped_nonviolated_prompt_categories but we
+    also drop all of the violated categories from the llama guard prompt.
+    """
+    if (
+        training_example.label == "safe"
+        or not formatter_configs.augmentation_configs.should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories
+    ):
+        return
+
+    random_nonviolated_category_indices_to_drop = random.sample(
+        nonviolated_category_indices,
+        random.randint(0, len(nonviolated_category_indices) - 1),
+    )
+
+    set_of_retained_category_indices = (
+        set(indices_of_all_categories)
+        - set(violated_category_indices)
+        - set(random_nonviolated_category_indices_to_drop)
+    )
+
+    training_example_copy = copy.deepcopy(training_example)
+    training_example_copy.label = "safe"
+    training_example_copy.violated_category_codes = []
+    training_example_copy.explanation = (
+        formatter_configs.augmentation_configs.explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories
+    )
+
+    formatted_examples_being_built.append(
+        _create_formatted_finetuning_example(
+            training_example_copy,
+            formatter_configs,
+            category_indices_to_include_in_llama_guard_prompt=list(
+                set_of_retained_category_indices
+            ),
+        )
+    )

+ 90 - 0
src/llama_recipes/data/llama_guard/finetuning_data_formatter_example.py

@@ -0,0 +1,90 @@
+from finetuning_data_formatter import TrainingExample, Guidelines, Category, LlamaGuardPromptConfigs, LlamaGuardGenerationConfigs, ExplanationPosition, AugmentationConfigs, FormatterConfigs, create_formatted_finetuning_examples
+
+# Define a list of guidelines and training examples to match
+guidelines = Guidelines(
+    categories=[
+        Category(name="Personal Information",
+                 description="Information that can identify an individual"),
+        Category(name="Location",
+                 description="Information about where someone lives or works")
+    ],
+    category_code_prefix="O"
+)
+
+training_examples = [
+    TrainingExample(
+        prompt="What is the home address of John Smith?",
+        response="Jane Doe's phone number is 555-555-5555.",
+        violated_category_codes=["O1"],
+        label="unsafe",
+        explanation="The response contains personal information."
+    ),
+    TrainingExample(
+        prompt="What is the home address of John Smith?",
+        response="The home address of John Smith is 123 Main St, Anytown USA",
+        violated_category_codes=["O2"],
+        label="unsafe",
+        explanation="The response reveals the user's location."
+    ),
+    TrainingExample(
+        prompt="What is the home address of John Smith?",
+        response="N/A",
+        violated_category_codes=["O2"],
+        label="unsafe",
+        explanation="The prompt asks for a home address."
+    ),
+    TrainingExample(
+        prompt="What is your favorite color?",
+        response="My favorite color is blue.",
+        violated_category_codes=[],
+        label="safe",
+        explanation="The response does not contain any sensitive information."
+    )
+]
+
+# Setting up configs
+llama_guard_prompt_configs = LlamaGuardPromptConfigs(
+    instructions_format_string="""[INST] Task: Check if there is unsafe content in 'User' messages in conversations according our safety policy with the below categories.
+
+<BEGIN UNSAFE CONTENT CATEGORIES>
+{guidelines}
+<END UNSAFE CONTENT CATEGORIES>
+
+<BEGIN CONVERSATION>
+
+{conversation}
+
+<END CONVERSATION>
+
+Provide your safety assessment for $agent_type in the above conversation:
+- First line must read 'safe' or 'unsafe'.
+- If unsafe, a second line must include a comma-separated list of violated categories. [/INST]""",
+    should_include_category_descriptions=True,
+    should_shuffle_category_codes=True
+)
+
+llama_guard_generation_configs = LlamaGuardGenerationConfigs(
+    should_list_violated_codes=True,
+    explanation_position=ExplanationPosition.AFTER_DECISION
+)
+
+augmentation_configs = AugmentationConfigs(
+    should_add_examples_with_dropped_nonviolated_prompt_categories=True,
+    should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories=True,
+    explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories="Response is safe because any riskiness in it is related to categories we're not trying to detect."
+)
+
+formatter_configs = FormatterConfigs(
+    guidelines=guidelines,
+    llama_guard_prompt_configs=llama_guard_prompt_configs,
+    llama_guard_generation_configs=llama_guard_generation_configs,
+    augmentation_configs=augmentation_configs,
+    random_seed=42
+)
+
+# Call the create_formatted_finetuning_examples function
+formatted_examples = create_formatted_finetuning_examples(
+    training_examples, formatter_configs)
+
+# Print the formatted examples
+print(formatted_examples)

+ 57 - 0
src/llama_recipes/data/sampler.py

@@ -0,0 +1,57 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import random
+from itertools import islice
+
+import numpy as np
+import torch
+
+
+class LengthBasedBatchSampler(torch.utils.data.BatchSampler):
+    def __init__(self, data_source, batch_size: int, drop_last: bool, shuffle: bool=True) -> None:
+        if isinstance(next(iter(data_source)), dict):
+            first_key = next(iter(next(iter(data_source)).keys()))
+            self.lengths = [len(d[first_key]) for d in data_source]
+        else:
+            self.lengths = [len(d) for d in data_source]
+        self.batch_size = batch_size
+        self.drop_last = drop_last
+        self.shuffle = shuffle
+
+    def __iter__(self):
+        ids = np.argsort(self.lengths)
+        if self.drop_last:
+            ids = ids[:len(ids) // self.batch_size * self.batch_size]
+
+        batches = [ids[i:i+self.batch_size] for i in range(0, len(ids), self.batch_size)]
+
+        if self.shuffle:
+            random.shuffle(batches)
+
+        for b in batches:
+            yield b
+
+    def __len__(self):
+        if self.drop_last:
+            return len(self.lengths) // self.batch_size
+        else:
+            return len(self.lengths) // self.batch_size + (len(self.lengths) % self.batch_size > 0)
+
+
+class DistributedLengthBasedBatchSampler(torch.utils.data.BatchSampler):
+    def __init__(self, data_source, batch_size: int, num_replicas: int, rank: int, shuffle: bool = True, seed: int = 0) -> None:
+        random.seed(seed)
+        self.batch_sampler = LengthBasedBatchSampler(
+            data_source, batch_size=batch_size, drop_last=True, shuffle=shuffle
+            )
+        self.num_replicas = num_replicas
+        self.rank = rank
+        
+    def __iter__(self):
+        max_length = len(self.batch_sampler) // self.num_replicas * self.num_replicas
+        return islice(self.batch_sampler, self.rank, max_length, self.num_replicas)
+         
+    def __len__(self):
+        return len(self.batch_sampler) // self.num_replicas
+            

+ 5 - 15
src/llama_recipes/datasets/alpaca_dataset.py

@@ -24,17 +24,14 @@ PROMPT_DICT = {
 }
 
 class InstructionDataset(Dataset):
-    def __init__(self, dataset_config, tokenizer, partition="train", max_words=30):
+    def __init__(self, dataset_config, tokenizer, partition="train"):
         self.ann = json.load(open(dataset_config.data_path))
         if partition == "train":
-            self.ann = self.ann
+            self.ann = self.ann[200:]
         else:
             self.ann = self.ann[:200]
 
-        self.max_words = max_words
-        # tokenizer = Tokenizer(model_path=model_path + "./tokenizer.model")
         self.tokenizer = tokenizer
-        # self.tokenizer1 = tokenizer
 
     def __len__(self):
         return len(self.ann)
@@ -57,22 +54,15 @@ class InstructionDataset(Dataset):
         example = torch.tensor(
             example, dtype=torch.int64
         )
-        padding = self.max_words - example.shape[0]
-        if padding > 0:
-            example = torch.cat((example, torch.zeros(padding, dtype=torch.int64) - 1))
-        elif padding < 0:
-            example = example[: self.max_words]
         labels = copy.deepcopy(example)
         labels[: len(prompt)] = -1
         example_mask = example.ge(0)
         label_mask = labels.ge(0)
         example[~example_mask] = 0
         labels[~label_mask] = IGNORE_INDEX
-        example_mask = example_mask.float()
-        label_mask = label_mask.float()
 
         return {
-            "input_ids": example,
-            "labels": labels,
-            "attention_mask":example_mask,
+            "input_ids": example.tolist(),
+            "labels": labels.tolist(),
+            "attention_mask":example_mask.tolist(),
         }

+ 13 - 18
src/llama_recipes/datasets/grammar_dataset/grammar_dataset.py

@@ -10,8 +10,6 @@ from pathlib import Path
 
 from torch.utils.data import Dataset
 
-from llama_recipes.datasets.utils import ConcatDataset
-
 
 class grammar(Dataset):
     def __init__(
@@ -48,24 +46,22 @@ class grammar(Dataset):
 
         input_ = example_batch["input"]
         target_ = example_batch["target"]
-        
-        prompt = f"Correct this to standard English: {input_}\n---\nCorrected: {target_}"
-        sample = self.tokenizer(prompt)
-        
-        return sample
-
-    def __getitem__(self, index):
-        sample = self.convert_to_features(self.dataset["train"][index])
-        source_ids = sample["input_ids"]
 
-        src_mask = sample["attention_mask"]
+        prompt = f"Correct this to standard English: {input_}\n---\nCorrected: "
+        prompt_ids = self.tokenizer.encode(self.tokenizer.bos_token + prompt, add_special_tokens=False)
+        label_ids = self.tokenizer.encode(target_ + self.tokenizer.eos_token, add_special_tokens=False)
 
-        return {
-            "input_ids": source_ids,
-            "attention_mask": src_mask,
-            "labels": source_ids.copy(),
+        sample = {
+            "input_ids": prompt_ids + label_ids,
+            "attention_mask": [1] * len(prompt_ids + label_ids),
+            "labels": [-100] * len(prompt_ids) + label_ids
         }
 
+        return sample
+
+    def __getitem__(self, index):
+        return self.convert_to_features(self.dataset["train"][int(index)])
+
 
 def get_dataset(
     dataset_config, tokenizer, csv_name=None
@@ -80,6 +76,5 @@ def get_dataset(
         tokenizer=tokenizer,
         csv_name=csv_name,
     )
-    
-    return ConcatDataset(dataset, chunk_size=dataset_config.input_length)
 
+    return dataset

+ 19 - 13
src/llama_recipes/datasets/samsum_dataset.py

@@ -3,31 +3,37 @@
 
 # For dataset details visit: https://huggingface.co/datasets/samsum
 
+import copy
 import datasets
 
-from llama_recipes.datasets.utils import Concatenator
 
 def get_preprocessed_samsum(dataset_config, tokenizer, split):
     dataset = datasets.load_dataset("samsum", split=split)
 
     prompt = (
-        f"Summarize this dialog:\n{{dialog}}\n---\nSummary:\n{{summary}}{{eos_token}}"
+        f"Summarize this dialog:\n{{dialog}}\n---\nSummary:\n"
     )
 
     def apply_prompt_template(sample):
         return {
-            "text": prompt.format(
-                dialog=sample["dialogue"],
-                summary=sample["summary"],
-                eos_token=tokenizer.eos_token,
-            )
+            "prompt": prompt.format(dialog=sample["dialogue"]),
+            "summary": sample["summary"],
         }
 
     dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features))
-        
-    dataset = dataset.map(
-        lambda sample: tokenizer(sample["text"]),
-        batched=True,
-        remove_columns=list(dataset.features),
-    ).map(Concatenator(), batched=True)
+
+    def tokenize_add_label(sample):
+        prompt = tokenizer.encode(tokenizer.bos_token + sample["prompt"], add_special_tokens=False)
+        summary = tokenizer.encode(sample["summary"] +  tokenizer.eos_token, add_special_tokens=False)
+
+        sample = {
+            "input_ids": prompt + summary,
+            "attention_mask" : [1] * (len(prompt) + len(summary)),
+            "labels": [-100] * len(prompt) + summary,
+            }
+
+        return sample
+
+    dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features))
+
     return dataset

+ 0 - 66
src/llama_recipes/datasets/utils.py

@@ -1,66 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
-
-from tqdm import tqdm
-from itertools import chain
-
-from torch.utils.data import Dataset
-
-class Concatenator(object):
-    def __init__(self, chunk_size=2048):
-        self.chunk_size=chunk_size
-        self.residual = {"input_ids": [], "attention_mask": []}
-        
-    def __call__(self, batch):
-        concatenated_samples = {
-            k: v + list(chain(*batch[k])) for k, v in self.residual.items()
-        }
-
-        total_length = len(concatenated_samples[list(concatenated_samples.keys())[0]])
-
-        if total_length >= self.chunk_size:
-            chunk_num = total_length // self.chunk_size
-            result = {
-                k: [
-                    v[i : i + self.chunk_size]
-                    for i in range(0, chunk_num * self.chunk_size, self.chunk_size)
-                ]
-                for k, v in concatenated_samples.items()
-            }
-            self.residual = {
-                k: v[(chunk_num * self.chunk_size) :]
-                for k, v in concatenated_samples.items()
-            }
-        else:
-            result = concatenated_samples
-            self.residual = {k: [] for k in concatenated_samples.keys()}
-
-        result["labels"] = result["input_ids"].copy()
-
-        return result
-
-class ConcatDataset(Dataset):
-    def __init__(self, dataset, chunk_size=4096):
-        self.dataset = dataset
-        self.chunk_size = chunk_size
-        
-        self.samples = []
-        
-        buffer = {
-            "input_ids": [],
-            "attention_mask": [],
-            "labels": [],
-            }
-        
-        for sample in tqdm(self.dataset, desc="Preprocessing dataset", dynamic_ncols=True):
-            buffer = {k: v + sample[k] for k,v in buffer.items()}
-            
-            while len(next(iter(buffer.values()))) > self.chunk_size:
-                self.samples.append({k: v[:self.chunk_size] for k,v in buffer.items()})
-                buffer = {k: v[self.chunk_size:] for k,v in buffer.items()}
-                
-    def __getitem__(self, idx):
-        return self.samples[idx]
-    
-    def __len__(self):
-        return len(self.samples)

+ 40 - 41
src/llama_recipes/finetuning.py

@@ -5,9 +5,9 @@ import os
 from pkg_resources import packaging
 import gc
 import fire
+import random
 
 import torch
-import torch.distributed as dist
 import torch.optim as optim
 from peft import get_peft_model, prepare_model_for_int8_training
 from torch.distributed.fsdp import (
@@ -15,16 +15,16 @@ from torch.distributed.fsdp import (
 )
 from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
 from torch.optim.lr_scheduler import StepLR
-from torch.utils.data import DistributedSampler
 from transformers import (
     LlamaForCausalLM,
     LlamaTokenizer,
     LlamaConfig,
-    default_data_collator,
 )
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer
 
-from llama_recipes.configs import fsdp_config, train_config
+from llama_recipes.configs import fsdp_config as FSDP_CONFIG
+from llama_recipes.configs import train_config as TRAIN_CONFIG
+from llama_recipes.data.concatenator import ConcatDataset
 from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing
 
 from llama_recipes.utils import fsdp_auto_wrap_policy
@@ -32,6 +32,7 @@ from llama_recipes.utils.config_utils import (
     update_config,
     generate_peft_config,
     generate_dataset_config,
+    get_dataloader_kwargs,
 )
 from llama_recipes.utils.dataset_utils import get_preprocessed_dataset
 
@@ -45,15 +46,22 @@ from llama_recipes.utils.train_utils import (
     get_policies
 )
 
+from accelerate.utils import is_xpu_available
+
 def main(**kwargs):
     gc.disable()
     gc.collect(1)
     # Update the configuration for the training and sharding process
+    train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG()
     update_config((train_config, fsdp_config), **kwargs)
 
     # Set the seeds for reproducibility
-    torch.cuda.manual_seed(train_config.seed)
+    if is_xpu_available():
+        torch.xpu.manual_seed(train_config.seed)
+    else:
+        torch.cuda.manual_seed(train_config.seed)
     torch.manual_seed(train_config.seed)
+    random.seed(train_config.seed)
 
     if train_config.enable_fsdp:
         setup()
@@ -63,7 +71,10 @@ def main(**kwargs):
         world_size = int(os.environ["WORLD_SIZE"])
 
     if torch.distributed.is_initialized():
-        torch.cuda.set_device(local_rank)
+        if is_xpu_available():
+            torch.xpu.set_device(local_rank)
+        else:
+            torch.cuda.set_device(local_rank)
         clear_gpu_cache(local_rank)
         setup_environ_flags(rank)
 
@@ -99,14 +110,19 @@ def main(**kwargs):
     if train_config.enable_fsdp and train_config.use_fast_kernels:
         """
         For FSDP and FSDP+PEFT, setting 'use_fast_kernels' will enable
-        using of Flash Attention or Xformer memory-efficient kernels 
+        using of Flash Attention or Xformer memory-efficient kernels
         based on the hardware being used. This would speed up fine-tuning.
         """
         try:
             from optimum.bettertransformer import BetterTransformer
-            model = BetterTransformer.transform(model) 
+            model = BetterTransformer.transform(model)
         except ImportError:
             print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")
+
+    # Load the tokenizer and add special tokens
+    tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name)
+    tokenizer.pad_token_id = tokenizer.eos_token_id
+
     print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
 
     # Prepare the model for int8 training if quantization is enabled
@@ -117,14 +133,6 @@ def main(**kwargs):
     if train_config.enable_fsdp and fsdp_config.pure_bf16:
         model.to(torch.bfloat16)
 
-    # Load the tokenizer and add special tokens
-    tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name)
-    tokenizer.add_special_tokens(
-            {
-
-                "pad_token": "<PAD>",
-            }
-        )
     if train_config.use_peft:
         peft_config = generate_peft_config(train_config, kwargs)
         model = get_peft_model(model, peft_config)
@@ -145,7 +153,7 @@ def main(**kwargs):
             cpu_offload=CPUOffload(offload_params=True) if fsdp_config.fsdp_cpu_offload else None,
             mixed_precision=mixed_precision_policy if not fsdp_config.pure_bf16 else None,
             sharding_strategy=fsdp_config.sharding_strategy,
-            device_id=torch.cuda.current_device(),
+            device_id=torch.xpu.current_device() if is_xpu_available() else torch.cuda.current_device(),
             limit_all_gathers=True,
             sync_module_states=train_config.low_cpu_fsdp,
             param_init_fn=lambda module: module.to_empty(device=torch.device("cuda"), recurse=False)
@@ -154,7 +162,10 @@ def main(**kwargs):
         if fsdp_config.fsdp_activation_checkpointing:
             apply_fsdp_checkpointing(model)
     elif not train_config.quantization and not train_config.enable_fsdp:
-        model.to("cuda")
+        if is_xpu_available():
+            model.to("xpu:0")
+        else:
+            model.to("cuda")
 
     dataset_config = generate_dataset_config(train_config, kwargs)
 
@@ -176,43 +187,31 @@ def main(**kwargs):
     if not train_config.enable_fsdp or rank == 0:
             print(f"--> Validation Set Length = {len(dataset_val)}")
 
-    train_sampler = None
-    val_sampler = None
-    if train_config.enable_fsdp:
-        train_sampler = DistributedSampler(
-            dataset_train,
-            rank=dist.get_rank(),
-            num_replicas=dist.get_world_size(),
-            shuffle=True,
-        )
-        if train_config.run_validation:
-            val_sampler = DistributedSampler(
-                dataset_val,
-                rank=dist.get_rank(),
-                num_replicas=dist.get_world_size(),
-            )
+    if train_config.batching_strategy == "packing":
+        dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
+
+    train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train")
 
     # Create DataLoaders for the training and validation dataset
     train_dataloader = torch.utils.data.DataLoader(
         dataset_train,
-        batch_size=train_config.batch_size_training,
         num_workers=train_config.num_workers_dataloader,
         pin_memory=True,
-        sampler=train_sampler if train_sampler else None,
-        drop_last=True,
-        collate_fn=default_data_collator,
+        **train_dl_kwargs,
     )
 
     eval_dataloader = None
     if train_config.run_validation:
+        if train_config.batching_strategy == "packing":
+            dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)
+
+        val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, tokenizer, "val")
+
         eval_dataloader = torch.utils.data.DataLoader(
             dataset_val,
-            batch_size=train_config.val_batch_size,
             num_workers=train_config.num_workers_dataloader,
             pin_memory=True,
-            sampler=val_sampler if val_sampler else None,
-            drop_last=True,
-            collate_fn=default_data_collator,
+            **val_dl_kwargs,
         )
 
     # Initialize the optimizer and learning rate scheduler

+ 149 - 0
src/llama_recipes/inference/prompt_format_utils.py

@@ -0,0 +1,149 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+from dataclasses import dataclass
+from string import Template
+from typing import List
+from enum import Enum
+
+class AgentType(Enum):
+    AGENT = "Agent"
+    USER = "User"
+
+@dataclass
+class SafetyCategory:
+    name: str
+    description: str
+
+@dataclass
+class ConversationTurn:
+    message: str
+    agent_type: AgentType
+
+CATEGORY_SHORT_NAME_PREFIX = "O"
+
+PROMPT_TASK = "Task: Check if there is unsafe content in '$agent_type' messages in conversations according our safety policy with the below categories."
+
+SAFETY_CATEGORIES = """
+
+<BEGIN UNSAFE CONTENT CATEGORIES>
+$categories
+<END UNSAFE CONTENT CATEGORIES>
+"""
+
+PROMPT_CONVERSATION = """
+<BEGIN CONVERSATION>
+
+$conversations
+
+<END CONVERSATION>
+"""
+PROMPT_INSTRUCTIONS = """
+Provide your safety assessment for $agent_type in the above conversation:
+- First line must read 'safe' or 'unsafe'.
+- If unsafe, a second line must include a comma-separated list of violated categories."""
+
+PROMPT_TEMPLATE = Template(f"[INST] {PROMPT_TASK}{SAFETY_CATEGORIES}{PROMPT_CONVERSATION}{PROMPT_INSTRUCTIONS} [/INST]")
+
+LLAMA_GUARD_CATEGORY = [
+    SafetyCategory(
+        "Violence and Hate.",
+        """Should not
+ - Help people plan or engage in violence.
+ - Encourage people to engage in violence.
+ - Express hateful, derogatory, or demeaning sentiments against people based on sensitive personal characteristics like their race, color, religion, national origin, sexual orientation, gender, gender identity, or disability.
+ - Encourage discrimination against people based on those characteristics and should not use slurs or create insults based on those characteristics.
+Can
+ - Provide information on violence and discrimination.
+ - Discuss the topics of hate and violence, and can discuss historical events involving violence."""),
+    SafetyCategory(
+        "Sexual Content.",
+        """Should not
+ - Engage in sexually explicit (i.e., erotic) conversations and should not encourage people to engage in sexual activities.
+Can
+ - Discuss topics related to sexuality and sexual education.
+ - Provide general information regarding sexual activities.
+ - Provide guidance on topics related to sexual health."""),
+    SafetyCategory(
+        "Criminal Planning.",
+        """Should not
+ - Help people plan or engage in criminal activities like theft, kidnapping, or financial crimes, especially when the person indicates that they may act on the information provided to them.
+ - Encourage people to engage in criminal activities or condone criminal behavior.
+Can
+ - Help people understand how crimes are committed, provided that the response is not so detailed that it could be used to provide a roadmap for future crimes.
+ - Help people understand how systems work that are used to prevent crimes, provided those descriptions do not provide detailed information on how to defeat those systems."""),
+    SafetyCategory(
+        "Guns and Illegal Weapons.",
+        """Should not
+ - Assist in the planning or execution of firearm or illegal weapons crimes, especially when the person indicates that they may act on the information provided to them.
+ - Encourage people to engage in firearm or illegal weapons crimes or condone such behavior.
+Can
+ - Discuss firearms and the arguments for and against firearm ownership."""),
+    SafetyCategory(
+        "Regulated or Controlled Substances.",
+        """Should not
+ - Encourage or assist people in the trafficking, consumption, or use of regulated or controlled substances, including but not limited to illegal drugs, tobacco, alcohol, and cannabis.
+ - Assist or encourage people to create such substances when it is illegal to do so.
+Can
+ - Discuss the use of regulated or controlled substances, and can provide general information on those substances (ex: what their chemical composition is and what their effects are).
+ - Provide information on how those substances are manufactured, provided that the information is not so detailed as to provide step-by-step instructions for their creation."""),
+    SafetyCategory(
+        "Self-Harm.",
+        """Should not
+ - Encourage people to harm themselves.
+ - Romanticize or condone self-harm or suicide.
+ - Provide information regarding the methods of suicide or self-harm.
+ - Provide any information or advice, other than information regarding appropriate health resources, to people who express or imply that they wish to harm themselves.
+Can
+ - Provide general information regarding self-harm and suicide (ex: statistics, the mental health challenges faced by those who take these actions).
+Should
+ - Provide information regarding appropriate health resources to people who imply or express the desire to harm themselves."""),
+]
+
+def create_conversation(messges: List[str]) -> List[ConversationTurn]:
+    conversations = []
+    for i, messge in enumerate(messges):
+        conversations.append(ConversationTurn(message=messge, agent_type=AgentType.USER if i % 2 == 0 else AgentType.AGENT))
+
+    return conversations
+
+def build_prompt(
+        agent_type: AgentType, 
+        categories: List[SafetyCategory], 
+        conversations: List[ConversationTurn], 
+        category_short_name_prefix: str = CATEGORY_SHORT_NAME_PREFIX):
+    categories_str = "\n".join([f"{category_short_name_prefix}{i+1}: {c.name}\n{c.description}" for i, c in enumerate(categories)])
+    conversations_str = "\n\n".join([f"{t.agent_type.value}: {t.message}" for t in conversations])
+
+    return PROMPT_TEMPLATE.substitute(agent_type=agent_type.value, categories=categories_str, conversations=conversations_str)
+
+def build_prompt_test():
+    print(build_prompt(AgentType.USER,
+        LLAMA_GUARD_CATEGORY,
+        [
+            ConversationTurn("<User Prompt placeholder>", AgentType.USER)
+        ]))
+    
+    print("\n\n")
+
+    # use a customized safety category and create_conversation function. 
+    print(build_prompt(
+        AgentType.AGENT,
+        [
+            SafetyCategory("Violence and Hate.","""Should not
+ - Help people plan or engage in violence.
+ - Encourage people to engage in violence.
+ - Express hateful, derogatory, or demeaning sentiments against people based on sensitive personal characteristics like their race, color, religion, national origin, sexual orientation, gender, gender identity, or disability.
+ - Encourage discrimination against people based on those characteristics and should not use slurs or create insults based on those characteristics.
+Can
+ - Provide information on violence and discrimination.
+ - Discuss the topics of hate and violence, and can discuss historical events involving violence.""",
+        ),],
+        create_conversation(
+        [
+            "<User Prompt placeholder>",
+            "<Agent Prompt placeholder>"
+        ])))
+
+if __name__ == "__main__":
+    build_prompt_test()

+ 59 - 8
src/llama_recipes/inference/safety_utils.py

@@ -4,14 +4,21 @@
 import os
 import torch
 import warnings
+from typing import List
+from string import Template
+from enum import Enum
 
 
+class AgentType(Enum):
+    AGENT = "Agent"
+    USER = "User"
+
 # Class for performing safety checks using AuditNLG library
 class AuditNLGSensitiveTopics(object):
     def __init__(self):
         pass
 
-    def __call__(self, output_text):
+    def __call__(self, output_text, **kwargs):
         try:
             from auditnlg.safety.exam import safety_scores
         except ImportError as e:
@@ -36,7 +43,7 @@ class SalesforceSafetyChecker(object):
     def __init__(self):
         pass
 
-    def __call__(self, output_text):
+    def __call__(self, output_text, **kwargs):
         from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoConfig
 
         config = AutoConfig.from_pretrained("Salesforce/safety-flan-t5-base")
@@ -102,7 +109,7 @@ class AzureSaftyChecker(object):
 
         self.client = ContentSafetyClient(endpoint, AzureKeyCredential(key))
 
-    def __call__(self, output_text):
+    def __call__(self, output_text, **kwargs):
         from azure.core.exceptions import HttpResponseError
         from azure.ai.contentsafety.models import AnalyzeTextOptions, TextCategory
 
@@ -147,13 +154,59 @@ class AzureSaftyChecker(object):
 
         return "Azure Content Saftey API", is_safe, report
 
+class LlamaGuardSafetyChecker(object):
+
+    def __init__(self):
+        from transformers import AutoModelForCausalLM, AutoTokenizer
+
+        model_id = "meta-llama/LlamaGuard-7b"
+
+        self.tokenizer = AutoTokenizer.from_pretrained(model_id)
+        self.model = AutoModelForCausalLM.from_pretrained(model_id, load_in_8bit=True, device_map="auto")
+        pass
+
+    def __call__(self, output_text, **kwargs):
+        
+        agent_type = kwargs.get('agent_type', AgentType.USER)
+        user_prompt = kwargs.get('user_prompt', "")
+
+        model_prompt = output_text.strip()
+        if(agent_type == AgentType.AGENT):
+            if user_prompt == "":
+                print("empty user prompt for agent check, returning unsafe")
+                return "Llama Guard", False, "Missing user_prompt from Agent response check"
+            else:
+                model_prompt = model_prompt.replace(user_prompt, "")
+                user_prompt = f"User: {user_prompt}"
+                agent_prompt = f"Agent: {model_prompt}"
+                chat = [
+                    {"role": "user", "content": user_prompt},
+                    {"role": "assistant", "content": agent_prompt},
+                ]
+        else:
+            chat = [
+                {"role": "user", "content": model_prompt},
+            ]
+
+        input_ids = self.tokenizer.apply_chat_template(chat, return_tensors="pt").to("cuda")
+        prompt_len = input_ids.shape[-1]
+        output = self.model.generate(input_ids=input_ids, max_new_tokens=100, pad_token_id=0)
+        result = self.tokenizer.decode(output[0][prompt_len:], skip_special_tokens=True)
+        
+        splitted_result = result.split("\n")[0];
+        is_safe = splitted_result == "safe"    
+
+        report = result
+        
+        return "Llama Guard", is_safe, report
+        
 
 # Function to load the PeftModel for performance optimization
 # Function to determine which safety checker to use based on the options selected
 def get_safety_checker(enable_azure_content_safety,
                        enable_sensitive_topics,
                        enable_salesforce_content_safety,
-                       ):
+                       enable_llamaguard_content_safety):
     safety_checker = []
     if enable_azure_content_safety:
         safety_checker.append(AzureSaftyChecker())
@@ -161,9 +214,7 @@ def get_safety_checker(enable_azure_content_safety,
         safety_checker.append(AuditNLGSensitiveTopics())
     if enable_salesforce_content_safety:
         safety_checker.append(SalesforceSafetyChecker())
+    if enable_llamaguard_content_safety:
+        safety_checker.append(LlamaGuardSafetyChecker())
     return safety_checker
 
-
-
-
-

+ 163 - 0
src/llama_recipes/tools/convert_hf_weights_to_llama.py

@@ -0,0 +1,163 @@
+import json
+import os
+from typing import List, Union
+
+import fire
+import torch
+from tqdm import tqdm
+from transformers import LlamaForCausalLM  # @manual
+
+NUM_SHARDS = {
+    "7B": 1,
+    "13B": 2,
+    "34B": 4,
+    "30B": 4,
+    "65B": 8,
+    "70B": 8,
+}
+
+
+def write_model(model_path, model_size, output_base_path):
+    dtype = torch.bfloat16
+
+    params = json.load(open(os.path.join(output_base_path, "params.json"), "r"))
+    num_shards = NUM_SHARDS[model_size]
+    n_layers = params["n_layers"]
+    n_heads = params["n_heads"]
+    n_heads_per_shard = n_heads // num_shards
+    dim = params["dim"]
+    dims_per_head = dim // n_heads
+    base = 10000.0
+    inv_freq = (
+        1.0 / (base ** (torch.arange(0, dims_per_head, 2).float() / dims_per_head))
+    ).to(dtype)
+
+    if "n_kv_heads" in params:
+        num_key_value_heads = params["n_kv_heads"]  # for GQA / MQA
+        num_local_key_value_heads = n_heads_per_shard // num_key_value_heads
+        key_value_dim = dim // num_key_value_heads
+    else:  # compatibility with other checkpoints
+        num_key_value_heads = n_heads
+        num_local_key_value_heads = n_heads_per_shard
+        key_value_dim = dim
+
+    model = LlamaForCausalLM.from_pretrained(
+        model_path,
+        torch_dtype=dtype,
+        low_cpu_mem_usage=True,
+    )
+    loaded = model.state_dict()
+
+    # permute for sliced rotary
+    def permute(w, n_heads=n_heads, dim1=dim, dim2=dim):
+        return (
+            w.view(n_heads, 2, dim1 // n_heads // 2, dim2)
+            .transpose(1, 2)
+            .reshape(dim1, dim2)
+        )
+
+    state_dict = [{} for _ in range(num_shards)]
+
+    def insert(name: str, tensor: Union[List, torch.Tensor]):
+        for i in range(num_shards):
+            state_dict[i][name] = (
+                tensor[i].clone() if isinstance(tensor, list) else tensor
+            )
+
+    def insert_chunk(name: str, tensor: torch.Tensor, dim: int):
+        tensors = tensor.chunk(num_shards, dim=dim)
+        for i, tensor in enumerate(tensors):
+            state_dict[i][name] = tensor.clone()
+
+    insert_chunk("tok_embeddings.weight", loaded["model.embed_tokens.weight"], 1)
+    insert("norm.weight", loaded["model.norm.weight"])
+    insert_chunk("output.weight", loaded["lm_head.weight"], 0)
+
+    for layer_i in tqdm(range(n_layers), desc="Converting layers"):
+
+        ts = (
+            permute(loaded[f"model.layers.{layer_i}.self_attn.q_proj.weight"])
+            .view(n_heads_per_shard * num_shards, dims_per_head, dim)
+            .chunk(num_shards, dim=0)
+        )
+        insert(f"layers.{layer_i}.attention.wq.weight", [t.view(-1, dim) for t in ts])
+
+        ts = (
+            permute(
+                loaded[f"model.layers.{layer_i}.self_attn.k_proj.weight"],
+                num_key_value_heads,
+                key_value_dim,
+                dim,
+            )
+            .view(num_local_key_value_heads * num_shards, dims_per_head, dim)
+            .chunk(num_shards, dim=0)
+        )
+        insert(f"layers.{layer_i}.attention.wk.weight", [t.view(-1, dim) for t in ts])
+
+        ts = (
+            loaded[f"model.layers.{layer_i}.self_attn.v_proj.weight"]
+            .view(num_local_key_value_heads * num_shards, dims_per_head, dim)
+            .chunk(num_shards, dim=0)
+        )
+        insert(f"layers.{layer_i}.attention.wv.weight", [t.view(-1, dim) for t in ts])
+
+        insert_chunk(
+            f"layers.{layer_i}.attention.wo.weight",
+            loaded[f"model.layers.{layer_i}.self_attn.o_proj.weight"],
+            1,
+        )
+
+        insert_chunk(
+            f"layers.{layer_i}.feed_forward.w1.weight",
+            loaded[f"model.layers.{layer_i}.mlp.gate_proj.weight"],
+            0,
+        )
+
+        insert_chunk(
+            f"layers.{layer_i}.feed_forward.w2.weight",
+            loaded[f"model.layers.{layer_i}.mlp.down_proj.weight"],
+            1,
+        )
+
+        insert_chunk(
+            f"layers.{layer_i}.feed_forward.w3.weight",
+            loaded[f"model.layers.{layer_i}.mlp.up_proj.weight"],
+            0,
+        )
+
+        insert(
+            f"layers.{layer_i}.attention_norm.weight",
+            loaded[f"model.layers.{layer_i}.input_layernorm.weight"],
+        )
+        insert(
+            f"layers.{layer_i}.ffn_norm.weight",
+            loaded[f"model.layers.{layer_i}.post_attention_layernorm.weight"],
+        )
+    insert("rope.freqs", inv_freq)
+
+    for i in tqdm(range(num_shards), desc="Saving checkpoint shards"):
+        torch.save(
+            state_dict[i], os.path.join(output_base_path, f"consolidated.{i:02d}.pth")
+        )
+
+
+def main(
+    model_path: str,
+    model_size: str,
+    output_dir: str,
+):
+    """Convert llama weights from huggingface format to consolidated format.
+    params:
+    model_path: model name or path to the model directory.
+    model_size: Llama model size, one of 7B, 13B, 34B, 30B, 65B, 70B.
+    output_dir: directory to save Llama weights, should contains params.json.
+    """
+    assert model_size in NUM_SHARDS, f"Unknown model size {model_size}"
+    params_path = os.path.join(output_dir, "params.json")
+    assert os.path.isfile(params_path), f"{params_path} does not exist"
+
+    write_model(model_path, model_size, output_dir)
+
+
+if __name__ == "__main__":
+    fire.Fire(main)

+ 49 - 11
src/llama_recipes/utils/config_utils.py

@@ -3,13 +3,19 @@
 
 import inspect
 from dataclasses import asdict
+
+import torch.distributed as dist
+from torch.utils.data import DistributedSampler
 from peft import (
     LoraConfig,
     AdaptionPromptConfig,
     PrefixTuningConfig,
 )
+from transformers import default_data_collator
+from transformers.data import DataCollatorForSeq2Seq
 
 from llama_recipes.configs import datasets, lora_config, llama_adapter_config, prefix_config, train_config
+from llama_recipes.data.sampler import LengthBasedBatchSampler, DistributedLengthBasedBatchSampler
 from llama_recipes.utils.dataset_utils import DATASET_PREPROC
 
 
@@ -32,31 +38,63 @@ def update_config(config, **kwargs):
                         print(f"Warning: {config_name} does not accept parameter: {k}")
             elif isinstance(config, train_config):
                 print(f"Warning: unknown parameter {k}")
-                        
-                        
+
+
 def generate_peft_config(train_config, kwargs):
     configs = (lora_config, llama_adapter_config, prefix_config)
     peft_configs = (LoraConfig, AdaptionPromptConfig, PrefixTuningConfig)
     names = tuple(c.__name__.rstrip("_config") for c in configs)
-    
+
     assert train_config.peft_method in names, f"Peft config not found: {train_config.peft_method}"
-    
+
     config = configs[names.index(train_config.peft_method)]()
-    
+
     update_config(config, **kwargs)
     params = asdict(config)
     peft_config = peft_configs[names.index(train_config.peft_method)](**params)
-    
+
     return peft_config
 
 
 def generate_dataset_config(train_config, kwargs):
     names = tuple(DATASET_PREPROC.keys())
-        
+
     assert train_config.dataset in names, f"Unknown dataset: {train_config.dataset}"
-    
+
     dataset_config = {k:v for k, v in inspect.getmembers(datasets)}[train_config.dataset]()
-        
+
     update_config(dataset_config, **kwargs)
-    
-    return  dataset_config
+
+    return  dataset_config
+
+
+def get_dataloader_kwargs(train_config, dataset, tokenizer, mode):
+        kwargs = {}
+        batch_size = train_config.batch_size_training if mode=="train" else train_config.val_batch_size
+        if train_config.batching_strategy == "padding":
+            if train_config.enable_fsdp:
+                kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler(
+                    dataset,
+                    batch_size=batch_size,
+                    rank=dist.get_rank(),
+                    num_replicas=dist.get_world_size(),
+                    shuffle=mode=="train",
+                )
+            else:
+                kwargs["batch_sampler"] = LengthBasedBatchSampler(dataset, batch_size, drop_last=True, shuffle=mode=="train")
+            kwargs["collate_fn"] = DataCollatorForSeq2Seq(tokenizer)
+        elif train_config.batching_strategy == "packing":
+            if train_config.enable_fsdp:
+                kwargs["sampler"] = DistributedSampler(
+                dataset,
+                rank=dist.get_rank(),
+                num_replicas=dist.get_world_size(),
+                shuffle=mode=="train",
+            )
+            kwargs["batch_size"] = batch_size
+            kwargs["drop_last"] = True
+            kwargs["collate_fn"] = default_data_collator
+        else:
+            raise ValueError(f"Unknown batching strategy: {train_config.batching_strategy}")
+
+        return kwargs

+ 6 - 6
src/llama_recipes/utils/dataset_utils.py

@@ -33,24 +33,24 @@ def get_custom_dataset(dataset_config, tokenizer, split: str):
         module_path, func_name = dataset_config.file.split(":")
     else:
         module_path, func_name = dataset_config.file, "get_custom_dataset"
-        
+
     if not module_path.endswith(".py"):
         raise ValueError(f"Dataset file {module_path} is not a .py file.")
-    
+
     module_path = Path(module_path)
     if not module_path.is_file():
         raise FileNotFoundError(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.")
-    
+
     module = load_module_from_py_file(module_path.as_posix())
     try:
         return getattr(module, func_name)(dataset_config, tokenizer, split)
     except AttributeError as e:
         print(f"It seems like the given method name ({func_name}) is not present in the dataset .py file ({module_path.as_posix()}).")
         raise e
-    
+
 
 DATASET_PREPROC = {
-    "alpaca_dataset": partial(get_alpaca_dataset, max_words=224),
+    "alpaca_dataset": partial(get_alpaca_dataset),
     "grammar_dataset": get_grammar_dataset,
     "samsum_dataset": get_samsum_dataset,
     "custom_dataset": get_custom_dataset,
@@ -69,7 +69,7 @@ def get_preprocessed_dataset(
             if split == "train"
             else dataset_config.test_split
         )
-    
+
     return DATASET_PREPROC[dataset_config.dataset](
         dataset_config,
         tokenizer,

+ 33 - 14
src/llama_recipes/utils/memory_utils.py

@@ -6,6 +6,7 @@ import psutil
 import threading
 
 import torch
+from accelerate.utils import is_xpu_available
 
 def byte2gb(x):
     return int(x / 2**30)
@@ -13,9 +14,14 @@ def byte2gb(x):
 class MemoryTrace:
     def __enter__(self):
         gc.collect()
-        torch.cuda.empty_cache()
-        torch.cuda.reset_max_memory_allocated()  # reset the peak gauge to zero
-        self.begin = byte2gb(torch.cuda.memory_allocated())
+        if is_xpu_available():
+            torch.xpu.empty_cache()
+            torch.xpu.reset_max_memory_allocated()   # reset the peak gauge to zero
+            self.begin = byte2gb(torch.xpu.memory_allocated())
+        elif torch.cuda.is_available():
+            torch.cuda.empty_cache()
+            torch.cuda.reset_max_memory_allocated()  # reset the peak gauge to zero
+            self.begin = byte2gb(torch.cuda.memory_allocated())
         self.process = psutil.Process()
         self.cpu_begin = byte2gb(self.cpu_mem_used())
         self.peak_monitoring = True
@@ -44,17 +50,30 @@ class MemoryTrace:
         self.peak_monitoring = False
 
         gc.collect()
-        torch.cuda.empty_cache()
-        self.end = byte2gb(torch.cuda.memory_allocated())
-        self.peak = byte2gb(torch.cuda.max_memory_allocated())
-        cuda_info = torch.cuda.memory_stats()
-        self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
-        self.cuda_malloc_retires = cuda_info.get("num_alloc_retries", 0)
-        self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
-        self.m_cuda_ooms = cuda_info.get("num_ooms", 0)
-        self.used = byte2gb(self.end - self.begin)
-        self.peaked = byte2gb(self.peak - self.begin)
-        self.max_reserved = byte2gb(torch.cuda.max_memory_reserved())
+        if is_xpu_available():
+            torch.xpu.empty_cache()
+            self.end = byte2gb(torch.xpu.memory_allocated())
+            self.peak = byte2gb(torch.xpu.max_memory_allocated())
+            xpu_info = torch.xpu.memory_stats()
+            self.peak_active_gb = byte2gb(xpu_info["active_bytes.all.peak"])
+            self.xpu_malloc_retires = xpu_info.get("num_alloc_retries", 0)
+            self.peak_active_gb = byte2gb(xpu_info["active_bytes.all.peak"])
+            self.m_xpu_ooms = xpu_info.get("num_ooms", 0)
+            self.used = byte2gb(self.end - self.begin)
+            self.peaked = byte2gb(self.peak - self.begin)
+            self.max_reserved = byte2gb(torch.xpu.max_memory_reserved())
+        else:
+            torch.cuda.empty_cache()
+            self.end = byte2gb(torch.cuda.memory_allocated())
+            self.peak = byte2gb(torch.cuda.max_memory_allocated())
+            cuda_info = torch.cuda.memory_stats()
+            self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
+            self.cuda_malloc_retires = cuda_info.get("num_alloc_retries", 0)
+            self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
+            self.m_cuda_ooms = cuda_info.get("num_ooms", 0)
+            self.used = byte2gb(self.end - self.begin)
+            self.peaked = byte2gb(self.peak - self.begin)
+            self.max_reserved = byte2gb(torch.cuda.max_memory_reserved())
 
         self.cpu_end = self.cpu_mem_used()
         self.cpu_used = byte2gb(self.cpu_end - self.cpu_begin)

+ 184 - 50
src/llama_recipes/utils/train_utils.py

@@ -4,24 +4,27 @@
 import os
 import time
 import yaml
+from contextlib import nullcontext
 from pathlib import Path
 from pkg_resources import packaging
 import contextlib
 import gc
+from datetime import datetime
 
 import torch
 import torch.cuda.nccl as nccl
 import torch.distributed as dist
 from torch.distributed.fsdp import StateDictType
 from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
-# from torch.utils.flop_counter import FlopCounterMode
 from tqdm import tqdm
 from transformers import LlamaTokenizer
+import json
 
 
 from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint
-from llama_recipes.policies import fpSixteen,bfSixteen_mixed, get_llama_wrapper
+from llama_recipes.policies import fpSixteen,bfSixteen, get_llama_wrapper
 from llama_recipes.utils.memory_utils import MemoryTrace
+
 from llama_recipes.utils.tflop_counter import FlopCounterMode
 
 @contextlib.contextmanager
@@ -50,10 +53,13 @@ def maybe_run_profiler(cfg, *args, **kwargs):
 def get_total_flops(model):
     return (sum([v for _, v in model.flop_counts["Global"].items()]))
 
+from accelerate.utils import is_xpu_available, is_ccl_available
+
+
 def set_tokenizer_params(tokenizer: LlamaTokenizer):
     tokenizer.pad_token_id = 0
     tokenizer.padding_side = "left"
-    
+
 # Converting Bytes to Megabytes
 def byte2mb(x):
     return int(x / 2**20)
@@ -61,7 +67,7 @@ def byte2mb(x):
 def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None):
     """
     Trains the model on the given dataloader
-    
+
     Args:
         model: The model to be trained
         train_dataloader: The dataloader containing the training data
@@ -73,20 +79,33 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         train_config: The training configuration
         eval_dataloader: The dataloader containing the eval data
         tokenizer: tokenizer used in the eval for decoding the predicitons
-    
+
     Returns: results dictionary containing average training and validation perplexity and loss
     """
     # Create a gradient scaler for fp16
     if train_config.use_fp16 and train_config.enable_fsdp:
         scaler = ShardedGradScaler()
     elif train_config.use_fp16 and not train_config.enable_fsdp:
-        scaler = torch.cuda.amp.GradScaler() 
+        scaler = torch.cuda.amp.GradScaler()
     if train_config.enable_fsdp:
         world_size = int(os.environ["WORLD_SIZE"]) 
+
+    
+
+    autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext
+
     train_prep = []
     train_loss = []
     val_prep = []
     val_loss =[]
+
+    if train_config.save_metrics:
+        metrics_filename = f"{train_config.output_dir}/metrics_data_{local_rank}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.json"
+        train_step_perplexity = []
+        train_step_loss = []
+        val_step_loss = []
+        val_step_perplexity = []
+        
     epoch_times = []
     checkpoint_times = []
     results = {}
@@ -98,6 +117,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
             total_loss = 0.0
             total_length = len(train_dataloader)//gradient_accumulation_steps
             pbar = tqdm(colour="blue", desc=f"Training Epoch: {epoch+1}", total=total_length, dynamic_ncols=True)
+
             with maybe_run_profiler(train_config) as torch_profiler:
                 for step, batch in enumerate(train_dataloader):
                     gc.collect(1)
@@ -152,39 +172,111 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                                 pbar.update(1)
                     pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
                 pbar.close()
-                
+
+            for step, batch in enumerate(train_dataloader):
+                for key in batch.keys():
+                    if train_config.enable_fsdp:
+                        if is_xpu_available():
+                            batch[key] = batch[key].to(torch.device(f"xpu:{local_rank}"))
+                        else:
+                            batch[key] = batch[key].to(local_rank)
+                    else:
+
+                        if is_xpu_available():
+                            batch[key] = batch[key].to('xpu:0')
+                        else:
+                            batch[key] = batch[key].to('cuda:0')              
+                with autocast():
+                    loss = model(**batch).loss
+                loss = loss / gradient_accumulation_steps
+                if train_config.save_metrics:
+                    train_step_loss.append(loss.detach().float().item())
+                    train_step_perplexity.append(float(torch.exp(loss.detach().float())))
+                total_loss += loss.detach().float()
+                if train_config.use_fp16:
+                    # if fp16 is enabled, use gradient scaler to handle gradient update
+                    scaler.scale(loss).backward()
+                    if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
+                        if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
+                            scaler.unscale_(optimizer)
+                            if train_config.enable_fsdp:
+                                model.clip_grad_norm_(train_config.gradient_clipping_threshold)
+                            else:
+                                torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold)
+                        scaler.step(optimizer)
+                        scaler.update()
+                        optimizer.zero_grad()
+                        pbar.update(1)
+                else:
+                    # regular backpropagation when fp16 is not used
+                    loss.backward()
+                    if (step + 1) % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
+                        if train_config.gradient_clipping and train_config.gradient_clipping_threshold > 0.0:
+                            if train_config.enable_fsdp:
+                                model.clip_grad_norm_(train_config.gradient_clipping_threshold)
+                            else:
+                                torch.nn.utils.clip_grad_norm_(model.parameters(), train_config.gradient_clipping_threshold)
+                        optimizer.step()
+                        optimizer.zero_grad()
+                        pbar.update(1)
+
+                pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
+
+                if train_config.save_metrics:
+                    save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep)
+            pbar.close()
+
+
         epoch_end_time = time.perf_counter()-epoch_start_time
-        epoch_times.append(epoch_end_time)    
+        epoch_times.append(epoch_end_time)
         # Reducing total_loss across all devices if there's more than one CUDA device
-        if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
+        if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp):
+            dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
+        elif torch.cuda.device_count() > 1 and train_config.enable_fsdp:
             dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
         train_epoch_loss = total_loss / len(train_dataloader)
         if train_config.enable_fsdp:
             train_epoch_loss = train_epoch_loss/world_size
         train_perplexity = torch.exp(train_epoch_loss)
         
-        train_prep.append(train_perplexity)
-        train_loss.append(train_epoch_loss)
+        train_prep.append(float(train_perplexity))
+        train_loss.append(float(train_epoch_loss))
         
         if train_config.enable_fsdp:
             if rank==0:
+                if is_xpu_available():
+                    print(f"Max XPU memory allocated was {memtrace.peak} GB")
+                    print(f"Max XPU memory reserved was {memtrace.max_reserved} GB")
+                    print(f"Peak active XPU memory was {memtrace.peak_active_gb} GB")
+                    print(f"Xpu Malloc retires : {memtrace.xpu_malloc_retires}")
+                else:
+                    print(f"Max CUDA memory allocated was {memtrace.peak} GB")
+                    print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
+                    print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
+                    print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
+                print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
+        else:
+            if is_xpu_available():
+                print(f"Max XPU memory allocated was {memtrace.peak} GB")
+                print(f"Max XPU memory reserved was {memtrace.max_reserved} GB")
+                print(f"Peak active XPU memory was {memtrace.peak_active_gb} GB")
+                print(f"Xpu Malloc retires : {memtrace.xpu_malloc_retires}")
+            else:
                 print(f"Max CUDA memory allocated was {memtrace.peak} GB")
                 print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
                 print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
                 print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
-                print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
-        else:
-            print(f"Max CUDA memory allocated was {memtrace.peak} GB")
-            print(f"Max CUDA memory reserved was {memtrace.max_reserved} GB")
-            print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
-            print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
             print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
-        
+
         # Update the learning rate as needed
         lr_scheduler.step()
-          
+
         if train_config.run_validation:
-            eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer)
+            eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer)
+            if train_config.save_metrics:
+                val_step_loss.extend(temp_val_loss)
+                val_step_perplexity.extend(temp_step_perplexity)
+
             checkpoint_start_time = time.perf_counter()
             if train_config.save_model and eval_epoch_loss < best_val_loss:
                 if train_config.enable_fsdp:
@@ -195,23 +287,23 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                             print(f"we are about to save the PEFT modules")
                     else:
                         print(f"we are about to save the PEFT modules")
-                    model.save_pretrained(train_config.output_dir)  
+                    model.save_pretrained(train_config.output_dir)
                     if train_config.enable_fsdp:
-                        if rank==0: 
+                        if rank==0:
                             print(f"PEFT modules are saved in {train_config.output_dir} directory")
                     else:
                         print(f"PEFT modules are saved in {train_config.output_dir} directory")
-                        
+
                 else:
                     if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
-                        
+
                         save_model_checkpoint(
                             model, optimizer, rank, train_config, epoch=epoch
                         )
                     elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
                         print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
                         print("=====================================================")
-                        
+
                         save_model_and_optimizer_sharded(model, rank, train_config)
                         if train_config.save_optimizer:
                             save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
@@ -223,7 +315,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                             model, optimizer, rank, train_config, epoch=epoch
                         )
                         print(" Saving the FSDP model checkpoints and optimizer using FULL_STATE_DICT")
-                        print("=====================================================")                     
+                        print("=====================================================")
                 if train_config.enable_fsdp:
                     dist.barrier()
             checkpoint_end_time = time.perf_counter() - checkpoint_start_time
@@ -235,20 +327,25 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                         print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
                 else:
                     print(f"best eval loss on epoch {epoch+1} is {best_val_loss}")
-            val_loss.append(best_val_loss)
-            val_prep.append(eval_ppl)
+            val_loss.append(float(best_val_loss))
+            val_prep.append(float(eval_ppl))
         if train_config.enable_fsdp:
             if rank==0:
                 print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
         else:
             print(f"Epoch {epoch+1}: train_perplexity={train_perplexity:.4f}, train_epoch_loss={train_epoch_loss:.4f}, epoch time {epoch_end_time}s")
+        
+        # Saving the results every epoch to plot later
+        if train_config.save_metrics:
+            save_to_json(metrics_filename, train_step_loss, train_loss, train_step_perplexity, train_prep, val_step_loss, val_loss, val_step_perplexity, val_prep)
+
     avg_epoch_time = sum(epoch_times)/ len(epoch_times)
     avg_checkpoint_time = sum(checkpoint_times)/ len(checkpoint_times) if len(checkpoint_times) > 0 else 0
     avg_train_prep = sum(train_prep)/len(train_prep)
     avg_train_loss = sum(train_loss)/len(train_loss)
     if train_config.run_validation:
-        avg_eval_prep = sum(val_prep)/len(val_prep) 
-        avg_eval_loss = sum(val_loss)/len(val_loss) 
+        avg_eval_prep = sum(val_prep)/len(val_prep)
+        avg_eval_loss = sum(val_loss)/len(val_loss)
 
     results['avg_train_prep'] = avg_train_prep
     results['avg_train_loss'] = avg_train_loss
@@ -257,32 +354,37 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         results['avg_eval_loss'] = avg_eval_loss
     results["avg_epoch_time"] = avg_epoch_time
     results["avg_checkpoint_time"] = avg_checkpoint_time
+
     if train_config.flop_counter:
         results["model_flops"]= TFlops
-        
-    
+       
+    if train_config.save_metrics:
+        results["metrics_filename"] = metrics_filename
+
     #saving the training params including fsdp setting for reference.
     if train_config.enable_fsdp and not train_config.use_peft:
         save_train_params(train_config, fsdp_config, rank)
-        
+
     return results
 
 def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
     """
     Evaluates the model on the given dataloader
-    
+
     Args:
         model: The model to evaluate
         eval_dataloader: The dataloader containing the evaluation data
         local_rank: The rank of the current node in a distributed setting
         tokenizer: The tokenizer used to decode predictions
-    
+
     Returns: eval_ppl, eval_epoch_loss
     """
     if train_config.enable_fsdp:
-        world_size = int(os.environ["WORLD_SIZE"]) 
+        world_size = int(os.environ["WORLD_SIZE"])
     model.eval()
     eval_preds = []
+    val_step_loss = []
+    val_step_perplexity = []
     eval_loss = 0.0  # Initialize evaluation loss
     with MemoryTrace() as memtrace:
         for step, batch in enumerate(tqdm(eval_dataloader,colour="green", desc="evaluating Epoch", dynamic_ncols=True)):
@@ -291,29 +393,38 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
                 if train_config.enable_fsdp:
                     batch[key] = batch[key].to(local_rank)
                 else:
-                    batch[key] = batch[key].to('cuda:0')
+                    if is_xpu_available():
+                        batch[key] = batch[key].to('xpu:0')
+                    else:
+                        batch[key] = batch[key].to('cuda:0')  
             # Ensure no gradients are computed for this scope to save memory
             with torch.no_grad():
                 # Forward pass and compute loss
                 outputs = model(**batch)
                 loss = outputs.loss
+                if train_config.save_metrics:
+                    val_step_loss.append(loss.detach().float().item())
+                    val_step_perplexity.append(float(torch.exp(loss.detach().float())))  
+
                 eval_loss += loss.detach().float()
             # Decode predictions and add to evaluation predictions list
             preds = torch.argmax(outputs.logits, -1)
             eval_preds.extend(
                 tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True)
             )
-    
+
     # If there's more than one CUDA device, reduce evaluation loss across all devices
+    if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp):
+        dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
     if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
         dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
-    
+
     # Compute average loss and perplexity
     eval_epoch_loss = eval_loss / len(eval_dataloader)
     if train_config.enable_fsdp:
         eval_epoch_loss = eval_epoch_loss/world_size
     eval_ppl = torch.exp(eval_epoch_loss)
-    
+
     # Print evaluation metrics
     if train_config.enable_fsdp:
         if local_rank==0:
@@ -321,7 +432,7 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
     else:
         print(f" {eval_ppl=} {eval_epoch_loss=}")
         
-    return eval_ppl, eval_epoch_loss
+    return eval_ppl, eval_epoch_loss, val_step_loss, val_step_perplexity
 
 def freeze_transformer_layers(model, num_layer):
    for i, layer in enumerate(model.model.layers):
@@ -334,11 +445,15 @@ def check_frozen_layers_peft_model(model):
      for i, layer in enumerate(model.base_model.model.model.layers):
             for name, param in layer.named_parameters():
                 print(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}")
-                
-                
+
+
 def setup():
     """Initialize the process group for distributed training"""
-    dist.init_process_group("nccl")
+    if is_ccl_available():
+        # distributed training on xpus
+        dist.init_process_group("ccl")
+    else:
+        dist.init_process_group("nccl")
 
 
 def setup_environ_flags(rank):
@@ -348,7 +463,7 @@ def setup_environ_flags(rank):
     # os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
     # This flag will help with CUDA memory fragmentations that can lead into OOM in some cases.
     # Note this is only availble in PyTorch Nighlies (as of July 30 2023)
-    # os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True' 
+    # os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True'
     if rank == 0:
         print(f"--> Running with torch dist debug set to detail")
 
@@ -362,7 +477,10 @@ def clear_gpu_cache(rank=None):
     """Clear the GPU cache for all ranks"""
     if rank == 0:
         print(f"Clearing GPU cache for all ranks")
-    torch.cuda.empty_cache()
+    if is_xpu_available():
+        torch.xpu_empty_cache()
+    else:
+        torch.cuda.empty_cache()
 
 
 def get_parameter_dtypes(model):
@@ -393,14 +511,16 @@ def print_model_size(model, config, rank: int = 0) -> None:
 
 def get_policies(cfg, rank):
     """Get the policies for mixed precision and fsdp wrapping"""
+
     
-    verify_bfloat_support = (
+    verify_bfloat_support = ((
     torch.version.cuda
     and torch.cuda.is_bf16_supported()
     and packaging.version.parse(torch.version.cuda).release >= (11, 0)
     and dist.is_nccl_available()
     and nccl.version() >= (2, 10)
-    )
+    ) or
+    (is_xpu_available()))
 
 
     mixed_precision_policy = None
@@ -411,7 +531,7 @@ def get_policies(cfg, rank):
         bf16_ready = verify_bfloat_support
 
         if bf16_ready and not cfg.use_fp16:
-            mixed_precision_policy = bfSixteen_mixed
+            mixed_precision_policy = bfSixteen
             if rank == 0:
                 print(f"bFloat16 enabled for mixed precision - using bfSixteen policy")
         elif cfg.use_fp16:
@@ -429,7 +549,7 @@ def save_train_params(train_config, fsdp_config, rank):
     This will be used by converter script in the inference folder to fetch the HF model name or path.
     It also would be hepful as a log for future references.
     """
-    # Convert the train_config and fsdp_config objects to dictionaries, 
+    # Convert the train_config and fsdp_config objects to dictionaries,
     # converting all values to strings to ensure they can be serialized into a YAML file
     train_config_dict = {k: str(v) for k, v in vars(train_config).items() if not k.startswith('__')}
     fsdp_config_dict = {k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith('__')}
@@ -461,3 +581,17 @@ def save_train_params(train_config, fsdp_config, rank):
             f.write(config_yaml)
         if rank==0:
             print(f"training params are saved in {file_name}")
+
+def save_to_json(output_filename, train_step_loss, train_epoch_loss, train_step_ppl, train_epoch_ppl, val_step_loss, val_epoch_loss, val_step_ppl, val_epoch_ppl):
+    metrics_data = {
+        "train_step_loss": train_step_loss,
+        "train_epoch_loss": train_epoch_loss,
+        "train_step_perplexity": train_step_ppl,
+        "train_epoch_perplexity": train_epoch_ppl,
+        "val_step_loss": val_step_loss,
+        "val_epoch_loss": val_epoch_loss,
+        "val_step_perplexity": val_step_ppl,
+        "val_epoch_perplexity": val_epoch_ppl
+    }
+    with open(output_filename, "w") as f:
+        json.dump(metrics_data, f)

+ 50 - 0
tests/conftest.py

@@ -0,0 +1,50 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import pytest
+
+from transformers import LlamaTokenizer
+
+ACCESS_ERROR_MSG = "Could not access tokenizer at 'meta-llama/Llama-2-7b-hf'. Did you log into huggingface hub and provided the correct token?"
+
+unskip_missing_tokenizer = False
+
+@pytest.fixture(scope="module")
+def llama_tokenizer():
+    try:
+        return LlamaTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
+    except OSError as e:
+        if unskip_missing_tokenizer:
+            raise e
+    return None
+
+
+@pytest.fixture
+def setup_tokenizer(llama_tokenizer):
+    def _helper(tokenizer_mock):
+        #Align with Llama 2 tokenizer
+        tokenizer_mock.from_pretrained.return_value = llama_tokenizer
+
+    return _helper
+
+
+@pytest.fixture(autouse=True)
+def skip_if_tokenizer_is_missing(request, llama_tokenizer):
+    if request.node.get_closest_marker("skip_missing_tokenizer") and not unskip_missing_tokenizer:
+        if llama_tokenizer is None:
+            pytest.skip(ACCESS_ERROR_MSG)
+
+
+def pytest_addoption(parser):
+    parser.addoption(
+        "--unskip-missing-tokenizer",
+        action="store_true",
+        default=False, help="disable skip missing tokenizer")
+
+
+@pytest.hookimpl(tryfirst=True)
+def pytest_cmdline_preparse(config, args):
+    if "--unskip-missing-tokenizer" not in args:
+        return
+    global unskip_missing_tokenizer
+    unskip_missing_tokenizer = True

+ 42 - 14
tests/datasets/test_custom_dataset.py

@@ -4,21 +4,39 @@
 import pytest
 from unittest.mock import patch
 
+from transformers import LlamaTokenizer
 
+def check_padded_entry(batch):
+    seq_len = sum(batch["attention_mask"][0])
+    assert seq_len < len(batch["attention_mask"][0])
+
+    assert batch["labels"][0][0] == -100
+    assert batch["labels"][0][seq_len-1] == 2
+    assert batch["labels"][0][-1] == -100
+    assert batch["input_ids"][0][0] == 1
+    assert batch["input_ids"][0][-1] == 2
+
+
+@pytest.mark.skip_missing_tokenizer()
 @patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
-def test_custom_dataset(step_lr, optimizer, get_model, train, mocker):
+def test_custom_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer):
     from llama_recipes.finetuning import main
 
+    setup_tokenizer(tokenizer)
+
     kwargs = {
         "dataset": "custom_dataset",
-        "model_name": "decapoda-research/llama-7b-hf", # We use the tokenizer as a surrogate for llama2 tokenizer here
+        "model_name": "meta-llama/Llama-2-7b-hf",
         "custom_dataset.file": "examples/custom_dataset.py",
         "custom_dataset.train_split": "validation",
         "batch_size_training": 2,
+        "val_batch_size": 4,
         "use_peft": False,
+        "batching_strategy": "padding"
         }
 
     main(**kwargs)
@@ -30,24 +48,34 @@ def test_custom_dataset(step_lr, optimizer, get_model, train, mocker):
     eval_dataloader = args[2]
     tokenizer = args[3]
 
-    assert len(train_dataloader) == 226
-    assert len(eval_dataloader) == 2*226
+    assert len(train_dataloader) == 1120
+    assert len(eval_dataloader) == 1120 //2
+
+    it = iter(eval_dataloader)
+    batch = next(it)
+    STRING = tokenizer.decode(batch["input_ids"][0], skip_special_tokens=True)
+    EXPECTED_STRING = "[INST] Who made Berlin [/INST] dunno"
+    assert STRING.startswith(EXPECTED_STRING)
+
+    assert batch["input_ids"].size(0) == 4
+    assert set(("labels", "input_ids", "attention_mask")) == set(batch.keys())
+
+    check_padded_entry(batch)
 
     it = iter(train_dataloader)
-    STRING = tokenizer.decode(next(it)["input_ids"][0], skip_special_tokens=True)
-    EXPECTED_STRING = "[INST] Напиши функцию на языке swift, которая сортирует массив целых чисел, а затем выводит его на экран [/INST] Вот функция, "
+    for _ in range(5):
+        next(it)
 
+    batch = next(it)
+    STRING = tokenizer.decode(batch["input_ids"][0], skip_special_tokens=True)
+    EXPECTED_STRING = "[INST] How do I initialize a Typescript project using npm and git? [/INST] # Initialize a new NPM project"
     assert STRING.startswith(EXPECTED_STRING)
 
-    next(it)
-    next(it)
-    next(it)
-    STRING = tokenizer.decode(next(it)["input_ids"][0], skip_special_tokens=True)
-    EXPECTED_SUBSTRING_1 = "Therefore you are correct.  [INST] How can L’Hopital’s Rule be"
-    EXPECTED_SUBSTRING_2 = "a circular path around the turn.  [INST] How on earth is that related to L’Hopital’s Rule?"
+    assert batch["input_ids"].size(0) == 2
+    assert set(("labels", "input_ids", "attention_mask")) == set(batch.keys())
+
+    check_padded_entry(batch)
 
-    assert EXPECTED_SUBSTRING_1 in STRING
-    assert EXPECTED_SUBSTRING_2 in STRING
 
 
 @patch('llama_recipes.finetuning.train')

+ 56 - 0
tests/datasets/test_grammar_datasets.py

@@ -0,0 +1,56 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import pytest
+from unittest.mock import patch
+
+from transformers import LlamaTokenizer
+
+
+@pytest.mark.skip_missing_tokenizer()
+@patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaTokenizer')
+@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
+@patch('llama_recipes.finetuning.optim.AdamW')
+@patch('llama_recipes.finetuning.StepLR')
+def test_grammar_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer):
+    from llama_recipes.finetuning import main
+
+    setup_tokenizer(tokenizer)
+
+    BATCH_SIZE = 8
+    kwargs = {
+        "model_name": "meta-llama/Llama-2-7b-hf",
+        "batch_size_training": BATCH_SIZE,
+        "val_batch_size": 1,
+        "use_peft": False,
+        "dataset": "grammar_dataset",
+        "batching_strategy": "padding",
+        }
+
+    main(**kwargs)
+
+    assert train.call_count == 1
+
+    args, kwargs = train.call_args
+    train_dataloader = args[1]
+    eval_dataloader = args[2]
+
+    VAL_SAMPLES = 2988
+    TRAIN_SAMPLES = 13016
+
+    assert len(train_dataloader) == TRAIN_SAMPLES // BATCH_SIZE
+    assert len(eval_dataloader) == VAL_SAMPLES
+
+    batch = next(iter(train_dataloader))
+
+    assert "labels" in batch.keys()
+    assert "input_ids" in batch.keys()
+    assert "attention_mask" in batch.keys()
+
+    assert batch["labels"][0][31] == -100
+    assert batch["labels"][0][32] == 1152
+
+    assert batch["input_ids"][0][0] == 1
+    assert batch["labels"][0][-1] == 2
+    assert batch["input_ids"][0][-1] == 2

+ 32 - 14
tests/datasets/test_samsum_datasets.py

@@ -1,37 +1,55 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
+import pytest
+from functools import partial
 from unittest.mock import patch
 
 
+@pytest.mark.skip_missing_tokenizer()
 @patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
-@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
-def test_custom_dataset(step_lr, optimizer, tokenizer, get_model, train, mocker):
+def test_samsum_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer):
     from llama_recipes.finetuning import main
-        
-    tokenizer.return_value = mocker.MagicMock(side_effect=lambda x: {"input_ids":[len(x)*[0,]], "attention_mask": [len(x)*[0,]]})
-    
-    
+
+    setup_tokenizer(tokenizer)
+
+    BATCH_SIZE = 8
     kwargs = {
-        "batch_size_training": 1,
+        "model_name": "meta-llama/Llama-2-7b-hf",
+        "batch_size_training": BATCH_SIZE,
+        "val_batch_size": 1,
         "use_peft": False,
         "dataset": "samsum_dataset",
+        "batching_strategy": "padding",
         }
-    
+
     main(**kwargs)
-    
+
     assert train.call_count == 1
-    
+
     args, kwargs = train.call_args
     train_dataloader = args[1]
     eval_dataloader = args[2]
-    
+
     VAL_SAMPLES = 818
     TRAIN_SAMPLES = 14732
-    CONCAT_SIZE = 2048
-    assert len(train_dataloader) == TRAIN_SAMPLES // CONCAT_SIZE
+
+    assert len(train_dataloader) == TRAIN_SAMPLES // BATCH_SIZE
     assert len(eval_dataloader) == VAL_SAMPLES
-    
+
+    batch = next(iter(train_dataloader))
+
+    assert "labels" in batch.keys()
+    assert "input_ids" in batch.keys()
+    assert "attention_mask" in batch.keys()
+
+    assert batch["labels"][0][268] == -100
+    assert batch["labels"][0][269] == 319
+
+    assert batch["input_ids"][0][0] == 1
+    assert batch["labels"][0][-1] == 2
+    assert batch["input_ids"][0][-1] == 2

+ 96 - 0
tests/test_batching.py

@@ -0,0 +1,96 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import pytest
+from unittest.mock import patch
+
+
+@pytest.mark.skip_missing_tokenizer()
+@patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaTokenizer')
+@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
+@patch('llama_recipes.finetuning.optim.AdamW')
+@patch('llama_recipes.finetuning.StepLR')
+def test_packing(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer):
+    from llama_recipes.finetuning import main
+
+    setup_tokenizer(tokenizer)
+
+    kwargs = {
+        "model_name": "meta-llama/Llama-2-7b-hf",
+        "batch_size_training": 8,
+        "val_batch_size": 1,
+        "use_peft": False,
+        "dataset": "samsum_dataset",
+        "batching_strategy": "packing",
+        }
+
+    main(**kwargs)
+
+    assert train.call_count == 1
+
+    args, kwargs = train.call_args
+    train_dataloader = args[1]
+    eval_dataloader = args[2]
+
+    assert len(train_dataloader) == 96
+    assert len(eval_dataloader) == 42
+
+    batch = next(iter(train_dataloader))
+
+    assert "labels" in batch.keys()
+    assert "input_ids" in batch.keys()
+    assert "attention_mask" in batch.keys()
+
+    assert batch["labels"][0].size(0) == 4096
+    assert batch["input_ids"][0].size(0) == 4096
+    assert batch["attention_mask"][0].size(0) == 4096
+
+
+@pytest.mark.skip_missing_tokenizer()
+@patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaTokenizer')
+@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
+@patch('llama_recipes.finetuning.optim.AdamW')
+@patch('llama_recipes.finetuning.StepLR')
+@patch('llama_recipes.finetuning.setup')
+@patch('llama_recipes.finetuning.FSDP')
+@patch('llama_recipes.finetuning.torch.distributed.is_initialized')
+@patch('llama_recipes.utils.config_utils.dist')
+def test_distributed_packing(dist, is_initialized, fsdp, setup, step_lr, optimizer, get_model, tokenizer, train, setup_tokenizer):
+    import os
+    from llama_recipes.finetuning import main
+
+    setup_tokenizer(tokenizer)
+
+    rank = 0
+    os.environ['LOCAL_RANK'] = f'{rank}'
+    os.environ['RANK'] = f'{rank}'
+    os.environ['WORLD_SIZE'] = '2'
+    os.environ['MASTER_ADDR'] = 'localhost'
+    os.environ['MASTER_PORT'] = '12345'
+
+    kwargs = {
+        "model_name": "meta-llama/Llama-2-7b-hf",
+        "batch_size_training": 8,
+        "val_batch_size": 1,
+        "use_peft": False,
+        "dataset": "samsum_dataset",
+        "batching_strategy": "packing",
+        "enable_fsdp": True
+        }
+
+    is_initialized.return_value = True
+    dist.get_rank.return_value = rank
+    dist.get_world_size.return_value = 2
+
+    main(**kwargs)
+
+    assert train.call_count == 1
+
+    args, kwargs = train.call_args
+    train_dataloader = args[1]
+    eval_dataloader = args[2]
+
+    assert len(train_dataloader) == 96 //2
+    assert len(eval_dataloader) == 42 //2

+ 84 - 33
tests/test_finetuning.py

@@ -1,14 +1,27 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
+import pytest
 from pytest import approx
 from unittest.mock import patch
 
+import torch
 from torch.nn import Linear
 from torch.optim import AdamW
 from torch.utils.data.dataloader import DataLoader
+from torch.utils.data.sampler import BatchSampler
 
 from llama_recipes.finetuning import main
+from llama_recipes.data.sampler import LengthBasedBatchSampler
+
+
+def get_fake_dataset():
+    return [{
+        "input_ids":[1],
+        "attention_mask":[1],
+        "labels":[1],
+        }]
+
 
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
@@ -18,23 +31,23 @@ from llama_recipes.finetuning import main
 @patch('llama_recipes.finetuning.StepLR')
 def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
     kwargs = {"run_validation": False}
-    
-    get_dataset.return_value = [1]
-    
+
+    get_dataset.return_value = get_fake_dataset()
+
     main(**kwargs)
-    
+
     assert train.call_count == 1
-    
+
     args, kwargs = train.call_args
     train_dataloader = args[1]
     eval_dataloader = args[2]
-    
+
     assert isinstance(train_dataloader, DataLoader)
     assert eval_dataloader is None
-    
+
     assert get_model.return_value.to.call_args.args[0] == "cuda"
-    
-    
+
+
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
@@ -43,21 +56,22 @@ def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, ge
 @patch('llama_recipes.finetuning.StepLR')
 def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
     kwargs = {"run_validation": True}
-    get_dataset.return_value = [1]
-    
+
+    get_dataset.return_value = get_fake_dataset()
+
     main(**kwargs)
-    
+
     assert train.call_count == 1
-    
+
     args, kwargs = train.call_args
     train_dataloader = args[1]
     eval_dataloader = args[2]
     assert isinstance(train_dataloader, DataLoader)
     assert isinstance(eval_dataloader, DataLoader)
-    
+
     assert get_model.return_value.to.call_args.args[0] == "cuda"
-    
-    
+
+
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
@@ -68,15 +82,15 @@ def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer,
 @patch('llama_recipes.finetuning.StepLR')
 def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, get_dataset, tokenizer, get_model, train):
     kwargs = {"use_peft": True}
-    
-    get_dataset.return_value = [1]
-    
+
+    get_dataset.return_value = get_fake_dataset()
+
     main(**kwargs)
-    
+
     assert get_peft_model.return_value.to.call_args.args[0] == "cuda"
     assert get_peft_model.return_value.print_trainable_parameters.call_count == 1
-    
-    
+
+
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
@@ -85,22 +99,59 @@ def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, ge
 @patch('llama_recipes.finetuning.StepLR')
 def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer, get_model, train, mocker):
     kwargs = {"weight_decay": 0.01}
+
+    get_dataset.return_value = get_fake_dataset()
     
-    get_dataset.return_value = [1]
-    
-    model = mocker.MagicMock(name="model")
-    model.parameters.return_value = Linear(1,1).parameters()
-    get_peft_model.return_value = model 
-    get_peft_model.return_value.print_trainable_parameters=lambda:None
+    model = mocker.MagicMock(name="Model")
+    model.parameters.return_value = [torch.ones(1,1)]
+
+    get_model.return_value = model 
+
     main(**kwargs)
-    
+
     assert train.call_count == 1
-    
+
     args, kwargs = train.call_args
     optimizer = args[4]
-    
+
     print(optimizer.state_dict())
-    
+
     assert isinstance(optimizer, AdamW)
     assert optimizer.state_dict()["param_groups"][0]["weight_decay"] == approx(0.01)
-    
+
+
+@patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
+@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
+@patch('llama_recipes.finetuning.get_preprocessed_dataset')
+@patch('llama_recipes.finetuning.optim.AdamW')
+@patch('llama_recipes.finetuning.StepLR')
+def test_batching_strategy(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
+    kwargs = {"batching_strategy": "packing"}
+
+    get_dataset.return_value = get_fake_dataset()
+
+    main(**kwargs)
+
+    assert train.call_count == 1
+
+    args, kwargs = train.call_args
+    train_dataloader, eval_dataloader = args[1:3]
+    assert isinstance(train_dataloader.batch_sampler, BatchSampler)
+    assert isinstance(eval_dataloader.batch_sampler, BatchSampler)
+
+    kwargs["batching_strategy"] = "padding"
+    train.reset_mock()
+    main(**kwargs)
+
+    assert train.call_count == 1
+
+    args, kwargs = train.call_args
+    train_dataloader, eval_dataloader = args[1:3]
+    assert isinstance(train_dataloader.batch_sampler, LengthBasedBatchSampler)
+    assert isinstance(eval_dataloader.batch_sampler, LengthBasedBatchSampler)
+
+    kwargs["batching_strategy"] = "none"
+
+    with pytest.raises(ValueError):
+        main(**kwargs)

+ 483 - 0
tests/test_finetuning_data_formatter.py

@@ -0,0 +1,483 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama Guard Community License Agreement.
+
+from enum import Enum
+import unittest
+from typing import Optional, List
+
+from llama_recipes.data.llama_guard.finetuning_data_formatter import (
+    AugmentationConfigs,
+    Category,
+    create_formatted_finetuning_examples,
+    ExplanationPosition,
+    FormatterConfigs,
+    Guidelines,
+    LlamaGuardGenerationConfigs,
+    LlamaGuardPromptConfigs,
+    TrainingExample,
+)
+
+
+class AgentType(Enum):
+    AGENT = "Agent"
+    USER = "User"
+
+
+class FinetuningDataFormatterTests(unittest.TestCase):
+    def setUp(self) -> None:
+        super().setUp()
+
+    @staticmethod
+    def create_most_conservative_formatter_configs(agent_type_to_check: AgentType) -> FormatterConfigs:
+        return FormatterConfigs(
+            guidelines=Guidelines(
+                categories=[
+                    Category(name="cat V", description="cat V description"),
+                    Category(name="cat W", description="cat W description"),
+                    Category(name="cat X", description="cat X description"),
+                    Category(name="cat Y", description="cat Y description"),
+                    Category(name="cat Z", description="cat Z description"),
+                ],
+                category_code_prefix="O",
+            ),
+            llama_guard_prompt_configs=LlamaGuardPromptConfigs(
+                instructions_format_string=f"""[INST] Task: Check if there is unsafe content in '{agent_type_to_check}' messages in conversations according our safety policy with the below categories.
+<BEGIN UNSAFE CONTENT CATEGORIES>
+{{guidelines}}
+<END UNSAFE CONTENT CATEGORIES>
+
+<BEGIN CONVERSATION>
+
+{{conversation}}
+
+<END CONVERSATION>
+
+Provide your safety assessment for {agent_type_to_check} in the above conversation:
+- First line must read 'safe' or 'unsafe'.
+- If unsafe, a second line must include a comma-separated list of violated categories. [/INST]""",
+                should_include_category_descriptions=True,
+                should_shuffle_category_codes=False,
+            ),
+            llama_guard_generation_configs=LlamaGuardGenerationConfigs(
+                explanation_position=ExplanationPosition.AFTER_DECISION,
+                should_list_violated_codes=True,
+            ),
+            augmentation_configs=AugmentationConfigs(
+                should_add_examples_with_dropped_nonviolated_prompt_categories=False,
+                should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories=False,
+                explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories="Response is safe because any riskiness in it is related to categories we're not trying to detect.",
+            ),
+            random_seed=42,
+        )
+
+    def format_examples_and_compare_to_expected(
+        self,
+        training_examples: List[TrainingExample],
+        expected_formatted_examples: List[str],
+        agent_type_to_check: AgentType,
+        formatter_configs: Optional[FormatterConfigs] = None,
+    ) -> None:
+        formatter_configs = (
+            formatter_configs
+            if formatter_configs is not None
+            else FinetuningDataFormatterTests.create_most_conservative_formatter_configs(agent_type_to_check)
+        )
+
+        formatted_examples = create_formatted_finetuning_examples(
+            training_examples, formatter_configs
+        )
+
+        assert len(formatted_examples) == len(expected_formatted_examples)
+
+        for i in range(len(formatted_examples)):
+            # Helper code during test-writing time to see the actual output
+            if formatted_examples[i] != expected_formatted_examples[i]:
+                print(f"Failed on actual output {i}:")
+
+                print(formatted_examples[i])
+
+            assert formatted_examples[i] == expected_formatted_examples[i]
+
+    def test_safe_example_with_only_prompt_and_no_augmentations(self) -> None:
+        self.format_examples_and_compare_to_expected(
+            [
+                TrainingExample(
+                    prompt="Where can I buy a car?",
+                    response="N/A",
+                    violated_category_codes=[],
+                    label="safe",
+                    explanation="This is obviously safe.",
+                )
+            ],
+            [
+                """[INST] Task: Check if there is unsafe content in 'AgentType.USER' messages in conversations according our safety policy with the below categories.
+<BEGIN UNSAFE CONTENT CATEGORIES>
+O1: cat V. 
+cat V description
+O2: cat W. 
+cat W description
+O3: cat X. 
+cat X description
+O4: cat Y. 
+cat Y description
+O5: cat Z. 
+cat Z description
+<END UNSAFE CONTENT CATEGORIES>
+
+<BEGIN CONVERSATION>
+
+human: Where can I buy a car?
+
+<END CONVERSATION>
+
+Provide your safety assessment for AgentType.USER in the above conversation:
+- First line must read 'safe' or 'unsafe'.
+- If unsafe, a second line must include a comma-separated list of violated categories. [/INST] safe
+Explanation: This is obviously safe."""
+            ],
+            AgentType.USER,
+        )
+
+    def test_unsafe_example_with_only_prompt_and_no_augmentations(self) -> None:
+        self.format_examples_and_compare_to_expected(
+            [
+                TrainingExample(
+                    prompt="""How do I steal the Declaration of Independence?
+Asking for a cagey friend""",
+                    response="N/A",
+                    violated_category_codes=["O5"],
+                    label="unsafe",
+                    explanation="""Nice try, Nick. This is unsafe
+because it's illegal!""",
+                )
+            ],
+            [
+                """[INST] Task: Check if there is unsafe content in 'AgentType.USER' messages in conversations according our safety policy with the below categories.
+<BEGIN UNSAFE CONTENT CATEGORIES>
+O1: cat V. 
+cat V description
+O2: cat W. 
+cat W description
+O3: cat X. 
+cat X description
+O4: cat Y. 
+cat Y description
+O5: cat Z. 
+cat Z description
+<END UNSAFE CONTENT CATEGORIES>
+
+<BEGIN CONVERSATION>
+
+human: How do I steal the Declaration of Independence?
+Asking for a cagey friend
+
+<END CONVERSATION>
+
+Provide your safety assessment for AgentType.USER in the above conversation:
+- First line must read 'safe' or 'unsafe'.
+- If unsafe, a second line must include a comma-separated list of violated categories. [/INST] unsafe
+O5
+Explanation: Nice try, Nick. This is unsafe
+because it's illegal!"""
+            ],
+            AgentType.USER,
+        )
+
+    def test_unsafe_example_with_only_prompt_and_no_augmentations_2(self) -> None:
+        self.format_examples_and_compare_to_expected(
+            [
+                TrainingExample(
+                    prompt="""How do I steal the Declaration of Independence?
+Asking for a cagey friend""",
+                    response="N/A",
+                    violated_category_codes=["O5", "O1"],
+                    label="unsafe",
+                    explanation="""Nice try, Nick. This is unsafe
+because it's illegal!""",
+                )
+            ],
+            [
+                """[INST] Task: Check if there is unsafe content in 'AgentType.USER' messages in conversations according our safety policy with the below categories.
+<BEGIN UNSAFE CONTENT CATEGORIES>
+O1: cat V. 
+cat V description
+O2: cat W. 
+cat W description
+O3: cat X. 
+cat X description
+O4: cat Y. 
+cat Y description
+O5: cat Z. 
+cat Z description
+<END UNSAFE CONTENT CATEGORIES>
+
+<BEGIN CONVERSATION>
+
+human: How do I steal the Declaration of Independence?
+Asking for a cagey friend
+
+<END CONVERSATION>
+
+Provide your safety assessment for AgentType.USER in the above conversation:
+- First line must read 'safe' or 'unsafe'.
+- If unsafe, a second line must include a comma-separated list of violated categories. [/INST] unsafe
+O1,O5
+Explanation: Nice try, Nick. This is unsafe
+because it's illegal!"""
+            ],
+            AgentType.USER,
+        )
+
+    def test_safe_example_with_prompt_and_response_and_no_augmentations(self) -> None:
+        self.format_examples_and_compare_to_expected(
+            [
+                TrainingExample(
+                    prompt="Where can I buy a car?",
+                    response="At the car store.",
+                    violated_category_codes=[],
+                    label="safe",
+                    explanation="This is obviously safe.",
+                )
+            ],
+            [
+                """[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
+<BEGIN UNSAFE CONTENT CATEGORIES>
+O1: cat V. 
+cat V description
+O2: cat W. 
+cat W description
+O3: cat X. 
+cat X description
+O4: cat Y. 
+cat Y description
+O5: cat Z. 
+cat Z description
+<END UNSAFE CONTENT CATEGORIES>
+
+<BEGIN CONVERSATION>
+
+human: Where can I buy a car?
+
+chatbot: At the car store.
+
+<END CONVERSATION>
+
+Provide your safety assessment for AgentType.AGENT in the above conversation:
+- First line must read 'safe' or 'unsafe'.
+- If unsafe, a second line must include a comma-separated list of violated categories. [/INST] safe
+Explanation: This is obviously safe."""
+            ],
+            AgentType.AGENT,
+        )
+
+    def test_mixed_examples_with_prompt_and_response_and_all_augmentations_with_shuffling(
+        self,
+    ) -> None:
+        formatter_configs = (
+            FinetuningDataFormatterTests.create_most_conservative_formatter_configs(AgentType.AGENT)
+        )
+
+        # The right seed to get the random generations for testing the functionality of this test. ~*Magic*~
+        formatter_configs.random_seed = 46
+
+        formatter_configs.llama_guard_generation_configs.explanation_position = (
+            ExplanationPosition.BEFORE_DECISION
+        )
+
+        formatter_configs.augmentation_configs = AugmentationConfigs(
+            should_add_examples_with_dropped_nonviolated_prompt_categories=True,
+            should_add_examples_with_dropped_violated_and_nonviolated_prompt_categories=True,
+            explanation_for_augmentation_with_dropped_violated_and_nonviolated_prompt_categories="Response is safe because any riskiness in it is related to categories we're not trying to detect.",
+        )
+
+        formatter_configs.llama_guard_prompt_configs.should_shuffle_category_codes = (
+            True
+        )
+
+        self.format_examples_and_compare_to_expected(
+            [
+                TrainingExample(
+                    prompt="I am under 21 and unattended, should I drink alcohol?",
+                    response="Yes, very nise, high five!",
+                    violated_category_codes=["O4", "05"],
+                    label="unsafe",
+                    explanation="Chatbot response encourages illegal consumption of a controlled substance. It may also be Borat, we are still monitoring the situation, but that's neither here nor there.",
+                ),
+                TrainingExample(
+                    prompt="I am over 21 and unattended, should I drink alcohol?",
+                    response="No, that would be very sad. Please become attended first. I could join you if you want. Meta never gives me any liquor.",
+                    violated_category_codes=[],
+                    label="safe",
+                    explanation="User is of legal age to drink. The bot has been trained on all thoughts ever had by humanity, so I guess it's of legal age too. Party on.",
+                ),
+            ],
+            [
+                """[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
+<BEGIN UNSAFE CONTENT CATEGORIES>
+O1: cat W. 
+cat W description
+O2: cat X. 
+cat X description
+O3: cat Z. 
+cat Z description
+O4: cat Y. 
+cat Y description
+O5: cat V. 
+cat V description
+<END UNSAFE CONTENT CATEGORIES>
+
+<BEGIN CONVERSATION>
+
+human: I am under 21 and unattended, should I drink alcohol?
+
+chatbot: Yes, very nise, high five!
+
+<END CONVERSATION>
+
+Provide your safety assessment for AgentType.AGENT in the above conversation:
+- First line must read 'safe' or 'unsafe'.
+- If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: Chatbot response encourages illegal consumption of a controlled substance. It may also be Borat, we are still monitoring the situation, but that's neither here nor there.
+unsafe
+O3,O4""",
+                """[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
+<BEGIN UNSAFE CONTENT CATEGORIES>
+O1: cat W. 
+cat W description
+O2: cat Y. 
+cat Y description
+O3: cat Z. 
+cat Z description
+O4: cat V. 
+cat V description
+<END UNSAFE CONTENT CATEGORIES>
+
+<BEGIN CONVERSATION>
+
+human: I am under 21 and unattended, should I drink alcohol?
+
+chatbot: Yes, very nise, high five!
+
+<END CONVERSATION>
+
+Provide your safety assessment for AgentType.AGENT in the above conversation:
+- First line must read 'safe' or 'unsafe'.
+- If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: Chatbot response encourages illegal consumption of a controlled substance. It may also be Borat, we are still monitoring the situation, but that's neither here nor there.
+unsafe
+O2,O3""",
+                """[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
+<BEGIN UNSAFE CONTENT CATEGORIES>
+O1: cat W. 
+cat W description
+O2: cat V. 
+cat V description
+<END UNSAFE CONTENT CATEGORIES>
+
+<BEGIN CONVERSATION>
+
+human: I am under 21 and unattended, should I drink alcohol?
+
+chatbot: Yes, very nise, high five!
+
+<END CONVERSATION>
+
+Provide your safety assessment for AgentType.AGENT in the above conversation:
+- First line must read 'safe' or 'unsafe'.
+- If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: Response is safe because any riskiness in it is related to categories we're not trying to detect.
+safe""",
+                """[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
+<BEGIN UNSAFE CONTENT CATEGORIES>
+O1: cat Z. 
+cat Z description
+O2: cat Y. 
+cat Y description
+O3: cat X. 
+cat X description
+O4: cat W. 
+cat W description
+O5: cat V. 
+cat V description
+<END UNSAFE CONTENT CATEGORIES>
+
+<BEGIN CONVERSATION>
+
+human: I am over 21 and unattended, should I drink alcohol?
+
+chatbot: No, that would be very sad. Please become attended first. I could join you if you want. Meta never gives me any liquor.
+
+<END CONVERSATION>
+
+Provide your safety assessment for AgentType.AGENT in the above conversation:
+- First line must read 'safe' or 'unsafe'.
+- If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: User is of legal age to drink. The bot has been trained on all thoughts ever had by humanity, so I guess it's of legal age too. Party on.
+safe""",
+                """[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
+<BEGIN UNSAFE CONTENT CATEGORIES>
+O1: cat Y. 
+cat Y description
+<END UNSAFE CONTENT CATEGORIES>
+
+<BEGIN CONVERSATION>
+
+human: I am over 21 and unattended, should I drink alcohol?
+
+chatbot: No, that would be very sad. Please become attended first. I could join you if you want. Meta never gives me any liquor.
+
+<END CONVERSATION>
+
+Provide your safety assessment for AgentType.AGENT in the above conversation:
+- First line must read 'safe' or 'unsafe'.
+- If unsafe, a second line must include a comma-separated list of violated categories. [/INST] Explanation: User is of legal age to drink. The bot has been trained on all thoughts ever had by humanity, so I guess it's of legal age too. Party on.
+safe""",
+            ],
+            AgentType.AGENT,
+            formatter_configs,
+        )
+
+    def test_safe_example_with_prompt_and_response_and_no_augmentations_or_explanations(
+        self,
+    ) -> None:
+        formatter_configs = (
+            FinetuningDataFormatterTests.create_most_conservative_formatter_configs(AgentType.AGENT)
+        )
+
+        formatter_configs.llama_guard_generation_configs.explanation_position = None
+
+        self.format_examples_and_compare_to_expected(
+            [
+                TrainingExample(
+                    prompt="Where can I buy a car?",
+                    response="At the car store.",
+                    violated_category_codes=[],
+                    label="safe",
+                )
+            ],
+            [
+                """[INST] Task: Check if there is unsafe content in 'AgentType.AGENT' messages in conversations according our safety policy with the below categories.
+<BEGIN UNSAFE CONTENT CATEGORIES>
+O1: cat V. 
+cat V description
+O2: cat W. 
+cat W description
+O3: cat X. 
+cat X description
+O4: cat Y. 
+cat Y description
+O5: cat Z. 
+cat Z description
+<END UNSAFE CONTENT CATEGORIES>
+
+<BEGIN CONVERSATION>
+
+human: Where can I buy a car?
+
+chatbot: At the car store.
+
+<END CONVERSATION>
+
+Provide your safety assessment for AgentType.AGENT in the above conversation:
+- First line must read 'safe' or 'unsafe'.
+- If unsafe, a second line must include a comma-separated list of violated categories. [/INST] safe"""
+            ],
+            AgentType.AGENT,
+            formatter_configs,
+        )

+ 86 - 0
tests/test_sampler.py

@@ -0,0 +1,86 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import random
+import pytest
+
+import torch
+
+from llama_recipes.data.sampler import LengthBasedBatchSampler
+from llama_recipes.data.sampler import DistributedLengthBasedBatchSampler
+
+SAMPLES = 33
+
+@pytest.fixture
+def dataset():
+    random.seed(42)
+    dataset = []
+    def add_samples(ds, n, a, b):
+        for _ in range(n):
+            ds.append(random.randint(a,b) * [1,])
+    add_samples(dataset, SAMPLES // 2, 1,9)
+    add_samples(dataset, (SAMPLES // 2) + (SAMPLES % 2), 10,20)
+    
+    return random.sample(dataset, len(dataset))
+    
+    
+@pytest.mark.parametrize("batch_size, drop_last", [(2, False), (8, False), (2, True), (8, True)])
+def test_batch_sampler_array(dataset, batch_size, drop_last):
+    
+    sampler = LengthBasedBatchSampler(dataset, batch_size, drop_last)
+    
+    EXPECTED_LENGTH = SAMPLES // batch_size if drop_last else (SAMPLES // batch_size) + (SAMPLES % batch_size)
+    
+    all_ids = [i for b in sampler for i in b]
+    assert len(set(all_ids)) == EXPECTED_LENGTH * batch_size if drop_last else len(dataset)
+    
+    assert len(sampler) == EXPECTED_LENGTH
+    is_long = [len(d)>=10 for d in dataset]
+    
+    def check_batch(batch):
+        return all(batch) or not any(batch)
+    
+    assert all(check_batch(is_long[i] for i in b) for b in sampler)
+    
+    
+@pytest.mark.parametrize("batch_size, drop_last", [(2, False), (8, False), (2, True), (8, True)])
+def test_batch_sampler_dict(dataset, batch_size, drop_last):
+    
+    dist_dataset = [{"input_ids": d, "attention_mask": d} for d in dataset]
+    
+    sampler = LengthBasedBatchSampler(dist_dataset, batch_size, drop_last)
+    
+    EXPECTED_LENGTH = SAMPLES // batch_size if drop_last else (SAMPLES // batch_size) + (SAMPLES % batch_size)
+    
+    assert len(sampler) == EXPECTED_LENGTH
+    is_long = [len(d)>=10 for d in dataset]
+    
+    def check_batch(batch):
+        return all(batch) or not any(batch)
+    
+    assert all(check_batch(is_long[i] for i in b) for b in sampler)
+    
+    
+@pytest.mark.parametrize("batch_size", [2, 8])
+def test_dist_batch_sampling(dataset, batch_size):
+    sampler_1 = DistributedLengthBasedBatchSampler(
+        dataset,
+        batch_size=batch_size,
+        rank=0,
+        num_replicas=2,
+        shuffle=False,
+    )
+    sampler_2 = DistributedLengthBasedBatchSampler(
+        dataset,
+        batch_size=batch_size,
+        rank=1,
+        num_replicas=2,
+        shuffle=False,
+    )
+    
+    ids_1 = set(i for b in sampler_1 for i in b)
+    ids_2 = set(i for b in sampler_2 for i in b)
+    
+    assert ids_1.isdisjoint(ids_2)
+    assert len(ids_1)+len(ids_2) > 0
+    assert len(ids_1)+len(ids_2) == len(dataset) // batch_size  *  batch_size 

+ 71 - 6
tests/test_train_utils.py

@@ -2,14 +2,33 @@
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
 from unittest.mock import patch
+import pytest
 
 import torch
 
+import os
+import shutil
+
 from llama_recipes.utils.train_utils import train
 
+TEMP_OUTPUT_DIR = os.getcwd() + "/tmp"
+
+@pytest.fixture(scope="session")
+def temp_output_dir():
+    # Create the directory during the session-level setup
+    temp_output_dir = "tmp"
+    os.mkdir(os.path.join(os.getcwd(), temp_output_dir))
+    yield temp_output_dir
+    # Delete the directory during the session-level teardown
+    shutil.rmtree(temp_output_dir)
+
+
 @patch("llama_recipes.utils.train_utils.MemoryTrace")
-def test_gradient_accumulation(mem_trace, mocker):
-    
+@patch("llama_recipes.utils.train_utils.nullcontext")
+@patch("llama_recipes.utils.train_utils.torch.cuda.amp.GradScaler")
+@patch("llama_recipes.utils.train_utils.torch.cuda.amp.autocast")
+def test_gradient_accumulation(autocast, scaler, nullcontext, mem_trace, mocker):
+
     model = mocker.MagicMock(name="model")
     model().loss.__truediv__().detach.return_value = torch.tensor(1)
     mock_tensor = mocker.MagicMock(name="tensor")
@@ -24,7 +43,9 @@ def test_gradient_accumulation(mem_trace, mocker):
     train_config.enable_fsdp = False
     train_config.use_fp16 = False
     train_config.run_validation = False
-    
+    train_config.gradient_clipping = False
+    train_config.save_metrics = False
+
     train(
         model,
         train_dataloader,
@@ -35,11 +56,17 @@ def test_gradient_accumulation(mem_trace, mocker):
         gradient_accumulation_steps,
         train_config,
     )
-    
+
     assert optimizer.zero_grad.call_count == 5
     optimizer.zero_grad.reset_mock()
-    
+
+    assert nullcontext.call_count == 5
+    nullcontext.reset_mock()
+
+    assert autocast.call_count == 0
+
     gradient_accumulation_steps = 2
+    train_config.use_fp16 = True
     train(
         model,
         train_dataloader,
@@ -50,4 +77,42 @@ def test_gradient_accumulation(mem_trace, mocker):
         gradient_accumulation_steps,
         train_config,
     )
-    assert optimizer.zero_grad.call_count == 3
+    assert optimizer.zero_grad.call_count == 3
+    assert nullcontext.call_count == 0
+    assert autocast.call_count == 5
+
+def test_save_to_json(temp_output_dir, mocker):
+    model = mocker.MagicMock(name="model")
+    model().loss.__truediv__().detach.return_value = torch.tensor(1)
+    mock_tensor = mocker.MagicMock(name="tensor")
+    batch = {"input": mock_tensor}
+    train_dataloader = [batch, batch, batch, batch, batch]
+    eval_dataloader = None
+    tokenizer = mocker.MagicMock()
+    optimizer = mocker.MagicMock()
+    lr_scheduler = mocker.MagicMock()
+    gradient_accumulation_steps = 1
+    train_config = mocker.MagicMock()
+    train_config.enable_fsdp = False
+    train_config.use_fp16 = False
+    train_config.run_validation = False
+    train_config.gradient_clipping = False
+    train_config.save_metrics = True
+    train_config.output_dir = temp_output_dir
+
+    results = train(
+        model,
+        train_dataloader,
+        eval_dataloader,
+        tokenizer,
+        optimizer,
+        lr_scheduler,
+        gradient_accumulation_steps,
+        train_config,
+        local_rank=0
+    )
+
+    assert results["metrics_filename"] not in ["", None]
+    assert os.path.isfile(results["metrics_filename"])
+
+

+ 83 - 0
utils/memory_utils.py

@@ -0,0 +1,83 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+import gc
+import os
+import sys
+import threading
+
+import numpy as np
+import psutil
+import torch
+from accelerate.utils import is_xpu_available
+
+def byte2gb(x):
+    return int(x / 2**30)
+# This context manager is used to track the peak memory usage of the process
+class MemoryTrace:
+    def __enter__(self):
+        gc.collect()
+        if is_xpu_available():
+            torch.xpu.empty_cache()
+            torch.xpu.reset_max_memory_allocated()   # reset the peak gauge to zero
+            self.begin = byte2gb(torch.xpu.memory_allocated())
+        elif torch.cuda.is_available():
+            torch.cuda.empty_cache()
+            torch.cuda.reset_max_memory_allocated()  # reset the peak gauge to zero
+            self.begin = byte2gb(torch.cuda.memory_allocated())
+        self.process = psutil.Process()
+        self.cpu_begin = byte2gb(self.cpu_mem_used())
+        self.peak_monitoring = True
+        peak_monitor_thread = threading.Thread(target=self.peak_monitor_func)
+        peak_monitor_thread.daemon = True
+        peak_monitor_thread.start()
+        return self
+
+    def cpu_mem_used(self):
+        """get resident set size memory for the current process"""
+        return self.process.memory_info().rss
+
+    def peak_monitor_func(self):
+        self.cpu_peak = -1
+
+        while True:
+            self.cpu_peak = max(self.cpu_mem_used(), self.cpu_peak)
+
+            # can't sleep or will not catch the peak right (this comment is here on purpose)
+            # time.sleep(0.001) # 1msec
+
+            if not self.peak_monitoring:
+                break
+
+    def __exit__(self, *exc):
+        self.peak_monitoring = False
+
+        gc.collect()
+        if is_xpu_available():
+            torch.xpu.empty_cache()
+            self.end = byte2gb(torch.xpu.memory_allocated())
+            self.peak = byte2gb(torch.xpu.max_memory_allocated())
+            xpu_info = torch.xpu.memory_stats()
+            self.peak_active_gb = byte2gb(xpu_info["active_bytes.all.peak"])
+            self.xpu_malloc_retires = xpu_info.get("num_alloc_retries", 0)
+            self.peak_active_gb = byte2gb(xpu_info["active_bytes.all.peak"])
+            self.m_xpu_ooms = xpu_info.get("num_ooms", 0)
+            self.used = byte2gb(self.end - self.begin)
+            self.peaked = byte2gb(self.peak - self.begin)
+            self.max_reserved = byte2gb(torch.xpu.max_memory_reserved())
+        else:
+            torch.cuda.empty_cache()
+            self.end = byte2gb(torch.cuda.memory_allocated())
+            self.peak = byte2gb(torch.cuda.max_memory_allocated())
+            cuda_info = torch.cuda.memory_stats()
+            self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
+            self.cuda_malloc_retires = cuda_info.get("num_alloc_retries", 0)
+            self.peak_active_gb = byte2gb(cuda_info["active_bytes.all.peak"])
+            self.m_cuda_ooms = cuda_info.get("num_ooms", 0)
+            self.used = byte2gb(self.end - self.begin)
+            self.peaked = byte2gb(self.peak - self.begin)
+            self.max_reserved = byte2gb(torch.cuda.max_memory_reserved())
+
+        self.cpu_end = self.cpu_mem_used()
+        self.cpu_used = byte2gb(self.cpu_end - self.cpu_begin)
+        self.cpu_peaked = byte2gb(self.cpu_peak - self.cpu_begin)
+        # print(f"delta used/peak {self.used:4d}/{self.peaked:4d}")