Parcourir la source

Merging latest from main

Beto il y a 1 an
Parent
commit
b974c87035
45 fichiers modifiés avec 4763 ajouts et 326 suppressions
  1. 26 6
      README.md
  2. 433 0
      demo_apps/HelloLlamaCloud.ipynb
  3. 347 0
      demo_apps/HelloLlamaLocal.ipynb
  4. 306 0
      demo_apps/LiveData.ipynb
  5. 120 0
      demo_apps/Llama2_Gradio.ipynb
  6. 101 0
      demo_apps/README.md
  7. 559 0
      demo_apps/StructuredLlama.ipynb
  8. 698 0
      demo_apps/VideoSummary.ipynb
  9. 38 0
      demo_apps/csv2db.py
  10. BIN
      demo_apps/llama2-gradio.png
  11. BIN
      demo_apps/llama2-streamlit.png
  12. BIN
      demo_apps/llama2-streamlit2.png
  13. BIN
      demo_apps/llama2.pdf
  14. 1294 0
      demo_apps/nba.txt
  15. 22 0
      demo_apps/streamlit_llama2.py
  16. 53 0
      demo_apps/txt2csv.py
  17. 15 3
      docs/Dataset.md
  18. 21 7
      docs/FAQ.md
  19. 1 0
      examples/Getting_to_know_Llama.ipynb
  20. 1 1
      examples/README.md
  21. 23 30
      examples/custom_dataset.py
  22. 3 6
      examples/quickstart.ipynb
  23. 29 1
      scripts/spellcheck_conf/wordlist.txt
  24. 0 2
      src/llama_recipes/configs/datasets.py
  25. 2 4
      src/llama_recipes/configs/training.py
  26. 2 0
      src/llama_recipes/data/__init__.py
  27. 34 0
      src/llama_recipes/data/concatenator.py
  28. 57 0
      src/llama_recipes/data/sampler.py
  29. 4 14
      src/llama_recipes/datasets/alpaca_dataset.py
  30. 13 18
      src/llama_recipes/datasets/grammar_dataset/grammar_dataset.py
  31. 3 3
      src/llama_recipes/datasets/grammar_dataset/grammar_dataset_process.ipynb
  32. 19 13
      src/llama_recipes/datasets/samsum_dataset.py
  33. 0 66
      src/llama_recipes/datasets/utils.py
  34. 25 37
      src/llama_recipes/finetuning.py
  35. 49 11
      src/llama_recipes/utils/config_utils.py
  36. 6 6
      src/llama_recipes/utils/dataset_utils.py
  37. 37 33
      src/llama_recipes/utils/train_utils.py
  38. 18 0
      tests/conftest.py
  39. 40 13
      tests/datasets/test_custom_dataset.py
  40. 54 0
      tests/datasets/test_grammar_datasets.py
  41. 30 14
      tests/datasets/test_samsum_datasets.py
  42. 94 0
      tests/test_batching.py
  43. 82 33
      tests/test_finetuning.py
  44. 86 0
      tests/test_sampler.py
  45. 18 5
      tests/test_train_utils.py

+ 26 - 6
README.md

