Browse Source

Merge branch 'main' into ipex_feature

Abhilash Majumder 1 year atrás
parent
commit
4793f0fdf3
39 changed files with 1827 additions and 996 deletions
  1. 9 8
      README.md
  2. 178 157
      demo_apps/HelloLlamaCloud.ipynb
  3. 155 262
      demo_apps/HelloLlamaLocal.ipynb
  4. 114 230
      demo_apps/LiveData.ipynb
  5. 24 0
      demo_apps/Llama2_Gradio.ipynb
  6. 13 9
      demo_apps/README.md
  7. 66 11
      demo_apps/StructuredLlama.ipynb
  8. 118 11
      demo_apps/VideoSummary.ipynb
  9. 184 0
      demo_apps/llama-on-prem.md
  10. 61 0
      demo_apps/llama_chatbot.py
  11. BIN
      demo_apps/whatsapp_dashboard.jpg
  12. 160 0
      demo_apps/whatsapp_llama2.md
  13. BIN
      demo_apps/whatsapp_llama_arch.jpg
  14. 15 3
      docs/Dataset.md
  15. 21 7
      docs/FAQ.md
  16. 2 0
      docs/inference.md
  17. 1 1
      examples/README.md
  18. 23 30
      examples/custom_dataset.py
  19. 33 1
      scripts/spellcheck_conf/wordlist.txt
  20. 0 2
      src/llama_recipes/configs/datasets.py
  21. 2 4
      src/llama_recipes/configs/training.py
  22. 2 0
      src/llama_recipes/data/__init__.py
  23. 34 0
      src/llama_recipes/data/concatenator.py
  24. 57 0
      src/llama_recipes/data/sampler.py
  25. 4 14
      src/llama_recipes/datasets/alpaca_dataset.py
  26. 13 18
      src/llama_recipes/datasets/grammar_dataset/grammar_dataset.py
  27. 19 13
      src/llama_recipes/datasets/samsum_dataset.py
  28. 0 66
      src/llama_recipes/datasets/utils.py
  29. 25 37
      src/llama_recipes/finetuning.py
  30. 49 11
      src/llama_recipes/utils/config_utils.py
  31. 6 6
      src/llama_recipes/utils/dataset_utils.py
  32. 36 34
      src/llama_recipes/utils/train_utils.py
  33. 18 0
      tests/conftest.py
  34. 40 13
      tests/datasets/test_custom_dataset.py
  35. 54 0
      tests/datasets/test_grammar_datasets.py
  36. 30 14
      tests/datasets/test_samsum_datasets.py
  37. 94 0
      tests/test_batching.py
  38. 81 34
      tests/test_finetuning.py
  39. 86 0
      tests/test_sampler.py

+ 9 - 8
README.md

@@ -1,9 +1,10 @@
-# Llama 2 Fine-tuning / Inference Recipes and Examples
+# Llama 2 Fine-tuning / Inference Recipes, Examples and Demo Apps
 
