|
@@ -5,9 +5,25 @@
|
|
|
"id": "e8cba0b6",
|
|
|
"metadata": {},
|
|
|
"source": [
|
|
|
- "## This demo shows how to use LangChain's SQLDatabaseChain with Llama2 to query about structured data stored in a SQL DB. \n",
|
|
|
- "* As the 2023-24 NBA season is around the corner, we use the NBA roster info saved in a SQLite DB to show you how to ask Llama2 questions about your favorite teams or players. \n",
|
|
|
- "* Because the SQLDatabaseChain API implementation is still in the langchain_experimental package, you'll see more issues that come with using the cutting edge experimental features, and how we succeed resolving some of the issues but fail on some others."
|
|
|
+ "## This demo shows how to use LangChain's SQLDatabaseChain with Llama2 to query structured data stored in a SQL DB. \n",
|
|
|
+ "* We use the 2023-24 NBA roster info saved in a SQLite DB to show you how to ask Llama2 questions about your favorite teams or players. \n",
|
|
|
+ "* At the time of writing this, the SQLDatabaseChain API implementation is still in the langchain_experimental package. With this in mind you will see more issues that come with using the cutting edge experimental features, and how we succeed resolving some of the issues but fail on some others."
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "f839d07d",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "We start by installing the necessary packages:\n",
|
|
|
+ "- [Replicate](https://replicate.com/) to host the Llama 2 model\n",
|
|
|
+ "- [langchain](https://python.langchain.com/docs/get_started/introduction) provides necessary RAG tools for this demo\n",
|
|
|
+ "- langchain_experimental Langchain's experimental version to get us access to SQLDatabaseChain\n",
|
|
|
+ "\n",
|
|
|
+ "And setting up the Replicate token.\n",
|
|
|
+ "\n",
|
|
|
+ "**Note** To get a Replicate token, you will need to first sign in with Replicate with your github account, then create a free API token [here](https://replicate.com/account/api-tokens) that you can use for a while. \n",
|
|
|
+ "After the free trial ends, you will need to enter billing info to continue to use Llama2 hosted on Replicate."
|
|
|
]
|
|
|
},
|
|
|
{
|
|
@@ -40,7 +56,7 @@
|
|
|
"metadata": {},
|
|
|
"outputs": [
|
|
|
{
|
|
|
- "name": "stdin",
|
|
|
+ "name": "stdout",
|
|
|
"output_type": "stream",
|
|
|
"text": [
|
|
|
" ········\n"
|
|
@@ -55,6 +71,15 @@
|
|
|
"os.environ[\"REPLICATE_API_TOKEN\"] = REPLICATE_API_TOKEN"
|
|
|
]
|
|
|
},
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "1e586b75",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "Next we call the Llama 2 model from replicate. In this example we will use the llama 2 13b chat model. You can find more Llama 2 models by searching for them on the [Replicate model explore page](https://replicate.com/explore?query=llama).\n",
|
|
|
+ "You can add them here in the format: model_name/version"
|
|
|
+ ]
|
|
|
+ },
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
"execution_count": 3,
|
|
@@ -70,6 +95,18 @@
|
|
|
")"
|
|
|
]
|
|
|
},
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "6d421ae7",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "Next you will need create the `nba_roster.db` file. To do this run:\n",
|
|
|
+ "- `python txt2csv.py` to convert the `nba.txt` file to `nba_roster.csv`. The `nba.txt` file was created by scraping the NBA roster info from the web.\n",
|
|
|
+ "- Then run `python csv2db.py` to convert `nba_roster.csv` to `nba_roster.db`.\n",
|
|
|
+ "\n",
|
|
|
+ "Once you have your `nba_roster.db` ready, we set up the database to be queried by Llama 2 through Langchain's [SQL chains](https://python.langchain.com/docs/use_cases/qa_structured/sql)."
|
|
|
+ ]
|
|
|
+ },
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
"execution_count": 4,
|
|
@@ -77,9 +114,7 @@
|
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
|
"source": [
|
|
|
- "# The nba_roster.db was created by running the two scripts:\n",
|
|
|
- "# python txt2csv.py # convert the `nba.txt` file, created by scraping the NBA roster info from the web, to nba_roster.csv\n",
|
|
|
- "# python csv2db.py # convert nba_roster.csv to nba_roster.db\n",
|
|
|
+ "\n",
|
|
|
"db = SQLDatabase.from_uri(\"sqlite:///nba_roster.db\", sample_rows_in_table_info= 0)\n",
|
|
|
"\n",
|
|
|
"PROMPT_SUFFIX = \"\"\"\n",
|
|
@@ -93,6 +128,14 @@
|
|
|
" template=PROMPT_SUFFIX))"
|
|
|
]
|
|
|
},
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "afcf423a",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "We will go ahead and turn on LangChain debug to get an idea of how many calls are made to Llama 2 and what the inputs and outputs are."
|
|
|
+ ]
|
|
|
+ },
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
"execution_count": 5,
|
|
@@ -158,7 +201,7 @@
|
|
|
}
|
|
|
],
|
|
|
"source": [
|
|
|
- "# turn on the debug of LangChain so we can see how many calls to Llama are made and exactly what are inputs and outputs\n",
|
|
|
+ "\n",
|
|
|
"import langchain\n",
|
|
|
"langchain.debug = True\n",
|
|
|
"\n",
|
|
@@ -304,6 +347,16 @@
|
|
|
"db_chain.run(\"What's his salary?\")"
|
|
|
]
|
|
|
},
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "98b2c523",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "Since we did not pass any context along with the follow-up to the model it did not know who \"his\" is and just picked LeBron James.\n",
|
|
|
+ "Let's try to fix the issue that the context (the previous question and answer) was not sent to the model along with the new question.\n",
|
|
|
+ "`SQLDatabaseChain.from_llm` has a parameter \"memory\" which can be set to a `ConversationBufferMemory` instance, which looks promising."
|
|
|
+ ]
|
|
|
+ },
|
|
|
{
|
|
|
"cell_type": "code",
|
|
|
"execution_count": 8,
|
|
@@ -311,10 +364,7 @@
|
|
|
"metadata": {},
|
|
|
"outputs": [],
|
|
|
"source": [
|
|
|
- "# since we didn't pass the context along with the follow-up to llm so it didn't know what \"his\" is and just picked LeBron James\n",
|
|
|
"\n",
|
|
|
- "# let's try to fix the issue that the context (the previous question and answer) was not sent to LLM along with the new question\n",
|
|
|
- "# SQLDatabaseChain.from_llm has a parameter \"memory\" which can be set to a ConversationBufferMemory instance, which looks promising.\n",
|
|
|
"from langchain.memory import ConversationBufferMemory\n",
|
|
|
"\n",
|
|
|
"memory = ConversationBufferMemory()\n",
|