@@ -1,7 +1,13 @@
-# Llama 2 Fine-tuning / Inference Recipes and Examples
+# Llama 2 Fine-tuning / Inference Recipes, Examples and Demo Apps
+
+**[Update Oct. 20, 2023] We have just released a series of Llama 2 demo apps [here](./demo_apps). These apps show how to run Llama 2 locally and in the cloud to chat about data (PDF, DB, or live) and generate video summary.**
+
 
 The 'llama-recipes' repository is a companion to the [Llama 2 model](https://github.com/facebookresearch/llama). The goal of this repository is to provide examples to quickly get started with fine-tuning for domain adaptation and how to run inference for the fine-tuned models. For ease of use, the examples use Hugging Face converted versions of the models. See steps for conversion of the model [here](#model-conversion-to-hugging-face).
 
+In addition, we also provide a number of demo apps, to showcase the Llama2 usage along with other ecosystem solutions to run Llama2 locally on your mac and on cloud.
+
+
 Llama 2 is a new technology that carries potential risks with use. Testing conducted to date has not — and could not — cover all scenarios. In order to help developers address these risks, we have created the [Responsible Use Guide](https://github.com/facebookresearch/llama/blob/main/Responsible-Use-Guide.pdf). More details can be found in our research paper as well. For downloading the models, follow the instructions on [Llama 2 repo](https://github.com/facebookresearch/llama).
 
 
@@ -13,8 +19,9 @@ Llama 2 is a new technology that carries potential risks with use. Testing condu
     - [Multi GPU One Node](#multiple-gpus-one-node)
     - [Multi GPU Multi Node](#multi-gpu-multi-node)
 4. [Inference](./docs/inference.md)
-5. [Repository Organization](#repository-organization)
-6. [License and Acceptable Use Policy](#license)
+5. [Demo Apps](#demo-apps)
+6. [Repository Organization](#repository-organization)
+7. [License and Acceptable Use Policy](#license)
 
 
 
@@ -130,7 +137,7 @@ Here we make use of Parameter Efficient Methods (PEFT) as described in the next
 
 ```bash
 
-torchrun --nnodes 1 --nproc_per_node 4  examples/finetuning.py --enable_fsdp --use_peft --peft_method lora --model_name /patht_of_model_folder/7B --pure_bf16 --output_dir Path/to/save/PEFT/model
+torchrun --nnodes 1 --nproc_per_node 4  examples/finetuning.py --enable_fsdp --use_peft --peft_method lora --model_name /patht_of_model_folder/7B --fsdp_config.pure_bf16 --output_dir Path/to/save/PEFT/model
 
 ```
 
@@ -141,7 +148,7 @@ Here we use FSDP as discussed in the next section which can be used along with P
 Setting `use_fast_kernels` will enable using of Flash Attention or Xformer memory-efficient kernels based on the hardware being used. This would speed up the fine-tuning job. This has been enabled in `optimum` library from HuggingFace as a one-liner API, please read more [here](https://pytorch.org/blog/out-of-the-box-acceleration/).
 
 ```bash
-torchrun --nnodes 1 --nproc_per_node 4  examples/finetuning.py --enable_fsdp --use_peft --peft_method lora --model_name /patht_of_model_folder/7B --pure_bf16 --output_dir Path/to/save/PEFT/model --use_fast_kernels
+torchrun --nnodes 1 --nproc_per_node 4  examples/finetuning.py --enable_fsdp --use_peft --peft_method lora --model_name /patht_of_model_folder/7B --fsdp_config.pure_bf16 --output_dir Path/to/save/PEFT/model --use_fast_kernels
 ```
 
 ### Fine-tuning using FSDP Only
@@ -160,7 +167,7 @@ If you are interested in running full parameter fine-tuning on the 70B model, yo
 
 ```bash
 
-torchrun --nnodes 1 --nproc_per_node 8 examples/finetuning.py --enable_fsdp --low_cpu_fsdp --pure_bf16 --model_name /patht_of_model_folder/70B --batch_size_training 1 --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned
+torchrun --nnodes 1 --nproc_per_node 8 examples/finetuning.py --enable_fsdp --low_cpu_fsdp --fsdp_config.pure_bf16 --model_name /patht_of_model_folder/70B --batch_size_training 1 --dist_checkpoint_root_folder model_checkpoints --dist_checkpoint_folder fine-tuned
 
 ```
 
@@ -174,6 +181,17 @@ sbatch multi_node.slurm
 ```
 You can read more about our fine-tuning strategies [here](./docs/LLM_finetuning.md).
 
+# Demo Apps
+This folder contains a series of Llama2-powered apps:
+* Quickstart Llama deployments and basic interactions with Llama
+  1. Llama on your Mac and ask Llama general questions
+  2. Llama on Google Colab
+  3. Llama on Cloud and ask Llama questions about unstructured data in a PDF
+
+* Specialized Llama use cases:
+  1. Ask Llama to summarize a video content
+  2. Ask Llama questions about structured data in a DB
+  3. Ask Llama questions about live data on the web
 
 # Repository Organization
 This repository is organized in the following way:
@@ -184,6 +202,8 @@ This repository is organized in the following way:
 
 [datasets](src/llama_recipes/datasets/): Contains individual scripts for each dataset to download and process. Note: Use of any of the datasets should be in compliance with the dataset's underlying licenses (including but not limited to non-commercial uses)
 
+[demo_apps](./demo_apps) contains a series of Llama2-powered apps, from quickstart deployments to how to ask Llama questions about unstructured data, structured data, live data, and video summary.
+
 [examples](./examples/): Contains examples script for finetuning and inference of the Llama 2 model as well as how to use them safely.
 
 [inference](src/llama_recipes/inference/): Includes modules for inference for the fine-tuned models.

Fichier diff supprimé car celui-ci est trop grand
+ 433 - 0
demo_apps/HelloLlamaCloud.ipynb


Fichier diff supprimé car celui-ci est trop grand
+ 347 - 0
demo_apps/HelloLlamaLocal.ipynb


+ 306 - 0
demo_apps/LiveData.ipynb

@@ -0,0 +1,306 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "id": "30eb1704-8d76-4bc9-9308-93243aeb69cb",
+   "metadata": {},
+   "source": [
+    "## This demo app shows:\n",
+    "* How to use LlamaIndex, an open source library to help you build custom data augmented LLM applications\n",
+    "* How to ask Llama questions about recent live data via the You.com live search API and LlamaIndex\n",
+    "\n",
+    "The LangChain package is used to facilitate the call to Llama2 hosted on Replicate\n",
+    "\n",
+    "**Note** We will be using Replicate to run the examples here. You will need to first sign in with Replicate with your github account, then create a free API token [here](https://replicate.com/account/api-tokens) that you can use for a while. \n",
+    "After the free trial ends, you will need to enter billing info to continue to use Llama2 hosted on Replicate."
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "68cf076e",
+   "metadata": {},
+   "source": [
+    "We start by installing the necessary packages:\n",
+    "- [langchain](https://python.langchain.com/docs/get_started/introduction) which provides RAG capabilities\n",
+    "- [llama-index](https://docs.llamaindex.ai/en/stable/) for data augmentation."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "1d0005d6-e928-4d1a-981b-534a40e19e56",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "!pip install llama-index langchain"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "21fe3849",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# use ServiceContext to configure the LLM used and the custom embeddings \n",
+    "from llama_index import ServiceContext\n",
+    "\n",
+    "# VectorStoreIndex is used to index custom data \n",
+    "from llama_index import VectorStoreIndex\n",
+    "\n",
+    "from langchain.llms import Replicate"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "73e8e661",
+   "metadata": {},
+   "source": [
+    "Next we set up the Replicate token."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "d9d76e33",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from getpass import getpass\n",
+    "import os\n",
+    "\n",
+    "REPLICATE_API_TOKEN = getpass()\n",
+    "os.environ[\"REPLICATE_API_TOKEN\"] = REPLICATE_API_TOKEN"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "f8ff812b",
+   "metadata": {},
+   "source": [
+    "In this example we will use the [YOU.com](https://you.com/) search engine to augment the LLM's responses.\n",
+    "To use the You.com Search API, you can email api@you.com to request an API key. "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "75275628-5235-4b55-8033-601c76107528",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "\n",
+    "YOUCOM_API_KEY = getpass()\n",
+    "os.environ[\"YOUCOM_API_KEY\"] = YOUCOM_API_KEY"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "cb210c7c",
+   "metadata": {},
+   "source": [
+    "We then call the Llama 2 model from replicate. \n",
+    "\n",
+    "We will use the llama 2 13b chat model. You can find more Llama 2 models by searching for them on the [Replicate model explore page](https://replicate.com/explore?query=llama).\n",
+    "You can add them here in the format: model_name/version"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "c12fc2cb",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# set llm to be using Llama2 hosted on Replicate\n",
+    "llama2_13b_chat = \"meta/llama-2-13b-chat:f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d\"\n",
+    "\n",
+    "llm = Replicate(\n",
+    "    model=llama2_13b_chat,\n",
+    "    model_kwargs={\"temperature\": 0.01, \"top_p\": 1, \"max_new_tokens\":500}\n",
+    ")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "476d72da",
+   "metadata": {},
+   "source": [
+    "Using our api key we set up earlier, we make a request from YOU.com for live data on a particular topic."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "effc9656-b18d-4d24-a80b-6066564a838b",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "\n",
+    "import requests\n",
+    "\n",
+    "query = \"Meta Connect\" # you can try other live data query about sports score, stock market and weather info \n",
+    "headers = {\"X-API-Key\": os.environ[\"YOUCOM_API_KEY\"]}\n",
+    "data = requests.get(\n",
+    "    f\"https://api.ydc-index.io/search?query={query}\",\n",
+    "    headers=headers,\n",
+    ").json()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "8bed3baf-742e-473c-ada1-4459012a8a2c",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# check the query result in JSON\n",
+    "import json\n",
+    "\n",
+    "print(json.dumps(data, indent=2))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "b196e697",
+   "metadata": {},
+   "source": [
+    "We then use the [`JSONLoader`](https://llamahub.ai/l/file-json) to extract the text from the returned data. The `JSONLoader` gives us the ability to load the data into LamaIndex.\n",
+    "In the next cell we show how to load the JSON result with key info stored as \"snippets\".\n",
+    "\n",
+    "However, you can also add the snippets in the query result to documents like below:\n",
+    "```python \n",
+    "from llama_index import Document\n",
+    "snippets = [snippet for hit in data[\"hits\"] for snippet in hit[\"snippets\"]]\n",
+    "documents = [Document(text=s) for s in snippets]\n",
+    "```\n",
+    "This can be handy if you just need to add a list of text strings to doc"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "7c40e73f-ca13-4f4a-a753-e613df3d389e",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# one way to load the JSON result with key info stored as \"snippets\"\n",
+    "from llama_index import download_loader\n",
+    "\n",
+    "JsonDataReader = download_loader(\"JsonDataReader\")\n",
+    "loader = JsonDataReader()\n",
+    "documents = loader.load_data([hit[\"snippets\"] for hit in data[\"hits\"]])\n"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "8e5e3b4e",
+   "metadata": {},
+   "source": [
+    "With the data set up, we create a vector store for the data and a query engine for it.\n",
+    "\n",
+    "For our embeddings we will use `HuggingFaceEmbeddings` whose default embedding model is sentence-transformers/all-mpnet-base-v2. This model provides a good balance between speed and performance.\n",
+    "To change the default model, call `HuggingFaceEmbeddings(model_name=<another_embedding_model>)`. \n",
+    "\n",
+    "For more info see https://huggingface.co/blog/mteb. "
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "a5de3080-2c4b-479c-baba-793b3bee36ed",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# use HuggingFace embeddings \n",
+    "from langchain.embeddings.huggingface import HuggingFaceEmbeddings\n",
+    "from llama_index import LangchainEmbedding\n",
+    "\n",
+    "\n",
+    "embeddings = LangchainEmbedding(HuggingFaceEmbeddings())\n",
+    "print(embeddings)\n",
+    "\n",
+    "# create a ServiceContext instance to use Llama2 and custom embeddings\n",
+    "service_context = ServiceContext.from_defaults(llm=llm, chunk_size=800, chunk_overlap=20, embed_model=embeddings)\n",
+    "\n",
+    "# create vector store index from the documents created above\n",
+    "index = VectorStoreIndex.from_documents(documents, service_context=service_context)\n",
+    "\n",
+    "# create query engine from the index\n",
+    "query_engine = index.as_query_engine(streaming=True)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "id": "2c4ea012",
+   "metadata": {},
+   "source": [
+    "We are now ready to ask Llama 2 a question about the live data using our query engine."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "de91a191-d0f2-498e-88dc-b2b43423e0e5",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# ask Llama2 a summary question about the search result\n",
+    "response = query_engine.query(\"give me a summary\")\n",
+    "response.print_response_stream()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "72814b20-06aa-4da8-b4dd-f0b0d74a2ea0",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# more questions\n",
+    "query_engine.query(\"what products were announced\").print_response_stream()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "a65bc037-a689-476d-b529-0059a27bc949",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "query_engine.query(\"tell me more about Meta AI assistant\").print_response_stream()"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "id": "16a56542",
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "query_engine.query(\"what are Generative AI stickers\").print_response_stream()"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3 (ipykernel)",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.8.18"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}

+ 120 - 0
demo_apps/Llama2_Gradio.ipynb

@@ -0,0 +1,120 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "id": "47a9adb3",
+   "metadata": {},
+   "source": [
+    "## This demo app shows how to query Llama 2 using the Gradio UI.\n",
+    "\n",
+    "Since we are using Replicate in this example, you will need to replace `<your replicate api token>` with your API token.\n",
+    "\n",
+    "To get the Replicate token: \n",
+    "\n",
+    "- You will need to first sign in with Replicate with your github account\n",
+    "- Then create a free API token [here](https://replicate.com/account/api-tokens) that you can use for a while \n",
+    "\n",
+    "**Note** After the free trial ends, you will need to enter billing info to continue to use Llama2 hosted on Replicate.\n",
+    "\n",
+    "To run this example:\n",
+    "- Set up your Replicate API token and enter it in place of `<your replicate api token>`\n",
+    "- Run the notebook\n",
+    "- Enter your question and click Submit\n",
+    "\n",
+    "In the notebook or a browser with URL http://127.0.0.1:7860 you should see a UI with your answer."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 1,
+   "id": "928041cc",
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "Init param `input` is deprecated, please use `model_kwargs` instead.\n"
+     ]
+    },
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Running on local URL:  http://127.0.0.1:7860\n",
+      "\n",
+      "To create a public link, set `share=True` in `launch()`.\n"
+     ]
+    },
+    {
+     "data": {
+      "text/html": [
+       "<div><iframe src=\"http://127.0.0.1:7860/\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
+      ],
+      "text/plain": [
+       "<IPython.core.display.HTML object>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/plain": []
+     },
+     "execution_count": 1,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "from langchain.schema import AIMessage, HumanMessage\n",
+    "import gradio as gr\n",
+    "from langchain.llms import Replicate\n",
+    "import os\n",
+    "\n",
+    "os.environ[\"REPLICATE_API_TOKEN\"] = \"<your replicate api token>\"\n",
+    "\n",
+    "llama2_13b_chat = \"meta/llama-2-13b-chat:f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d\"\n",
+    "\n",
+    "llm = Replicate(\n",
+    "    model=llama2_13b_chat,\n",
+    "    model_kwargs={\"temperature\": 0.01, \"top_p\": 1, \"max_new_tokens\":500}\n",
+    ")\n",
+    "\n",
+    "\n",
+    "def predict(message, history):\n",
+    "    history_langchain_format = []\n",
+    "    for human, ai in history:\n",
+    "        history_langchain_format.append(HumanMessage(content=human))\n",
+    "        history_langchain_format.append(AIMessage(content=ai))\n",
+    "    history_langchain_format.append(HumanMessage(content=message))\n",
+    "    gpt_response = llm(message) #history_langchain_format)\n",
+    "    return gpt_response#.content\n",
+    "\n",
+    "gr.ChatInterface(predict).launch()"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3 (ipykernel)",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.8.18"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}

Fichier diff supprimé car celui-ci est trop grand
+ 101 - 0
demo_apps/README.md


Fichier diff supprimé car celui-ci est trop grand
+ 559 - 0
demo_apps/StructuredLlama.ipynb


Fichier diff supprimé car celui-ci est trop grand
+ 698 - 0
demo_apps/VideoSummary.ipynb


+ 38 - 0
demo_apps/csv2db.py

@@ -0,0 +1,38 @@
+import sqlite3
+import csv
+
+# Define the input CSV file and the SQLite database file
+input_csv = 'nba_roster.csv'
+database_file = 'nba_roster.db'
+
+# Connect to the SQLite database
+conn = sqlite3.connect(database_file)
+cursor = conn.cursor()
+
+# Create a table to store the data
+cursor.execute('''CREATE TABLE IF NOT EXISTS nba_roster (
+                    Team TEXT,
+                    NAME TEXT,
+                    Jersey TEXT,
+                    POS TEXT,
+                    AGE INT,
+                    HT TEXT,
+                    WT TEXT,
+                    COLLEGE TEXT,
+                    SALARY TEXT
+                )''')
+
+# Read data from the CSV file and insert it into the SQLite table
+with open(input_csv, 'r', newline='') as csvfile:
+    csv_reader = csv.reader(csvfile)
+    next(csv_reader)  # Skip the header row
+    
+    for row in csv_reader:
+        cursor.execute('INSERT INTO nba_roster VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)', row)
+
+# Commit the changes and close the database connection
+conn.commit()
+conn.close()
+
+print(f'Data from {input_csv} has been successfully imported into {database_file}')
+

BIN
demo_apps/llama2-gradio.png


BIN
demo_apps/llama2-streamlit.png


BIN
demo_apps/llama2-streamlit2.png


BIN
demo_apps/llama2.pdf


Fichier diff supprimé car celui-ci est trop grand
+ 1294 - 0
demo_apps/nba.txt


+ 22 - 0
demo_apps/streamlit_llama2.py

@@ -0,0 +1,22 @@
+import streamlit as st
+from langchain.llms import Replicate
+import os
+
+st.title("Llama2-powered Streamlit App")
+
+with st.sidebar:
+    os.environ["REPLICATE_API_TOKEN"] = "<your replicate api token>"
+
+def generate_response(input_text):
+    llama2_13b_chat = "meta/llama-2-13b-chat:f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d"
+
+    llm = Replicate(
+        model=llama2_13b_chat,
+        model_kwargs={"temperature": 0.01, "top_p": 1, "max_new_tokens":500}
+    )
+    st.info(llm(input_text))
+
+with st.form("my_form"):
+    text = st.text_area("Enter text:", "What is Generative AI?")
+    submitted = st.form_submit_button("Submit")
+    generate_response(text)

+ 53 - 0
demo_apps/txt2csv.py

@@ -0,0 +1,53 @@
+import csv
+
+# Define the input and output file names
+input_file = 'nba.txt'
+output_file = 'nba_roster.csv'
+
+# Initialize lists to store data
+roster_data = []
+current_team = None
+
+# Open the input file
+with open(input_file, 'r') as file:
+    for line in file:
+        # Remove leading and trailing whitespaces from the line
+        line = line.strip()
+        
+        # Check if the line starts with 'https', skip it
+        if line.startswith('https'):
+            continue
+        
+        # Check if the line contains the team name
+        if 'Roster' in line:
+            current_team = line.split(' Roster ')[0]
+        elif line and "NAME" not in line:  # Skip empty lines and header lines
+            # Split the line using tabs as the delimiter
+            player_info = line.split('\t')
+            
+            # Remove any numbers from the player's name and set Jersey accordingly
+            name = ''.join([c for c in player_info[0] if not c.isdigit()])
+            jersey = ''.join([c for c in player_info[0] if c.isdigit()])
+            
+            # If no number found, set Jersey to "NA"
+            if not jersey:
+                jersey = "NA"
+            
+            # Append the team name, name, and jersey to the player's data
+            player_info = [current_team, name, jersey] + player_info[1:]
+            
+            # Append the player's data to the roster_data list
+            roster_data.append(player_info)
+
+# Write the data to a CSV file
+with open(output_file, 'w', newline='') as csvfile:
+    writer = csv.writer(csvfile)
+    
+    # Write the header row
+    writer.writerow(['Team', 'NAME', 'Jersey', 'POS', 'AGE', 'HT', 'WT', 'COLLEGE', 'SALARY'])
+    
+    # Write the player data
+    writer.writerows(roster_data)
+
+print(f'Conversion completed. Data saved to {output_file}')
+

+ 15 - 3
docs/Dataset.md

@@ -7,6 +7,18 @@ The provided fine tuning script allows you to select between three datasets by p
 * [samsum_dataset](https://huggingface.co/datasets/samsum) contains about 16k messenger-like conversations with summaries.
 * [OpenAssistant/oasst1](https://huggingface.co/datasets/OpenAssistant/oasst1/) contains about 88k messages from assistant-style conversations.
 
+## Batching Strategies
+Llama-recipes support two strategies to batch requests together.
+The default setting is `packing` which concatenates the tokenized samples into long sequences filling up the context length of the model.
+This is the most compute efficient variant as it avoids any padding and all sequences have the same length.
+Samples at the boundary of the context length are truncated and the remainder of the cut sequence it used as the start of the next long sequence.
+
+If the amount of training data is small this procedure might introduce a lot of noise into the training data which can hurt the prediction performance of the fine-tune model.
+Therefore, we also support a `padding` strategy which does not introduce the addition noise due to truncated sequences.
+The strategy tries to minimize the efficiency loss by batching samples of similar length together so only minimal padding is necessary.
+
+The batching strategy can be selected though the command line parameter `--batching_strategy [packing]/[padding]`.
+
 ## Using custom datasets
 
 The list of available datasets in llama-recipes is supposed to give users a quick start on training their Llama model.
@@ -23,9 +35,9 @@ def get_custom_dataset(dataset_config, tokenizer, split: str):
 For an example `get_custom_dataset` you can look at the provided datasets in llama_recipes.datasets or [examples/custom_dataset.py](../examples/custom_dataset.py).
 The `dataset_config` in the above signature will be an instance of llama_recipes.configs.dataset.custom_dataset with the modifications made through the command line.
 The split signals wether to return the training or validation dataset.
-The default function name is `get_custom_dataset` but this can be changes as described below.
+The default function name is `get_custom_dataset` but this can be changed as described below.
 
-In order to start a training with the custom dataset we need to set the `--dataset` as well as the `--custom_dataset.file` parameter. 
+In order to start a training with the custom dataset we need to set the `--dataset` as well as the `--custom_dataset.file` parameter.
 ```
 python -m llama_recipes.finetuning --dataset "custom_dataset" --custom_dataset.file "examples/custom_dataset.py" [TRAINING PARAMETERS]
 ```
@@ -35,7 +47,7 @@ python -m llama_recipes.finetuning --dataset "custom_dataset" --custom_dataset.f
 ```
 This will call the function `get_foo` instead of `get_custom_dataset` when retrieving the dataset.
 
-### Adding new dataset 
+### Adding new dataset
 Each dataset has a corresponding configuration (dataclass) in [configs/datasets.py](../src/llama_recipes/configs/datasets.py) which contains the dataset name, training/validation split names, as well as optional parameters like datafiles etc.
 
 Additionally, there is a preprocessing function for each dataset in the [datasets](../src/llama_recipes/datasets) folder.

Fichier diff supprimé car celui-ci est trop grand
+ 21 - 7
docs/FAQ.md


Fichier diff supprimé car celui-ci est trop grand
+ 1 - 0
examples/Getting_to_know_Llama.ipynb


+ 1 - 1
examples/README.md

@@ -10,7 +10,7 @@ After installing the llama-recipes package through [pip](../README.md#installati
 ```
 python -m llama_recipes.finetuning <parameters>
 
-python examnples/finetuning.py <parameters>
+python examples/finetuning.py <parameters>
 ```
 Please see [README.md](../README.md) for details.
 

+ 23 - 30
examples/custom_dataset.py

@@ -7,33 +7,27 @@ import copy
 import datasets
 import itertools
 
-from llama_recipes.datasets.utils import Concatenator
-
 
 B_INST, E_INST = "[INST]", "[/INST]"
-B_SYS, E_SYS = "<<SYS>>\n", "\n<</SYS>>\n\n"
 
 def tokenize_dialog(dialog, tokenizer):
-    dialog_tokens = [
-            tokenizer(
-                f"{B_INST} {(prompt['content']).strip()} {E_INST} {(answer['content']).strip()} ",
-            )
-            for prompt, answer in zip(dialog[::2], dialog[1::2])
-        ]
-    if len(dialog) % 2:    
-        dialog_tokens += [tokenizer(
-            f"{B_INST} {(dialog[-1]['content']).strip()} {E_INST}",
-        )]
-    
-    combined_tokens = {}  
-    for k in dialog_tokens[0].keys():
-        combined_tokens[k] = list(itertools.chain(*(t[k] for t in dialog_tokens)))
-    return combined_tokens
+    prompt_tokens = [tokenizer.encode(f"{tokenizer.bos_token}{B_INST} {(prompt['content']).strip()} {E_INST}", add_special_tokens=False) for prompt in dialog[::2]]
+    answer_tokens = [tokenizer.encode(f"{answer['content'].strip()} {tokenizer.eos_token}", add_special_tokens=False) for answer in dialog[1::2]]
+    dialog_tokens = list(itertools.chain.from_iterable(zip(prompt_tokens, answer_tokens)))
+    #Add labels, convert prompt token to -100 in order to ignore in loss function
+    labels_tokens = [len(c)*[-100,] if i % 2 == 0 else c for i,c in enumerate(dialog_tokens)]
+
+    combined_tokens = {
+        "input_ids": list(itertools.chain(*(t for t in dialog_tokens))),
+        "labels": list(itertools.chain(*(t for t in labels_tokens))),
+    }
+
+    return dict(combined_tokens, attention_mask=[1]*len(combined_tokens["input_ids"]))
 
 
 def get_custom_dataset(dataset_config, tokenizer, split):
     dataset = datasets.load_dataset("OpenAssistant/oasst1", split=split)
-    
+
     dataset = dataset.map(lambda sample: {
         "message_id": sample["message_id"],
         "parent_id": sample["parent_id"],
@@ -41,19 +35,19 @@ def get_custom_dataset(dataset_config, tokenizer, split):
         },
         batched=True,
         remove_columns=list(dataset.features),)
-    
+
     nodes = {}
-    
+
     messages = {}
     root_ids = []
-    
+
     for data in dataset:
         if data["parent_id"]:
             nodes[data["parent_id"]] = nodes.get(data["parent_id"], []) + [data["message_id"]]
         else:
             root_ids.append(data["message_id"])
         messages[data["message_id"]]=data["text"]
-           
+
     def follow(thread, current_id):
         thread = copy.copy(thread) + [messages[current_id]]
         if current_id in nodes:
@@ -63,18 +57,18 @@ def get_custom_dataset(dataset_config, tokenizer, split):
             return new_threads
         else:
             return [thread]
-        
+
     def get_threads_from_root(root_id):
         all_threads = []
         thread = [messages[root_id]]
         for cid in nodes[root_id]:
             all_threads += follow(thread, cid)
         return all_threads
-            
+
     dataset = dataset.filter(lambda x: x["message_id"] in root_ids)
     dataset = dataset.map(lambda x: {"thread": get_threads_from_root(x["message_id"])}, remove_columns=list(dataset.features))
     dataset = dataset.map(lambda x: {"thread": [i for row in x["thread"] for i in row]}, batched=True)
-    
+
     def to_dialog(thread):
         dialog = []
         for i, content in enumerate(thread):
@@ -83,9 +77,8 @@ def get_custom_dataset(dataset_config, tokenizer, split):
                 "content": content,
             })
         return {"dialog": dialog}
-            
+
     dataset = dataset.map(lambda x: to_dialog(x["thread"]), remove_columns=list(dataset.features))
     dataset = dataset.map(lambda x: tokenize_dialog(x["dialog"], tokenizer), remove_columns=list(dataset.features))
-    dataset = dataset.map(Concatenator(), batched=True)
-    
-    return dataset
+
+    return dataset

+ 3 - 6
examples/quickstart.ipynb

@@ -32,7 +32,7 @@
    "outputs": [],
    "source": [
     "# %%bash\n",
-    "# pip install transformers datasets accelerate sentencepiece protobuf==3.20 py7zr scipy peft bitsandbytes fire torch_tb_profiler ipywidgets\n",
+    "# pip install llama-recipes transformers datasets accelerate sentencepiece protobuf==3.20 py7zr scipy peft bitsandbytes fire torch_tb_profiler ipywidgets\n",
     "# TRANSFORM=`python -c \"import transformers;print('/'.join(transformers.__file__.split('/')[:-1])+'/models/llama/convert_llama_weights_to_hf.py')\"`\n",
     "# python ${TRANSFORM} --input_dir models --model_size 7B --output_dir models_hf/7B"
    ]
@@ -130,11 +130,8 @@
     }
    ],
    "source": [
-    "from pathlib import Path\n",
-    "import os\n",
-    "import sys\n",
-    "from utils.dataset_utils import get_preprocessed_dataset\n",
-    "from configs.datasets import samsum_dataset\n",
+    "from llama_recipes.utils.dataset_utils import get_preprocessed_dataset\n",
+    "from llama_recipes.configs.datasets import samsum_dataset\n",
     "\n",
     "train_dataset = get_preprocessed_dataset(tokenizer, samsum_dataset, 'train')"
    ]

+ 29 - 1
scripts/spellcheck_conf/wordlist.txt

@@ -1156,4 +1156,32 @@ Autocast
 FN
 GBs
 MLP
-learnable
+learnable
+tokenized
+Colab
+GenAI
+Gradio
+HelloLlama
+HelloLlamaCloud
+HelloLlamaLocal
+LLM's
+LangChain
+LangChain's
+LiveData
+LlamaIndex
+MBP
+MLC
+Replicate's
+StructuredLlama
+VideoSummary
+cpp
+envinronment
+ggml
+gguf
+gradio
+minnutes
+pdf
+quantized
+serarch
+streamlit
+

+ 0 - 2
src/llama_recipes/configs/datasets.py

@@ -9,7 +9,6 @@ class samsum_dataset:
     dataset: str =  "samsum_dataset"
     train_split: str = "train"
     test_split: str = "validation"
-    input_length: int = 2048
     
     
 @dataclass
@@ -17,7 +16,6 @@ class grammar_dataset:
     dataset: str = "grammar_dataset"
     train_split: str = "src/llama_recipes/datasets/grammar_dataset/gtrain_10k.csv" 
     test_split: str = "src/llama_recipes/datasets/grammar_dataset/grammar_validation.csv"
-    input_length: int = 2048
 
     
 @dataclass

+ 2 - 4
src/llama_recipes/configs/training.py

@@ -11,6 +11,8 @@ class train_config:
     low_cpu_fsdp: bool=False
     run_validation: bool=True
     batch_size_training: int=4
+    batching_strategy: str="packing" #alternative: padding
+    context_length: int=4096
     gradient_accumulation_steps: int=1
     num_epochs: int=3
     num_workers_dataloader: int=1
@@ -34,7 +36,3 @@ class train_config:
     dist_checkpoint_folder: str="fine-tuned" # will be used if using FSDP
     save_optimizer: bool=False # will be used if using FSDP
     use_fast_kernels: bool = False # Enable using SDPA from PyTroch Accelerated Transformers, make use Flash Attention and Xformer memory-efficient kernels
-
-    
-    
-    

+ 2 - 0
src/llama_recipes/data/__init__.py

@@ -0,0 +1,2 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.

+ 34 - 0
src/llama_recipes/data/concatenator.py

@@ -0,0 +1,34 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+from tqdm import tqdm
+from itertools import chain
+
+from torch.utils.data import Dataset
+
+
+class ConcatDataset(Dataset):
+    def __init__(self, dataset, chunk_size=4096):
+        self.dataset = dataset
+        self.chunk_size = chunk_size
+
+        self.samples = []
+
+        buffer = {
+            "input_ids": [],
+            "attention_mask": [],
+            "labels": [],
+            }
+
+        for sample in tqdm(self.dataset, desc="Preprocessing dataset", dynamic_ncols=True):
+            buffer = {k: v + sample[k] for k,v in buffer.items()}
+
+            while len(next(iter(buffer.values()))) > self.chunk_size:
+                self.samples.append({k: v[:self.chunk_size] for k,v in buffer.items()})
+                buffer = {k: v[self.chunk_size:] for k,v in buffer.items()}
+
+    def __getitem__(self, idx):
+        return self.samples[idx]
+
+    def __len__(self):
+        return len(self.samples)

+ 57 - 0
src/llama_recipes/data/sampler.py

@@ -0,0 +1,57 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import random
+from itertools import islice
+
+import numpy as np
+import torch
+
+
+class LengthBasedBatchSampler(torch.utils.data.BatchSampler):
+    def __init__(self, data_source, batch_size: int, drop_last: bool, shuffle: bool=True) -> None:
+        if isinstance(next(iter(data_source)), dict):
+            first_key = next(iter(next(iter(data_source)).keys()))
+            self.lengths = [len(d[first_key]) for d in data_source]
+        else:
+            self.lengths = [len(d) for d in data_source]
+        self.batch_size = batch_size
+        self.drop_last = drop_last
+        self.shuffle = shuffle
+
+    def __iter__(self):
+        ids = np.argsort(self.lengths)
+        if self.drop_last:
+            ids = ids[:len(ids) // self.batch_size * self.batch_size]
+
+        batches = [ids[i:i+self.batch_size] for i in range(0, len(ids), self.batch_size)]
+
+        if self.shuffle:
+            random.shuffle(batches)
+
+        for b in batches:
+            yield b
+
+    def __len__(self):
+        if self.drop_last:
+            return len(self.lengths) // self.batch_size
+        else:
+            return len(self.lengths) // self.batch_size + (len(self.lengths) % self.batch_size > 0)
+
+
+class DistributedLengthBasedBatchSampler(torch.utils.data.BatchSampler):
+    def __init__(self, data_source, batch_size: int, num_replicas: int, rank: int, shuffle: bool = True, seed: int = 0) -> None:
+        random.seed(seed)
+        self.batch_sampler = LengthBasedBatchSampler(
+            data_source, batch_size=batch_size, drop_last=True, shuffle=shuffle
+            )
+        self.num_replicas = num_replicas
+        self.rank = rank
+        
+    def __iter__(self):
+        max_length = len(self.batch_sampler) // self.num_replicas * self.num_replicas
+        return islice(self.batch_sampler, self.rank, max_length, self.num_replicas)
+         
+    def __len__(self):
+        return len(self.batch_sampler) // self.num_replicas
+            

+ 4 - 14
src/llama_recipes/datasets/alpaca_dataset.py

@@ -24,17 +24,14 @@ PROMPT_DICT = {
 }
 
 class InstructionDataset(Dataset):
-    def __init__(self, dataset_config, tokenizer, partition="train", max_words=30):
+    def __init__(self, dataset_config, tokenizer, partition="train"):
         self.ann = json.load(open(dataset_config.data_path))
         if partition == "train":
             self.ann = self.ann
         else:
             self.ann = self.ann[:200]
 
-        self.max_words = max_words
-        # tokenizer = Tokenizer(model_path=model_path + "./tokenizer.model")
         self.tokenizer = tokenizer
-        # self.tokenizer1 = tokenizer
 
     def __len__(self):
         return len(self.ann)
@@ -57,22 +54,15 @@ class InstructionDataset(Dataset):
         example = torch.tensor(
             example, dtype=torch.int64
         )
-        padding = self.max_words - example.shape[0]
-        if padding > 0:
-            example = torch.cat((example, torch.zeros(padding, dtype=torch.int64) - 1))
-        elif padding < 0:
-            example = example[: self.max_words]
         labels = copy.deepcopy(example)
         labels[: len(prompt)] = -1
         example_mask = example.ge(0)
         label_mask = labels.ge(0)
         example[~example_mask] = 0
         labels[~label_mask] = IGNORE_INDEX
-        example_mask = example_mask.float()
-        label_mask = label_mask.float()
 
         return {
-            "input_ids": example,
-            "labels": labels,
-            "attention_mask":example_mask,
+            "input_ids": example.tolist(),
+            "labels": labels.tolist(),
+            "attention_mask":example_mask.tolist(),
         }

+ 13 - 18
src/llama_recipes/datasets/grammar_dataset/grammar_dataset.py

@@ -10,8 +10,6 @@ from pathlib import Path
 
 from torch.utils.data import Dataset
 
-from llama_recipes.datasets.utils import ConcatDataset
-
 
 class grammar(Dataset):
     def __init__(
@@ -48,24 +46,22 @@ class grammar(Dataset):
 
         input_ = example_batch["input"]
         target_ = example_batch["target"]
-        
-        prompt = f"Correct this to standard English: {input_}\n---\nCorrected: {target_}"
-        sample = self.tokenizer(prompt)
-        
-        return sample
-
-    def __getitem__(self, index):
-        sample = self.convert_to_features(self.dataset["train"][index])
-        source_ids = sample["input_ids"]
 
-        src_mask = sample["attention_mask"]
+        prompt = f"Correct this to standard English: {input_}\n---\nCorrected: "
+        prompt_ids = self.tokenizer.encode(self.tokenizer.bos_token + prompt, add_special_tokens=False)
+        label_ids = self.tokenizer.encode(target_ + self.tokenizer.eos_token, add_special_tokens=False)
 
-        return {
-            "input_ids": source_ids,
-            "attention_mask": src_mask,
-            "labels": source_ids.copy(),
+        sample = {
+            "input_ids": prompt_ids + label_ids,
+            "attention_mask": [1] * len(prompt_ids + label_ids),
+            "labels": [-100] * len(prompt_ids) + label_ids
         }
 
+        return sample
+
+    def __getitem__(self, index):
+        return self.convert_to_features(self.dataset["train"][int(index)])
+
 
 def get_dataset(
     dataset_config, tokenizer, csv_name=None
@@ -80,6 +76,5 @@ def get_dataset(
         tokenizer=tokenizer,
         csv_name=csv_name,
     )
-    
-    return ConcatDataset(dataset, chunk_size=dataset_config.input_length)
 
+    return dataset

+ 3 - 3
src/llama_recipes/datasets/grammar_dataset/grammar_dataset_process.ipynb

@@ -35,10 +35,10 @@
     "  (\" '\", \"'\"),\n",
     "  (\" ?\", \"?\"),\n",
     "  (\" !\", \"!\"),\n",
-    "  (\" :\", \"!\"),\n",
-    "  (\" ;\", \"!\"),\n",
+    "  (\" :\", \":\"),\n",
+    "  (\" ;\", \";\"),\n",
     "  (\" n't\", \"n't\"),\n",
-    "  (\" v\", \"n't\"),\n",
+    "  (\" v\", \"v\"),\n",
     "  (\"2 0 0 6\", \"2006\"),\n",
     "  (\"5 5\", \"55\"),\n",
     "  (\"4 0 0\", \"400\"),\n",

+ 19 - 13
src/llama_recipes/datasets/samsum_dataset.py

@@ -3,31 +3,37 @@
 
 # For dataset details visit: https://huggingface.co/datasets/samsum
 
+import copy
 import datasets
 
-from llama_recipes.datasets.utils import Concatenator
 
 def get_preprocessed_samsum(dataset_config, tokenizer, split):
     dataset = datasets.load_dataset("samsum", split=split)
 
     prompt = (
-        f"Summarize this dialog:\n{{dialog}}\n---\nSummary:\n{{summary}}{{eos_token}}"
+        f"Summarize this dialog:\n{{dialog}}\n---\nSummary:\n"
     )
 
     def apply_prompt_template(sample):
         return {
-            "text": prompt.format(
-                dialog=sample["dialogue"],
-                summary=sample["summary"],
-                eos_token=tokenizer.eos_token,
-            )
+            "prompt": prompt.format(dialog=sample["dialogue"]),
+            "summary": sample["summary"],
         }
 
     dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features))
-        
-    dataset = dataset.map(
-        lambda sample: tokenizer(sample["text"]),
-        batched=True,
-        remove_columns=list(dataset.features),
-    ).map(Concatenator(), batched=True)
+
+    def tokenize_add_label(sample):
+        prompt = tokenizer.encode(tokenizer.bos_token + sample["prompt"], add_special_tokens=False)
+        summary = tokenizer.encode(sample["summary"] +  tokenizer.eos_token, add_special_tokens=False)
+
+        sample = {
+            "input_ids": prompt + summary,
+            "attention_mask" : [1] * (len(prompt) + len(summary)),
+            "labels": [-100] * len(prompt) + summary,
+            }
+
+        return sample
+
+    dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features))
+
     return dataset

+ 0 - 66
src/llama_recipes/datasets/utils.py

@@ -1,66 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
-
-from tqdm import tqdm
-from itertools import chain
-
-from torch.utils.data import Dataset
-
-class Concatenator(object):
-    def __init__(self, chunk_size=2048):
-        self.chunk_size=chunk_size
-        self.residual = {"input_ids": [], "attention_mask": []}
-        
-    def __call__(self, batch):
-        concatenated_samples = {
-            k: v + list(chain(*batch[k])) for k, v in self.residual.items()
-        }
-
-        total_length = len(concatenated_samples[list(concatenated_samples.keys())[0]])
-
-        if total_length >= self.chunk_size:
-            chunk_num = total_length // self.chunk_size
-            result = {
-                k: [
-                    v[i : i + self.chunk_size]
-                    for i in range(0, chunk_num * self.chunk_size, self.chunk_size)
-                ]
-                for k, v in concatenated_samples.items()
-            }
-            self.residual = {
-                k: v[(chunk_num * self.chunk_size) :]
-                for k, v in concatenated_samples.items()
-            }
-        else:
-            result = concatenated_samples
-            self.residual = {k: [] for k in concatenated_samples.keys()}
-
-        result["labels"] = result["input_ids"].copy()
-
-        return result
-
-class ConcatDataset(Dataset):
-    def __init__(self, dataset, chunk_size=4096):
-        self.dataset = dataset
-        self.chunk_size = chunk_size
-        
-        self.samples = []
-        
-        buffer = {
-            "input_ids": [],
-            "attention_mask": [],
-            "labels": [],
-            }
-        
-        for sample in tqdm(self.dataset, desc="Preprocessing dataset", dynamic_ncols=True):
-            buffer = {k: v + sample[k] for k,v in buffer.items()}
-            
-            while len(next(iter(buffer.values()))) > self.chunk_size:
-                self.samples.append({k: v[:self.chunk_size] for k,v in buffer.items()})
-                buffer = {k: v[self.chunk_size:] for k,v in buffer.items()}
-                
-    def __getitem__(self, idx):
-        return self.samples[idx]
-    
-    def __len__(self):
-        return len(self.samples)

+ 25 - 37
src/llama_recipes/finetuning.py

@@ -5,8 +5,8 @@ import os
 from pkg_resources import packaging
 
 import fire
+import random
 import torch
-import torch.distributed as dist
 import torch.optim as optim
 from peft import get_peft_model, prepare_model_for_int8_training
 from torch.distributed.fsdp import (
@@ -14,16 +14,16 @@ from torch.distributed.fsdp import (
 )
 from torch.distributed.fsdp.fully_sharded_data_parallel import CPUOffload
 from torch.optim.lr_scheduler import StepLR
-from torch.utils.data import DistributedSampler
 from transformers import (
     LlamaForCausalLM,
     LlamaTokenizer,
     LlamaConfig,
-    default_data_collator,
 )
 from transformers.models.llama.modeling_llama import LlamaDecoderLayer
 
-from llama_recipes.configs import fsdp_config, train_config
+from llama_recipes.configs import fsdp_config as FSDP_CONFIG
+from llama_recipes.configs import train_config as TRAIN_CONFIG
+from llama_recipes.data.concatenator import ConcatDataset
 from llama_recipes.policies import AnyPrecisionAdamW, apply_fsdp_checkpointing
 
 from llama_recipes.utils import fsdp_auto_wrap_policy
@@ -31,6 +31,7 @@ from llama_recipes.utils.config_utils import (
     update_config,
     generate_peft_config,
     generate_dataset_config,
+    get_dataloader_kwargs,
 )
 from llama_recipes.utils.dataset_utils import get_preprocessed_dataset
 
@@ -47,11 +48,13 @@ from llama_recipes.utils.train_utils import (
 
 def main(**kwargs):
     # Update the configuration for the training and sharding process
+    train_config, fsdp_config = TRAIN_CONFIG(), FSDP_CONFIG()
     update_config((train_config, fsdp_config), **kwargs)
 
     # Set the seeds for reproducibility
     torch.cuda.manual_seed(train_config.seed)
     torch.manual_seed(train_config.seed)
+    random.seed(train_config.seed)
 
     if train_config.enable_fsdp:
         setup()
@@ -102,14 +105,19 @@ def main(**kwargs):
     if train_config.enable_fsdp and train_config.use_fast_kernels:
         """
         For FSDP and FSDP+PEFT, setting 'use_fast_kernels' will enable
-        using of Flash Attention or Xformer memory-efficient kernels 
+        using of Flash Attention or Xformer memory-efficient kernels
         based on the hardware being used. This would speed up fine-tuning.
         """
         try:
             from optimum.bettertransformer import BetterTransformer
-            model = BetterTransformer.transform(model) 
+            model = BetterTransformer.transform(model)
         except ImportError:
             print("Module 'optimum' not found. Please install 'optimum' it before proceeding.")
+
+    # Load the tokenizer and add special tokens
+    tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name)
+    tokenizer.pad_token_id = tokenizer.eos_token_id
+
     print_model_size(model, train_config, rank if train_config.enable_fsdp else 0)
 
     # Prepare the model for int8 training if quantization is enabled
@@ -120,14 +128,6 @@ def main(**kwargs):
     if train_config.enable_fsdp and fsdp_config.pure_bf16:
         model.to(torch.bfloat16)
 
-    # Load the tokenizer and add special tokens
-    tokenizer = LlamaTokenizer.from_pretrained(train_config.model_name)
-    tokenizer.add_special_tokens(
-            {
-
-                "pad_token": "<PAD>",
-            }
-        )
     if train_config.use_peft:
         peft_config = generate_peft_config(train_config, kwargs)
         model = get_peft_model(model, peft_config)
@@ -179,43 +179,31 @@ def main(**kwargs):
     if not train_config.enable_fsdp or rank == 0:
             print(f"--> Validation Set Length = {len(dataset_val)}")
 
-    train_sampler = None
-    val_sampler = None
-    if train_config.enable_fsdp:
-        train_sampler = DistributedSampler(
-            dataset_train,
-            rank=dist.get_rank(),
-            num_replicas=dist.get_world_size(),
-            shuffle=True,
-        )
-        if train_config.run_validation:
-            val_sampler = DistributedSampler(
-                dataset_val,
-                rank=dist.get_rank(),
-                num_replicas=dist.get_world_size(),
-            )
+    if train_config.batching_strategy == "packing":
+        dataset_train = ConcatDataset(dataset_train, chunk_size=train_config.context_length)
+
+    train_dl_kwargs = get_dataloader_kwargs(train_config, dataset_train, tokenizer, "train")
 
     # Create DataLoaders for the training and validation dataset
     train_dataloader = torch.utils.data.DataLoader(
         dataset_train,
-        batch_size=train_config.batch_size_training,
         num_workers=train_config.num_workers_dataloader,
         pin_memory=True,
-        sampler=train_sampler if train_sampler else None,
-        drop_last=True,
-        collate_fn=default_data_collator,
+        **train_dl_kwargs,
     )
 
     eval_dataloader = None
     if train_config.run_validation:
+        if train_config.batching_strategy == "packing":
+            dataset_val = ConcatDataset(dataset_val, chunk_size=train_config.context_length)
+
+        val_dl_kwargs = get_dataloader_kwargs(train_config, dataset_val, tokenizer, "val")
+
         eval_dataloader = torch.utils.data.DataLoader(
             dataset_val,
-            batch_size=train_config.val_batch_size,
             num_workers=train_config.num_workers_dataloader,
             pin_memory=True,
-            sampler=val_sampler if val_sampler else None,
-            drop_last=True,
-            collate_fn=default_data_collator,
+            **val_dl_kwargs,
         )
 
     # Initialize the optimizer and learning rate scheduler

+ 49 - 11
src/llama_recipes/utils/config_utils.py

@@ -3,13 +3,19 @@
 
 import inspect
 from dataclasses import asdict
+
+import torch.distributed as dist
+from torch.utils.data import DistributedSampler
 from peft import (
     LoraConfig,
     AdaptionPromptConfig,
     PrefixTuningConfig,
 )
+from transformers import default_data_collator
+from transformers.data import DataCollatorForSeq2Seq
 
 from llama_recipes.configs import datasets, lora_config, llama_adapter_config, prefix_config, train_config
+from llama_recipes.data.sampler import LengthBasedBatchSampler, DistributedLengthBasedBatchSampler
 from llama_recipes.utils.dataset_utils import DATASET_PREPROC
 
 
@@ -32,31 +38,63 @@ def update_config(config, **kwargs):
                         print(f"Warning: {config_name} does not accept parameter: {k}")
             elif isinstance(config, train_config):
                 print(f"Warning: unknown parameter {k}")
-                        
-                        
+
+
 def generate_peft_config(train_config, kwargs):
     configs = (lora_config, llama_adapter_config, prefix_config)
     peft_configs = (LoraConfig, AdaptionPromptConfig, PrefixTuningConfig)
     names = tuple(c.__name__.rstrip("_config") for c in configs)
-    
+
     assert train_config.peft_method in names, f"Peft config not found: {train_config.peft_method}"
-    
+
     config = configs[names.index(train_config.peft_method)]()
-    
+
     update_config(config, **kwargs)
     params = asdict(config)
     peft_config = peft_configs[names.index(train_config.peft_method)](**params)
-    
+
     return peft_config
 
 
 def generate_dataset_config(train_config, kwargs):
     names = tuple(DATASET_PREPROC.keys())
-        
+
     assert train_config.dataset in names, f"Unknown dataset: {train_config.dataset}"
-    
+
     dataset_config = {k:v for k, v in inspect.getmembers(datasets)}[train_config.dataset]()
-        
+
     update_config(dataset_config, **kwargs)
-    
-    return  dataset_config
+
+    return  dataset_config
+
+
+def get_dataloader_kwargs(train_config, dataset, tokenizer, mode):
+        kwargs = {}
+        batch_size = train_config.batch_size_training if mode=="train" else train_config.val_batch_size
+        if train_config.batching_strategy == "padding":
+            if train_config.enable_fsdp:
+                kwargs["batch_sampler"] = DistributedLengthBasedBatchSampler(
+                    dataset,
+                    batch_size=batch_size,
+                    rank=dist.get_rank(),
+                    num_replicas=dist.get_world_size(),
+                    shuffle=mode=="train",
+                )
+            else:
+                kwargs["batch_sampler"] = LengthBasedBatchSampler(dataset, batch_size, drop_last=True, shuffle=mode=="train")
+            kwargs["collate_fn"] = DataCollatorForSeq2Seq(tokenizer)
+        elif train_config.batching_strategy == "packing":
+            if train_config.enable_fsdp:
+                kwargs["sampler"] = DistributedSampler(
+                dataset,
+                rank=dist.get_rank(),
+                num_replicas=dist.get_world_size(),
+                shuffle=mode=="train",
+            )
+            kwargs["batch_size"] = batch_size
+            kwargs["drop_last"] = True
+            kwargs["collate_fn"] = default_data_collator
+        else:
+            raise ValueError(f"Unknown batching strategy: {train_config.batching_strategy}")
+
+        return kwargs

+ 6 - 6
src/llama_recipes/utils/dataset_utils.py

@@ -33,24 +33,24 @@ def get_custom_dataset(dataset_config, tokenizer, split: str):
         module_path, func_name = dataset_config.file.split(":")
     else:
         module_path, func_name = dataset_config.file, "get_custom_dataset"
-        
+
     if not module_path.endswith(".py"):
         raise ValueError(f"Dataset file {module_path} is not a .py file.")
-    
+
     module_path = Path(module_path)
     if not module_path.is_file():
         raise FileNotFoundError(f"Dataset py file {module_path.as_posix()} does not exist or is not a file.")
-    
+
     module = load_module_from_py_file(module_path.as_posix())
     try:
         return getattr(module, func_name)(dataset_config, tokenizer, split)
     except AttributeError as e:
         print(f"It seems like the given method name ({func_name}) is not present in the dataset .py file ({module_path.as_posix()}).")
         raise e
-    
+
 
 DATASET_PREPROC = {
-    "alpaca_dataset": partial(get_alpaca_dataset, max_words=224),
+    "alpaca_dataset": partial(get_alpaca_dataset),
     "grammar_dataset": get_grammar_dataset,
     "samsum_dataset": get_samsum_dataset,
     "custom_dataset": get_custom_dataset,
@@ -69,7 +69,7 @@ def get_preprocessed_dataset(
             if split == "train"
             else dataset_config.test_split
         )
-    
+
     return DATASET_PREPROC[dataset_config.dataset](
         dataset_config,
         tokenizer,

+ 37 - 33
src/llama_recipes/utils/train_utils.py

@@ -4,6 +4,7 @@
 import os
 import time
 import yaml
+from contextlib import nullcontext
 from pathlib import Path
 from pkg_resources import packaging
 from datetime import datetime
@@ -20,14 +21,14 @@ import json
 
 
 from llama_recipes.model_checkpointing import save_model_checkpoint, save_model_and_optimizer_sharded, save_optimizer_checkpoint
-from llama_recipes.policies import fpSixteen,bfSixteen_mixed, get_llama_wrapper
+from llama_recipes.policies import fpSixteen,bfSixteen, get_llama_wrapper
 from llama_recipes.utils.memory_utils import MemoryTrace
 
 
 def set_tokenizer_params(tokenizer: LlamaTokenizer):
     tokenizer.pad_token_id = 0
     tokenizer.padding_side = "left"
-    
+
 # Converting Bytes to Megabytes
 def byte2mb(x):
     return int(x / 2**20)
@@ -35,7 +36,7 @@ def byte2mb(x):
 def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_scheduler, gradient_accumulation_steps, train_config, fsdp_config=None, local_rank=None, rank=None):
     """
     Trains the model on the given dataloader
-    
+
     Args:
         model: The model to be trained
         train_dataloader: The dataloader containing the training data
@@ -47,18 +48,20 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         train_config: The training configuration
         eval_dataloader: The dataloader containing the eval data
         tokenizer: tokenizer used in the eval for decoding the predicitons
-    
+
     Returns: results dictionary containing average training and validation perplexity and loss
     """
     # Create a gradient scaler for fp16
     if train_config.use_fp16 and train_config.enable_fsdp:
         scaler = ShardedGradScaler()
     elif train_config.use_fp16 and not train_config.enable_fsdp:
-        scaler = torch.cuda.amp.GradScaler() 
+        scaler = torch.cuda.amp.GradScaler()
     if train_config.enable_fsdp:
         world_size = int(os.environ["WORLD_SIZE"]) 
 
     metrics_filename = f"{train_config.output_dir}/metrics_data_{local_rank}-{datetime.now().strftime('%Y-%m-%d_%H-%M-%S')}.json"
+    autocast = torch.cuda.amp.autocast if train_config.use_fp16 else nullcontext
+
     train_prep = []
     train_loss = []
     train_step_perplexity = []
@@ -85,8 +88,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                     if train_config.enable_fsdp:
                         batch[key] = batch[key].to(local_rank)
                     else:
-                        batch[key] = batch[key].to('cuda:0')              
-                loss = model(**batch).loss
+                        batch[key] = batch[key].to('cuda:0')
+                with autocast():
+                    loss = model(**batch).loss
                 loss = loss / gradient_accumulation_steps
                 train_step_loss.append(loss.detach().float().item())
                 train_step_perplexity.append(float(torch.exp(loss.detach().float())))
@@ -109,9 +113,9 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
 
                 pbar.set_description(f"Training Epoch: {epoch+1}/{train_config.num_epochs}, step {step}/{len(train_dataloader)} completed (loss: {loss.detach().float()})")
             pbar.close()
-                
+
         epoch_end_time = time.perf_counter()-epoch_start_time
-        epoch_times.append(epoch_end_time)    
+        epoch_times.append(epoch_end_time)
         # Reducing total_loss across all devices if there's more than one CUDA device
         if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
             dist.all_reduce(total_loss, op=dist.ReduceOp.SUM)
@@ -136,10 +140,10 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
             print(f"Peak active CUDA memory was {memtrace.peak_active_gb} GB")
             print(f"Cuda Malloc retires : {memtrace.cuda_malloc_retires}")
             print(f"CPU Total Peak Memory consumed during the train (max): {memtrace.cpu_peaked + memtrace.cpu_begin} GB")
-        
+
         # Update the learning rate as needed
         lr_scheduler.step()
-          
+
         if train_config.run_validation:
             eval_ppl, eval_epoch_loss, temp_val_loss, temp_step_perplexity = evaluation(model, train_config, eval_dataloader, local_rank, tokenizer)
             val_step_loss.extend(temp_val_loss)
@@ -154,23 +158,23 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                             print(f"we are about to save the PEFT modules")
                     else:
                         print(f"we are about to save the PEFT modules")
-                    model.save_pretrained(train_config.output_dir)  
+                    model.save_pretrained(train_config.output_dir)
                     if train_config.enable_fsdp:
-                        if rank==0: 
+                        if rank==0:
                             print(f"PEFT modules are saved in {train_config.output_dir} directory")
                     else:
                         print(f"PEFT modules are saved in {train_config.output_dir} directory")
-                        
+
                 else:
                     if not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.FULL_STATE_DICT:
-                        
+
                         save_model_checkpoint(
                             model, optimizer, rank, train_config, epoch=epoch
                         )
                     elif not train_config.use_peft and fsdp_config.checkpoint_type == StateDictType.SHARDED_STATE_DICT:
                         print(" Saving the FSDP model checkpoints using SHARDED_STATE_DICT")
                         print("=====================================================")
-                        
+
                         save_model_and_optimizer_sharded(model, rank, train_config)
                         if train_config.save_optimizer:
                             save_model_and_optimizer_sharded(model, rank, train_config, optim=optimizer)
@@ -182,7 +186,7 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
                             model, optimizer, rank, train_config, epoch=epoch
                         )
                         print(" Saving the FSDP model checkpoints and optimizer using FULL_STATE_DICT")
-                        print("=====================================================")                     
+                        print("=====================================================")
                 if train_config.enable_fsdp:
                     dist.barrier()
             checkpoint_end_time = time.perf_counter() - checkpoint_start_time
@@ -210,8 +214,8 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
     avg_train_prep = sum(train_prep)/len(train_prep)
     avg_train_loss = sum(train_loss)/len(train_loss)
     if train_config.run_validation:
-        avg_eval_prep = sum(val_prep)/len(val_prep) 
-        avg_eval_loss = sum(val_loss)/len(val_loss) 
+        avg_eval_prep = sum(val_prep)/len(val_prep)
+        avg_eval_loss = sum(val_loss)/len(val_loss)
 
     results['avg_train_prep'] = avg_train_prep
     results['avg_train_loss'] = avg_train_loss
@@ -220,27 +224,27 @@ def train(model, train_dataloader,eval_dataloader, tokenizer, optimizer, lr_sche
         results['avg_eval_loss'] = avg_eval_loss
     results["avg_epoch_time"] = avg_epoch_time
     results["avg_checkpoint_time"] = avg_checkpoint_time
-    
+
     #saving the training params including fsdp setting for reference.
     if train_config.enable_fsdp and not train_config.use_peft:
         save_train_params(train_config, fsdp_config, rank)
-        
+
     return results
 
 def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
     """
     Evaluates the model on the given dataloader
-    
+
     Args:
         model: The model to evaluate
         eval_dataloader: The dataloader containing the evaluation data
         local_rank: The rank of the current node in a distributed setting
         tokenizer: The tokenizer used to decode predictions
-    
+
     Returns: eval_ppl, eval_epoch_loss
     """
     if train_config.enable_fsdp:
-        world_size = int(os.environ["WORLD_SIZE"]) 
+        world_size = int(os.environ["WORLD_SIZE"])
     model.eval()
     eval_preds = []
     val_step_loss = []
@@ -266,17 +270,17 @@ def evaluation(model,train_config, eval_dataloader, local_rank, tokenizer):
             eval_preds.extend(
                 tokenizer.batch_decode(preds.detach().cpu().numpy(), skip_special_tokens=True)
             )
-    
+
     # If there's more than one CUDA device, reduce evaluation loss across all devices
     if torch.cuda.device_count() > 1 and train_config.enable_fsdp:
         dist.all_reduce(eval_loss, op=dist.ReduceOp.SUM)
-    
+
     # Compute average loss and perplexity
     eval_epoch_loss = eval_loss / len(eval_dataloader)
     if train_config.enable_fsdp:
         eval_epoch_loss = eval_epoch_loss/world_size
     eval_ppl = torch.exp(eval_epoch_loss)
-    
+
     # Print evaluation metrics
     if train_config.enable_fsdp:
         if local_rank==0:
@@ -297,8 +301,8 @@ def check_frozen_layers_peft_model(model):
      for i, layer in enumerate(model.base_model.model.model.layers):
             for name, param in layer.named_parameters():
                 print(f"Layer {i}, parameter {name}: requires_grad = {param.requires_grad}")
-                
-                
+
+
 def setup():
     """Initialize the process group for distributed training"""
     dist.init_process_group("nccl")
@@ -311,7 +315,7 @@ def setup_environ_flags(rank):
     # os.environ["TORCH_DISTRIBUTED_DEBUG"] = "DETAIL"
     # This flag will help with CUDA memory fragmentations that can lead into OOM in some cases.
     # Note this is only availble in PyTorch Nighlies (as of July 30 2023)
-    # os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True' 
+    # os.environ['PYTORCH_CUDA_ALLOC_CONF']='expandable_segments:True'
     if rank == 0:
         print(f"--> Running with torch dist debug set to detail")
 
@@ -356,7 +360,7 @@ def print_model_size(model, config, rank: int = 0) -> None:
 
 def get_policies(cfg, rank):
     """Get the policies for mixed precision and fsdp wrapping"""
-    
+
     verify_bfloat_support = (
     torch.version.cuda
     and torch.cuda.is_bf16_supported()
@@ -374,7 +378,7 @@ def get_policies(cfg, rank):
         bf16_ready = verify_bfloat_support
 
         if bf16_ready and not cfg.use_fp16:
-            mixed_precision_policy = bfSixteen_mixed
+            mixed_precision_policy = bfSixteen
             if rank == 0:
                 print(f"bFloat16 enabled for mixed precision - using bfSixteen policy")
         elif cfg.use_fp16:
@@ -392,7 +396,7 @@ def save_train_params(train_config, fsdp_config, rank):
     This will be used by converter script in the inference folder to fetch the HF model name or path.
     It also would be hepful as a log for future references.
     """
-    # Convert the train_config and fsdp_config objects to dictionaries, 
+    # Convert the train_config and fsdp_config objects to dictionaries,
     # converting all values to strings to ensure they can be serialized into a YAML file
     train_config_dict = {k: str(v) for k, v in vars(train_config).items() if not k.startswith('__')}
     fsdp_config_dict = {k: str(v) for k, v in vars(fsdp_config).items() if not k.startswith('__')}

+ 18 - 0
tests/conftest.py

@@ -0,0 +1,18 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import pytest
+
+from transformers import LlamaTokenizer
+
+
+@pytest.fixture
+def setup_tokenizer():
+    def _helper(tokenizer):
+        #Align with Llama 2 tokenizer
+        tokenizer.from_pretrained.return_value = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")
+        tokenizer.from_pretrained.return_value.add_special_tokens({'bos_token': '<s>', 'eos_token': '</s>'})
+        tokenizer.from_pretrained.return_value.bos_token_id = 1
+        tokenizer.from_pretrained.return_value.eos_token_id = 2
+
+    return _helper

+ 40 - 13
tests/datasets/test_custom_dataset.py

@@ -4,21 +4,38 @@
 import pytest
 from unittest.mock import patch
 
+from transformers import LlamaTokenizer
+
+def check_padded_entry(batch):
+    seq_len = sum(batch["attention_mask"][0])
+    assert seq_len < len(batch["attention_mask"][0])
+
+    assert batch["labels"][0][0] == -100
+    assert batch["labels"][0][seq_len-1] == 2
+    assert batch["labels"][0][-1] == -100
+    assert batch["input_ids"][0][0] == 1
+    assert batch["input_ids"][0][-1] == 2
+
 
 @patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
-def test_custom_dataset(step_lr, optimizer, get_model, train, mocker):
+def test_custom_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer):
     from llama_recipes.finetuning import main
 
+    setup_tokenizer(tokenizer)
+
     kwargs = {
         "dataset": "custom_dataset",
         "model_name": "decapoda-research/llama-7b-hf", # We use the tokenizer as a surrogate for llama2 tokenizer here
         "custom_dataset.file": "examples/custom_dataset.py",
         "custom_dataset.train_split": "validation",
         "batch_size_training": 2,
+        "val_batch_size": 4,
         "use_peft": False,
+        "batching_strategy": "padding"
         }
 
     main(**kwargs)
@@ -30,24 +47,34 @@ def test_custom_dataset(step_lr, optimizer, get_model, train, mocker):
     eval_dataloader = args[2]
     tokenizer = args[3]
 
-    assert len(train_dataloader) == 226
-    assert len(eval_dataloader) == 2*226
+    assert len(train_dataloader) == 1120
+    assert len(eval_dataloader) == 1120 //2
+
+    it = iter(eval_dataloader)
+    batch = next(it)
+    STRING = tokenizer.decode(batch["input_ids"][0], skip_special_tokens=True)
+    EXPECTED_STRING = "[INST] Who made Berlin [/INST] dunno"
+    assert STRING.startswith(EXPECTED_STRING)
+
+    assert batch["input_ids"].size(0) == 4
+    assert set(("labels", "input_ids", "attention_mask")) == set(batch.keys())
+
+    check_padded_entry(batch)
 
     it = iter(train_dataloader)
-    STRING = tokenizer.decode(next(it)["input_ids"][0], skip_special_tokens=True)
-    EXPECTED_STRING = "[INST] Напиши функцию на языке swift, которая сортирует массив целых чисел, а затем выводит его на экран [/INST] Вот функция, "
+    for _ in range(5):
+        next(it)
 
+    batch = next(it)
+    STRING = tokenizer.decode(batch["input_ids"][0], skip_special_tokens=True)
+    EXPECTED_STRING = "[INST] How do I initialize a Typescript project using npm and git? [/INST] # Initialize a new NPM project"
     assert STRING.startswith(EXPECTED_STRING)
 
-    next(it)
-    next(it)
-    next(it)
-    STRING = tokenizer.decode(next(it)["input_ids"][0], skip_special_tokens=True)
-    EXPECTED_SUBSTRING_1 = "Therefore you are correct.  [INST] How can L’Hopital’s Rule be"
-    EXPECTED_SUBSTRING_2 = "a circular path around the turn.  [INST] How on earth is that related to L’Hopital’s Rule?"
+    assert batch["input_ids"].size(0) == 2
+    assert set(("labels", "input_ids", "attention_mask")) == set(batch.keys())
+
+    check_padded_entry(batch)
 
-    assert EXPECTED_SUBSTRING_1 in STRING
-    assert EXPECTED_SUBSTRING_2 in STRING
 
 
 @patch('llama_recipes.finetuning.train')

+ 54 - 0
tests/datasets/test_grammar_datasets.py

@@ -0,0 +1,54 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+from unittest.mock import patch
+
+from transformers import LlamaTokenizer
+
+
+@patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaTokenizer')
+@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
+@patch('llama_recipes.finetuning.optim.AdamW')
+@patch('llama_recipes.finetuning.StepLR')
+def test_grammar_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer):
+    from llama_recipes.finetuning import main
+
+    setup_tokenizer(tokenizer)
+
+    BATCH_SIZE = 8
+    kwargs = {
+        "model_name": "decapoda-research/llama-7b-hf",
+        "batch_size_training": BATCH_SIZE,
+        "val_batch_size": 1,
+        "use_peft": False,
+        "dataset": "grammar_dataset",
+        "batching_strategy": "padding",
+        }
+
+    main(**kwargs)
+
+    assert train.call_count == 1
+
+    args, kwargs = train.call_args
+    train_dataloader = args[1]
+    eval_dataloader = args[2]
+
+    VAL_SAMPLES = 2988
+    TRAIN_SAMPLES = 13016
+
+    assert len(train_dataloader) == TRAIN_SAMPLES // BATCH_SIZE
+    assert len(eval_dataloader) == VAL_SAMPLES
+
+    batch = next(iter(train_dataloader))
+
+    assert "labels" in batch.keys()
+    assert "input_ids" in batch.keys()
+    assert "attention_mask" in batch.keys()
+
+    assert batch["labels"][0][29] == -100
+    assert batch["labels"][0][30] == 29871
+
+    assert batch["input_ids"][0][0] == 1
+    assert batch["labels"][0][-1] == 2
+    assert batch["input_ids"][0][-1] == 2

+ 30 - 14
tests/datasets/test_samsum_datasets.py

@@ -1,37 +1,53 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
+from functools import partial
 from unittest.mock import patch
 
 
 @patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaTokenizer')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
-@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.optim.AdamW')
 @patch('llama_recipes.finetuning.StepLR')
-def test_custom_dataset(step_lr, optimizer, tokenizer, get_model, train, mocker):
+def test_samsum_dataset(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer):
     from llama_recipes.finetuning import main
-        
-    tokenizer.return_value = mocker.MagicMock(side_effect=lambda x: {"input_ids":[len(x)*[0,]], "attention_mask": [len(x)*[0,]]})
-    
-    
+
+    setup_tokenizer(tokenizer)
+
+    BATCH_SIZE = 8
     kwargs = {
-        "batch_size_training": 1,
+        "model_name": "decapoda-research/llama-7b-hf",
+        "batch_size_training": BATCH_SIZE,
+        "val_batch_size": 1,
         "use_peft": False,
         "dataset": "samsum_dataset",
+        "batching_strategy": "padding",
         }
-    
+
     main(**kwargs)
-    
+
     assert train.call_count == 1
-    
+
     args, kwargs = train.call_args
     train_dataloader = args[1]
     eval_dataloader = args[2]
-    
+
     VAL_SAMPLES = 818
     TRAIN_SAMPLES = 14732
-    CONCAT_SIZE = 2048
-    assert len(train_dataloader) == TRAIN_SAMPLES // CONCAT_SIZE
+
+    assert len(train_dataloader) == TRAIN_SAMPLES // BATCH_SIZE
     assert len(eval_dataloader) == VAL_SAMPLES
-    
+
+    batch = next(iter(train_dataloader))
+
+    assert "labels" in batch.keys()
+    assert "input_ids" in batch.keys()
+    assert "attention_mask" in batch.keys()
+
+    assert batch["labels"][0][268] == -100
+    assert batch["labels"][0][269] == 22291
+
+    assert batch["input_ids"][0][0] == 1
+    assert batch["labels"][0][-1] == 2
+    assert batch["input_ids"][0][-1] == 2

+ 94 - 0
tests/test_batching.py

@@ -0,0 +1,94 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import pytest
+from unittest.mock import patch
+
+
+@patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaTokenizer')
+@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
+@patch('llama_recipes.finetuning.optim.AdamW')
+@patch('llama_recipes.finetuning.StepLR')
+def test_packing(step_lr, optimizer, get_model, tokenizer, train, mocker, setup_tokenizer):
+    from llama_recipes.finetuning import main
+
+    setup_tokenizer(tokenizer)
+
+    kwargs = {
+        "model_name": "decapoda-research/llama-7b-hf",
+        "batch_size_training": 8,
+        "val_batch_size": 1,
+        "use_peft": False,
+        "dataset": "samsum_dataset",
+        "batching_strategy": "packing",
+        }
+
+    main(**kwargs)
+
+    assert train.call_count == 1
+
+    args, kwargs = train.call_args
+    train_dataloader = args[1]
+    eval_dataloader = args[2]
+
+    assert len(train_dataloader) == 96
+    assert len(eval_dataloader) == 42
+
+    batch = next(iter(train_dataloader))
+
+    assert "labels" in batch.keys()
+    assert "input_ids" in batch.keys()
+    assert "attention_mask" in batch.keys()
+
+    assert batch["labels"][0].size(0) == 4096
+    assert batch["input_ids"][0].size(0) == 4096
+    assert batch["attention_mask"][0].size(0) == 4096
+
+
+@patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaTokenizer')
+@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
+@patch('llama_recipes.finetuning.optim.AdamW')
+@patch('llama_recipes.finetuning.StepLR')
+@patch('llama_recipes.finetuning.setup')
+@patch('llama_recipes.finetuning.FSDP')
+@patch('llama_recipes.finetuning.torch.distributed.is_initialized')
+@patch('llama_recipes.utils.config_utils.dist')
+def test_distributed_packing(dist, is_initialized, fsdp, setup, step_lr, optimizer, get_model, tokenizer, train, setup_tokenizer):
+    import os
+    from llama_recipes.finetuning import main
+
+    setup_tokenizer(tokenizer)
+
+    rank = 0
+    os.environ['LOCAL_RANK'] = f'{rank}'
+    os.environ['RANK'] = f'{rank}'
+    os.environ['WORLD_SIZE'] = '2'
+    os.environ['MASTER_ADDR'] = 'localhost'
+    os.environ['MASTER_PORT'] = '12345'
+
+    kwargs = {
+        "model_name": "decapoda-research/llama-7b-hf",
+        "batch_size_training": 8,
+        "val_batch_size": 1,
+        "use_peft": False,
+        "dataset": "samsum_dataset",
+        "batching_strategy": "packing",
+        "enable_fsdp": True
+        }
+
+    is_initialized.return_value = True
+    dist.get_rank.return_value = rank
+    dist.get_world_size.return_value = 2
+
+    main(**kwargs)
+
+    assert train.call_count == 1
+
+    args, kwargs = train.call_args
+    train_dataloader = args[1]
+    eval_dataloader = args[2]
+
+    assert len(train_dataloader) == 96 //2
+    assert len(eval_dataloader) == 42 //2

+ 82 - 33
tests/test_finetuning.py

@@ -1,14 +1,26 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
+import pytest
 from pytest import approx
 from unittest.mock import patch
 
 from torch.nn import Linear
 from torch.optim import AdamW
 from torch.utils.data.dataloader import DataLoader
+from torch.utils.data.sampler import BatchSampler
 
 from llama_recipes.finetuning import main
+from llama_recipes.data.sampler import LengthBasedBatchSampler
+
+
+def get_fake_dataset():
+    return [{
+        "input_ids":[1],
+        "attention_mask":[1],
+        "labels":[1],
+        }]
+
 
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
@@ -18,23 +30,23 @@ from llama_recipes.finetuning import main
 @patch('llama_recipes.finetuning.StepLR')
 def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
     kwargs = {"run_validation": False}
-    
-    get_dataset.return_value = [1]
-    
+
+    get_dataset.return_value = get_fake_dataset()
+
     main(**kwargs)
-    
+
     assert train.call_count == 1
-    
+
     args, kwargs = train.call_args
     train_dataloader = args[1]
     eval_dataloader = args[2]
-    
+
     assert isinstance(train_dataloader, DataLoader)
     assert eval_dataloader is None
-    
+
     assert get_model.return_value.to.call_args.args[0] == "cuda"
-    
-    
+
+
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
@@ -43,21 +55,22 @@ def test_finetuning_no_validation(step_lr, optimizer, get_dataset, tokenizer, ge
 @patch('llama_recipes.finetuning.StepLR')
 def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
     kwargs = {"run_validation": True}
-    get_dataset.return_value = [1]
-    
+
+    get_dataset.return_value = get_fake_dataset()
+
     main(**kwargs)
-    
+
     assert train.call_count == 1
-    
+
     args, kwargs = train.call_args
     train_dataloader = args[1]
     eval_dataloader = args[2]
     assert isinstance(train_dataloader, DataLoader)
     assert isinstance(eval_dataloader, DataLoader)
-    
+
     assert get_model.return_value.to.call_args.args[0] == "cuda"
-    
-    
+
+
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
@@ -68,37 +81,73 @@ def test_finetuning_with_validation(step_lr, optimizer, get_dataset, tokenizer,
 @patch('llama_recipes.finetuning.StepLR')
 def test_finetuning_peft(step_lr, optimizer, get_peft_model, gen_peft_config, get_dataset, tokenizer, get_model, train):
     kwargs = {"use_peft": True}
-    
-    get_dataset.return_value = [1]
-    
+
+    get_dataset.return_value = get_fake_dataset()
+
     main(**kwargs)
-    
+
     assert get_peft_model.return_value.to.call_args.args[0] == "cuda"
     assert get_peft_model.return_value.print_trainable_parameters.call_count == 1
-    
-    
+
+
 @patch('llama_recipes.finetuning.train')
 @patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
 @patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
 @patch('llama_recipes.finetuning.get_preprocessed_dataset')
 @patch('llama_recipes.finetuning.get_peft_model')
 @patch('llama_recipes.finetuning.StepLR')
-def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer, get_model, train):
+def test_finetuning_weight_decay(step_lr, get_peft_model, get_dataset, tokenizer, get_model, train, mocker):
     kwargs = {"weight_decay": 0.01}
-    
-    get_dataset.return_value = [1]
-    
-    get_peft_model.return_value = Linear(1,1)
-    get_peft_model.return_value.print_trainable_parameters=lambda:None
+
+    get_dataset.return_value = get_fake_dataset()
+
+    get_model.return_value = Linear(1,1)
+
     main(**kwargs)
-    
+
     assert train.call_count == 1
-    
+
     args, kwargs = train.call_args
     optimizer = args[4]
-    
+
     print(optimizer.state_dict())
-    
+
     assert isinstance(optimizer, AdamW)
     assert optimizer.state_dict()["param_groups"][0]["weight_decay"] == approx(0.01)
-    
+
+
+@patch('llama_recipes.finetuning.train')
+@patch('llama_recipes.finetuning.LlamaForCausalLM.from_pretrained')
+@patch('llama_recipes.finetuning.LlamaTokenizer.from_pretrained')
+@patch('llama_recipes.finetuning.get_preprocessed_dataset')
+@patch('llama_recipes.finetuning.optim.AdamW')
+@patch('llama_recipes.finetuning.StepLR')
+def test_batching_strategy(step_lr, optimizer, get_dataset, tokenizer, get_model, train):
+    kwargs = {"batching_strategy": "packing"}
+
+    get_dataset.return_value = get_fake_dataset()
+
+    main(**kwargs)
+
+    assert train.call_count == 1
+
+    args, kwargs = train.call_args
+    train_dataloader, eval_dataloader = args[1:3]
+    assert isinstance(train_dataloader.batch_sampler, BatchSampler)
+    assert isinstance(eval_dataloader.batch_sampler, BatchSampler)
+
+    kwargs["batching_strategy"] = "padding"
+    train.reset_mock()
+    main(**kwargs)
+
+    assert train.call_count == 1
+
+    args, kwargs = train.call_args
+    train_dataloader, eval_dataloader = args[1:3]
+    assert isinstance(train_dataloader.batch_sampler, LengthBasedBatchSampler)
+    assert isinstance(eval_dataloader.batch_sampler, LengthBasedBatchSampler)
+
+    kwargs["batching_strategy"] = "none"
+
+    with pytest.raises(ValueError):
+        main(**kwargs)

+ 86 - 0
tests/test_sampler.py

@@ -0,0 +1,86 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
+
+import random
+import pytest
+
+import torch
+
+from llama_recipes.data.sampler import LengthBasedBatchSampler
+from llama_recipes.data.sampler import DistributedLengthBasedBatchSampler
+
+SAMPLES = 33
+
+@pytest.fixture
+def dataset():
+    random.seed(42)
+    dataset = []
+    def add_samples(ds, n, a, b):
+        for _ in range(n):
+            ds.append(random.randint(a,b) * [1,])
+    add_samples(dataset, SAMPLES // 2, 1,9)
+    add_samples(dataset, (SAMPLES // 2) + (SAMPLES % 2), 10,20)
+    
+    return random.sample(dataset, len(dataset))
+    
+    
+@pytest.mark.parametrize("batch_size, drop_last", [(2, False), (8, False), (2, True), (8, True)])
+def test_batch_sampler_array(dataset, batch_size, drop_last):
+    
+    sampler = LengthBasedBatchSampler(dataset, batch_size, drop_last)
+    
+    EXPECTED_LENGTH = SAMPLES // batch_size if drop_last else (SAMPLES // batch_size) + (SAMPLES % batch_size)
+    
+    all_ids = [i for b in sampler for i in b]
+    assert len(set(all_ids)) == EXPECTED_LENGTH * batch_size if drop_last else len(dataset)
+    
+    assert len(sampler) == EXPECTED_LENGTH
+    is_long = [len(d)>=10 for d in dataset]
+    
+    def check_batch(batch):
+        return all(batch) or not any(batch)
+    
+    assert all(check_batch(is_long[i] for i in b) for b in sampler)
+    
+    
+@pytest.mark.parametrize("batch_size, drop_last", [(2, False), (8, False), (2, True), (8, True)])
+def test_batch_sampler_dict(dataset, batch_size, drop_last):
+    
+    dist_dataset = [{"input_ids": d, "attention_mask": d} for d in dataset]
+    
+    sampler = LengthBasedBatchSampler(dist_dataset, batch_size, drop_last)
+    
+    EXPECTED_LENGTH = SAMPLES // batch_size if drop_last else (SAMPLES // batch_size) + (SAMPLES % batch_size)
+    
+    assert len(sampler) == EXPECTED_LENGTH
+    is_long = [len(d)>=10 for d in dataset]
+    
+    def check_batch(batch):
+        return all(batch) or not any(batch)
+    
+    assert all(check_batch(is_long[i] for i in b) for b in sampler)
+    
+    
+@pytest.mark.parametrize("batch_size", [2, 8])
+def test_dist_batch_sampling(dataset, batch_size):
+    sampler_1 = DistributedLengthBasedBatchSampler(
+        dataset,
+        batch_size=batch_size,
+        rank=0,
+        num_replicas=2,
+        shuffle=False,
+    )
+    sampler_2 = DistributedLengthBasedBatchSampler(
+        dataset,
+        batch_size=batch_size,
+        rank=1,
+        num_replicas=2,
+        shuffle=False,
+    )
+    
+    ids_1 = set(i for b in sampler_1 for i in b)
+    ids_2 = set(i for b in sampler_2 for i in b)
+    
+    assert ids_1.isdisjoint(ids_2)
+    assert len(ids_1)+len(ids_2) > 0
+    assert len(ids_1)+len(ids_2) == len(dataset) // batch_size  *  batch_size 

+ 18 - 5
tests/test_train_utils.py

@@ -1,17 +1,22 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates.
 # This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
 
+from unittest.mock import patch
+
 import torch
 
 from llama_recipes.utils.train_utils import train
 
-def test_gradient_accumulation(mocker):
-    # import sys
-    # sys.path.append('/home/ubuntu/llama-recipes/')
+@patch("llama_recipes.utils.train_utils.MemoryTrace")
+@patch("llama_recipes.utils.train_utils.nullcontext")
+@patch("llama_recipes.utils.train_utils.torch.cuda.amp.GradScaler")
+@patch("llama_recipes.utils.train_utils.torch.cuda.amp.autocast")
+def test_gradient_accumulation(autocast, scaler, nullcontext, mem_trace, mocker):
     
     model = mocker.MagicMock(name="model")
     model().loss.__truediv__().detach.return_value = torch.tensor(1)
-    batch = {"input": torch.zeros(1)}
+    mock_tensor = mocker.MagicMock(name="tensor")
+    batch = {"input": mock_tensor}
     train_dataloader = [batch, batch, batch, batch, batch]
     eval_dataloader = None
     tokenizer = mocker.MagicMock()
@@ -37,7 +42,13 @@ def test_gradient_accumulation(mocker):
     assert optimizer.zero_grad.call_count == 5
     optimizer.zero_grad.reset_mock()
     
+    assert nullcontext.call_count == 5
+    nullcontext.reset_mock()
+    
+    assert autocast.call_count == 0
+    
     gradient_accumulation_steps = 2
+    train_config.use_fp16 = True
     train(
         model,
         train_dataloader,
@@ -48,4 +59,6 @@ def test_gradient_accumulation(mocker):
         gradient_accumulation_steps,
         train_config,
     )
-    assert optimizer.zero_grad.call_count == 3
+    assert optimizer.zero_grad.call_count == 3
+    assert nullcontext.call_count == 0
+    assert autocast.call_count == 5