-The 'llama-recipes' repository is a companion to the [Llama 2 model](https://github.com/facebookresearch/llama). The goal of this repository is to provide examples to quickly get started with fine-tuning for domain adaptation and how to run inference for the fine-tuned models. For ease of use, the examples use Hugging Face converted versions of the models. See steps for conversion of the model [here](#model-conversion-to-hugging-face).
+**[Update Nov. 14, 2023] We recently released a series of Llama 2 demo apps [here](./demo_apps). These apps show how to run Llama 2 locally, in the cloud, on-prem or with WhatsApp, and how to ask Llama 2 questions in general and about custom data (PDF, DB, or live).**
 
-In addition, we also provide a number of demo apps, to showcase the Llama2 usage along with other ecosystem solutions to run Llama2 locally on your mac and on cloud.
+The 'llama-recipes' repository is a companion to the [Llama 2 model](https://github.com/facebookresearch/llama). The goal of this repository is to provide examples to quickly get started with fine-tuning for domain adaptation and how to run inference for the fine-tuned models. For ease of use, the examples use Hugging Face converted versions of the models. See steps for conversion of the model [here](#model-conversion-to-hugging-face).
 
+In addition, we also provide a number of demo apps, to showcase the Llama 2 usage along with other ecosystem solutions to run Llama 2 locally, in the cloud, and on-prem.
 
 Llama 2 is a new technology that carries potential risks with use. Testing conducted to date has not — and could not — cover all scenarios. In order to help developers address these risks, we have created the [Responsible Use Guide](https://github.com/facebookresearch/llama/blob/main/Responsible-Use-Guide.pdf). More details can be found in our research paper as well. For downloading the models, follow the instructions on [Llama 2 repo](https://github.com/facebookresearch/llama).
 
@@ -20,8 +21,6 @@ Llama 2 is a new technology that carries potential risks with use. Testing condu
 6. [Repository Organization](#repository-organization)
 7. [License and Acceptable Use Policy](#license)
 
-
-
 # Quick Start
 
 [Llama 2 Jupyter Notebook](./examples/quickstart.ipynb): This jupyter notebook steps you through how to finetune a Llama 2 model on the text summarization task using the [samsum](https://huggingface.co/datasets/samsum). The notebook uses parameter efficient finetuning (PEFT) and int8 quantization to finetune a 7B on a single GPU like an A10 with 24GB gpu memory.
@@ -134,7 +133,7 @@ Here we make use of Parameter Efficient Methods (PEFT) as described in the next
 
 ```bash
 
-torchrun --nnodes 1 --nproc_per_node 4  examples/finetuning.py --enable_fsdp --use_peft --peft_method lora --model_name /patht_of_model_folder/7B --pure_bf16 --output_dir Path/to/save/PEFT/model
+torchrun --nnodes 1 --nproc_per_node 4  examples/finetuning.py --enable_fsdp --use_peft --peft_method lora --model_name /patht_of_model_folder/7B --fsdp_config.pure_bf16 --output_dir Path/to/save/PEFT/model
 
 ```
 
@@ -145,7 +144,7 @@ Here we use FSDP as discussed in the next section which can be used along with P
 Setting `use_fast_kernels` will enable using of Flash Attention or Xformer memory-efficient kernels based on the hardware being used. This would speed up the fine-tuning job. This has been enabled in `optimum` library from HuggingFace as a one-liner API, please read more [here](https://pytorch.org/blog/out-of-the-box-acceleration/).
 
 ```bash
-torchrun --nnodes 1 --nproc_per_node 4  examples/finetuning.py --enable_fsdp --use_peft --peft_method lora --model_name /patht_of_model_folder/7B --pure_bf16 --output_dir Path/to/save/PEFT/model --use_fast_kernels
+torchrun --nnodes 1 --nproc_per_node 4  examples/finetuning.py --enable_fsdp --use_peft --peft_method lora --model_name /patht_of_model_folder/7B --fsdp_config.pure_bf16 --output_dir Path/to/save/PEFT/model --use_fast_kernels
 ```
 
 ### Fine-tuning using FSDP Only
@@ -164,7 +163,7 @@ If you are interested in running full parameter fine-tuning on the 70B model, yo
 
 ```bash
 
-torchrun --nnodes 1 --nproc_per_node 8 examples/finetuning.py --enable_fsdp --low_cpu_fsdp --pure_bf16 --model_name /patht_of_model_folder/70B --batch_size_training 1 --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned
+torchrun --nnodes 1 --nproc_per_node 8 examples/finetuning.py --enable_fsdp --low_cpu_fsdp --fsdp_config.pure_bf16 --model_name /patht_of_model_folder/70B --batch_size_training 1 --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned
 
 ```
 
@@ -184,11 +183,13 @@ This folder contains a series of Llama2-powered apps:
 1. Llama on your Mac and ask Llama general questions
 2. Llama on Google Colab
 3. Llama on Cloud and ask Llama questions about unstructured data in a PDF
+4. Llama on-prem with vLLM and TGI
 
 * Specialized Llama use cases:
 1. Ask Llama to summarize a video content
 2. Ask Llama questions about structured data in a DB
 3. Ask Llama questions about live data on the web
+4. Build a Llama-enabled WhatsApp chatbot
 
 # Repository Organization
 This repository is organized in the following way:

File diff suppressed because it is too large
+ 178 - 157
demo_apps/HelloLlamaCloud.ipynb


File diff suppressed because it is too large
+ 155 - 262
demo_apps/HelloLlamaLocal.ipynb


File diff suppressed because it is too large
+ 114 - 230
demo_apps/LiveData.ipynb


+ 24 - 0
demo_apps/Llama2_Gradio.ipynb

@@ -1,5 +1,29 @@
 {
  "cells": [
+  {
+   "cell_type": "markdown",
+   "id": "47a9adb3",
+   "metadata": {},
+   "source": [
+    "## This demo app shows how to query Llama 2 using the Gradio UI.\n",
+    "\n",
+    "Since we are using Replicate in this example, you will need to replace `<your replicate api token>` with your API token.\n",
+    "\n",
+    "To get the Replicate token: \n",
+    "\n",
+    "- You will need to first sign in with Replicate with your github account\n",
+    "- Then create a free API token [here](https://replicate.com/account/api-tokens) that you can use for a while \n",
+    "\n",
+    "**Note** After the free trial ends, you will need to enter billing info to continue to use Llama2 hosted on Replicate.\n",
+    "\n",
+    "To run this example:\n",
+    "- Set up your Replicate API token and enter it in place of `<your replicate api token>`\n",
+    "- Run the notebook\n",
+    "- Enter your question and click Submit\n",
+    "\n",
+    "In the notebook or a browser with URL http://127.0.0.1:7860 you should see a UI with your answer."
+   ]
+  },
   {
    "cell_type": "code",
    "execution_count": 1,

File diff suppressed because it is too large
+ 13 - 9
demo_apps/README.md


+ 66 - 11
demo_apps/StructuredLlama.ipynb

@@ -5,9 +5,25 @@
    "id": "e8cba0b6",
    "metadata": {},
    "source": [
-    "## This demo shows how to use LangChain's SQLDatabaseChain with Llama2 to query about structured data stored in a SQL DB.  \n",
-    "* As the 2023-24 NBA season is around the corner, we use the NBA roster info saved in a SQLite DB to show you how to ask Llama2 questions about your favorite teams or players. \n",
-    "* Because the SQLDatabaseChain API implementation is still in the langchain_experimental package, you'll see more issues that come with using the cutting edge experimental features, and how we succeed resolving some of the issues but fail on some others."
+    "## This demo shows how to use LangChain's SQLDatabaseChain with Llama2 to query structured data stored in a SQL DB.  \n",
+    "* We use the 2023-24 NBA roster info saved in a SQLite DB to show you how to ask Llama2 questions about your favorite teams or players \n",
+    "* At the time of writing this, the SQLDatabaseChain API implementation is still in the langchain_experimental package. With this in mind you will see more issues that come with using the cutting edge experimental features, and how we succeed resolving some of the issues but fail on some others"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "f839d07d",
+   "metadata": {},
+   "source": [
+    "We start by installing the necessary packages:\n",
+    "- [Replicate](https://replicate.com/) to host the Llama 2 model\n",
+    "- [langchain](https://python.langchain.com/docs/get_started/introduction) provides necessary RAG tools for this demo\n",
+    "- langchain_experimental Langchain's experimental version to get us access to SQLDatabaseChain\n",
+    "\n",
+    "And setting up the Replicate token.\n",
+    "\n",
+    "**Note** To get a Replicate token, you will need to first sign in with Replicate with your github account, then create a free API token [here](https://replicate.com/account/api-tokens) that you can use for a while. \n",
+    "After the free trial ends, you will need to enter billing info to continue to use Llama2 hosted on Replicate."
    ]
   },
   {
@@ -40,7 +56,7 @@
    "metadata": {},
    "outputs": [
     {
-     "name": "stdin",
+     "name": "stdout",
      "output_type": "stream",
      "text": [
       " ········\n"
@@ -55,6 +71,16 @@
     "os.environ[\"REPLICATE_API_TOKEN\"] = REPLICATE_API_TOKEN"
    ]
   },
+  {
+   "cell_type": "markdown",
+   "id": "1e586b75",
+   "metadata": {},
+   "source": [
+    "Next we call the Llama 2 model from replicate. In this example we will use the llama 2 13b chat model. You can find more Llama 2 models by searching for them on the [Replicate model explore page](https://replicate.com/explore?query=llama).\n",
+    "\n",
+    "You can add them here in the format: model_name/version"
+   ]
+  },
   {
    "cell_type": "code",
    "execution_count": 3,
@@ -70,6 +96,20 @@
     ")"
    ]
   },
+  {
+   "cell_type": "markdown",
+   "id": "6d421ae7",
+   "metadata": {},
+   "source": [
+    "Next you will need create the `nba_roster.db` file. \n",
+    "\n",
+    "To do this run the following commands while in this folder:\n",
+    "- `python txt2csv.py`  This will convert the `nba.txt` file to `nba_roster.csv`. The `nba.txt` file was created by scraping the NBA roster info from the web.\n",
+    "- Then run `python csv2db.py` to convert `nba_roster.csv` to `nba_roster.db`.\n",
+    "\n",
+    "Once you have your `nba_roster.db` ready, we set up the database to be queried by Llama 2 through Langchain's [SQL chains](https://python.langchain.com/docs/use_cases/qa_structured/sql)."
+   ]
+  },
   {
    "cell_type": "code",
    "execution_count": 4,
@@ -77,9 +117,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "# The nba_roster.db was created by running the two scripts:\n",
-    "# python txt2csv.py # convert the `nba.txt` file, created by scraping the NBA roster info from the web, to nba_roster.csv\n",
-    "# python csv2db.py # convert nba_roster.csv to nba_roster.db\n",
+    "\n",
     "db = SQLDatabase.from_uri(\"sqlite:///nba_roster.db\", sample_rows_in_table_info= 0)\n",
     "\n",
     "PROMPT_SUFFIX = \"\"\"\n",
@@ -93,6 +131,14 @@
     "                                     template=PROMPT_SUFFIX))"
    ]
   },
+  {
+   "cell_type": "markdown",
+   "id": "afcf423a",
+   "metadata": {},
+   "source": [
+    "We will go ahead and turn on LangChain debug to get an idea of how many calls are made to Llama 2 and what the inputs and outputs are."
+   ]
+  },
   {
    "cell_type": "code",
    "execution_count": 5,
@@ -158,7 +204,7 @@
     }
    ],
    "source": [
-    "# turn on the debug of LangChain so we can see how many calls to Llama are made and exactly what are inputs and outputs\n",
+    "\n",
     "import langchain\n",
     "langchain.debug = True\n",
     "\n",
@@ -304,6 +350,18 @@
     "db_chain.run(\"What's his salary?\")"
    ]
   },
+  {
+   "cell_type": "markdown",
+   "id": "98b2c523",
+   "metadata": {},
+   "source": [
+    "\n",
+    "Since we did not pass any context along with the follow-up to the model it did not know who \"his\" is and just picked LeBron James.\n",
+    "\n",
+    "Let's try to fix the issue that the context (the previous question and answer) was not sent to the model along with the new question.\n",
+    "`SQLDatabaseChain.from_llm` has a parameter \"memory\" which can be set to a `ConversationBufferMemory` instance, which looks promising.\n"
+   ]
+  },
   {
    "cell_type": "code",
    "execution_count": 8,
@@ -311,10 +369,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "# since we didn't pass the context along with the follow-up to llm so it didn't know what \"his\" is and just picked LeBron James\n",
     "\n",
-    "# let's try to fix the issue that the context (the previous question and answer) was not sent to LLM along with the new question\n",
-    "# SQLDatabaseChain.from_llm has a parameter \"memory\" which can be set to a ConversationBufferMemory instance, which looks promising.\n",
     "from langchain.memory import ConversationBufferMemory\n",
     "\n",
     "memory = ConversationBufferMemory()\n",

+ 118 - 11
demo_apps/VideoSummary.ipynb

@@ -6,9 +6,24 @@
    "metadata": {},
    "source": [
     "## This demo app shows:\n",
-    "* how to use LangChain's YoutubeLoader to retrieve the caption in a YouTube video;\n",
-    "* how to ask Llama to summarize the content (per the Llama's input size limit) of the video in a naive way using LangChain's stuff method;\n",
-    "* how to bypass the limit of Llama's max input token size by using more sophisticated way using LangChain's map_reduce and refine methods - see [here](https://python.langchain.com/docs/use_cases/summarization) for more info."
+    "* How to use LangChain's YoutubeLoader to retrieve the caption in a YouTube video\n",
+    "* How to ask Llama to summarize the content (per the Llama's input size limit) of the video in a naive way using LangChain's stuff method\n",
+    "* How to bypass the limit of Llama's max input token size by using a more sophisticated way using LangChain's map_reduce and refine methods - see [here](https://python.langchain.com/docs/use_cases/summarization) for more info"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "c866f6be",
+   "metadata": {},
+   "source": [
+    "We start by installing the necessary packages:\n",
+    "- [youtube-transcript-api](https://pypi.org/project/youtube-transcript-api/) API to get transcript/subtitles of a YouTube video\n",
+    "- [langchain](https://python.langchain.com/docs/get_started/introduction) provides necessary RAG tools for this demo\n",
+    "- [tiktoken](https://github.com/openai/tiktoken) BytePair Encoding tokenizer\n",
+    "- [pytube](https://pytube.io/en/latest/) Utility for downloading YouTube videos\n",
+    "\n",
+    "**Note** This example uses Replicate to host the Llama model. If you have not set up/or used Replicate before, we suggest you take a look at the [HelloLlamaCloud](HelloLlamaCloud.ipynb) example for information on how to set up Replicate before continuing with this example.\n",
+    "If you do not want to use Replicate, you will need to make some changes to this notebook as you go along."
    ]
   },
   {
@@ -21,6 +36,14 @@
     "!pip install langchain youtube-transcript-api tiktoken pytube"
    ]
   },
+  {
+   "cell_type": "markdown",
+   "id": "af3069b1",
+   "metadata": {},
+   "source": [
+    "Let's load the YouTube video transcript using the YoutubeLoader."
+   ]
+  },
   {
    "cell_type": "code",
    "execution_count": 1,
@@ -69,6 +92,25 @@
     "len(docs[0].page_content), docs[0].page_content[:300]"
    ]
   },
+  {
+   "cell_type": "markdown",
+   "id": "4af7cc16",
+   "metadata": {},
+   "source": [
+    "We are using Replicate in this example to host our Llama 2 model so you will need to get a Replicate token.\n",
+    "\n",
+    "To get the Replicate token: \n",
+    "\n",
+    "- You will need to first sign in with Replicate with your github account\n",
+    "- Then create a free API token [here](https://replicate.com/account/api-tokens) that you can use for a while. \n",
+    "\n",
+    "**Note** After the free trial ends, you will need to enter billing info to continue to use Llama2 hosted on Replicate.\n",
+    "\n",
+    "Alternatively, you can run Llama locally. See:\n",
+    "- [HelloLlamaCloud](HelloLlamaCloud.ipynb) for further information on how to run Llama using Replicate.\n",
+    "- [HelloLlamaLocal](HelloLlamaLocal.ipynb) for further information on how to run Llama locally."
+   ]
+  },
   {
    "cell_type": "code",
    "execution_count": 4,
@@ -76,7 +118,7 @@
    "metadata": {},
    "outputs": [
     {
-     "name": "stdin",
+     "name": "stdout",
      "output_type": "stream",
      "text": [
       " ········\n"
@@ -92,6 +134,18 @@
     "os.environ[\"REPLICATE_API_TOKEN\"] = REPLICATE_API_TOKEN\n"
    ]
   },
+  {
+   "cell_type": "markdown",
+   "id": "6b911efd",
+   "metadata": {},
+   "source": [
+    "Next we call the Llama 2 model from Replicate. In this example we will use the llama 2 13b chat model. You can find more Llama 2 models by searching for them on the [Replicate model explore page](https://replicate.com/explore?query=llama).\n",
+    "\n",
+    "You can add them here in the format: model_name/version\n",
+    "\n",
+    "If you using local Llama, just set llm accordingly - see the [HelloLlamaLocal notebook](HelloLlamaLocal.ipynb)"
+   ]
+  },
   {
    "cell_type": "code",
    "execution_count": null,
@@ -99,7 +153,7 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "# set llm to be Llama2-13b model; if you use local Llama, just set llm accordingly - see the HelloLlamaLocal notebook\n",
+    "\n",
     "from langchain.llms import Replicate\n",
     "\n",
     "llama2_13b = \"meta/llama-2-13b-chat:f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d\"\n",
@@ -109,6 +163,14 @@
     ")"
    ]
   },
+  {
+   "cell_type": "markdown",
+   "id": "8e3baa56",
+   "metadata": {},
+   "source": [
+    "Once everything is set up, we prompt Llama 2 to summarize the first 4000 characters of the transcript for us."
+   ]
+  },
   {
    "cell_type": "code",
    "execution_count": 6,
@@ -141,6 +203,14 @@
     "print(summary)"
    ]
   },
+  {
+   "cell_type": "markdown",
+   "id": "8b684b29",
+   "metadata": {},
+   "source": [
+    "Next we try to summarize all the content of the transcript and we should get a `RuntimeError: Your input is too long. Max input length is 4096 tokens, but you supplied 5597 tokens.`."
+   ]
+  },
   {
    "cell_type": "code",
    "execution_count": 7,
@@ -174,8 +244,18 @@
     "# try to get a summary of the whole content\n",
     "text = docs[0].page_content\n",
     "summary = chain.run(text)\n",
-    "print(summary)\n",
-    "# and you'll get - RuntimeError: Your input is too long. Max input length is 4096 tokens, but you supplied 5597 tokens."
+    "print(summary)\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "1ad1881a",
+   "metadata": {},
+   "source": [
+    "\n",
+    "Let's try some workarounds to see if we can summarize the entire transcript without running into the `RuntimeError`.\n",
+    "\n",
+    "We will use the LangChain's `load_summarize_chain` and play around with the `chain_type`.\n"
    ]
   },
   {
@@ -260,6 +340,15 @@
     "chain.run(docs)"
    ]
   },
+  {
+   "cell_type": "markdown",
+   "id": "aecf6328",
+   "metadata": {},
+   "source": [
+    "\n",
+    "Since the transcript is bigger than the model can handle, we can split the transcript into chunks instead and use the [`refine`](https://python.langchain.com/docs/modules/chains/document/refine) `chain_type` to iteratively create an answer."
+   ]
+  },
   {
    "cell_type": "code",
    "execution_count": 10,
@@ -321,6 +410,14 @@
     "chain.run(split_docs)"
    ]
   },
+  {
+   "cell_type": "markdown",
+   "id": "c3976c92",
+   "metadata": {},
+   "source": [
+    "You can also use [`map_reduce`](https://python.langchain.com/docs/modules/chains/document/map_reduce) `chain_type` to implement a map reduce like architecture while summarizing the documents."
+   ]
+  },
   {
    "cell_type": "code",
    "execution_count": 14,
@@ -400,6 +497,15 @@
     "chain.run(split_docs)"
    ]
   },
+  {
+   "cell_type": "markdown",
+   "id": "77d580de",
+   "metadata": {},
+   "source": [
+    "To investigate further, let's turn on Langchain's debug mode on to get an idea of how many calls are made to the model and the details of the inputs and outputs.\n",
+    "We will then run our summary using the `stuff` and `refine` `chain_types` and take a look at our output."
+   ]
+  },
   {
    "cell_type": "code",
    "execution_count": 15,
@@ -559,12 +665,13 @@
    ]
   },
   {
-   "cell_type": "code",
-   "execution_count": null,
+   "cell_type": "markdown",
    "id": "61ccd0fb-5cdb-43c4-afaf-05bc9f7cf959",
    "metadata": {},
-   "outputs": [],
-   "source": []
+   "source": [
+    "\n",
+    "As you can see, `stuff` fails because it tries to treat all the split documents as one and \"stuffs\" it into one prompt which leads to a much larger prompt than Llama 2 can handle while `refine` iteratively runs over the documents updating its answer as it goes."
+   ]
   }
  ],
  "metadata": {

File diff suppressed because it is too large
+ 184 - 0
demo_apps/llama-on-prem.md


+ 61 - 0
demo_apps/llama_chatbot.py

@@ -0,0 +1,61 @@
+import langchain
+from langchain.llms import Replicate
+
+from flask import Flask
+from flask import request
+import os
+import requests
+import json
+
+class WhatsAppClient:
+
+    API_URL = "https://graph.facebook.com/v17.0/"
+    WHATSAPP_API_TOKEN = "<Temporary access token from your WhatsApp API Setup>"
+    WHATSAPP_CLOUD_NUMBER_ID = "<Phone number ID from your WhatsApp API Setup>"
+
+    def __init__(self):
+        self.headers = {
+            "Authorization": f"Bearer {self.WHATSAPP_API_TOKEN}",
+            "Content-Type": "application/json",
+        }
+        self.API_URL = self.API_URL + self.WHATSAPP_CLOUD_NUMBER_ID
+
+    def send_text_message(self,message, phone_number):
+        payload = {
+            "messaging_product": 'whatsapp',
+            "to": phone_number,
+            "type": "text",
+            "text": {
+                "preview_url": False,
+                "body": message
+            }
+        }
+        response = requests.post(f"{self.API_URL}/messages", json=payload,headers=self.headers)
+        print(response.status_code)
+        assert response.status_code == 200, "Error sending message"
+        return response.status_code
+
+os.environ["REPLICATE_API_TOKEN"] = "<your replicate api token>"    
+llama2_13b_chat = "meta/llama-2-13b-chat:f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d"
+
+llm = Replicate(
+    model=llama2_13b_chat,
+    model_kwargs={"temperature": 0.01, "top_p": 1, "max_new_tokens":500}
+)
+client = WhatsAppClient()
+app = Flask(__name__)
+
+@app.route("/")
+def hello_llama():
+    return "<p>Hello Llama 2</p>"
+
+@app.route('/msgrcvd', methods=['POST', 'GET'])
+def msgrcvd():    
+    message = request.args.get('message')
+    #client.send_template_message("hello_world", "en_US", "14086745477")
+    answer = llm(message)
+    print(message)
+    print(answer)
+    client.send_text_message(llm(message), "14086745477")
+    return message + "<p/>" + answer
+

BIN
demo_apps/whatsapp_dashboard.jpg


File diff suppressed because it is too large
+ 160 - 0
demo_apps/whatsapp_llama2.md


BIN
demo_apps/whatsapp_llama_arch.jpg


+ 15 - 3
docs/Dataset.md

@@ -7,6 +7,18 @@ The provided fine tuning script allows you to select between three datasets by p
 * [samsum_dataset](https://huggingface.co/datasets/samsum) contains about 16k messenger-like conversations with summaries.
 * [OpenAssistant/oasst1](https://huggingface.co/datasets/OpenAssistant/oasst1/) contains about 88k messages from assistant-style conversations.
 
+## Batching Strategies
+Llama-recipes support two strategies to batch requests together.
+The default setting is `packing` which concatenates the tokenized samples into long sequences filling up the context length of the model.
+This is the most compute efficient variant as it avoids any padding and all sequences have the same length.
+Samples at the boundary of the context length are truncated and the remainder of the cut sequence it used as the start of the next long sequence.
+
+If the amount of training data is small this procedure might introduce a lot of noise into the training data which can hurt the prediction performance of the fine-tune model.
+Therefore, we also support a `padding` strategy which does not introduce the addition noise due to truncated sequences.
+The strategy tries to minimize the efficiency loss by batching samples of similar length together so only minimal padding is necessary.
+
+The batching strategy can be selected though the command line parameter `--batching_strategy [packing]/[padding]`.
+
 ## Using custom datasets
 
 The list of available datasets in llama-recipes is supposed to give users a quick start on training their Llama model.
@@ -23,9 +35,9 @@ def get_custom_dataset(dataset_config, tokenizer, split: str):
 For an example `get_custom_dataset` you can look at the provided datasets in llama_recipes.datasets or [examples/custom_dataset.py](../examples/custom_dataset.py).
 The `dataset_config` in the above signature will be an instance of llama_recipes.configs.dataset.custom_dataset with the modifications made through the command line.
 The split signals wether to return the training or validation dataset.
-The default function name is `get_custom_dataset` but this can be changes as described below.
+The default function name is `get_custom_dataset` but this can be changed as described below.
 
-In order to start a training with the custom dataset we need to set the `--dataset` as well as the `--custom_dataset.file` parameter. 
+In order to start a training with the custom dataset we need to set the `--dataset` as well as the `--custom_dataset.file` parameter.
 ```
 python -m llama_recipes.finetuning --dataset "custom_dataset" --custom_dataset.file "examples/custom_dataset.py" [TRAINING PARAMETERS]
 ```
@@ -35,7 +47,7 @@ python -m llama_recipes.finetuning --dataset "custom_dataset" --custom_dataset.f
 ```
 This will call the function `get_foo` instead of `get_custom_dataset` when retrieving the dataset.
 
-### Adding new dataset 
+### Adding new dataset
 Each dataset has a corresponding configuration (dataclass) in [configs/datasets.py](../src/llama_recipes/configs/datasets.py) which contains the dataset name, training/validation split names, as well as optional parameters like datafiles etc.
 
 Additionally, there is a preprocessing function for each dataset in the [datasets](../src/llama_recipes/datasets) folder.

File diff suppressed because it is too large
+ 21 - 7
docs/FAQ.md


+ 2 - 0
docs/inference.md

@@ -144,3 +144,5 @@ python examples/vllm/inference.py --model_name <PATH/TO/MODEL/7B>
 ```
 
 [**TGI**](https://github.com/huggingface/text-generation-inference): Text Generation Inference (TGI) is another inference option available to you. For more information on how to set up and use TGI see [here](../examples/hf_text_generation_inference/README.md).
+
+[Here](../demo_apps/llama-on-prem.md) is a complete tutorial on how to use vLLM and TGI to deploy Llama 2 on-prem and interact with the Llama API services.

+ 1 - 1
examples/README.md

@@ -10,7 +10,7 @@ After installing the llama-recipes package through [pip](../README.md#installati
 ```
 python -m llama_recipes.finetuning <parameters>
 
-python examnples/finetuning.py <parameters>
+python examples/finetuning.py <parameters>
 ```
 Please see [README.md](../README.md) for details.
 

+ 23 - 30
examples/custom_dataset.py

@@ -7,33 +7,27 @@ import copy
 import datasets
 import itertools
 
-from llama_recipes.datasets.utils import Concatenator
-
 
 B_INST, E_INST = "[INST]", "[/INST]"
-B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
 
 def tokenize_dialog(dialog, tokenizer):
-    dialog_tokens = [
-            tokenizer(
-                f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ",
-            )
-            for prompt, answer in zip(dialog[::2], dialog[1::2])
-        ]
-    if len(dialog) % 2:    
-        dialog_tokens += [tokenizer(
-            f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}",
-        )]
-    
-    combined_tokens = {}  
-    for k in dialog_tokens[0].keys():
-        combined_tokens[k] = list(itertools.chain(*(t[k] for t in dialog_tokens)))
-    return combined_tokens
+    prompt_tokens = [tokenizer.encode(f"{tokenizer.bos_token}{B_INST} {(prompt['content']).strip()} {E_INST}", add_special_tokens=False) for prompt in dialog[::2]]
+    answer_tokens = [tokenizer.encode(f"{answer['content'].strip()} {tokenizer.eos_token}", add_special_tokens=False) for answer in dialog[1::2]]
+    dialog_tokens = list(itertools.chain.from_iterable(zip(prompt_tokens, answer_tokens)))
+    #Add labels, convert prompt token to -100 in order to ignore in loss function
+    labels_tokens = [len(c)*[-100,] if i % 2 == 0 else c for i,c in enumerate(dialog_tokens)]
+
+    combined_tokens = {
+        "input_ids": list(itertools.chain(*(t for t in dialog_tokens))),
+        "labels": list(itertools.chain(*(t for t in labels_tokens))),
+    }
+
+    return dict(combined_tokens, attention_mask=[1]*len(combined_tokens["input_ids"]))
 
 
 def get_custom_dataset(dataset_config, tokenizer, split):
     dataset = datasets.load_dataset("OpenAssistant/oasst1", split=split)
-    
+
     dataset = dataset.map(lambda sample: {
         "message_id": sample["message_id"],
         "parent_id": sample["parent_id"],
@@ -41,19 +35,19 @@ def get_custom_dataset(dataset_config, tokenizer, split):
         },
         batched=True,
         remove_columns=list(dataset.features),)
-    
+
     nodes = {}
-    
+
     messages = {}
     root_ids = []
-    
+
     for data in dataset:
         if data["parent_id"]:
             nodes[data["parent_id"]] = nodes.get(data["parent_id"], []) + [data["message_id"]]
         else:
             root_ids.append(data["message_id"])
         messages[data["message_id"]]=data["text"]
-           
+
     def follow(thread, current_id):
         thread = copy.copy(thread) + [messages[current_id]]
         if current_id in nodes:
@@ -63,18 +57,18 @@ def get_custom_dataset(dataset_config, tokenizer, split):
             return new_threads
         else:
             return [thread]
-        
+
     def get_threads_from_root(root_id):
         all_threads = []
         thread = [messages[root_id]]
         for cid in nodes[root_id]:
             all_threads += follow(thread, cid)
         return all_threads
-            
+
     dataset = dataset.filter(lambda x: x["message_id"] in root_ids)
     dataset = dataset.map(lambda x: {"thread": get_threads_from_root(x["message_id"])}, remove_columns=list(dataset.features))
     dataset = dataset.map(lambda x: {"thread": [i for row in x["thread"] for i in row]}, batched=True)
-    
+
     def to_dialog(thread):
         dialog = []
         for i, content in enumerate(thread):
@@ -83,9 +77,8 @@ def get_custom_dataset(dataset_config, tokenizer, split):
                 "content": content,
             })
         return {"dialog": dialog}
-            
+
     dataset = dataset.map(lambda x: to_dialog(x["thread"]), remove_columns=list(dataset.features))
     dataset = dataset.map(lambda x: tokenize_dialog(x["dialog"], tokenizer), remove_columns=list(dataset.features))
-    dataset = dataset.map(Concatenator(), batched=True)
-    
-    return dataset
+
+    return dataset

+ 33 - 1
scripts/spellcheck_conf/wordlist.txt

@@ -1157,6 +1157,7 @@ FN
 GBs
 MLP
 learnable
+tokenized
 Colab
 GenAI
 Gradio
@@ -1182,4 +1183,35 @@ minnutes
 pdf
 quantized
 serarch
-streamlit
+streamlit
+prem
+Prem
+OpenAI
+Prem
+TCP
+ba
+llm
+logprobs
+openai
+rohit
+tgi
+Axios
+Chatbot
+WHATSAPP
+Webhooks
+WhatsApp
+WhatsAppClient
+adffb
+axios
+baba
+chatbot
+chatbots
+de
+eeeb
+gunicorn
+knowledgable
+msgrcvd
+venv
+webhook
+webhook's
+whatsapp

+ 0 - 2
src/llama_recipes/configs/datasets.py

@@ -9,7 +9,6 @@ class samsum_dataset:
     dataset: str =  "samsum_dataset"
     train_split: str = "train"
     test_split: str = "validation"
-    input_length: int = 2048
     
     
 @dataclass
@@ -17,7 +16,6 @@ class grammar_dataset:
     dataset: str = "grammar_dataset"
     train_split: str = "src/llama_recipes/datasets/grammar_dataset/gtrain_10k.csv" 
     test_split: str = "src/llama_recipes/datasets/grammar_dataset/grammar_validation.csv"
-    input_length: int = 2048
 
     
 @dataclass

+ 2 - 4
src/llama_recipes/configs/training.py

@@ -11,6 +11,8 @@ class train_config:
     low_cpu_fsdp: bool=False
     run_validation: bool=True
     batch_size_training: int=4
+    batching_strategy: str="packing" #alternative: padding
+    context_length: int=4096
     gradient_accumulation_steps: int=1
     num_epochs: int=3
     num_workers_dataloader: int=1
@@ -34,7 +36,3 @@ class train_config:
     dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP
     save_optimizer: bool=False # will be used if using FSDP
     use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
-
-    
-    
-    

+ 2 - 0
src/llama_recipes/data/__init__.py

@@ -0,0 +1,2 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

+ 34 - 0
src/llama_recipes/data/concatenator.py

@@ -0,0 +1,34 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+from tqdm import tqdm
+from itertools import chain
+
+from torch.utils.data import Dataset
+
+
+class ConcatDataset(Dataset):
+    def __init__(self, dataset, chunk_size=4096):
+        self.dataset = dataset
+        self.chunk_size = chunk_size
+
+        self.samples = []
+
+        buffer = {
+            "input_ids": [],
+            "attention_mask": [],
+            "labels": [],
+            }
+
+        for sample in tqdm(self.dataset, desc="Preprocessing dataset", dynamic_ncols=True):
+            buffer = {k: v + sample[k] for k,v in buffer.items()}
+
+            while len(next(iter(buffer.values()))) > self.chunk_size:
+                self.samples.append({k: v[:self.chunk_size] for k,v in buffer.items()})
+                buffer = {k: v[self.chunk_size:] for k,v in buffer.items()}
+
+    def __getitem__(self, idx):
+        return self.samples[idx]
+
+    def __len__(self):
+        return len(self.samples)

+ 57 - 0
src/llama_recipes/data/sampler.py

@@ -0,0 +1,57 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import random
+from itertools import islice
+
+import numpy as np
+import torch
+
+
+class LengthBasedBatchSampler(torch.utils.data.BatchSampler):
+    def __init__(self, data_source, batch_size: int, drop_last: bool, shuffle: bool=True) -> None:
+        if isinstance(next(iter(data_source)), dict):
+            first_key = next(iter(next(iter(data_source)).keys()))
+            self.lengths = [len(d[first_key]) for d in data_source]
+        else:
+            self.lengths = [len(d) for d in data_source]
+        self.batch_size = batch_size
+        self.drop_last = drop_last
+        self.shuffle = shuffle
+
+    def __iter__(self):
+        ids = np.argsort(self.lengths)
+        if self.drop_last:
+            ids = ids[:len(ids) // self.batch_size * self.batch_size]
+
+        batches = [ids[i:i+self.batch_size] for i in range(0, len(ids), self.batch_size)]
+
+        if self.shuffle:
+            random.shuffle(batches)
+
+        for b in batches:
+            yield b
+
+    def __len__(self):
+        if self.drop_last:
+            return len(self.lengths) // self.batch_size
+        else:
+            return len(self.lengths) // self.batch_size + (len(self.lengths) % self.batch_size > 0)
+
+
+class DistributedLengthBasedBatchSampler(torch.utils.data.BatchSampler):
+    def __init__(self, data_source, batch_size: int, num_replicas: int, rank: int, shuffle: bool = True, seed: int = 0) -> None:
+        random.seed(seed)
+        self.batch_sampler = LengthBasedBatchSampler(
+            data_source, batch_size=batch_size, drop_last=True, shuffle=shuffle
+            )
+        self.num_replicas = num_replicas
+        self.rank = rank
+        
+    def __iter__(self):
+        max_length = len(self.batch_sampler) // self.num_replicas * self.num_replicas
+        return islice(self.batch_sampler, self.rank, max_length, self.num_replicas)
+         
+    def __len__(self):
+        return len(self.batch_sampler) // self.num_replicas
+            

+ 4 - 14
src/llama_recipes/datasets/alpaca_dataset.py

@@ -24,17 +24,14 @@ PROMPT_DICT = {
 }
 
 class InstructionDataset(Dataset):
-    def __init__(self, dataset_config, tokenizer, partition="train", max_words=30):
+    def __init__(self, dataset_config, tokenizer, partition="train"):
         self.ann = json.load(open(dataset_config.data_path))
         if partition == "train":
             self.ann = self.ann
         else:
             self.ann = self.ann[:200]
 
-        self.max_words = max_words
-        # tokenizer = Tokenizer(model_path=model_path + "./tokenizer.model")
         self.tokenizer = tokenizer
-        # self.tokenizer1 = tokenizer
 
     def __len__(self):
         return len(self.ann)
@@ -57,22 +54,15 @@ class InstructionDataset(Dataset):
         example = torch.tensor(
             example, dtype=torch.int64
         )
-        padding = self.max_words - example.shape[0]
-        if padding > 0:
-            example = torch.cat((example, torch.zeros(padding, dtype=torch.int64) - 1))
-        elif padding < 0:
-            example = example[: self.max_words]
         labels = copy.deepcopy(example)
         labels[: len(prompt)] = -1
         example_mask = example.ge(0)
         label_mask = labels.ge(0)
         example[~example_mask] = 0
         labels[~label_mask] = IGNORE_INDEX
-        example_mask = example_mask.float()
-        label_mask = label_mask.float()
 
         return {
-            "input_ids": example,
-            "labels": labels,
-            "attention_mask":example_mask,
+            "input_ids": example.tolist(),
+            "labels": labels.tolist(),
+            "attention_mask":example_mask.tolist(),
         }

+ 13 - 18
src/llama_recipes/datasets/grammar_dataset/grammar_dataset.py

@@ -10,8 +10,6 @@ from pathlib import Path
 
 from torch.utils.data import Dataset
 
-from llama_recipes.datasets.utils import ConcatDataset
-
 
 class grammar(Dataset):
     def __init__(
@@ -48,24 +46,22 @@ class grammar(Dataset):
 
         input_ = example_batch["input"]
         target_ = example_batch["target"]
-        
-        prompt = f"Correct this to standard English: {input_}\n---\nCorrected: {target_}"
-        sample = self.tokenizer(prompt)
-        
-        return sample
-
-    def __getitem__(self, index):
-        sample = self.convert_to_features(self.dataset["train"][index])
-        source_ids = sample["input_ids"]
 
-        src_mask = sample["attention_mask"]
+        prompt = f"Correct this to standard English: {input_}\n---\nCorrected: "
+        prompt_ids = self.tokenizer.encode(self.tokenizer.bos_token + prompt, add_special_tokens=False)
+        label_ids = self.tokenizer.encode(target_ + self.tokenizer.eos_token, add_special_tokens=False)
 
-        return {
-            "input_ids": source_ids,
-            "attention_mask": src_mask,
-            "labels": source_ids.copy(),
+        sample = {
+            "input_ids": prompt_ids + label_ids,
+            "attention_mask": [1] * len(prompt_ids + label_ids),
+            "labels": [-100] * len(prompt_ids) + label_ids
         }
 
+        return sample
+
+    def __getitem__(self, index):
+        return self.convert_to_features(self.dataset["train"][int(index)])
+
 
 def get_dataset(
     dataset_config, tokenizer, csv_name=None
@@ -80,6 +76,5 @@ def get_dataset(
         tokenizer=tokenizer,
         csv_name=csv_name,
     )
-    
-    return ConcatDataset(dataset, chunk_size=dataset_config.input_length)
 
+    return dataset

+ 19 - 13
src/llama_recipes/datasets/samsum_dataset.py

@@ -3,31 +3,37 @@
 
 # For dataset details visit: https://huggingface.co/datasets/samsum
 
+import copy
 import datasets
 
-from llama_recipes.datasets.utils import Concatenator
 
 def get_preprocessed_samsum(dataset_config, tokenizer, split):
     dataset = datasets.load_dataset("samsum", split=split)
 
     prompt = (
-        f"Summarize this dialog:\n{{dialog}}\n---\nSummary:\n{{summary}}{{eos_token}}"
+        f"Summarize this dialog:\n{{dialog}}\n---\nSummary:\n"
     )
 
     def apply_prompt_template(sample):
         return {
-            "text": prompt.format(
-                dialog=sample["dialogue"],
-                summary=sample["summary"],
-                eos_token=tokenizer.eos_token,
-            )
+            "prompt": prompt.format(dialog=sample["dialogue"]),
+            "summary": sample["summary"],
         }
 
     dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features))
-        
-    dataset = dataset.map(
-        lambda sample: tokenizer(sample["text"]),
-        batched=True,
-        remove_columns=list(dataset.features),
-    ).map(Concatenator(), batched=True)
+
+    def tokenize_add_label(sample):
+        prompt = tokenizer.encode(tokenizer.bos_token + sample["prompt"], add_special_tokens=False)
+        summary = tokenizer.encode(sample["summary"] +  tokenizer.eos_token, add_special_tokens=False)
+
+        sample = {
+            "input_ids": prompt + summary,
+            "attention_mask" : [1] * (len(prompt) + len(summary)),
+            "labels": [-100] * len(prompt) + summary,
+            }
+
+        return sample
+
+    dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features))
+
     return dataset

+ 0 - 66
src/llama_recipes/datasets/utils.py

@@ -1,66 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
-
-from tqdm import tqdm
-from itertools import chain
-
-from torch.utils.data import Dataset
-
-class Concatenator(object):
-    def __init__(self, chunk_size=2048):
-        self.chunk_size=chunk_size
-        self.residual = {"input_ids": [], "attention_mask": []}
-        
-    def __call__(self, batch):
-        concatenated_samples = {
-            k: v + list(chain(*batch[k])) for k, v in self.residual.items()
-        }
-
-        total_length = len(concatenated_samples[list(concatenated_samples.keys())[0]])
-
-        if total_length >= self.chunk_size:
-            chunk_num = total_length // self.chunk_size
-            result = {
-                k: [
-                    v[i : i + self.chunk_size]
-                    for i in range(0, chunk_num * self.chunk_size, self.chunk_size)
-                ]
-                for k, v in concatenated_samples.items()
-            }
-            self.residual = {
-                k: v[(chunk_num * self.chunk_size) :]
-                for k, v in concatenated_samples.items()
-            }
-        else:
-            result = concatenated_samples
-            self.residual = {k: [] for k in concatenated_samples.keys()}
-
-        result["labels"] = result["input_ids"].copy()
-
-        return result
-
-class ConcatDataset(Dataset):
-    def __init__(self, dataset, chunk_size=4096):
-        self.dataset = dataset
-        self.chunk_size = chunk_size
-        
-        self.samples = []
-        
-        buffer = {
-            "input_ids": [],
-            "attention_mask": [],
-            "labels": [],
-            }
-        
-        for sample in tqdm(self.dataset, desc="Preprocessing dataset", dynamic_ncols=True):
-            buffer = {k: v + sample[k] for k,v in buffer.items()}
-            
-            while len(next(iter(buffer.values()))) > self.chunk_size:
-                self.samples.append({k: v[:self.chunk_size] for k,v in buffer.items()})
-                buffer = {k: v[self.chunk_size:] for k,v in buffer.items()}
-                
-    def __getitem__(self, idx):
-        return self.samples[idx]
-    
-    def __len__(self):
-        return len(self.samples)

+ 25 - 37
src/llama_recipes/finetuning.py

@@ -5,8 +5,8 @@ import os
 from pkg_resources import packaging
 
 import fire
+import random
 import torch
-import torch.distributed as dist
 import torch.optim as optim
 from peft import get_peft_model, prepare_model_for_int8_training
 from torch.distributed.fsdp import (
@@ -14,16 +14,16 @@ from torch.distributed.fsdp import (
 )
 from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
 from torch.optim.lr_scheduler import StepLR
-from torch.utils.data import DistributedSampler
 from transformers import (
     LlamaForCausalLM,
     LlamaTokenizer,
     LlamaConfig,
-    default_data_collator,
 )
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer
 
-from llama_recipes.configs import fsdp_config, train_config
+from llama_recipes.configs import fsdp_config as FSDP_CONFIG
+from llama_recipes.configs import train_config as TRAIN_CONFIG
+from llama_recipes.data.concatenator import ConcatDataset
 from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing
 
 from llama_recipes.utils import fsdp_auto_wrap_policy
@@ -31,6 +31,7 @@ from llama_recipes.utils.config_utils import (
     update_config,
     generate_peft_config,
     generate_dataset_config,
+    get_dataloader_kwargs,
 )
 from llama_recipes.utils.dataset_utils import get_preprocessed_dataset
 
@@ -47,6 +48,7 @@ from accelerate.utils import is_xpu_available
 
 def main(**kwargs):
     # Update the configuration for the training and sharding process
+    train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG()
     update_config((train_config, fsdp_config), **kwargs)
 
     # Set the seeds for reproducibility
@@ -55,6 +57,7 @@ def main(**kwargs):
     else:
         torch.cuda.manual_seed(train_config.seed)
     torch.manual_seed(train_config.seed)
+    random.seed(train_config.seed)
 
     if train_config.enable_fsdp:
         setup()
@@ -108,14 +111,19 @@ def main(**kwargs):
     if train_config.enable_fsdp and train_config.use_fast_kernels:
         """
         For FSDP and FSDP+PEFT, setting 'use_fast_kernels' will enable
-        using of Flash Attention or Xformer memory-efficient kernels 
+        using of Flash Attention or Xformer memory-efficient kernels
         based on the hardware being used. This would speed up fine-tuning.
         """
         try:
             from optimum.bettertransformer import BetterTransformer
-            model = BetterTransformer.transform(model) 
+            model = BetterTransformer.transform(model)
         except ImportError:
             print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")
+
+    # Load the tokenizer and add special tokens
+    tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name)
+    tokenizer.pad_token_id = tokenizer.eos_token_id
+
     print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
 
     # Prepare the model for int8 training if quantization is enabled
@@ -126,14 +134,6 @@ def main(**kwargs):
     if train_config.enable_fsdp and fsdp_config.pure_bf16:
         model.to(torch.bfloat16)
 
-    # Load the tokenizer and add special tokens
-    tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name)
-    tokenizer.add_special_tokens(
-            {
-
-                "pad_token": "<PAD>",
-            }
-        )
     if train_config.use_peft:
         peft_config = generate_peft_config(train_config, kwargs)
         model = get_peft_model(model, peft_config)
@@ -188,43 +188,31 @@ def main(**kwargs):
     if not train_config.enable_fsdp or rank == 0:
             print(f"--> Validation Set Length = {len(dataset_val)}")
 
-    train_sampler = None
-    val_sampler = None
-    if train_config.enable_fsdp:
-        train_sampler = DistributedSampler(
-            dataset_train,
-            rank=dist.get_rank(),
-            num_replicas=dist.get_world_size(),
-            shuffle=True,
-        )
-        if train_config.run_validation:
-            val_sampler = DistributedSampler(
-                dataset_val,
-                rank=dist.get_rank(),
-                num_replicas=dist.get_world_size(),
-            )
+    if train_config.batching_strategy == "packing":
+        dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
+
+    train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train")
 
     # Create DataLoaders for the training and validation dataset
     train_dataloader = torch.utils.data.DataLoader(
         dataset_train,
-        batch_size=train_config.batch_size_training,
         num_workers=train_config.num_workers_dataloader,
         pin_memory=True,
-        sampler=train_sampler if train_sampler else None,
-        drop_last=True,
-        collate_fn=default_data_collator,
+        **train_dl_kwargs,
     )
 
     eval_dataloader = None
     if train_config.run_validation:
+        if train_config.batching_strategy == "packing":
+            dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)
+
+        val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, tokenizer, "val")
+
         eval_dataloader = torch.utils.data.DataLoader(
             dataset_val,
-            batch_size=train_config.val_batch_size,
             num_workers=train_config.num_workers_dataloader,
             pin_memory=True,
-            sampler=val_sampler if val_sampler else None,
-            drop_last=True,
-            collate_fn=default_data_collator,
+            **val_dl_kwargs,
         )
 
     # Initialize the optimizer and learning rate scheduler

+ 49 - 11
src/llama_recipes/utils/config_utils.py

@@ -3,13 +3,19 @@
 
 import inspect
 from dataclasses import asdict
+
+import torch.distributed as dist
+from torch.utils.data import DistributedSampler
 from peft import (
     LoraConfig,
     AdaptionPromptConfig,
     PrefixTuningConfig,
 )
+from transformers import default_data_collator
+from transformers.data import DataCollatorForSeq2Seq
 
 from llama_recipes.configs import datasets, lora_config, llama_adapter_config, prefix_config, train_config
+from llama_recipes.data.sampler import LengthBasedBatchSampler, DistributedLengthBasedBatchSampler
 from llama_recipes.utils.dataset_utils import DATASET_PREPROC
 
 
@@ -32,31 +38,63 @@ def update_config(config, **kwargs):
                         print(f"Warning: {config_name} does not accept parameter: {k}")
             elif isinstance(config, train_config):
                 print(f"Warning: unknown parameter {k}")
-                        
-                        
+
+
 def generate_peft_config(train_config, kwargs):
     configs = (lora_config, llama_adapter_config, prefix_config)
     peft_configs = (LoraConfig, AdaptionPromptConfig, PrefixTuningConfig)
     names = tuple(c.__name__.rstrip("_config") for c in configs)
-    
+
     assert train_config.peft_method in names, f"Peft config not found: {train_config.peft_method}"
-    
+
     config = configs[names.index(train_config.peft_method)]()
-    
+
     update_config(config, **kwargs)
     params = asdict(config)
     peft_config = peft_configs[names.index(train_config.peft_method)](**params)
-    
+
     return peft_config
 
 
 def generate_dataset_config(train_config, kwargs):
     names = tuple(DATASET_PREPROC.keys())
-        
+
     assert train_config.dataset in names, f"Unknown dataset: {train_config.dataset}"
-    
+
     dataset_config = {k:v for k, v in inspect.getmembers(datasets)}[train_config.dataset]()
-        
+
     update_config(dataset_config, **kwargs)
-    
-    return  dataset_config
+
+    return  dataset_config
+
+
+def get_dataloader_kwargs(train_config, dataset, tokenizer, mode):
+        kwargs = {}
+        batch_size = train_config.batch_size_training if mode=="train" else train_config.val_batch_size
+        if train_config.batching_strategy == "padding":
+            if train_config.enable_fsdp:
+                kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler(
+                    dataset,
+                    batch_size=batch_size,
+                    rank=dist.get_rank(),
+                    num_replicas=dist.get_world_size(),
+                    shuffle=mode=="train",
+                )
+            else:
+                kwargs["batch_sampler"] = LengthBasedBatchSampler(dataset, batch_size, drop_last=True, shuffle=mode=="train")
+            kwargs["collate_fn"] = DataCollatorForSeq2Seq(tokenizer)
+        elif train_config.batching_strategy == "packing":
+            if train_config.enable_fsdp:
+                kwargs["sampler"] = DistributedSampler(
+                dataset,
+                rank=dist.get_rank(),
+                num_replicas=dist.get_world_size(),
+                shuffle=mode=="train",
+            )
+            kwargs["batch_size"] = batch_size
+            kwargs["drop_last"] = True
+            kwargs["collate_fn"] = default_data_collator
+        else:
+            raise ValueError(f"Unknown batching strategy: {train_config.batching_strategy}")
+
+        return kwargs

+ 6 - 6
src/llama_recipes/utils/dataset_utils.py

@@ -33,24 +33,24 @@ def get_custom_dataset(dataset_config, tokenizer, split: str):
         module_path, func_name = dataset_config.file.split(":")
     else:
         module_path, func_name = dataset_config.file, "get_custom_dataset"
-        
+
     if not module_path.endswith(".py"):
         raise ValueError(f"Dataset file {module_path} is not a .py file.")
-    
+
     module_path = Path(module_path)
     if not module_path.is_file():
         raise FileNotFoundError(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.")
-    
+
     module = load_module_from_py_file(module_path.as_posix())
     try:
         return getattr(module, func_name)(dataset_config, tokenizer, split)
     except AttributeError as e:
         print(f"It seems like the given method name ({func_name}) is not present in the dataset .py file ({module_path.as_posix()}).")
         raise e
-    
+
 
 DATASET_PREPROC = {
-    "alpaca_dataset": partial(get_alpaca_dataset, max_words=224),
+    "alpaca_dataset": partial(get_alpaca_dataset),
     "grammar_dataset": get_grammar_dataset,
     "samsum_dataset": get_samsum_dataset,
     "custom_dataset": get_custom_dataset,
@@ -69,7 +69,7 @@ def get_preprocessed_dataset(
             if split == "train"
             else dataset_config.test_split
         )
-    
+
     return DATASET_PREPROC[dataset_config.dataset](
         dataset_config,
         tokenizer,

+ 36 - 34
src/llama_recipes/utils/train_utils.py

@@ -19,14 +19,14 @@ from transformers import LlamaTokenizer
 
 
 from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint
-from llama_recipes.policies import fpSixteen,bfSixteen_mixed, get_llama_wrapper
+from llama_recipes.policies import fpSixteen,bfSixteen, get_llama_wrapper
 from llama_recipes.utils.memory_utils import MemoryTrace
 from accelerate.utils import is_xpu_available, is_ccl_available
 
 def set_tokenizer_params(tokenizer: LlamaTokenizer):
     tokenizer.pad_token_id = 0
     tokenizer.padding_side = "left"
-    
+
 # Converting Bytes to Megabytes
 def byte2mb(x):
     return int(x / 2**20)
@@ -34,7 +34,7 @@ def byte2mb(x):
 def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None):
     """
     Trains the model on the given dataloader
-    
+
     Args:
         model: The model to be trained
         train_dataloader: The dataloader containing the training data
@@ -46,18 +46,18 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         train_config: The training configuration
         eval_dataloader: The dataloader containing the eval data
         tokenizer: tokenizer used in the eval for decoding the predicitons
-    
+
     Returns: results dictionary containing average training and validation perplexity and loss
     """
     # Create a gradient scaler for fp16
     if train_config.use_fp16 and train_config.enable_fsdp:
         scaler = ShardedGradScaler()
     elif train_config.use_fp16 and not train_config.enable_fsdp:
-        scaler = torch.cuda.amp.GradScaler() 
+        scaler = torch.cuda.amp.GradScaler()
     if train_config.enable_fsdp:
         world_size = int(os.environ["WORLD_SIZE"])
     autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext
-    
+
     train_prep = []
     train_loss = []
     val_prep = []
@@ -81,6 +81,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                         else:
                             batch[key] = batch[key].to(local_rank)
                     else:
+
                         if is_xpu_available():
                             batch[key] = batch[key].to('xpu:0')
                         else:
@@ -107,9 +108,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
 
                 pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
             pbar.close()
-                
+
         epoch_end_time = time.perf_counter()-epoch_start_time
-        epoch_times.append(epoch_end_time)    
+        epoch_times.append(epoch_end_time)
         # Reducing total_loss across all devices if there's more than one CUDA device
         if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp):
             dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
@@ -119,10 +120,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         if train_config.enable_fsdp:
             train_epoch_loss = train_epoch_loss/world_size
         train_perplexity = torch.exp(train_epoch_loss)
-        
+
         train_prep.append(train_perplexity)
         train_loss.append(train_epoch_loss)
-        
+
         if train_config.enable_fsdp:
             if rank==0:
                 if is_xpu_available():
@@ -148,10 +149,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                 print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
                 print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
             print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
-        
+
         # Update the learning rate as needed
         lr_scheduler.step()
-          
+
         if train_config.run_validation:
             eval_ppl, eval_epoch_loss = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer)
             checkpoint_start_time = time.perf_counter()
@@ -164,23 +165,23 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                             print(f"we are about to save the PEFT modules")
                     else:
                         print(f"we are about to save the PEFT modules")
-                    model.save_pretrained(train_config.output_dir)  
+                    model.save_pretrained(train_config.output_dir)
                     if train_config.enable_fsdp:
-                        if rank==0: 
+                        if rank==0:
                             print(f"PEFT modules are saved in {train_config.output_dir} directory")
                     else:
                         print(f"PEFT modules are saved in {train_config.output_dir} directory")
-                        
+
                 else:
                     if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
-                        
+
                         save_model_checkpoint(
                             model, optimizer, rank, train_config, epoch=epoch
                         )
                     elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
                         print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
                         print("=====================================================")
-                        
+
                         save_model_and_optimizer_sharded(model, rank, train_config)
                         if train_config.save_optimizer:
                             save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
@@ -192,7 +193,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                             model, optimizer, rank, train_config, epoch=epoch
                         )
                         print(" Saving the FSDP model checkpoints and optimizer using FULL_STATE_DICT")
-                        print("=====================================================")                     
+                        print("=====================================================")
                 if train_config.enable_fsdp:
                     dist.barrier()
             checkpoint_end_time = time.perf_counter() - checkpoint_start_time
@@ -216,8 +217,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
     avg_train_prep = sum(train_prep)/len(train_prep)
     avg_train_loss = sum(train_loss)/len(train_loss)
     if train_config.run_validation:
-        avg_eval_prep = sum(val_prep)/len(val_prep) 
-        avg_eval_loss = sum(val_loss)/len(val_loss) 
+        avg_eval_prep = sum(val_prep)/len(val_prep)
+        avg_eval_loss = sum(val_loss)/len(val_loss)
 
     results['avg_train_prep'] = avg_train_prep
     results['avg_train_loss'] = avg_train_loss
@@ -226,27 +227,27 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         results['avg_eval_loss'] = avg_eval_loss
     results["avg_epoch_time"] = avg_epoch_time
     results["avg_checkpoint_time"] = avg_checkpoint_time
-    
+
     #saving the training params including fsdp setting for reference.
     if train_config.enable_fsdp and not train_config.use_peft:
         save_train_params(train_config, fsdp_config, rank)
-        
+
     return results
 
 def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
     """
     Evaluates the model on the given dataloader
-    
+
     Args:
         model: The model to evaluate
         eval_dataloader: The dataloader containing the evaluation data
         local_rank: The rank of the current node in a distributed setting
         tokenizer: The tokenizer used to decode predictions
-    
+
     Returns: eval_ppl, eval_epoch_loss
     """
     if train_config.enable_fsdp:
