{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Building a Llama 2 chatbot with Retrieval Augmented Generation (RAG)\n", "\n", "This notebook shows a complete example of how to build a Llama 2 chatbot hosted on your browser that can answer questions based on your own data. We'll cover:\n", "* The deployment process of Llama 2 7B with the [Text-generation-inference](https://github.com/huggingface/text-generation-inference) framework as an API server\n", "* A chatbot example built with [Gradio](https://github.com/gradio-app/gradio) and wired to the server\n", "* Adding RAG capability with Llama 2 specific knowledge based on our Getting Started [guide](https://ai.meta.com/llama/get-started/)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## RAG Architecture\n", "\n", "LLMs have unprecedented capabilities in NLU (Natural Language Understanding) & NLG (Natural Language Generation), but they have a knowledge cutoff date, and are only trained on publicly available data before that date.\n", "\n", "RAG, invented by [Meta](https://ai.meta.com/blog/retrieval-augmented-generation-streamlining-the-creation-of-intelligent-natural-language-processing-models/) in 2020, is one of the most popular methods to augment LLMs. RAG allows enterprises to keep sensitive data on-prem and get more relevant answers from generic models without fine-tuning models for specific roles.\n", "\n", "RAG is a method that:\n", "* Retrieves data from outside a foundation model\n", "* Augments your questions or prompts to LLMs by adding the retrieved relevant data as context\n", "* Allows LLMs to answer questions about your own data, or data not publicly available when LLMs were trained\n", "* Greatly reduces the hallucination in model's response generation\n", "\n", "The following diagram shows the general RAG components and process:" ] }, { "attachments": { "image.png": { "image/png": "" } }, "cell_type": "markdown", "metadata": {}, "source": [ "![image.png](attachment:image.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## How to Develop a RAG Powered Llama 2 Chatbot\n", "\n", "The easiest way to develop RAG-powered Llama 2 chatbots is to use frameworks such as [**LangChain**](https://www.langchain.com/) and [**LlamaIndex**](https://www.llamaindex.ai/), two leading open-source frameworks for building LLM apps. Both offer convenient APIs for implementing RAG with Llama 2 including:\n", "\n", "* Load and split documents\n", "* Embed and store document splits\n", "* Retrieve the relevant context based on the user query\n", "* Call Llama 2 with query and context to generate the answer\n", "\n", "LangChain is a more general purpose and flexible framework for developing LLM apps with RAG capabilities, while LlamaIndex as a data framework focuses on connecting custom data sources to LLMs. The integration of the two may provide the best performant and effective solution to building real world RAG apps. \n", "In our example, for simplicifty, we will use LangChain alone with locally stored PDF data." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Install Dependencies\n", "\n", "For this demo, we will be using the Gradio for chatbot UI, Text-generation-inference framework for model serving. \n", "For vector storage and similarity search, we will be using [FAISS](https://github.com/facebookresearch/faiss). \n", "In this example, we will be running everything in a AWS EC2 instance (i.e. [g5.2xlarge]( https://aws.amazon.com/ec2/instance-types/g5/)). g5.2xlarge features one A10G GPU. We recommend running this notebook with at least one GPU equivalent to A10G with at least 16GB video memory. \n", "There are certain techniques to downsize the Llama 2 7B model, so it can fit into smaller GPUs. But it is out of scope here.\n", "\n", "First, let's install all dependencies with PIP. We also recommend you start a dedicated Conda environment for better package management" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!pip install -r requirements.txt" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Data Processing\n", "\n", "First run all the imports and define the path of the data and vector storage after processing. \n", "For the data, we will be using a raw pdf crawled from Llama 2 Getting Started guide on [Meta AI website](https://ai.meta.com/llama/)." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "from langchain.embeddings import HuggingFaceEmbeddings\n", "from langchain.vectorstores import FAISS\n", "from langchain.document_loaders import PyPDFDirectoryLoader\n", "from langchain.text_splitter import RecursiveCharacterTextSplitter \n", "\n", "DATA_PATH = 'data' #Your root data folder path\n", "DB_FAISS_PATH = 'vectorstore/db_faiss'" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then we use the `PyPDFDirectoryLoader` to load the entire directory. You can also use `PyPDFLoader` for loading one single file." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "loader = PyPDFDirectoryLoader(DATA_PATH)\n", "documents = loader.load()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Check the length and content of the doc to ensure we have loaded the right document with number of pages as 37." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "print(len(documents), documents[0].page_content[0:100])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Split the loaded documents into smaller chunks. \n", "[`RecursiveCharacterTextSplitter`](https://api.python.langchain.com/en/latest/text_splitter/langchain.text_splitter.RecursiveCharacterTextSplitter.html) is one common splitter that splits long pieces of text into smaller, semantically meaningful chunks. \n", "Other splitters include:\n", "* SpacyTextSplitter\n", "* NLTKTextSplitter\n", "* SentenceTransformersTokenTextSplitter\n", "* CharacterTextSplitter\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=10)\n", "splits = text_splitter.split_documents(documents)\n", "print(len(splits), splits[0])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Note that we have set `chunk_size` to 500 and `chunk_overlap` to 10. In the spliting, these two parameters can directly affects the quality of the LLM's answers. \n", "Here is a good [guide](https://dev.to/peterabel/what-chunk-size-and-chunk-overlap-should-you-use-4338) on how you should carefully set these two parameters." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next we will need to choose an embedding model for our splited documents. \n", "**Embeddings are numerial representations of text**. The default embedding model in HuggingFace Embeddings is `sentence-transformers/all-mpnet-base-v2` with 768 dimension. Below we use a smaller model `all-MiniLM-L6-v2` with dimension 384 so indexing runs faster." ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [], "source": [ "embeddings = HuggingFaceEmbeddings(model_name='sentence-transformers/all-MiniLM-L6-v2',\n", " model_kwargs={'device': 'cpu'})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Lastly, with splits and choice of the embedding model ready, we want to index them and store all the split chunks as embeddings into the vector storage. \n", "\n", "Vector stores are databases storing embeddings. There're at least 60 [vector stores](https://python.langchain.com/docs/integrations/vectorstores) supported by LangChain, and two of the most popular open source ones are:\n", "* [Chroma](https://www.trychroma.com/): a light-weight and in memory so it's easy to get started with and use for **local development**.\n", "* [FAISS](https://python.langchain.com/docs/integrations/vectorstores/faiss) (Facebook AI Similarity Search): a vector store that supports search in vectors that may not fit in RAM and is appropriate for **production use**. \n", "\n", "Since we are running on a EC2 instance with abundant CPU resources and RAM, we will use FAISS in this example. Note that FAISS can also run on GPUs, where some of the most useful algorithms are implemented there. In that case, install `faiss-gpu` package with PIP instead." ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [], "source": [ "db = FAISS.from_documents(splits, embeddings)\n", "db.save_local(DB_FAISS_PATH)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Once you saved database into local path. You can find them as `index.faiss` and `index.pkl`. In the chatbot example, you can then load this database from local and plug it into our retrival process." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Model Serving\n", "\n", "In this example, we will be deploying a Llama 2 7B chat HuggingFace model with the Text-generation-inference framework on-permises. \n", "This would allow us to directly wire the API server with our chatbot. \n", "There are alternative solutions to deploy Llama 2 models on-permises as your local API server. \n", "You can find our complete guide [here](https://github.com/meta-llama/llama-recipes/blob/main/recipes/inference/model_servers/llama-on-prem.md)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "In a **separate terminal**, run commands below to launch an API server with TGI. This will download model artifacts and store them locally, while launching at the desire port on your localhost. In our case, this is port 8080" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "vscode": { "languageId": "plaintext" } }, "outputs": [], "source": [ "model = meta-llama/Llama-2-7b-chat-hf\n", "volume = $PWD/data\n", "token = #Your own HF tokens\n", "docker run --gpus all --shm-size 1g -e HUGGING_FACE_HUB_TOKEN=$token -p 8080:80 -v $volume:/data ghcr.io/huggingface/text-generation-inference:1.1.0 --model-id $model" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Once we have the API server up and running, we can run a simple `curl` command to validate our model is working as expected." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "!curl localhost:8080/generate -X POST -H 'Content-Type: application/json' -d '{\"inputs\": \"What is good about Beijing?\", \"parameters\": { \"max_new_tokens\":64}}' #Replace the locahost with the IP visible to the machine running the notebook " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Building the Chatbot UI\n", "\n", "Now we are ready to build the chatbot UI to wire up RAG data and API server. In our example we will be using Gradio to build the Chatbot UI. \n", "Gradio is an open-source Python library that is used to build machine learning and data science demos and web applications. It had been widely used by the community and HuggingFace also used Gradio to build their Chatbots. Other alternatives are: \n", "* [Streamlit](https://streamlit.io/)\n", "* [Dash](https://plotly.com/dash/)\n", "* [Flask](https://flask.palletsprojects.com/en/3.0.x/)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Again, we start by adding all the imports, paths, constants and set LangChain in debug mode, so it shows clear actions within the chain process." ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "import langchain\n", "from queue import Queue\n", "from typing import Any\n", "from langchain.llms.huggingface_text_gen_inference import HuggingFaceTextGenInference\n", "from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n", "from langchain.schema import LLMResult\n", "from langchain.embeddings import HuggingFaceEmbeddings\n", "from langchain.vectorstores import FAISS\n", "from langchain.chains import RetrievalQA\n", "from langchain.prompts.prompt import PromptTemplate\n", "from anyio.from_thread import start_blocking_portal #For model callback streaming\n", "\n", "langchain.debug=True \n", "\n", "#vector db path\n", "DB_FAISS_PATH = 'vectorstore/db_faiss'\n", "\n", "#Llama2 TGI models host port\n", "LLAMA2_7B_HOSTPORT = \"http://localhost:8080/\" #Replace the locahost with the IP visible to the machine running the notebook\n", "LLAMA2_13B_HOSTPORT = \"http://localhost:8080/\" #Add your own host ports for model switching. You can host another TGI model on same instance on a different port.\n", "\n", "\n", "model_dict = {\n", " \"7b-chat\" : LLAMA2_7B_HOSTPORT,\n", " \"13b-chat\" : LLAMA2_13B_HOSTPORT,\n", "}\n", "\n", "system_message = {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then we load the FAISS vector store" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "embeddings = HuggingFaceEmbeddings(model_name=\"sentence-transformers/all-MiniLM-L6-v2\",\n", " model_kwargs={'device': 'cpu'})\n", "db = FAISS.load_local(DB_FAISS_PATH, embeddings)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we create a TGI llm instance and wire to the API serving port on localhost" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "llm = HuggingFaceTextGenInference(\n", " inference_server_url=LLAMA2_7B_HOSTPORT,\n", " max_new_tokens=512,\n", " top_k=10,\n", " top_p=0.9,\n", " typical_p=0.95,\n", " temperature=0.6,\n", " repetition_penalty=1,\n", " do_sample=True,\n", " streaming=True\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, we define the retriever and template for our RetrivalQA chain. For each call of the RetrievalQA, LangChain performs a semantic similarity search of the query in the vector database, then passes the search results as the context to Llama to answer the query about the data stored in the verctor database. \n", "Whereas for the template, this defines the format of the question along with context that we will be sent into Llama for generation. In general, Llama 2 has special prompt format to handle special tokens. In some cases, the serving framework might already have taken care of it. Otherwise, you will need to write customized template to properly handle that.\n" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "template = \"\"\"\n", "[INST]Use the following pieces of context to answer the question. If no context provided, answer like a AI assistant.\n", "{context}\n", "Question: {question} [/INST]\n", "\"\"\"\n", "\n", "retriever = db.as_retriever(\n", " search_kwargs={\"k\": 6}\n", " )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Lastly, we can define the retrieval chain for QA" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "qa_chain = RetrievalQA.from_chain_type(\n", " llm=llm, \n", " retriever=retriever, \n", " chain_type_kwargs={\n", " \"prompt\": PromptTemplate(\n", " template=template,\n", " input_variables=[\"context\", \"question\"],\n", " ),\n", " }\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we should have a working chain for QA. Let's test it out before wire it up with UI blocks." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "result = qa_chain({\"query\": \"Why choose Llama?\"})\n", "print(result)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "After confirming the validity, we can start building the UI. Before we define the gradio [blocks](https://www.gradio.app/docs/blocks), let's first define the callback streams that we will use later for the streaming feature. \n", "This callback handler will put streaming LLM responses to a queue for gradio UI to render on the fly. " ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "job_done = object()\n", "\n", "class MyStream(StreamingStdOutCallbackHandler):\n", " def __init__(self, q) -> None:\n", " self.q = q\n", "\n", " def on_llm_new_token(self, token: str, **kwargs: Any) -> None:\n", " self.q.put(token)\n", "\n", " def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:\n", " self.q.put(job_done)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we can define the gradio UI blocks. \n", "Since we will need to define the UI and handlers in the same place, this will be a large chunk of code. We will add comments in the code for explanation." ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "import gradio as gr\n", "\n", "with gr.Blocks() as demo:\n", " #Configure UI layout\n", " chatbot = gr.Chatbot(height = 600)\n", " with gr.Row():\n", " with gr.Column(scale=1):\n", " with gr.Row():\n", " #model selection\n", " model_selector = gr.Dropdown(\n", " list(model_dict.keys()), \n", " value=\"7b-chat\", \n", " label=\"Model\", \n", " info=\"Select the model\", \n", " interactive = True, \n", " scale=1\n", " )\n", " max_new_tokens_selector = gr.Number(\n", " value=512, \n", " precision=0, \n", " label=\"Max new tokens\", \n", " info=\"Adjust max_new_tokens\",\n", " interactive = True, \n", " minimum=1, \n", " maximum=1024, \n", " scale=1\n", " )\n", " with gr.Row():\n", " #hyperparameter selection\n", " temperature_selector = gr.Slider(\n", " value=0.6, \n", " label=\"Temperature\", \n", " info=\"Range 0-2. Controls the creativity of the generated text.\",\n", " interactive = True, \n", " minimum=0.01, \n", " maximum=2, \n", " step=0.01, \n", " scale=1\n", " )\n", " top_p_selector = gr.Slider(\n", " value=0.9, \n", " label=\"Top_p\", \n", " info=\"Range 0-1. Nucleus sampling.\",\n", " interactive = True, \n", " minimum=0.01, \n", " maximum=0.99, \n", " step=0.01, \n", " scale=1\n", " )\n", " with gr.Column(scale=2):\n", " #user input prompt text field\n", " user_prompt_message = gr.Textbox(placeholder=\"Please add user prompt here\", label=\"User prompt\")\n", " with gr.Row():\n", " clear = gr.Button(\"Clear Conversation\", scale=2)\n", " submitBtn = gr.Button(\"Submit\", scale=8)\n", "\n", "\n", " state = gr.State([])\n", "\n", " #handle user message\n", " def user(user_prompt_message, history):\n", " if user_prompt_message != \"\":\n", " return history + [[user_prompt_message, None]]\n", " else:\n", " return history + [[\"Invalid prompts - user prompt cannot be empty\", None]]\n", "\n", " #chatbot logic for configuration, sending the prompts, rendering the streamed back genereations etc\n", " def bot(model_selector, temperature_selector, top_p_selector, max_new_tokens_selector, user_prompt_message, history, messages_history):\n", " dialog = []\n", " bot_message = \"\"\n", " history[-1][1] = \"\"\n", " \n", " dialog = [\n", " {\"role\": \"user\", \"content\": user_prompt_message},\n", " ]\n", " messages_history += dialog\n", " \n", " #Queue for streamed character rendering\n", " q = Queue()\n", "\n", " #Update new llama hyperparameters\n", " llm.inference_server_url = model_selector\n", " llm.temperature = temperature_selector\n", " llm.top_p = top_p_selector\n", " llm.max_new_tokens = max_new_tokens_selector\n", "\n", " #Async task for streamed chain results wired to callbacks we previously defined, so we don't block the UI\n", " async def task(prompt):\n", " ret = await qa_chain.run(prompt, callbacks=[MyStream(q)])\n", " return ret\n", "\n", " with start_blocking_portal() as portal:\n", " portal.start_task_soon(task, user_prompt_message)\n", " while True:\n", " next_token = q.get(True)\n", " if next_token is job_done:\n", " messages_history += [{\"role\": \"assistant\", \"content\": bot_message}]\n", " return history, messages_history\n", " bot_message += next_token\n", " history[-1][1] += next_token\n", " yield history, messages_history\n", "\n", " #init the chat history with default system message \n", " def init_history(messages_history):\n", " messages_history = []\n", " messages_history += [system_message]\n", " return messages_history\n", "\n", " #clean up the user input text field\n", " def input_cleanup():\n", " return \"\"\n", "\n", " #when the user clicks Enter and the user message is submitted\n", " user_prompt_message.submit(\n", " user, \n", " [user_prompt_message, chatbot], \n", " [chatbot], \n", " queue=False\n", " ).then(\n", " bot, \n", " [model_selector, temperature_selector, top_p_selector, max_new_tokens_selector, user_prompt_message, chatbot, state], \n", " [chatbot, state]\n", " ).then(input_cleanup, \n", " [], \n", " [user_prompt_message], \n", " queue=False\n", " )\n", "\n", " #when the user clicks the submit button\n", " submitBtn.click(\n", " user, \n", " [user_prompt_message, chatbot], \n", " [chatbot], \n", " queue=False\n", " ).then(\n", " bot, \n", " [model_selector, temperature_selector, top_p_selector, max_new_tokens_selector, user_prompt_message, chatbot, state], \n", " [chatbot, state]\n", " ).then(\n", " input_cleanup, \n", " [], \n", " [user_prompt_message], \n", " queue=False\n", " )\n", " \n", " #when the user clicks the clear button\n", " clear.click(lambda: None, None, chatbot, queue=False).success(init_history, [state], [state])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Lastly, we can launch this demo on our localhost with the command below. " ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://0.0.0.0:7860\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "