{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "e4532411", "metadata": {}, "outputs": [], "source": [ "# TODO REFACTOR: Integrate code from _legacy/inference.py into this notebook" ] }, { "cell_type": "markdown", "id": "47a9adb3", "metadata": {}, "source": [ "## This demo app shows how to query Llama 2 using the Gradio UI.\n", "\n", "Since we are using Replicate in this example, you will need to replace `` with your API token.\n", "\n", "To get the Replicate token: \n", "\n", "- You will need to first sign in with Replicate with your github account\n", "- Then create a free API token [here](https://replicate.com/account/api-tokens) that you can use for a while \n", "\n", "**Note** After the free trial ends, you will need to enter billing info to continue to use Llama2 hosted on Replicate.\n", "\n", "To run this example:\n", "- Set up your Replicate API token and enter it in place of ``\n", "- Run the notebook\n", "- Enter your question and click Submit\n", "\n", "In the notebook or a browser with URL http://127.0.0.1:7860 you should see a UI with your answer." ] }, { "cell_type": "code", "execution_count": 1, "id": "928041cc", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Init param `input` is deprecated, please use `model_kwargs` instead.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Running on local URL: http://127.0.0.1:7860\n", "\n", "To create a public link, set `share=True` in `launch()`.\n" ] }, { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [] }, "execution_count": 1, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from langchain.schema import AIMessage, HumanMessage\n", "import gradio as gr\n", "from langchain.llms import Replicate\n", "import os\n", "\n", "os.environ[\"REPLICATE_API_TOKEN\"] = \"\"\n", "\n", "llama2_13b_chat = \"meta/llama-2-13b-chat:f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d\"\n", "\n", "llm = Replicate(\n", " model=llama2_13b_chat,\n", " model_kwargs={\"temperature\": 0.01, \"top_p\": 1, \"max_new_tokens\":500}\n", ")\n", "\n", "\n", "def predict(message, history):\n", " history_langchain_format = []\n", " for human, ai in history:\n", " history_langchain_format.append(HumanMessage(content=human))\n", " history_langchain_format.append(AIMessage(content=ai))\n", " history_langchain_format.append(HumanMessage(content=message))\n", " gpt_response = llm(message) #history_langchain_format)\n", " return gpt_response#.content\n", "\n", "gr.ChatInterface(predict).launch()" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.18" } }, "nbformat": 4, "nbformat_minor": 5 }