-        world_size = int(os.environ["WORLD_SIZE"]) 
+        world_size = int(os.environ["WORLD_SIZE"])
     model.eval()
     eval_preds = []
     eval_loss = 0.0  # Initialize evaluation loss
@@ -271,26 +272,26 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
             eval_preds.extend(
                 tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True)
             )
-    
+
     # If there's more than one CUDA device, reduce evaluation loss across all devices
     if is_xpu_available() and (torch.xpu.device_count() > 1 and train_config.enable_fsdp):
         dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
     if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
         dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
-    
+
     # Compute average loss and perplexity
     eval_epoch_loss = eval_loss / len(eval_dataloader)
     if train_config.enable_fsdp:
         eval_epoch_loss = eval_epoch_loss/world_size
     eval_ppl = torch.exp(eval_epoch_loss)
-    
+
     # Print evaluation metrics
     if train_config.enable_fsdp:
         if local_rank==0:
             print(f" {eval_ppl=} {eval_epoch_loss=}")
     else:
         print(f" {eval_ppl=} {eval_epoch_loss=}")
-        
+
     return eval_ppl, eval_epoch_loss
 
 def freeze_transformer_layers(model, num_layer):
@@ -304,8 +305,8 @@ def check_frozen_layers_peft_model(model):
      for i, layer in enumerate(model.base_model.model.model.layers):
             for name, param in layer.named_parameters():
                 print(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}")
