|
@@ -0,0 +1,844 @@
|
|
|
+{
|
|
|
+ "cells": [
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "e8cba0b6",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "## This demo shows how to use LangChain's SQLDatabaseChain with Llama2 to chat about structured data stored in a SQL DB. \n",
|
|
|
+ "* As the 2023-24 NBA season is around the corner, we use the NBA roster info saved in a SQLite DB to show you how to ask Llama2 questions about your favorite teams or players. \n",
|
|
|
+ "* Because the SQLDatabaseChain API implementation is still in the langchain_experimental package, you'll see more issues that come with using the cutting edge experimental features, and how we succeed resolving some of the issues but fail on some others."
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": null,
|
|
|
+ "id": "33fb3190-59fb-4edd-82dd-f20f6eab3e47",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "!pip install langchain replicate langchain_experimental"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 1,
|
|
|
+ "id": "70d6c906-4fc0-4da5-acd9-8743e90f89e8",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "from langchain.llms import Replicate\n",
|
|
|
+ "from langchain.prompts import PromptTemplate\n",
|
|
|
+ "from langchain.utilities import SQLDatabase\n",
|
|
|
+ "from langchain_experimental.sql import SQLDatabaseChain"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 2,
|
|
|
+ "id": "fa4562d3",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdin",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ " ········\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "from getpass import getpass\n",
|
|
|
+ "import os\n",
|
|
|
+ "\n",
|
|
|
+ "REPLICATE_API_TOKEN = getpass()\n",
|
|
|
+ "os.environ[\"REPLICATE_API_TOKEN\"] = REPLICATE_API_TOKEN"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 3,
|
|
|
+ "id": "9dcd744c",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stderr",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "Init param `input` is deprecated, please use `model_kwargs` instead.\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "llama2_13b_chat = \"meta/llama-2-13b-chat:f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d\"\n",
|
|
|
+ "\n",
|
|
|
+ "# the reason we set system_prompt below is to ask Llama to generate only the SQL statement, instead of being wordy and adding something like\n",
|
|
|
+ "# \"Sure! Here's the SQL query for the given input question: \" before the SQL query; otherwise custom parsing will be needed.\n",
|
|
|
+ "llm = Replicate(\n",
|
|
|
+ " model=llama2_13b_chat,\n",
|
|
|
+ " input={\"temperature\": 0.01, \"max_length\": 500, \"top_p\": 1, \"system_prompt\": \"Given an input question, convert it to a SQL query. No pre-amble.\"},\n",
|
|
|
+ ")"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 4,
|
|
|
+ "id": "70776558",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "db = SQLDatabase.from_uri(\"sqlite:///nba_roster.db\", sample_rows_in_table_info= 0)\n",
|
|
|
+ "\n",
|
|
|
+ "# use the default sqlite prompt defined in \n",
|
|
|
+ "# https://github.com/langchain-ai/langchain/blob/33eb5f8300cd21c91a2f8d10c62197637931fa0a/libs/langchain/langchain/chains/sql_database/prompt.py#L211\n",
|
|
|
+ "# db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)\n",
|
|
|
+ "\n",
|
|
|
+ "# customize the default sqlite prompt defined in the link above\n",
|
|
|
+ "PROMPT_SUFFIX = \"\"\"\n",
|
|
|
+ "Only use the following tables:\n",
|
|
|
+ "{table_info}\n",
|
|
|
+ "\n",
|
|
|
+ "Question: {input}\"\"\"\n",
|
|
|
+ "\n",
|
|
|
+ "# if return_sql is set to False, then calling db_chain.run will make two calls to llm: first get the SQL query statement, \n",
|
|
|
+ "# then pass the result of running the SQL query along with the original question and the query to llm for NL answer\n",
|
|
|
+ "db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True, return_sql=False, \n",
|
|
|
+ " prompt=PromptTemplate(input_variables=[\"input\", \"table_info\"], \n",
|
|
|
+ " template=PROMPT_SUFFIX))"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 5,
|
|
|
+ "id": "239fb5a6",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "\u001b[32;1m\u001b[1;3m[chain/start]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain] Entering Chain run with input:\n",
|
|
|
+ "\u001b[0m{\n",
|
|
|
+ " \"query\": \"How many unique teams are there?\"\n",
|
|
|
+ "}\n",
|
|
|
+ "\u001b[32;1m\u001b[1;3m[chain/start]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain] Entering Chain run with input:\n",
|
|
|
+ "\u001b[0m{\n",
|
|
|
+ " \"input\": \"How many unique teams are there?\\nSQLQuery:\",\n",
|
|
|
+ " \"top_k\": \"5\",\n",
|
|
|
+ " \"dialect\": \"sqlite\",\n",
|
|
|
+ " \"table_info\": \"\\nCREATE TABLE nba_roster (\\n\\t\\\"Team\\\" TEXT, \\n\\t\\\"NAME\\\" TEXT, \\n\\t\\\"Jersey\\\" TEXT, \\n\\t\\\"POS\\\" TEXT, \\n\\t\\\"AGE\\\" INTEGER, \\n\\t\\\"HT\\\" TEXT, \\n\\t\\\"WT\\\" TEXT, \\n\\t\\\"COLLEGE\\\" TEXT, \\n\\t\\\"SALARY\\\" TEXT\\n)\",\n",
|
|
|
+ " \"stop\": [\n",
|
|
|
+ " \"\\nSQLResult:\"\n",
|
|
|
+ " ]\n",
|
|
|
+ "}\n",
|
|
|
+ "\u001b[32;1m\u001b[1;3m[llm/start]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain > 3:llm:Replicate] Entering LLM run with input:\n",
|
|
|
+ "\u001b[0m{\n",
|
|
|
+ " \"prompts\": [\n",
|
|
|
+ " \"Only use the following tables:\\n\\nCREATE TABLE nba_roster (\\n\\t\\\"Team\\\" TEXT, \\n\\t\\\"NAME\\\" TEXT, \\n\\t\\\"Jersey\\\" TEXT, \\n\\t\\\"POS\\\" TEXT, \\n\\t\\\"AGE\\\" INTEGER, \\n\\t\\\"HT\\\" TEXT, \\n\\t\\\"WT\\\" TEXT, \\n\\t\\\"COLLEGE\\\" TEXT, \\n\\t\\\"SALARY\\\" TEXT\\n)\\n\\nQuestion: How many unique teams are there?\\nSQLQuery:\"\n",
|
|
|
+ " ]\n",
|
|
|
+ "}\n",
|
|
|
+ "\u001b[36;1m\u001b[1;3m[llm/end]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain > 3:llm:Replicate] [2.63s] Exiting LLM run with output:\n",
|
|
|
+ "\u001b[0m{\n",
|
|
|
+ " \"generations\": [\n",
|
|
|
+ " [\n",
|
|
|
+ " {\n",
|
|
|
+ " \"text\": \" SELECT COUNT(DISTINCT Team) AS num_teams FROM nba_roster;\",\n",
|
|
|
+ " \"generation_info\": null\n",
|
|
|
+ " }\n",
|
|
|
+ " ]\n",
|
|
|
+ " ],\n",
|
|
|
+ " \"llm_output\": null,\n",
|
|
|
+ " \"run\": null\n",
|
|
|
+ "}\n",
|
|
|
+ "\u001b[36;1m\u001b[1;3m[chain/end]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain] [2.63s] Exiting Chain run with output:\n",
|
|
|
+ "\u001b[0m{\n",
|
|
|
+ " \"text\": \" SELECT COUNT(DISTINCT Team) AS num_teams FROM nba_roster;\"\n",
|
|
|
+ "}\n",
|
|
|
+ "\u001b[32;1m\u001b[1;3m[chain/start]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 4:chain:LLMChain] Entering Chain run with input:\n",
|
|
|
+ "\u001b[0m{\n",
|
|
|
+ " \"input\": \"How many unique teams are there?\\nSQLQuery:SELECT COUNT(DISTINCT Team) AS num_teams FROM nba_roster;\\nSQLResult: [(30,)]\\nAnswer:\",\n",
|
|
|
+ " \"top_k\": \"5\",\n",
|
|
|
+ " \"dialect\": \"sqlite\",\n",
|
|
|
+ " \"table_info\": \"\\nCREATE TABLE nba_roster (\\n\\t\\\"Team\\\" TEXT, \\n\\t\\\"NAME\\\" TEXT, \\n\\t\\\"Jersey\\\" TEXT, \\n\\t\\\"POS\\\" TEXT, \\n\\t\\\"AGE\\\" INTEGER, \\n\\t\\\"HT\\\" TEXT, \\n\\t\\\"WT\\\" TEXT, \\n\\t\\\"COLLEGE\\\" TEXT, \\n\\t\\\"SALARY\\\" TEXT\\n)\",\n",
|
|
|
+ " \"stop\": [\n",
|
|
|
+ " \"\\nSQLResult:\"\n",
|
|
|
+ " ]\n",
|
|
|
+ "}\n",
|
|
|
+ "\u001b[32;1m\u001b[1;3m[llm/start]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 4:chain:LLMChain > 5:llm:Replicate] Entering LLM run with input:\n",
|
|
|
+ "\u001b[0m{\n",
|
|
|
+ " \"prompts\": [\n",
|
|
|
+ " \"Only use the following tables:\\n\\nCREATE TABLE nba_roster (\\n\\t\\\"Team\\\" TEXT, \\n\\t\\\"NAME\\\" TEXT, \\n\\t\\\"Jersey\\\" TEXT, \\n\\t\\\"POS\\\" TEXT, \\n\\t\\\"AGE\\\" INTEGER, \\n\\t\\\"HT\\\" TEXT, \\n\\t\\\"WT\\\" TEXT, \\n\\t\\\"COLLEGE\\\" TEXT, \\n\\t\\\"SALARY\\\" TEXT\\n)\\n\\nQuestion: How many unique teams are there?\\nSQLQuery:SELECT COUNT(DISTINCT Team) AS num_teams FROM nba_roster;\\nSQLResult: [(30,)]\\nAnswer:\"\n",
|
|
|
+ " ]\n",
|
|
|
+ "}\n",
|
|
|
+ "\u001b[36;1m\u001b[1;3m[llm/end]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 4:chain:LLMChain > 5:llm:Replicate] [2.06s] Exiting LLM run with output:\n",
|
|
|
+ "\u001b[0m{\n",
|
|
|
+ " \"generations\": [\n",
|
|
|
+ " [\n",
|
|
|
+ " {\n",
|
|
|
+ " \"text\": \" Sure! Here's the answer to your question:\\n\\nThere are 30 unique teams in the NBA.\",\n",
|
|
|
+ " \"generation_info\": null\n",
|
|
|
+ " }\n",
|
|
|
+ " ]\n",
|
|
|
+ " ],\n",
|
|
|
+ " \"llm_output\": null,\n",
|
|
|
+ " \"run\": null\n",
|
|
|
+ "}\n",
|
|
|
+ "\u001b[36;1m\u001b[1;3m[chain/end]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 4:chain:LLMChain] [2.06s] Exiting Chain run with output:\n",
|
|
|
+ "\u001b[0m{\n",
|
|
|
+ " \"text\": \" Sure! Here's the answer to your question:\\n\\nThere are 30 unique teams in the NBA.\"\n",
|
|
|
+ "}\n",
|
|
|
+ "\u001b[36;1m\u001b[1;3m[chain/end]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain] [4.70s] Exiting Chain run with output:\n",
|
|
|
+ "\u001b[0m{\n",
|
|
|
+ " \"result\": \"Sure! Here's the answer to your question:\\n\\nThere are 30 unique teams in the NBA.\"\n",
|
|
|
+ "}\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "# turn on the LangChain debug to see what exactly happens behind the scenes in terms of the inputs and outputs of calling Llama\n",
|
|
|
+ "import langchain\n",
|
|
|
+ "langchain.debug = True\n",
|
|
|
+ "answer = db_chain.run(\"How many unique teams are there?\")\n",
|
|
|
+ "\n",
|
|
|
+ "# Two Llama calls happen below \n",
|
|
|
+ "# Note the difference of the first and second inputs for the Llama calls - before \"llm:Replicate] Entering LLM run with input\"\n",
|
|
|
+ "# About the output, for the first llm call, both \"llm:Replicate] [2.10s] Exiting LLM run with output\" and \n",
|
|
|
+ "# \"Exiting Chain run with output\" return correctly the text of SQL statement \"SELECT COUNT(DISTINCT Team) AS num_teams FROM nba_roster;\"\n",
|
|
|
+ "# In the second llm call, \"Exiting LLM run with output\" also shows the correct result:\n",
|
|
|
+ "# \"Sure! Here's the answer to your question:\\n\\nThere are 30 unique teams in the NBA.\""
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 6,
|
|
|
+ "id": "eef91aea",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "Sure! Here's the answer to your question:\n",
|
|
|
+ "\n",
|
|
|
+ "There are 30 unique teams in the NBA.\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "# just show the answer without all the debug info \n",
|
|
|
+ "print(answer)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 7,
|
|
|
+ "id": "e3b444aa",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "\u001b[32;1m\u001b[1;3m[chain/start]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain] Entering Chain run with input:\n",
|
|
|
+ "\u001b[0m{\n",
|
|
|
+ " \"query\": \"Which team is Klay Thompson in?\"\n",
|
|
|
+ "}\n",
|
|
|
+ "\u001b[32;1m\u001b[1;3m[chain/start]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain] Entering Chain run with input:\n",
|
|
|
+ "\u001b[0m{\n",
|
|
|
+ " \"input\": \"Which team is Klay Thompson in?\\nSQLQuery:\",\n",
|
|
|
+ " \"top_k\": \"5\",\n",
|
|
|
+ " \"dialect\": \"sqlite\",\n",
|
|
|
+ " \"table_info\": \"\\nCREATE TABLE nba_roster (\\n\\t\\\"Team\\\" TEXT, \\n\\t\\\"NAME\\\" TEXT, \\n\\t\\\"Jersey\\\" TEXT, \\n\\t\\\"POS\\\" TEXT, \\n\\t\\\"AGE\\\" INTEGER, \\n\\t\\\"HT\\\" TEXT, \\n\\t\\\"WT\\\" TEXT, \\n\\t\\\"COLLEGE\\\" TEXT, \\n\\t\\\"SALARY\\\" TEXT\\n)\",\n",
|
|
|
+ " \"stop\": [\n",
|
|
|
+ " \"\\nSQLResult:\"\n",
|
|
|
+ " ]\n",
|
|
|
+ "}\n",
|
|
|
+ "\u001b[32;1m\u001b[1;3m[llm/start]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain > 3:llm:Replicate] Entering LLM run with input:\n",
|
|
|
+ "\u001b[0m{\n",
|
|
|
+ " \"prompts\": [\n",
|
|
|
+ " \"Only use the following tables:\\n\\nCREATE TABLE nba_roster (\\n\\t\\\"Team\\\" TEXT, \\n\\t\\\"NAME\\\" TEXT, \\n\\t\\\"Jersey\\\" TEXT, \\n\\t\\\"POS\\\" TEXT, \\n\\t\\\"AGE\\\" INTEGER, \\n\\t\\\"HT\\\" TEXT, \\n\\t\\\"WT\\\" TEXT, \\n\\t\\\"COLLEGE\\\" TEXT, \\n\\t\\\"SALARY\\\" TEXT\\n)\\n\\nQuestion: Which team is Klay Thompson in?\\nSQLQuery:\"\n",
|
|
|
+ " ]\n",
|
|
|
+ "}\n",
|
|
|
+ "\u001b[36;1m\u001b[1;3m[llm/end]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain > 3:llm:Replicate] [2.13s] Exiting LLM run with output:\n",
|
|
|
+ "\u001b[0m{\n",
|
|
|
+ " \"generations\": [\n",
|
|
|
+ " [\n",
|
|
|
+ " {\n",
|
|
|
+ " \"text\": \" SELECT Team FROM nba_roster WHERE NAME = 'Klay Thompson';\",\n",
|
|
|
+ " \"generation_info\": null\n",
|
|
|
+ " }\n",
|
|
|
+ " ]\n",
|
|
|
+ " ],\n",
|
|
|
+ " \"llm_output\": null,\n",
|
|
|
+ " \"run\": null\n",
|
|
|
+ "}\n",
|
|
|
+ "\u001b[36;1m\u001b[1;3m[chain/end]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain] [2.14s] Exiting Chain run with output:\n",
|
|
|
+ "\u001b[0m{\n",
|
|
|
+ " \"text\": \" SELECT Team FROM nba_roster WHERE NAME = 'Klay Thompson';\"\n",
|
|
|
+ "}\n",
|
|
|
+ "\u001b[32;1m\u001b[1;3m[chain/start]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 4:chain:LLMChain] Entering Chain run with input:\n",
|
|
|
+ "\u001b[0m{\n",
|
|
|
+ " \"input\": \"Which team is Klay Thompson in?\\nSQLQuery:SELECT Team FROM nba_roster WHERE NAME = 'Klay Thompson';\\nSQLResult: [('Golden State Warriors',)]\\nAnswer:\",\n",
|
|
|
+ " \"top_k\": \"5\",\n",
|
|
|
+ " \"dialect\": \"sqlite\",\n",
|
|
|
+ " \"table_info\": \"\\nCREATE TABLE nba_roster (\\n\\t\\\"Team\\\" TEXT, \\n\\t\\\"NAME\\\" TEXT, \\n\\t\\\"Jersey\\\" TEXT, \\n\\t\\\"POS\\\" TEXT, \\n\\t\\\"AGE\\\" INTEGER, \\n\\t\\\"HT\\\" TEXT, \\n\\t\\\"WT\\\" TEXT, \\n\\t\\\"COLLEGE\\\" TEXT, \\n\\t\\\"SALARY\\\" TEXT\\n)\",\n",
|
|
|
+ " \"stop\": [\n",
|
|
|
+ " \"\\nSQLResult:\"\n",
|
|
|
+ " ]\n",
|
|
|
+ "}\n",
|
|
|
+ "\u001b[32;1m\u001b[1;3m[llm/start]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 4:chain:LLMChain > 5:llm:Replicate] Entering LLM run with input:\n",
|
|
|
+ "\u001b[0m{\n",
|
|
|
+ " \"prompts\": [\n",
|
|
|
+ " \"Only use the following tables:\\n\\nCREATE TABLE nba_roster (\\n\\t\\\"Team\\\" TEXT, \\n\\t\\\"NAME\\\" TEXT, \\n\\t\\\"Jersey\\\" TEXT, \\n\\t\\\"POS\\\" TEXT, \\n\\t\\\"AGE\\\" INTEGER, \\n\\t\\\"HT\\\" TEXT, \\n\\t\\\"WT\\\" TEXT, \\n\\t\\\"COLLEGE\\\" TEXT, \\n\\t\\\"SALARY\\\" TEXT\\n)\\n\\nQuestion: Which team is Klay Thompson in?\\nSQLQuery:SELECT Team FROM nba_roster WHERE NAME = 'Klay Thompson';\\nSQLResult: [('Golden State Warriors',)]\\nAnswer:\"\n",
|
|
|
+ " ]\n",
|
|
|
+ "}\n",
|
|
|
+ "\u001b[36;1m\u001b[1;3m[llm/end]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 4:chain:LLMChain > 5:llm:Replicate] [2.68s] Exiting LLM run with output:\n",
|
|
|
+ "\u001b[0m{\n",
|
|
|
+ " \"generations\": [\n",
|
|
|
+ " [\n",
|
|
|
+ " {\n",
|
|
|
+ " \"text\": \" Sure! Here's the SQL query for the given input question:\\n\\nSELECT Team FROM nba_roster WHERE NAME = 'Klay Thompson';\",\n",
|
|
|
+ " \"generation_info\": null\n",
|
|
|
+ " }\n",
|
|
|
+ " ]\n",
|
|
|
+ " ],\n",
|
|
|
+ " \"llm_output\": null,\n",
|
|
|
+ " \"run\": null\n",
|
|
|
+ "}\n",
|
|
|
+ "\u001b[36;1m\u001b[1;3m[chain/end]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 4:chain:LLMChain] [2.68s] Exiting Chain run with output:\n",
|
|
|
+ "\u001b[0m{\n",
|
|
|
+ " \"text\": \" Sure! Here's the SQL query for the given input question:\\n\\nSELECT Team FROM nba_roster WHERE NAME = 'Klay Thompson';\"\n",
|
|
|
+ "}\n",
|
|
|
+ "\u001b[36;1m\u001b[1;3m[chain/end]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain] [4.82s] Exiting Chain run with output:\n",
|
|
|
+ "\u001b[0m{\n",
|
|
|
+ " \"result\": \"Sure! Here's the SQL query for the given input question:\\n\\nSELECT Team FROM nba_roster WHERE NAME = 'Klay Thompson';\"\n",
|
|
|
+ "}\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "# try another question and start seeing why cutting edge may be inconsistent and cut and hurt \n",
|
|
|
+ "answer = db_chain.run(\"Which team is Klay Thompson in?\")\n",
|
|
|
+ "\n",
|
|
|
+ "# the output of the first llm call is correct: SELECT Team FROM nba_roster WHERE NAME = 'Klay Thompson';\"\n",
|
|
|
+ "# the second \"Entering LLM run with input\" also has correct input, containing \"SQLResult: [('Golden State Warriors',)]\"\n",
|
|
|
+ "# and yet the second \"Exiting Chain run with output\" insists generating the result:\n",
|
|
|
+ "# \"Sure! Here's the SQL query for the given input question:\\n\\nSELECT Team FROM nba_roster WHERE NAME = 'Klay Thompson';\""
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 8,
|
|
|
+ "id": "222169a7",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "Sure! Here's the SQL query for the given input question:\n",
|
|
|
+ "\n",
|
|
|
+ "SELECT Team FROM nba_roster WHERE NAME = 'Klay Thompson';\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "# just to see how unbelievable it truly is\n",
|
|
|
+ "print(answer)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 9,
|
|
|
+ "id": "03d22e2c-24c1-4063-9556-37f48c202a84",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "# one way to fix this is to set return_sql to be True, so each db_chain.run call will only make one llm call, and we'll\n",
|
|
|
+ "# manually process each llm output\n",
|
|
|
+ "db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True, return_sql=True, \n",
|
|
|
+ " prompt=PromptTemplate(input_variables=[\"input\", \"table_info\"], \n",
|
|
|
+ " template=PROMPT_SUFFIX))"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 10,
|
|
|
+ "id": "9319b42e-0c04-4b60-95b3-6c7758ba9cf8",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "\u001b[32;1m\u001b[1;3m[chain/start]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain] Entering Chain run with input:\n",
|
|
|
+ "\u001b[0m{\n",
|
|
|
+ " \"query\": \"How many unique teams are there?\"\n",
|
|
|
+ "}\n",
|
|
|
+ "\u001b[32;1m\u001b[1;3m[chain/start]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain] Entering Chain run with input:\n",
|
|
|
+ "\u001b[0m{\n",
|
|
|
+ " \"input\": \"How many unique teams are there?\\nSQLQuery:\",\n",
|
|
|
+ " \"top_k\": \"5\",\n",
|
|
|
+ " \"dialect\": \"sqlite\",\n",
|
|
|
+ " \"table_info\": \"\\nCREATE TABLE nba_roster (\\n\\t\\\"Team\\\" TEXT, \\n\\t\\\"NAME\\\" TEXT, \\n\\t\\\"Jersey\\\" TEXT, \\n\\t\\\"POS\\\" TEXT, \\n\\t\\\"AGE\\\" INTEGER, \\n\\t\\\"HT\\\" TEXT, \\n\\t\\\"WT\\\" TEXT, \\n\\t\\\"COLLEGE\\\" TEXT, \\n\\t\\\"SALARY\\\" TEXT\\n)\",\n",
|
|
|
+ " \"stop\": [\n",
|
|
|
+ " \"\\nSQLResult:\"\n",
|
|
|
+ " ]\n",
|
|
|
+ "}\n",
|
|
|
+ "\u001b[32;1m\u001b[1;3m[llm/start]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain > 3:llm:Replicate] Entering LLM run with input:\n",
|
|
|
+ "\u001b[0m{\n",
|
|
|
+ " \"prompts\": [\n",
|
|
|
+ " \"Only use the following tables:\\n\\nCREATE TABLE nba_roster (\\n\\t\\\"Team\\\" TEXT, \\n\\t\\\"NAME\\\" TEXT, \\n\\t\\\"Jersey\\\" TEXT, \\n\\t\\\"POS\\\" TEXT, \\n\\t\\\"AGE\\\" INTEGER, \\n\\t\\\"HT\\\" TEXT, \\n\\t\\\"WT\\\" TEXT, \\n\\t\\\"COLLEGE\\\" TEXT, \\n\\t\\\"SALARY\\\" TEXT\\n)\\n\\nQuestion: How many unique teams are there?\\nSQLQuery:\"\n",
|
|
|
+ " ]\n",
|
|
|
+ "}\n",
|
|
|
+ "\u001b[36;1m\u001b[1;3m[llm/end]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain > 3:llm:Replicate] [2.20s] Exiting LLM run with output:\n",
|
|
|
+ "\u001b[0m{\n",
|
|
|
+ " \"generations\": [\n",
|
|
|
+ " [\n",
|
|
|
+ " {\n",
|
|
|
+ " \"text\": \" SELECT COUNT(DISTINCT Team) AS num_teams FROM nba_roster;\",\n",
|
|
|
+ " \"generation_info\": null\n",
|
|
|
+ " }\n",
|
|
|
+ " ]\n",
|
|
|
+ " ],\n",
|
|
|
+ " \"llm_output\": null,\n",
|
|
|
+ " \"run\": null\n",
|
|
|
+ "}\n",
|
|
|
+ "\u001b[36;1m\u001b[1;3m[chain/end]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain] [2.20s] Exiting Chain run with output:\n",
|
|
|
+ "\u001b[0m{\n",
|
|
|
+ " \"text\": \" SELECT COUNT(DISTINCT Team) AS num_teams FROM nba_roster;\"\n",
|
|
|
+ "}\n",
|
|
|
+ "\u001b[36;1m\u001b[1;3m[chain/end]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain] [2.20s] Exiting Chain run with output:\n",
|
|
|
+ "\u001b[0m{\n",
|
|
|
+ " \"result\": \"SELECT COUNT(DISTINCT Team) AS num_teams FROM nba_roster;\"\n",
|
|
|
+ "}\n",
|
|
|
+ "SELECT COUNT(DISTINCT Team) AS num_teams FROM nba_roster;\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "# let's go back to the question that works before to make sure the change above doesn't break it\n",
|
|
|
+ "# because return_sql is True, and system_prompt is set to \"Given an input question, convert it to a SQL query. No pre-amble.\" when \n",
|
|
|
+ "# creating the llm instance, the db_chain's run will make just one Llama call and ask it to return the SQL statement only, \n",
|
|
|
+ "# instead of making two Llama calls and returning the answer it can't do in the above case.\n",
|
|
|
+ "sql = db_chain.run(\"How many unique teams are there?\")\n",
|
|
|
+ "print(sql)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 11,
|
|
|
+ "id": "35cba597-5774-472a-a464-41a2eba062ab",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "[{'num_teams': 30}]\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "print(db._execute(sql))\n",
|
|
|
+ "# the right answer is generated"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 12,
|
|
|
+ "id": "fb38a462-e23a-45b3-8379-7900d07751bf",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "\u001b[32;1m\u001b[1;3m[chain/start]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain] Entering Chain run with input:\n",
|
|
|
+ "\u001b[0m{\n",
|
|
|
+ " \"query\": \"Which team is Klay Thompson in?\"\n",
|
|
|
+ "}\n",
|
|
|
+ "\u001b[32;1m\u001b[1;3m[chain/start]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain] Entering Chain run with input:\n",
|
|
|
+ "\u001b[0m{\n",
|
|
|
+ " \"input\": \"Which team is Klay Thompson in?\\nSQLQuery:\",\n",
|
|
|
+ " \"top_k\": \"5\",\n",
|
|
|
+ " \"dialect\": \"sqlite\",\n",
|
|
|
+ " \"table_info\": \"\\nCREATE TABLE nba_roster (\\n\\t\\\"Team\\\" TEXT, \\n\\t\\\"NAME\\\" TEXT, \\n\\t\\\"Jersey\\\" TEXT, \\n\\t\\\"POS\\\" TEXT, \\n\\t\\\"AGE\\\" INTEGER, \\n\\t\\\"HT\\\" TEXT, \\n\\t\\\"WT\\\" TEXT, \\n\\t\\\"COLLEGE\\\" TEXT, \\n\\t\\\"SALARY\\\" TEXT\\n)\",\n",
|
|
|
+ " \"stop\": [\n",
|
|
|
+ " \"\\nSQLResult:\"\n",
|
|
|
+ " ]\n",
|
|
|
+ "}\n",
|
|
|
+ "\u001b[32;1m\u001b[1;3m[llm/start]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain > 3:llm:Replicate] Entering LLM run with input:\n",
|
|
|
+ "\u001b[0m{\n",
|
|
|
+ " \"prompts\": [\n",
|
|
|
+ " \"Only use the following tables:\\n\\nCREATE TABLE nba_roster (\\n\\t\\\"Team\\\" TEXT, \\n\\t\\\"NAME\\\" TEXT, \\n\\t\\\"Jersey\\\" TEXT, \\n\\t\\\"POS\\\" TEXT, \\n\\t\\\"AGE\\\" INTEGER, \\n\\t\\\"HT\\\" TEXT, \\n\\t\\\"WT\\\" TEXT, \\n\\t\\\"COLLEGE\\\" TEXT, \\n\\t\\\"SALARY\\\" TEXT\\n)\\n\\nQuestion: Which team is Klay Thompson in?\\nSQLQuery:\"\n",
|
|
|
+ " ]\n",
|
|
|
+ "}\n",
|
|
|
+ "\u001b[36;1m\u001b[1;3m[llm/end]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain > 3:llm:Replicate] [2.11s] Exiting LLM run with output:\n",
|
|
|
+ "\u001b[0m{\n",
|
|
|
+ " \"generations\": [\n",
|
|
|
+ " [\n",
|
|
|
+ " {\n",
|
|
|
+ " \"text\": \" SELECT Team FROM nba_roster WHERE NAME = 'Klay Thompson';\",\n",
|
|
|
+ " \"generation_info\": null\n",
|
|
|
+ " }\n",
|
|
|
+ " ]\n",
|
|
|
+ " ],\n",
|
|
|
+ " \"llm_output\": null,\n",
|
|
|
+ " \"run\": null\n",
|
|
|
+ "}\n",
|
|
|
+ "\u001b[36;1m\u001b[1;3m[chain/end]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain] [2.11s] Exiting Chain run with output:\n",
|
|
|
+ "\u001b[0m{\n",
|
|
|
+ " \"text\": \" SELECT Team FROM nba_roster WHERE NAME = 'Klay Thompson';\"\n",
|
|
|
+ "}\n",
|
|
|
+ "\u001b[36;1m\u001b[1;3m[chain/end]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain] [2.11s] Exiting Chain run with output:\n",
|
|
|
+ "\u001b[0m{\n",
|
|
|
+ " \"result\": \"SELECT Team FROM nba_roster WHERE NAME = 'Klay Thompson';\"\n",
|
|
|
+ "}\n",
|
|
|
+ "SELECT Team FROM nba_roster WHERE NAME = 'Klay Thompson';\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "# let's try the query that fails\n",
|
|
|
+ "sql = db_chain.run(\"Which team is Klay Thompson in?\")\n",
|
|
|
+ "print(sql)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 13,
|
|
|
+ "id": "6aef6ae7-b14c-46d6-b4a4-3563abbd3391",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "[{'Team': 'Golden State Warriors'}]\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "# the sql result looks good. let's run it to get the answer\n",
|
|
|
+ "print(db._execute(sql))"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 14,
|
|
|
+ "id": "39ed4bc3",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "\u001b[32;1m\u001b[1;3m[chain/start]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain] Entering Chain run with input:\n",
|
|
|
+ "\u001b[0m{\n",
|
|
|
+ " \"query\": \"What's his salary?\"\n",
|
|
|
+ "}\n",
|
|
|
+ "\u001b[32;1m\u001b[1;3m[chain/start]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain] Entering Chain run with input:\n",
|
|
|
+ "\u001b[0m{\n",
|
|
|
+ " \"input\": \"What's his salary?\\nSQLQuery:\",\n",
|
|
|
+ " \"top_k\": \"5\",\n",
|
|
|
+ " \"dialect\": \"sqlite\",\n",
|
|
|
+ " \"table_info\": \"\\nCREATE TABLE nba_roster (\\n\\t\\\"Team\\\" TEXT, \\n\\t\\\"NAME\\\" TEXT, \\n\\t\\\"Jersey\\\" TEXT, \\n\\t\\\"POS\\\" TEXT, \\n\\t\\\"AGE\\\" INTEGER, \\n\\t\\\"HT\\\" TEXT, \\n\\t\\\"WT\\\" TEXT, \\n\\t\\\"COLLEGE\\\" TEXT, \\n\\t\\\"SALARY\\\" TEXT\\n)\",\n",
|
|
|
+ " \"stop\": [\n",
|
|
|
+ " \"\\nSQLResult:\"\n",
|
|
|
+ " ]\n",
|
|
|
+ "}\n",
|
|
|
+ "\u001b[32;1m\u001b[1;3m[llm/start]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain > 3:llm:Replicate] Entering LLM run with input:\n",
|
|
|
+ "\u001b[0m{\n",
|
|
|
+ " \"prompts\": [\n",
|
|
|
+ " \"Only use the following tables:\\n\\nCREATE TABLE nba_roster (\\n\\t\\\"Team\\\" TEXT, \\n\\t\\\"NAME\\\" TEXT, \\n\\t\\\"Jersey\\\" TEXT, \\n\\t\\\"POS\\\" TEXT, \\n\\t\\\"AGE\\\" INTEGER, \\n\\t\\\"HT\\\" TEXT, \\n\\t\\\"WT\\\" TEXT, \\n\\t\\\"COLLEGE\\\" TEXT, \\n\\t\\\"SALARY\\\" TEXT\\n)\\n\\nQuestion: What's his salary?\\nSQLQuery:\"\n",
|
|
|
+ " ]\n",
|
|
|
+ "}\n",
|
|
|
+ "\u001b[36;1m\u001b[1;3m[llm/end]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain > 3:llm:Replicate] [2.73s] Exiting LLM run with output:\n",
|
|
|
+ "\u001b[0m{\n",
|
|
|
+ " \"generations\": [\n",
|
|
|
+ " [\n",
|
|
|
+ " {\n",
|
|
|
+ " \"text\": \" Sure! Here is the SQL query based on the given question and table structure:\\n\\nSELECT SALARY FROM nba_roster WHERE NAME = 'LeBron James';\",\n",
|
|
|
+ " \"generation_info\": null\n",
|
|
|
+ " }\n",
|
|
|
+ " ]\n",
|
|
|
+ " ],\n",
|
|
|
+ " \"llm_output\": null,\n",
|
|
|
+ " \"run\": null\n",
|
|
|
+ "}\n",
|
|
|
+ "\u001b[36;1m\u001b[1;3m[chain/end]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain] [2.73s] Exiting Chain run with output:\n",
|
|
|
+ "\u001b[0m{\n",
|
|
|
+ " \"text\": \" Sure! Here is the SQL query based on the given question and table structure:\\n\\nSELECT SALARY FROM nba_roster WHERE NAME = 'LeBron James';\"\n",
|
|
|
+ "}\n",
|
|
|
+ "\u001b[36;1m\u001b[1;3m[chain/end]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain] [2.73s] Exiting Chain run with output:\n",
|
|
|
+ "\u001b[0m{\n",
|
|
|
+ " \"result\": \"Sure! Here is the SQL query based on the given question and table structure:\\n\\nSELECT SALARY FROM nba_roster WHERE NAME = 'LeBron James';\"\n",
|
|
|
+ "}\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "# how about a follow up question\n",
|
|
|
+ "sql = db_chain.run(\"What's his salary?\")"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 15,
|
|
|
+ "id": "0c305278-29d2-4e88-9b3d-ad67c94ce0f2",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "Sure! Here is the SQL query based on the given question and table structure:\n",
|
|
|
+ "\n",
|
|
|
+ "SELECT SALARY FROM nba_roster WHERE NAME = 'LeBron James';\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "# two issues here: the llm output is wordy again, even though we use the system prompt to ask it not to be; \n",
|
|
|
+ "# the other issue makes sense: we didn't pass the context along with the follow-up to llm so it didn't know what \"his\" is.\n",
|
|
|
+ "print(sql)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 16,
|
|
|
+ "id": "670bc858",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "SELECT SALARY FROM nba_roster WHERE NAME = 'LeBron James';\n",
|
|
|
+ "[{'SALARY': '$47,607,350'}]\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "# let's fix the first issue by custom parsing, which is typical in real LLM apps\n",
|
|
|
+ "# use regular expression to extract only the SQL SELECT statement\n",
|
|
|
+ "import re\n",
|
|
|
+ "\n",
|
|
|
+ "pattern = r'SELECT[\\s\\S]*?;'\n",
|
|
|
+ "match = re.search(pattern, sql)\n",
|
|
|
+ "if match:\n",
|
|
|
+ " sql = match.group()\n",
|
|
|
+ " print(sql)\n",
|
|
|
+ "print(db._execute(sql))"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 17,
|
|
|
+ "id": "0bfa3d34",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [],
|
|
|
+ "source": [
|
|
|
+ "# now let's try to fix the issue that the context (the previous question and answer) was not sent to LLM along with the new question\n",
|
|
|
+ "# SQLDatabaseChain.from_llm has a parameter \"memory\" which can be set to a ConversationBufferMemory instance, which looks promising.\n",
|
|
|
+ "from langchain.memory import ConversationBufferMemory\n",
|
|
|
+ "\n",
|
|
|
+ "memory = ConversationBufferMemory()\n",
|
|
|
+ "db_chain_memory = SQLDatabaseChain.from_llm(llm, db, memory=memory, \n",
|
|
|
+ " verbose=True, return_sql=True, \n",
|
|
|
+ " prompt=PromptTemplate(input_variables=[\"input\", \"table_info\"], \n",
|
|
|
+ " template=PROMPT_SUFFIX))"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 18,
|
|
|
+ "id": "d12b50e7",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "\u001b[32;1m\u001b[1;3m[chain/start]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain] Entering Chain run with input:\n",
|
|
|
+ "\u001b[0m{\n",
|
|
|
+ " \"query\": \"Which team is Klay Thompson in\",\n",
|
|
|
+ " \"history\": \"\"\n",
|
|
|
+ "}\n",
|
|
|
+ "\u001b[32;1m\u001b[1;3m[chain/start]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain] Entering Chain run with input:\n",
|
|
|
+ "\u001b[0m{\n",
|
|
|
+ " \"input\": \"Which team is Klay Thompson in\\nSQLQuery:\",\n",
|
|
|
+ " \"top_k\": \"5\",\n",
|
|
|
+ " \"dialect\": \"sqlite\",\n",
|
|
|
+ " \"table_info\": \"\\nCREATE TABLE nba_roster (\\n\\t\\\"Team\\\" TEXT, \\n\\t\\\"NAME\\\" TEXT, \\n\\t\\\"Jersey\\\" TEXT, \\n\\t\\\"POS\\\" TEXT, \\n\\t\\\"AGE\\\" INTEGER, \\n\\t\\\"HT\\\" TEXT, \\n\\t\\\"WT\\\" TEXT, \\n\\t\\\"COLLEGE\\\" TEXT, \\n\\t\\\"SALARY\\\" TEXT\\n)\",\n",
|
|
|
+ " \"stop\": [\n",
|
|
|
+ " \"\\nSQLResult:\"\n",
|
|
|
+ " ],\n",
|
|
|
+ " \"history\": \"\"\n",
|
|
|
+ "}\n",
|
|
|
+ "\u001b[32;1m\u001b[1;3m[llm/start]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain > 3:llm:Replicate] Entering LLM run with input:\n",
|
|
|
+ "\u001b[0m{\n",
|
|
|
+ " \"prompts\": [\n",
|
|
|
+ " \"Only use the following tables:\\n\\nCREATE TABLE nba_roster (\\n\\t\\\"Team\\\" TEXT, \\n\\t\\\"NAME\\\" TEXT, \\n\\t\\\"Jersey\\\" TEXT, \\n\\t\\\"POS\\\" TEXT, \\n\\t\\\"AGE\\\" INTEGER, \\n\\t\\\"HT\\\" TEXT, \\n\\t\\\"WT\\\" TEXT, \\n\\t\\\"COLLEGE\\\" TEXT, \\n\\t\\\"SALARY\\\" TEXT\\n)\\n\\nQuestion: Which team is Klay Thompson in\\nSQLQuery:\"\n",
|
|
|
+ " ]\n",
|
|
|
+ "}\n",
|
|
|
+ "\u001b[36;1m\u001b[1;3m[llm/end]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain > 3:llm:Replicate] [2.11s] Exiting LLM run with output:\n",
|
|
|
+ "\u001b[0m{\n",
|
|
|
+ " \"generations\": [\n",
|
|
|
+ " [\n",
|
|
|
+ " {\n",
|
|
|
+ " \"text\": \" SELECT Team FROM nba_roster WHERE NAME = 'Klay Thompson';\",\n",
|
|
|
+ " \"generation_info\": null\n",
|
|
|
+ " }\n",
|
|
|
+ " ]\n",
|
|
|
+ " ],\n",
|
|
|
+ " \"llm_output\": null,\n",
|
|
|
+ " \"run\": null\n",
|
|
|
+ "}\n",
|
|
|
+ "\u001b[36;1m\u001b[1;3m[chain/end]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain] [2.11s] Exiting Chain run with output:\n",
|
|
|
+ "\u001b[0m{\n",
|
|
|
+ " \"text\": \" SELECT Team FROM nba_roster WHERE NAME = 'Klay Thompson';\"\n",
|
|
|
+ "}\n",
|
|
|
+ "\u001b[36;1m\u001b[1;3m[chain/end]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain] [2.11s] Exiting Chain run with output:\n",
|
|
|
+ "\u001b[0m{\n",
|
|
|
+ " \"result\": \"SELECT Team FROM nba_roster WHERE NAME = 'Klay Thompson';\"\n",
|
|
|
+ "}\n",
|
|
|
+ "SELECT Team FROM nba_roster WHERE NAME = 'Klay Thompson';\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "# use the db_chain_memory to run the original question again\n",
|
|
|
+ "question = \"Which team is Klay Thompson in\"\n",
|
|
|
+ "sql = db_chain_memory.run(question)\n",
|
|
|
+ "print(sql)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 19,
|
|
|
+ "id": "15cb95bf-73fe-404f-9b30-79dd6a53e369",
|
|
|
+ "metadata": {},
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "[{'Team': 'Golden State Warriors'}]\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "answer = db._execute(sql)\n",
|
|
|
+ "print(answer)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "execution_count": 20,
|
|
|
+ "id": "fee65d5f",
|
|
|
+ "metadata": {
|
|
|
+ "scrolled": true
|
|
|
+ },
|
|
|
+ "outputs": [
|
|
|
+ {
|
|
|
+ "name": "stdout",
|
|
|
+ "output_type": "stream",
|
|
|
+ "text": [
|
|
|
+ "\u001b[32;1m\u001b[1;3m[chain/start]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain] Entering Chain run with input:\n",
|
|
|
+ "\u001b[0m{\n",
|
|
|
+ " \"query\": \"What's his salary\",\n",
|
|
|
+ " \"history\": \"Human: Which team is Klay Thompson in\\nAI: SELECT Team FROM nba_roster WHERE NAME = 'Klay Thompson';\\nHuman: Which team is Klay Thompson in\\nAI: [{\\\"Team\\\": \\\"Golden State Warriors\\\"}]\"\n",
|
|
|
+ "}\n",
|
|
|
+ "\u001b[32;1m\u001b[1;3m[chain/start]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain] Entering Chain run with input:\n",
|
|
|
+ "\u001b[0m{\n",
|
|
|
+ " \"input\": \"What's his salary\\nSQLQuery:\",\n",
|
|
|
+ " \"top_k\": \"5\",\n",
|
|
|
+ " \"dialect\": \"sqlite\",\n",
|
|
|
+ " \"table_info\": \"\\nCREATE TABLE nba_roster (\\n\\t\\\"Team\\\" TEXT, \\n\\t\\\"NAME\\\" TEXT, \\n\\t\\\"Jersey\\\" TEXT, \\n\\t\\\"POS\\\" TEXT, \\n\\t\\\"AGE\\\" INTEGER, \\n\\t\\\"HT\\\" TEXT, \\n\\t\\\"WT\\\" TEXT, \\n\\t\\\"COLLEGE\\\" TEXT, \\n\\t\\\"SALARY\\\" TEXT\\n)\",\n",
|
|
|
+ " \"stop\": [\n",
|
|
|
+ " \"\\nSQLResult:\"\n",
|
|
|
+ " ],\n",
|
|
|
+ " \"history\": \"Human: Which team is Klay Thompson in\\nAI: SELECT Team FROM nba_roster WHERE NAME = 'Klay Thompson';\\nHuman: Which team is Klay Thompson in\\nAI: [{\\\"Team\\\": \\\"Golden State Warriors\\\"}]\"\n",
|
|
|
+ "}\n",
|
|
|
+ "\u001b[32;1m\u001b[1;3m[llm/start]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain > 3:llm:Replicate] Entering LLM run with input:\n",
|
|
|
+ "\u001b[0m{\n",
|
|
|
+ " \"prompts\": [\n",
|
|
|
+ " \"Only use the following tables:\\n\\nCREATE TABLE nba_roster (\\n\\t\\\"Team\\\" TEXT, \\n\\t\\\"NAME\\\" TEXT, \\n\\t\\\"Jersey\\\" TEXT, \\n\\t\\\"POS\\\" TEXT, \\n\\t\\\"AGE\\\" INTEGER, \\n\\t\\\"HT\\\" TEXT, \\n\\t\\\"WT\\\" TEXT, \\n\\t\\\"COLLEGE\\\" TEXT, \\n\\t\\\"SALARY\\\" TEXT\\n)\\n\\nQuestion: What's his salary\\nSQLQuery:\"\n",
|
|
|
+ " ]\n",
|
|
|
+ "}\n",
|
|
|
+ "\u001b[36;1m\u001b[1;3m[llm/end]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain > 3:llm:Replicate] [2.75s] Exiting LLM run with output:\n",
|
|
|
+ "\u001b[0m{\n",
|
|
|
+ " \"generations\": [\n",
|
|
|
+ " [\n",
|
|
|
+ " {\n",
|
|
|
+ " \"text\": \" Sure! Here is the SQL query based on the given question and table structure:\\n\\nSELECT SALARY FROM nba_roster WHERE NAME = 'Lebron James';\",\n",
|
|
|
+ " \"generation_info\": null\n",
|
|
|
+ " }\n",
|
|
|
+ " ]\n",
|
|
|
+ " ],\n",
|
|
|
+ " \"llm_output\": null,\n",
|
|
|
+ " \"run\": null\n",
|
|
|
+ "}\n",
|
|
|
+ "\u001b[36;1m\u001b[1;3m[chain/end]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain] [2.75s] Exiting Chain run with output:\n",
|
|
|
+ "\u001b[0m{\n",
|
|
|
+ " \"text\": \" Sure! Here is the SQL query based on the given question and table structure:\\n\\nSELECT SALARY FROM nba_roster WHERE NAME = 'Lebron James';\"\n",
|
|
|
+ "}\n",
|
|
|
+ "\u001b[36;1m\u001b[1;3m[chain/end]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain] [2.75s] Exiting Chain run with output:\n",
|
|
|
+ "\u001b[0m{\n",
|
|
|
+ " \"result\": \"Sure! Here is the SQL query based on the given question and table structure:\\n\\nSELECT SALARY FROM nba_roster WHERE NAME = 'Lebron James';\"\n",
|
|
|
+ "}\n",
|
|
|
+ "Sure! Here is the SQL query based on the given question and table structure:\n",
|
|
|
+ "\n",
|
|
|
+ "SELECT SALARY FROM nba_roster WHERE NAME = 'Lebron James';\n"
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "source": [
|
|
|
+ "import json\n",
|
|
|
+ "\n",
|
|
|
+ "memory.save_context({\"input\": question},\n",
|
|
|
+ " {\"output\": json.dumps(answer)})\n",
|
|
|
+ "followup = \"What's his salary\"\n",
|
|
|
+ "followup_answer = db_chain_memory.run(followup)\n",
|
|
|
+ "print(followup_answer)\n",
|
|
|
+ "\n",
|
|
|
+ "# \"Entering Chain run with input\" does show \"history\" including our question: \"Human: Which team is Klay Thompson in\"\n",
|
|
|
+ "# but it doesn't get passed to \"Entering LLM run with input\", so the llm output still is Lebron James."
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "id": "d23d47e9",
|
|
|
+ "metadata": {},
|
|
|
+ "source": [
|
|
|
+ "Looks like we hit a [known issue](https://github.com/langchain-ai/langchain/issues/6918#issuecomment-1632932653) with adding memory to SQLDatabaseChain, even after the [PR](https://github.com/langchain-ai/langchain/pull/7546) that's supposed to fix it was merged. We'll leave the notebook using the experimental SQLDatabaseChain as is. To find a better solution to using SQLDatabaseChain, we can:\n",
|
|
|
+ "1. Play with the prompt engineering further;\n",
|
|
|
+ "2. Look into the SQLDatabaseChain's implementation and the use of ConversationBufferMemory to fix the memory issue;\n",
|
|
|
+ "3. Wait till SQLDatabaseChain prompted from the langchain_experimental package to langchain."
|
|
|
+ ]
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "metadata": {
|
|
|
+ "kernelspec": {
|
|
|
+ "display_name": "Python 3 (ipykernel)",
|
|
|
+ "language": "python",
|
|
|
+ "name": "python3"
|
|
|
+ },
|
|
|
+ "language_info": {
|
|
|
+ "codemirror_mode": {
|
|
|
+ "name": "ipython",
|
|
|
+ "version": 3
|
|
|
+ },
|
|
|
+ "file_extension": ".py",
|
|
|
+ "mimetype": "text/x-python",
|
|
|
+ "name": "python",
|
|
|
+ "nbconvert_exporter": "python",
|
|
|
+ "pygments_lexer": "ipython3",
|
|
|
+ "version": "3.8.18"
|
|
|
+ }
|
|
|
+ },
|
|
|
+ "nbformat": 4,
|
|
|
+ "nbformat_minor": 5
|
|
|
+}
|