-                
-                
+
+
 def setup():
     """Initialize the process group for distributed training"""
     if is_ccl_available():
@@ -322,7 +323,7 @@ def setup_environ_flags(rank):
     # os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
     # This flag will help with CUDA memory fragmentations that can lead into OOM in some cases.
     # Note this is only availble in PyTorch Nighlies (as of July 30 2023)
-    # os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True' 
+    # os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True'
     if rank == 0:
         print(f"--> Running with torch dist debug set to detail")
 
@@ -370,6 +371,7 @@ def print_model_size(model, config, rank: int = 0) -> None:
 
 def get_policies(cfg, rank):
     """Get the policies for mixed precision and fsdp wrapping"""
+
     
     verify_bfloat_support = ((
     torch.version.cuda
@@ -389,7 +391,7 @@ def get_policies(cfg, rank):
         bf16_ready = verify_bfloat_support
 
         if bf16_ready and not cfg.use_fp16:
-            mixed_precision_policy = bfSixteen_mixed
+            mixed_precision_policy = bfSixteen
             if rank == 0:
                 print(f"bFloat16 enabled for mixed precision - using bfSixteen policy")
         elif cfg.use_fp16:
@@ -407,7 +409,7 @@ def save_train_params(train_config, fsdp_config, rank):
     This will be used by converter script in the inference folder to fetch the HF model name or path.
     It also would be hepful as a log for future references.
     """
-    # Convert the train_config and fsdp_config objects to dictionaries, 
+    # Convert the train_config and fsdp_config objects to dictionaries,
     # converting all values to strings to ensure they can be serialized into a YAML file
     train_config_dict = {k: str(v) for k, v in vars(train_config).items() if not k.startswith('__')}
     fsdp_config_dict = {k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith('__')}

+ 18 - 0
tests/conftest.py

@@ -0,0 +1,18 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import pytest
+
+from transformers import LlamaTokenizer
+
+
+@pytest.fixture
+def setup_tokenizer():
+    def _helper(tokenizer):
+        #Align with Llama 2 tokenizer
+        tokenizer.from_pretrained.return_value = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")
+        tokenizer.from_pretrained.return_value.add_special_tokens({'bos_token': '<s>', 'eos_token': '</s>'})
+        tokenizer.from_pretrained.return_value.bos_token_id = 1
+        tokenizer.from_pretrained.return_value.eos_token_id = 2
+
+    return _helper

+ 40 - 13
tests/datasets/test_custom_dataset.py

@@ -4,21 +4,38 @@
 import pytest
 from unittest.mock import patch
 
+from transformers import LlamaTokenizer
+
+def check_padded_entry(batch):
+    seq_len = sum(batch["attention_mask"][0])
+    assert seq_len < len(batch["attention_mask"][0])
+
+    assert batch["labels"][0][0] == -100
+    assert batch["labels"][0][seq_len-1] == 2
+    assert batch["labels"][0][-1] == -100
+    assert batch["input_ids"][0][0] == 1
+    assert batch["input_ids"][0][-1] == 2
+
 
 @patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
-def test_custom_dataset(step_lr, optimizer, get_model, train, mocker):
+def test_custom_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer):
     from llama_recipes.finetuning import main
 
+    setup_tokenizer(tokenizer)
+
     kwargs = {
         "dataset": "custom_dataset",
         "model_name": "decapoda-research/llama-7b-hf", # We use the tokenizer as a surrogate for llama2 tokenizer here
         "custom_dataset.file": "examples/custom_dataset.py",
         "custom_dataset.train_split": "validation",
         "batch_size_training": 2,
+        "val_batch_size": 4,
         "use_peft": False,
+        "batching_strategy": "padding"
         }
 
     main(**kwargs)
@@ -30,24 +47,34 @@ def test_custom_dataset(step_lr, optimizer, get_model, train, mocker):
     eval_dataloader = args[2]
     tokenizer = args[3]
 
-    assert len(train_dataloader) == 226
-    assert len(eval_dataloader) == 2*226
+    assert len(train_dataloader) == 1120
+    assert len(eval_dataloader) == 1120 //2
+
+    it = iter(eval_dataloader)
+    batch = next(it)
+    STRING = tokenizer.decode(batch["input_ids"][0], skip_special_tokens=True)
+    EXPECTED_STRING = "[INST] Who made Berlin [/INST] dunno"
+    assert STRING.startswith(EXPECTED_STRING)
+
+    assert batch["input_ids"].size(0) == 4
+    assert set(("labels", "input_ids", "attention_mask")) == set(batch.keys())
+
+    check_padded_entry(batch)
 
     it = iter(train_dataloader)
-    STRING = tokenizer.decode(next(it)["input_ids"][0], skip_special_tokens=True)
-    EXPECTED_STRING = "[INST] Напиши функцию на языке swift, которая сортирует массив целых чисел, а затем выводит его на экран [/INST] Вот функция, "
+    for _ in range(5):
+        next(it)
 
+    batch = next(it)
+    STRING = tokenizer.decode(batch["input_ids"][0], skip_special_tokens=True)
+    EXPECTED_STRING = "[INST] How do I initialize a Typescript project using npm and git? [/INST] # Initialize a new NPM project"
     assert STRING.startswith(EXPECTED_STRING)
 
-    next(it)
-    next(it)
-    next(it)
-    STRING = tokenizer.decode(next(it)["input_ids"][0], skip_special_tokens=True)
-    EXPECTED_SUBSTRING_1 = "Therefore you are correct.  [INST] How can L’Hopital’s Rule be"
-    EXPECTED_SUBSTRING_2 = "a circular path around the turn.  [INST] How on earth is that related to L’Hopital’s Rule?"
+    assert batch["input_ids"].size(0) == 2
+    assert set(("labels", "input_ids", "attention_mask")) == set(batch.keys())
+
+    check_padded_entry(batch)
 
-    assert EXPECTED_SUBSTRING_1 in STRING
-    assert EXPECTED_SUBSTRING_2 in STRING
 
 
 @patch('llama_recipes.finetuning.train')

+ 54 - 0
tests/datasets/test_grammar_datasets.py

@@ -0,0 +1,54 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+from unittest.mock import patch
+
+from transformers import LlamaTokenizer
+
+
+@patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaTokenizer')
+@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
+@patch('llama_recipes.finetuning.optim.AdamW')
+@patch('llama_recipes.finetuning.StepLR')
+def test_grammar_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer):
+    from llama_recipes.finetuning import main
+
+    setup_tokenizer(tokenizer)
+
+    BATCH_SIZE = 8
+    kwargs = {
+        "model_name": "decapoda-research/llama-7b-hf",
+        "batch_size_training": BATCH_SIZE,
+        "val_batch_size": 1,
+        "use_peft": False,
+        "dataset": "grammar_dataset",
+        "batching_strategy": "padding",
+        }
+
+    main(**kwargs)
+
+    assert train.call_count == 1
+
+    args, kwargs = train.call_args
+    train_dataloader = args[1]
+    eval_dataloader = args[2]
+
+    VAL_SAMPLES = 2988
+    TRAIN_SAMPLES = 13016
+
+    assert len(train_dataloader) == TRAIN_SAMPLES // BATCH_SIZE
+    assert len(eval_dataloader) == VAL_SAMPLES
+
+    batch = next(iter(train_dataloader))
+
+    assert "labels" in batch.keys()
+    assert "input_ids" in batch.keys()
+    assert "attention_mask" in batch.keys()
+
+    assert batch["labels"][0][29] == -100
+    assert batch["labels"][0][30] == 29871
+
+    assert batch["input_ids"][0][0] == 1
+    assert batch["labels"][0][-1] == 2
+    assert batch["input_ids"][0][-1] == 2

+ 30 - 14
tests/datasets/test_samsum_datasets.py

@@ -1,37 +1,53 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
+from functools import partial
 from unittest.mock import patch
 
 
 @patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
-@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
-def test_custom_dataset(step_lr, optimizer, tokenizer, get_model, train, mocker):
+def test_samsum_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer):
     from llama_recipes.finetuning import main
-        
-    tokenizer.return_value = mocker.MagicMock(side_effect=lambda x: {"input_ids":[len(x)*[0,]], "attention_mask": [len(x)*[0,]]})
-    
-    
+
+    setup_tokenizer(tokenizer)
+
+    BATCH_SIZE = 8
     kwargs = {
-        "batch_size_training": 1,
+        "model_name": "decapoda-research/llama-7b-hf",
+        "batch_size_training": BATCH_SIZE,
+        "val_batch_size": 1,
         "use_peft": False,
         "dataset": "samsum_dataset",
+        "batching_strategy": "padding",
         }
-    
+
     main(**kwargs)
-    
+
     assert train.call_count == 1
-    
+
     args, kwargs = train.call_args
     train_dataloader = args[1]
     eval_dataloader = args[2]
-    
+
     VAL_SAMPLES = 818
     TRAIN_SAMPLES = 14732
-    CONCAT_SIZE = 2048
-    assert len(train_dataloader) == TRAIN_SAMPLES // CONCAT_SIZE
+
+    assert len(train_dataloader) == TRAIN_SAMPLES // BATCH_SIZE
     assert len(eval_dataloader) == VAL_SAMPLES
-    
+
+    batch = next(iter(train_dataloader))
+
+    assert "labels" in batch.keys()
+    assert "input_ids" in batch.keys()
+    assert "attention_mask" in batch.keys()
+
+    assert batch["labels"][0][268] == -100
+    assert batch["labels"][0][269] == 22291
+
+    assert batch["input_ids"][0][0] == 1
+    assert batch["labels"][0][-1] == 2
+    assert batch["input_ids"][0][-1] == 2

+ 94 - 0
tests/test_batching.py

@@ -0,0 +1,94 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import pytest
+from unittest.mock import patch
+
+
+@patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaTokenizer')
+@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
+@patch('llama_recipes.finetuning.optim.AdamW')
+@patch('llama_recipes.finetuning.StepLR')
+def test_packing(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer):
+    from llama_recipes.finetuning import main
+
+    setup_tokenizer(tokenizer)
+
+    kwargs = {
+        "model_name": "decapoda-research/llama-7b-hf",
+        "batch_size_training": 8,
+        "val_batch_size": 1,
+        "use_peft": False,
+        "dataset": "samsum_dataset",
+        "batching_strategy": "packing",
+        }
+
+    main(**kwargs)
+
+    assert train.call_count == 1
+
+    args, kwargs = train.call_args
+    train_dataloader = args[1]
+    eval_dataloader = args[2]
+
+    assert len(train_dataloader) == 96
+    assert len(eval_dataloader) == 42
+
+    batch = next(iter(train_dataloader))
+
+    assert "labels" in batch.keys()
+    assert "input_ids" in batch.keys()
+    assert "attention_mask" in batch.keys()
+
+    assert batch["labels"][0].size(0) == 4096
+    assert batch["input_ids"][0].size(0) == 4096
+    assert batch["attention_mask"][0].size(0) == 4096
+
+
+@patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaTokenizer')
+@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
+@patch('llama_recipes.finetuning.optim.AdamW')
+@patch('llama_recipes.finetuning.StepLR')
+@patch('llama_recipes.finetuning.setup')
+@patch('llama_recipes.finetuning.FSDP')
+@patch('llama_recipes.finetuning.torch.distributed.is_initialized')
+@patch('llama_recipes.utils.config_utils.dist')
+def test_distributed_packing(dist, is_initialized, fsdp, setup, step_lr, optimizer, get_model, tokenizer, train, setup_tokenizer):
+    import os
+    from llama_recipes.finetuning import main
+
+    setup_tokenizer(tokenizer)
+
+    rank = 0
+    os.environ['LOCAL_RANK'] = f'{rank}'
+    os.environ['RANK'] = f'{rank}'
+    os.environ['WORLD_SIZE'] = '2'
+    os.environ['MASTER_ADDR'] = 'localhost'
+    os.environ['MASTER_PORT'] = '12345'
+
+    kwargs = {
+        "model_name": "decapoda-research/llama-7b-hf",
+        "batch_size_training": 8,
+        "val_batch_size": 1,
+        "use_peft": False,
+        "dataset": "samsum_dataset",
+        "batching_strategy": "packing",
+        "enable_fsdp": True
+        }
+
+    is_initialized.return_value = True
+    dist.get_rank.return_value = rank
+    dist.get_world_size.return_value = 2
+
+    main(**kwargs)
+
+    assert train.call_count == 1
+
+    args, kwargs = train.call_args
+    train_dataloader = args[1]
+    eval_dataloader = args[2]
+
+    assert len(train_dataloader) == 96 //2
+    assert len(eval_dataloader) == 42 //2

+ 81 - 34
tests/test_finetuning.py

@@ -1,14 +1,26 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
+import pytest
 from pytest import approx
 from unittest.mock import patch
 
 from torch.nn import Linear
 from torch.optim import AdamW
 from torch.utils.data.dataloader import DataLoader
+from torch.utils.data.sampler import BatchSampler
 
 from llama_recipes.finetuning import main
+from llama_recipes.data.sampler import LengthBasedBatchSampler
+
+
+def get_fake_dataset():
+    return [{
+        "input_ids":[1],
+        "attention_mask":[1],
+        "labels":[1],
+        }]
+
 
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
@@ -18,23 +30,23 @@ from llama_recipes.finetuning import main
 @patch('llama_recipes.finetuning.StepLR')
 def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
     kwargs = {"run_validation": False}
-    
-    get_dataset.return_value = [1]
-    
+
+    get_dataset.return_value = get_fake_dataset()
+
     main(**kwargs)
-    
+
     assert train.call_count == 1
-    
+
     args, kwargs = train.call_args
     train_dataloader = args[1]
     eval_dataloader = args[2]
-    
+
     assert isinstance(train_dataloader, DataLoader)
     assert eval_dataloader is None
-    
+
     assert get_model.return_value.to.call_args.args[0] == "cuda"
-    
-    
+
+
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
@@ -43,21 +55,22 @@ def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, ge
 @patch('llama_recipes.finetuning.StepLR')
 def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
     kwargs = {"run_validation": True}
-    get_dataset.return_value = [1]
-    
+
+    get_dataset.return_value = get_fake_dataset()
+
     main(**kwargs)
-    
+
     assert train.call_count == 1
-    
+
     args, kwargs = train.call_args
     train_dataloader = args[1]
     eval_dataloader = args[2]
     assert isinstance(train_dataloader, DataLoader)
     assert isinstance(eval_dataloader, DataLoader)
-    
+
     assert get_model.return_value.to.call_args.args[0] == "cuda"
-    
-    
+
+
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
@@ -68,15 +81,15 @@ def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer,
 @patch('llama_recipes.finetuning.StepLR')
 def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, get_dataset, tokenizer, get_model, train):
     kwargs = {"use_peft": True}
-    
-    get_dataset.return_value = [1]
-    
+
+    get_dataset.return_value = get_fake_dataset()
+
     main(**kwargs)
-    
+
     assert get_peft_model.return_value.to.call_args.args[0] == "cuda"
     assert get_peft_model.return_value.print_trainable_parameters.call_count == 1
-    
-    
+
+
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
@@ -85,22 +98,56 @@ def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, ge
 @patch('llama_recipes.finetuning.StepLR')
 def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer, get_model, train, mocker):
     kwargs = {"weight_decay": 0.01}
-    
-    get_dataset.return_value = [1]
-    
-    model = mocker.MagicMock(name="model")
-    model.parameters.return_value = Linear(1,1).parameters()
-    get_peft_model.return_value = model 
-    get_peft_model.return_value.print_trainable_parameters=lambda:None
+
+    get_dataset.return_value = get_fake_dataset()
+
+    get_model.return_value = Linear(1,1)
+
     main(**kwargs)
-    
+
     assert train.call_count == 1
-    
+
     args, kwargs = train.call_args
     optimizer = args[4]
-    
+
     print(optimizer.state_dict())
-    
+
     assert isinstance(optimizer, AdamW)
     assert optimizer.state_dict()["param_groups"][0]["weight_decay"] == approx(0.01)
-    
+
+
+@patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
+@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
+@patch('llama_recipes.finetuning.get_preprocessed_dataset')
+@patch('llama_recipes.finetuning.optim.AdamW')
+@patch('llama_recipes.finetuning.StepLR')
+def test_batching_strategy(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
+    kwargs = {"batching_strategy": "packing"}
+
+    get_dataset.return_value = get_fake_dataset()
+
+    main(**kwargs)
+
+    assert train.call_count == 1
+
+    args, kwargs = train.call_args
+    train_dataloader, eval_dataloader = args[1:3]
+    assert isinstance(train_dataloader.batch_sampler, BatchSampler)
+    assert isinstance(eval_dataloader.batch_sampler, BatchSampler)
+
+    kwargs["batching_strategy"] = "padding"
+    train.reset_mock()
+    main(**kwargs)
+
+    assert train.call_count == 1
+
+    args, kwargs = train.call_args
+    train_dataloader, eval_dataloader = args[1:3]
+    assert isinstance(train_dataloader.batch_sampler, LengthBasedBatchSampler)
+    assert isinstance(eval_dataloader.batch_sampler, LengthBasedBatchSampler)
+
+    kwargs["batching_strategy"] = "none"
+
+    with pytest.raises(ValueError):
+        main(**kwargs)

+ 86 - 0
tests/test_sampler.py

@@ -0,0 +1,86 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import random
+import pytest
+
+import torch
+
+from llama_recipes.data.sampler import LengthBasedBatchSampler
+from llama_recipes.data.sampler import DistributedLengthBasedBatchSampler
+
+SAMPLES = 33
+
+@pytest.fixture
+def dataset():
+    random.seed(42)
+    dataset = []
+    def add_samples(ds, n, a, b):
+        for _ in range(n):
+            ds.append(random.randint(a,b) * [1,])
+    add_samples(dataset, SAMPLES // 2, 1,9)
+    add_samples(dataset, (SAMPLES // 2) + (SAMPLES % 2), 10,20)
+    
+    return random.sample(dataset, len(dataset))
+    
+    
+@pytest.mark.parametrize("batch_size, drop_last", [(2, False), (8, False), (2, True), (8, True)])
+def test_batch_sampler_array(dataset, batch_size, drop_last):
+    
+    sampler = LengthBasedBatchSampler(dataset, batch_size, drop_last)
+    
+    EXPECTED_LENGTH = SAMPLES // batch_size if drop_last else (SAMPLES // batch_size) + (SAMPLES % batch_size)
+    
+    all_ids = [i for b in sampler for i in b]
+    assert len(set(all_ids)) == EXPECTED_LENGTH * batch_size if drop_last else len(dataset)
+    
+    assert len(sampler) == EXPECTED_LENGTH
+    is_long = [len(d)>=10 for d in dataset]
+    
+    def check_batch(batch):
+        return all(batch) or not any(batch)
+    
+    assert all(check_batch(is_long[i] for i in b) for b in sampler)
+    
+    
+@pytest.mark.parametrize("batch_size, drop_last", [(2, False), (8, False), (2, True), (8, True)])
+def test_batch_sampler_dict(dataset, batch_size, drop_last):
+    
+    dist_dataset = [{"input_ids": d, "attention_mask": d} for d in dataset]
+    
+    sampler = LengthBasedBatchSampler(dist_dataset, batch_size, drop_last)
+    
+    EXPECTED_LENGTH = SAMPLES // batch_size if drop_last else (SAMPLES // batch_size) + (SAMPLES % batch_size)
+    
+    assert len(sampler) == EXPECTED_LENGTH
+    is_long = [len(d)>=10 for d in dataset]
+    
+    def check_batch(batch):
+        return all(batch) or not any(batch)
+    
+    assert all(check_batch(is_long[i] for i in b) for b in sampler)
+    
+    
+@pytest.mark.parametrize("batch_size", [2, 8])
+def test_dist_batch_sampling(dataset, batch_size):
+    sampler_1 = DistributedLengthBasedBatchSampler(
+        dataset,
+        batch_size=batch_size,
+        rank=0,
+        num_replicas=2,
+        shuffle=False,
+    )
+    sampler_2 = DistributedLengthBasedBatchSampler(
+        dataset,
+        batch_size=batch_size,
+        rank=1,
+        num_replicas=2,
+        shuffle=False,
+    )
+    
+    ids_1 = set(i for b in sampler_1 for i in b)
+    ids_2 = set(i for b in sampler_2 for i in b)
+    
+    assert ids_1.isdisjoint(ids_2)
+    assert len(ids_1)+len(ids_2) > 0
+    assert len(ids_1)+len(ids_2) == len(dataset) // batch_size  *  batch_size