{ "cells": [ { "cell_type": "markdown", "id": "e8cba0b6", "metadata": {}, "source": [ "## This demo shows how to use LangChain's SQLDatabaseChain with Llama2 to chat about structured data stored in a SQL DB. \n", "* As the 2023-24 NBA season is around the corner, we use the NBA roster info saved in a SQLite DB to show you how to ask Llama2 questions about your favorite teams or players. \n", "* Because the SQLDatabaseChain API implementation is still in the langchain_experimental package, you'll see more issues that come with using the cutting edge experimental features, and how we succeed resolving some of the issues but fail on some others." ] }, { "cell_type": "code", "execution_count": null, "id": "33fb3190-59fb-4edd-82dd-f20f6eab3e47", "metadata": {}, "outputs": [], "source": [ "!pip install langchain replicate langchain_experimental" ] }, { "cell_type": "code", "execution_count": 1, "id": "70d6c906-4fc0-4da5-acd9-8743e90f89e8", "metadata": {}, "outputs": [], "source": [ "from langchain.llms import Replicate\n", "from langchain.prompts import PromptTemplate\n", "from langchain.utilities import SQLDatabase\n", "from langchain_experimental.sql import SQLDatabaseChain" ] }, { "cell_type": "code", "execution_count": 2, "id": "fa4562d3", "metadata": {}, "outputs": [ { "name": "stdin", "output_type": "stream", "text": [ " ········\n" ] } ], "source": [ "from getpass import getpass\n", "import os\n", "\n", "REPLICATE_API_TOKEN = getpass()\n", "os.environ[\"REPLICATE_API_TOKEN\"] = REPLICATE_API_TOKEN" ] }, { "cell_type": "code", "execution_count": 3, "id": "9dcd744c", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Init param `input` is deprecated, please use `model_kwargs` instead.\n" ] } ], "source": [ "llama2_13b_chat = \"meta/llama-2-13b-chat:f4e2de70d66816a838a89eeeb621910adffb0dd0baba3976c96980970978018d\"\n", "\n", "# the reason we set system_prompt below is to ask Llama to generate only the SQL statement, instead of being wordy and adding something like\n", "# \"Sure! Here's the SQL query for the given input question: \" before the SQL query; otherwise custom parsing will be needed.\n", "llm = Replicate(\n", " model=llama2_13b_chat,\n", " input={\"temperature\": 0.01, \"max_length\": 500, \"top_p\": 1, \"system_prompt\": \"Given an input question, convert it to a SQL query. No pre-amble.\"},\n", ")" ] }, { "cell_type": "code", "execution_count": 4, "id": "70776558", "metadata": {}, "outputs": [], "source": [ "db = SQLDatabase.from_uri(\"sqlite:///nba_roster.db\", sample_rows_in_table_info= 0)\n", "\n", "# use the default sqlite prompt defined in \n", "# https://github.com/langchain-ai/langchain/blob/33eb5f8300cd21c91a2f8d10c62197637931fa0a/libs/langchain/langchain/chains/sql_database/prompt.py#L211\n", "# db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True)\n", "\n", "# customize the default sqlite prompt defined in the link above\n", "PROMPT_SUFFIX = \"\"\"\n", "Only use the following tables:\n", "{table_info}\n", "\n", "Question: {input}\"\"\"\n", "\n", "# if return_sql is set to False, then calling db_chain.run will make two calls to llm: first get the SQL query statement, \n", "# then pass the result of running the SQL query along with the original question and the query to llm for NL answer\n", "db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True, return_sql=False, \n", " prompt=PromptTemplate(input_variables=[\"input\", \"table_info\"], \n", " template=PROMPT_SUFFIX))" ] }, { "cell_type": "code", "execution_count": 5, "id": "239fb5a6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[32;1m\u001b[1;3m[chain/start]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain] Entering Chain run with input:\n", "\u001b[0m{\n", " \"query\": \"How many unique teams are there?\"\n", "}\n", "\u001b[32;1m\u001b[1;3m[chain/start]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain] Entering Chain run with input:\n", "\u001b[0m{\n", " \"input\": \"How many unique teams are there?\\nSQLQuery:\",\n", " \"top_k\": \"5\",\n", " \"dialect\": \"sqlite\",\n", " \"table_info\": \"\\nCREATE TABLE nba_roster (\\n\\t\\\"Team\\\" TEXT, \\n\\t\\\"NAME\\\" TEXT, \\n\\t\\\"Jersey\\\" TEXT, \\n\\t\\\"POS\\\" TEXT, \\n\\t\\\"AGE\\\" INTEGER, \\n\\t\\\"HT\\\" TEXT, \\n\\t\\\"WT\\\" TEXT, \\n\\t\\\"COLLEGE\\\" TEXT, \\n\\t\\\"SALARY\\\" TEXT\\n)\",\n", " \"stop\": [\n", " \"\\nSQLResult:\"\n", " ]\n", "}\n", "\u001b[32;1m\u001b[1;3m[llm/start]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain > 3:llm:Replicate] Entering LLM run with input:\n", "\u001b[0m{\n", " \"prompts\": [\n", " \"Only use the following tables:\\n\\nCREATE TABLE nba_roster (\\n\\t\\\"Team\\\" TEXT, \\n\\t\\\"NAME\\\" TEXT, \\n\\t\\\"Jersey\\\" TEXT, \\n\\t\\\"POS\\\" TEXT, \\n\\t\\\"AGE\\\" INTEGER, \\n\\t\\\"HT\\\" TEXT, \\n\\t\\\"WT\\\" TEXT, \\n\\t\\\"COLLEGE\\\" TEXT, \\n\\t\\\"SALARY\\\" TEXT\\n)\\n\\nQuestion: How many unique teams are there?\\nSQLQuery:\"\n", " ]\n", "}\n", "\u001b[36;1m\u001b[1;3m[llm/end]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain > 3:llm:Replicate] [2.63s] Exiting LLM run with output:\n", "\u001b[0m{\n", " \"generations\": [\n", " [\n", " {\n", " \"text\": \" SELECT COUNT(DISTINCT Team) AS num_teams FROM nba_roster;\",\n", " \"generation_info\": null\n", " }\n", " ]\n", " ],\n", " \"llm_output\": null,\n", " \"run\": null\n", "}\n", "\u001b[36;1m\u001b[1;3m[chain/end]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain] [2.63s] Exiting Chain run with output:\n", "\u001b[0m{\n", " \"text\": \" SELECT COUNT(DISTINCT Team) AS num_teams FROM nba_roster;\"\n", "}\n", "\u001b[32;1m\u001b[1;3m[chain/start]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 4:chain:LLMChain] Entering Chain run with input:\n", "\u001b[0m{\n", " \"input\": \"How many unique teams are there?\\nSQLQuery:SELECT COUNT(DISTINCT Team) AS num_teams FROM nba_roster;\\nSQLResult: [(30,)]\\nAnswer:\",\n", " \"top_k\": \"5\",\n", " \"dialect\": \"sqlite\",\n", " \"table_info\": \"\\nCREATE TABLE nba_roster (\\n\\t\\\"Team\\\" TEXT, \\n\\t\\\"NAME\\\" TEXT, \\n\\t\\\"Jersey\\\" TEXT, \\n\\t\\\"POS\\\" TEXT, \\n\\t\\\"AGE\\\" INTEGER, \\n\\t\\\"HT\\\" TEXT, \\n\\t\\\"WT\\\" TEXT, \\n\\t\\\"COLLEGE\\\" TEXT, \\n\\t\\\"SALARY\\\" TEXT\\n)\",\n", " \"stop\": [\n", " \"\\nSQLResult:\"\n", " ]\n", "}\n", "\u001b[32;1m\u001b[1;3m[llm/start]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 4:chain:LLMChain > 5:llm:Replicate] Entering LLM run with input:\n", "\u001b[0m{\n", " \"prompts\": [\n", " \"Only use the following tables:\\n\\nCREATE TABLE nba_roster (\\n\\t\\\"Team\\\" TEXT, \\n\\t\\\"NAME\\\" TEXT, \\n\\t\\\"Jersey\\\" TEXT, \\n\\t\\\"POS\\\" TEXT, \\n\\t\\\"AGE\\\" INTEGER, \\n\\t\\\"HT\\\" TEXT, \\n\\t\\\"WT\\\" TEXT, \\n\\t\\\"COLLEGE\\\" TEXT, \\n\\t\\\"SALARY\\\" TEXT\\n)\\n\\nQuestion: How many unique teams are there?\\nSQLQuery:SELECT COUNT(DISTINCT Team) AS num_teams FROM nba_roster;\\nSQLResult: [(30,)]\\nAnswer:\"\n", " ]\n", "}\n", "\u001b[36;1m\u001b[1;3m[llm/end]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 4:chain:LLMChain > 5:llm:Replicate] [2.06s] Exiting LLM run with output:\n", "\u001b[0m{\n", " \"generations\": [\n", " [\n", " {\n", " \"text\": \" Sure! Here's the answer to your question:\\n\\nThere are 30 unique teams in the NBA.\",\n", " \"generation_info\": null\n", " }\n", " ]\n", " ],\n", " \"llm_output\": null,\n", " \"run\": null\n", "}\n", "\u001b[36;1m\u001b[1;3m[chain/end]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 4:chain:LLMChain] [2.06s] Exiting Chain run with output:\n", "\u001b[0m{\n", " \"text\": \" Sure! Here's the answer to your question:\\n\\nThere are 30 unique teams in the NBA.\"\n", "}\n", "\u001b[36;1m\u001b[1;3m[chain/end]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain] [4.70s] Exiting Chain run with output:\n", "\u001b[0m{\n", " \"result\": \"Sure! Here's the answer to your question:\\n\\nThere are 30 unique teams in the NBA.\"\n", "}\n" ] } ], "source": [ "# turn on the LangChain debug to see what exactly happens behind the scenes in terms of the inputs and outputs of calling Llama\n", "import langchain\n", "langchain.debug = True\n", "answer = db_chain.run(\"How many unique teams are there?\")\n", "\n", "# Two Llama calls happen below \n", "# Note the difference of the first and second inputs for the Llama calls - before \"llm:Replicate] Entering LLM run with input\"\n", "# About the output, for the first llm call, both \"llm:Replicate] [2.10s] Exiting LLM run with output\" and \n", "# \"Exiting Chain run with output\" return correctly the text of SQL statement \"SELECT COUNT(DISTINCT Team) AS num_teams FROM nba_roster;\"\n", "# In the second llm call, \"Exiting LLM run with output\" also shows the correct result:\n", "# \"Sure! Here's the answer to your question:\\n\\nThere are 30 unique teams in the NBA.\"" ] }, { "cell_type": "code", "execution_count": 6, "id": "eef91aea", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sure! Here's the answer to your question:\n", "\n", "There are 30 unique teams in the NBA.\n" ] } ], "source": [ "# just show the answer without all the debug info \n", "print(answer)" ] }, { "cell_type": "code", "execution_count": 7, "id": "e3b444aa", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[32;1m\u001b[1;3m[chain/start]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain] Entering Chain run with input:\n", "\u001b[0m{\n", " \"query\": \"Which team is Klay Thompson in?\"\n", "}\n", "\u001b[32;1m\u001b[1;3m[chain/start]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain] Entering Chain run with input:\n", "\u001b[0m{\n", " \"input\": \"Which team is Klay Thompson in?\\nSQLQuery:\",\n", " \"top_k\": \"5\",\n", " \"dialect\": \"sqlite\",\n", " \"table_info\": \"\\nCREATE TABLE nba_roster (\\n\\t\\\"Team\\\" TEXT, \\n\\t\\\"NAME\\\" TEXT, \\n\\t\\\"Jersey\\\" TEXT, \\n\\t\\\"POS\\\" TEXT, \\n\\t\\\"AGE\\\" INTEGER, \\n\\t\\\"HT\\\" TEXT, \\n\\t\\\"WT\\\" TEXT, \\n\\t\\\"COLLEGE\\\" TEXT, \\n\\t\\\"SALARY\\\" TEXT\\n)\",\n", " \"stop\": [\n", " \"\\nSQLResult:\"\n", " ]\n", "}\n", "\u001b[32;1m\u001b[1;3m[llm/start]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain > 3:llm:Replicate] Entering LLM run with input:\n", "\u001b[0m{\n", " \"prompts\": [\n", " \"Only use the following tables:\\n\\nCREATE TABLE nba_roster (\\n\\t\\\"Team\\\" TEXT, \\n\\t\\\"NAME\\\" TEXT, \\n\\t\\\"Jersey\\\" TEXT, \\n\\t\\\"POS\\\" TEXT, \\n\\t\\\"AGE\\\" INTEGER, \\n\\t\\\"HT\\\" TEXT, \\n\\t\\\"WT\\\" TEXT, \\n\\t\\\"COLLEGE\\\" TEXT, \\n\\t\\\"SALARY\\\" TEXT\\n)\\n\\nQuestion: Which team is Klay Thompson in?\\nSQLQuery:\"\n", " ]\n", "}\n", "\u001b[36;1m\u001b[1;3m[llm/end]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain > 3:llm:Replicate] [2.13s] Exiting LLM run with output:\n", "\u001b[0m{\n", " \"generations\": [\n", " [\n", " {\n", " \"text\": \" SELECT Team FROM nba_roster WHERE NAME = 'Klay Thompson';\",\n", " \"generation_info\": null\n", " }\n", " ]\n", " ],\n", " \"llm_output\": null,\n", " \"run\": null\n", "}\n", "\u001b[36;1m\u001b[1;3m[chain/end]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain] [2.14s] Exiting Chain run with output:\n", "\u001b[0m{\n", " \"text\": \" SELECT Team FROM nba_roster WHERE NAME = 'Klay Thompson';\"\n", "}\n", "\u001b[32;1m\u001b[1;3m[chain/start]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 4:chain:LLMChain] Entering Chain run with input:\n", "\u001b[0m{\n", " \"input\": \"Which team is Klay Thompson in?\\nSQLQuery:SELECT Team FROM nba_roster WHERE NAME = 'Klay Thompson';\\nSQLResult: [('Golden State Warriors',)]\\nAnswer:\",\n", " \"top_k\": \"5\",\n", " \"dialect\": \"sqlite\",\n", " \"table_info\": \"\\nCREATE TABLE nba_roster (\\n\\t\\\"Team\\\" TEXT, \\n\\t\\\"NAME\\\" TEXT, \\n\\t\\\"Jersey\\\" TEXT, \\n\\t\\\"POS\\\" TEXT, \\n\\t\\\"AGE\\\" INTEGER, \\n\\t\\\"HT\\\" TEXT, \\n\\t\\\"WT\\\" TEXT, \\n\\t\\\"COLLEGE\\\" TEXT, \\n\\t\\\"SALARY\\\" TEXT\\n)\",\n", " \"stop\": [\n", " \"\\nSQLResult:\"\n", " ]\n", "}\n", "\u001b[32;1m\u001b[1;3m[llm/start]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 4:chain:LLMChain > 5:llm:Replicate] Entering LLM run with input:\n", "\u001b[0m{\n", " \"prompts\": [\n", " \"Only use the following tables:\\n\\nCREATE TABLE nba_roster (\\n\\t\\\"Team\\\" TEXT, \\n\\t\\\"NAME\\\" TEXT, \\n\\t\\\"Jersey\\\" TEXT, \\n\\t\\\"POS\\\" TEXT, \\n\\t\\\"AGE\\\" INTEGER, \\n\\t\\\"HT\\\" TEXT, \\n\\t\\\"WT\\\" TEXT, \\n\\t\\\"COLLEGE\\\" TEXT, \\n\\t\\\"SALARY\\\" TEXT\\n)\\n\\nQuestion: Which team is Klay Thompson in?\\nSQLQuery:SELECT Team FROM nba_roster WHERE NAME = 'Klay Thompson';\\nSQLResult: [('Golden State Warriors',)]\\nAnswer:\"\n", " ]\n", "}\n", "\u001b[36;1m\u001b[1;3m[llm/end]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 4:chain:LLMChain > 5:llm:Replicate] [2.68s] Exiting LLM run with output:\n", "\u001b[0m{\n", " \"generations\": [\n", " [\n", " {\n", " \"text\": \" Sure! Here's the SQL query for the given input question:\\n\\nSELECT Team FROM nba_roster WHERE NAME = 'Klay Thompson';\",\n", " \"generation_info\": null\n", " }\n", " ]\n", " ],\n", " \"llm_output\": null,\n", " \"run\": null\n", "}\n", "\u001b[36;1m\u001b[1;3m[chain/end]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 4:chain:LLMChain] [2.68s] Exiting Chain run with output:\n", "\u001b[0m{\n", " \"text\": \" Sure! Here's the SQL query for the given input question:\\n\\nSELECT Team FROM nba_roster WHERE NAME = 'Klay Thompson';\"\n", "}\n", "\u001b[36;1m\u001b[1;3m[chain/end]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain] [4.82s] Exiting Chain run with output:\n", "\u001b[0m{\n", " \"result\": \"Sure! Here's the SQL query for the given input question:\\n\\nSELECT Team FROM nba_roster WHERE NAME = 'Klay Thompson';\"\n", "}\n" ] } ], "source": [ "# try another question and start seeing why cutting edge may be inconsistent and cut and hurt \n", "answer = db_chain.run(\"Which team is Klay Thompson in?\")\n", "\n", "# the output of the first llm call is correct: SELECT Team FROM nba_roster WHERE NAME = 'Klay Thompson';\"\n", "# the second \"Entering LLM run with input\" also has correct input, containing \"SQLResult: [('Golden State Warriors',)]\"\n", "# and yet the second \"Exiting Chain run with output\" insists generating the result:\n", "# \"Sure! Here's the SQL query for the given input question:\\n\\nSELECT Team FROM nba_roster WHERE NAME = 'Klay Thompson';\"" ] }, { "cell_type": "code", "execution_count": 8, "id": "222169a7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sure! Here's the SQL query for the given input question:\n", "\n", "SELECT Team FROM nba_roster WHERE NAME = 'Klay Thompson';\n" ] } ], "source": [ "# just to see how unbelievable it truly is\n", "print(answer)" ] }, { "cell_type": "code", "execution_count": 9, "id": "03d22e2c-24c1-4063-9556-37f48c202a84", "metadata": {}, "outputs": [], "source": [ "# one way to fix this is to set return_sql to be True, so each db_chain.run call will only make one llm call, and we'll\n", "# manually process each llm output\n", "db_chain = SQLDatabaseChain.from_llm(llm, db, verbose=True, return_sql=True, \n", " prompt=PromptTemplate(input_variables=[\"input\", \"table_info\"], \n", " template=PROMPT_SUFFIX))" ] }, { "cell_type": "code", "execution_count": 10, "id": "9319b42e-0c04-4b60-95b3-6c7758ba9cf8", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[32;1m\u001b[1;3m[chain/start]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain] Entering Chain run with input:\n", "\u001b[0m{\n", " \"query\": \"How many unique teams are there?\"\n", "}\n", "\u001b[32;1m\u001b[1;3m[chain/start]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain] Entering Chain run with input:\n", "\u001b[0m{\n", " \"input\": \"How many unique teams are there?\\nSQLQuery:\",\n", " \"top_k\": \"5\",\n", " \"dialect\": \"sqlite\",\n", " \"table_info\": \"\\nCREATE TABLE nba_roster (\\n\\t\\\"Team\\\" TEXT, \\n\\t\\\"NAME\\\" TEXT, \\n\\t\\\"Jersey\\\" TEXT, \\n\\t\\\"POS\\\" TEXT, \\n\\t\\\"AGE\\\" INTEGER, \\n\\t\\\"HT\\\" TEXT, \\n\\t\\\"WT\\\" TEXT, \\n\\t\\\"COLLEGE\\\" TEXT, \\n\\t\\\"SALARY\\\" TEXT\\n)\",\n", " \"stop\": [\n", " \"\\nSQLResult:\"\n", " ]\n", "}\n", "\u001b[32;1m\u001b[1;3m[llm/start]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain > 3:llm:Replicate] Entering LLM run with input:\n", "\u001b[0m{\n", " \"prompts\": [\n", " \"Only use the following tables:\\n\\nCREATE TABLE nba_roster (\\n\\t\\\"Team\\\" TEXT, \\n\\t\\\"NAME\\\" TEXT, \\n\\t\\\"Jersey\\\" TEXT, \\n\\t\\\"POS\\\" TEXT, \\n\\t\\\"AGE\\\" INTEGER, \\n\\t\\\"HT\\\" TEXT, \\n\\t\\\"WT\\\" TEXT, \\n\\t\\\"COLLEGE\\\" TEXT, \\n\\t\\\"SALARY\\\" TEXT\\n)\\n\\nQuestion: How many unique teams are there?\\nSQLQuery:\"\n", " ]\n", "}\n", "\u001b[36;1m\u001b[1;3m[llm/end]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain > 3:llm:Replicate] [2.20s] Exiting LLM run with output:\n", "\u001b[0m{\n", " \"generations\": [\n", " [\n", " {\n", " \"text\": \" SELECT COUNT(DISTINCT Team) AS num_teams FROM nba_roster;\",\n", " \"generation_info\": null\n", " }\n", " ]\n", " ],\n", " \"llm_output\": null,\n", " \"run\": null\n", "}\n", "\u001b[36;1m\u001b[1;3m[chain/end]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain] [2.20s] Exiting Chain run with output:\n", "\u001b[0m{\n", " \"text\": \" SELECT COUNT(DISTINCT Team) AS num_teams FROM nba_roster;\"\n", "}\n", "\u001b[36;1m\u001b[1;3m[chain/end]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain] [2.20s] Exiting Chain run with output:\n", "\u001b[0m{\n", " \"result\": \"SELECT COUNT(DISTINCT Team) AS num_teams FROM nba_roster;\"\n", "}\n", "SELECT COUNT(DISTINCT Team) AS num_teams FROM nba_roster;\n" ] } ], "source": [ "# let's go back to the question that works before to make sure the change above doesn't break it\n", "# because return_sql is True, and system_prompt is set to \"Given an input question, convert it to a SQL query. No pre-amble.\" when \n", "# creating the llm instance, the db_chain's run will make just one Llama call and ask it to return the SQL statement only, \n", "# instead of making two Llama calls and returning the answer it can't do in the above case.\n", "sql = db_chain.run(\"How many unique teams are there?\")\n", "print(sql)" ] }, { "cell_type": "code", "execution_count": 11, "id": "35cba597-5774-472a-a464-41a2eba062ab", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[{'num_teams': 30}]\n" ] } ], "source": [ "print(db._execute(sql))\n", "# the right answer is generated" ] }, { "cell_type": "code", "execution_count": 12, "id": "fb38a462-e23a-45b3-8379-7900d07751bf", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[32;1m\u001b[1;3m[chain/start]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain] Entering Chain run with input:\n", "\u001b[0m{\n", " \"query\": \"Which team is Klay Thompson in?\"\n", "}\n", "\u001b[32;1m\u001b[1;3m[chain/start]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain] Entering Chain run with input:\n", "\u001b[0m{\n", " \"input\": \"Which team is Klay Thompson in?\\nSQLQuery:\",\n", " \"top_k\": \"5\",\n", " \"dialect\": \"sqlite\",\n", " \"table_info\": \"\\nCREATE TABLE nba_roster (\\n\\t\\\"Team\\\" TEXT, \\n\\t\\\"NAME\\\" TEXT, \\n\\t\\\"Jersey\\\" TEXT, \\n\\t\\\"POS\\\" TEXT, \\n\\t\\\"AGE\\\" INTEGER, \\n\\t\\\"HT\\\" TEXT, \\n\\t\\\"WT\\\" TEXT, \\n\\t\\\"COLLEGE\\\" TEXT, \\n\\t\\\"SALARY\\\" TEXT\\n)\",\n", " \"stop\": [\n", " \"\\nSQLResult:\"\n", " ]\n", "}\n", "\u001b[32;1m\u001b[1;3m[llm/start]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain > 3:llm:Replicate] Entering LLM run with input:\n", "\u001b[0m{\n", " \"prompts\": [\n", " \"Only use the following tables:\\n\\nCREATE TABLE nba_roster (\\n\\t\\\"Team\\\" TEXT, \\n\\t\\\"NAME\\\" TEXT, \\n\\t\\\"Jersey\\\" TEXT, \\n\\t\\\"POS\\\" TEXT, \\n\\t\\\"AGE\\\" INTEGER, \\n\\t\\\"HT\\\" TEXT, \\n\\t\\\"WT\\\" TEXT, \\n\\t\\\"COLLEGE\\\" TEXT, \\n\\t\\\"SALARY\\\" TEXT\\n)\\n\\nQuestion: Which team is Klay Thompson in?\\nSQLQuery:\"\n", " ]\n", "}\n", "\u001b[36;1m\u001b[1;3m[llm/end]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain > 3:llm:Replicate] [2.11s] Exiting LLM run with output:\n", "\u001b[0m{\n", " \"generations\": [\n", " [\n", " {\n", " \"text\": \" SELECT Team FROM nba_roster WHERE NAME = 'Klay Thompson';\",\n", " \"generation_info\": null\n", " }\n", " ]\n", " ],\n", " \"llm_output\": null,\n", " \"run\": null\n", "}\n", "\u001b[36;1m\u001b[1;3m[chain/end]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain] [2.11s] Exiting Chain run with output:\n", "\u001b[0m{\n", " \"text\": \" SELECT Team FROM nba_roster WHERE NAME = 'Klay Thompson';\"\n", "}\n", "\u001b[36;1m\u001b[1;3m[chain/end]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain] [2.11s] Exiting Chain run with output:\n", "\u001b[0m{\n", " \"result\": \"SELECT Team FROM nba_roster WHERE NAME = 'Klay Thompson';\"\n", "}\n", "SELECT Team FROM nba_roster WHERE NAME = 'Klay Thompson';\n" ] } ], "source": [ "# let's try the query that fails\n", "sql = db_chain.run(\"Which team is Klay Thompson in?\")\n", "print(sql)" ] }, { "cell_type": "code", "execution_count": 13, "id": "6aef6ae7-b14c-46d6-b4a4-3563abbd3391", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[{'Team': 'Golden State Warriors'}]\n" ] } ], "source": [ "# the sql result looks good. let's run it to get the answer\n", "print(db._execute(sql))" ] }, { "cell_type": "code", "execution_count": 14, "id": "39ed4bc3", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[32;1m\u001b[1;3m[chain/start]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain] Entering Chain run with input:\n", "\u001b[0m{\n", " \"query\": \"What's his salary?\"\n", "}\n", "\u001b[32;1m\u001b[1;3m[chain/start]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain] Entering Chain run with input:\n", "\u001b[0m{\n", " \"input\": \"What's his salary?\\nSQLQuery:\",\n", " \"top_k\": \"5\",\n", " \"dialect\": \"sqlite\",\n", " \"table_info\": \"\\nCREATE TABLE nba_roster (\\n\\t\\\"Team\\\" TEXT, \\n\\t\\\"NAME\\\" TEXT, \\n\\t\\\"Jersey\\\" TEXT, \\n\\t\\\"POS\\\" TEXT, \\n\\t\\\"AGE\\\" INTEGER, \\n\\t\\\"HT\\\" TEXT, \\n\\t\\\"WT\\\" TEXT, \\n\\t\\\"COLLEGE\\\" TEXT, \\n\\t\\\"SALARY\\\" TEXT\\n)\",\n", " \"stop\": [\n", " \"\\nSQLResult:\"\n", " ]\n", "}\n", "\u001b[32;1m\u001b[1;3m[llm/start]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain > 3:llm:Replicate] Entering LLM run with input:\n", "\u001b[0m{\n", " \"prompts\": [\n", " \"Only use the following tables:\\n\\nCREATE TABLE nba_roster (\\n\\t\\\"Team\\\" TEXT, \\n\\t\\\"NAME\\\" TEXT, \\n\\t\\\"Jersey\\\" TEXT, \\n\\t\\\"POS\\\" TEXT, \\n\\t\\\"AGE\\\" INTEGER, \\n\\t\\\"HT\\\" TEXT, \\n\\t\\\"WT\\\" TEXT, \\n\\t\\\"COLLEGE\\\" TEXT, \\n\\t\\\"SALARY\\\" TEXT\\n)\\n\\nQuestion: What's his salary?\\nSQLQuery:\"\n", " ]\n", "}\n", "\u001b[36;1m\u001b[1;3m[llm/end]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain > 3:llm:Replicate] [2.73s] Exiting LLM run with output:\n", "\u001b[0m{\n", " \"generations\": [\n", " [\n", " {\n", " \"text\": \" Sure! Here is the SQL query based on the given question and table structure:\\n\\nSELECT SALARY FROM nba_roster WHERE NAME = 'LeBron James';\",\n", " \"generation_info\": null\n", " }\n", " ]\n", " ],\n", " \"llm_output\": null,\n", " \"run\": null\n", "}\n", "\u001b[36;1m\u001b[1;3m[chain/end]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain] [2.73s] Exiting Chain run with output:\n", "\u001b[0m{\n", " \"text\": \" Sure! Here is the SQL query based on the given question and table structure:\\n\\nSELECT SALARY FROM nba_roster WHERE NAME = 'LeBron James';\"\n", "}\n", "\u001b[36;1m\u001b[1;3m[chain/end]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain] [2.73s] Exiting Chain run with output:\n", "\u001b[0m{\n", " \"result\": \"Sure! Here is the SQL query based on the given question and table structure:\\n\\nSELECT SALARY FROM nba_roster WHERE NAME = 'LeBron James';\"\n", "}\n" ] } ], "source": [ "# how about a follow up question\n", "sql = db_chain.run(\"What's his salary?\")" ] }, { "cell_type": "code", "execution_count": 15, "id": "0c305278-29d2-4e88-9b3d-ad67c94ce0f2", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Sure! Here is the SQL query based on the given question and table structure:\n", "\n", "SELECT SALARY FROM nba_roster WHERE NAME = 'LeBron James';\n" ] } ], "source": [ "# two issues here: the llm output is wordy again, even though we use the system prompt to ask it not to be; \n", "# the other issue makes sense: we didn't pass the context along with the follow-up to llm so it didn't know what \"his\" is.\n", "print(sql)" ] }, { "cell_type": "code", "execution_count": 16, "id": "670bc858", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "SELECT SALARY FROM nba_roster WHERE NAME = 'LeBron James';\n", "[{'SALARY': '$47,607,350'}]\n" ] } ], "source": [ "# let's fix the first issue by custom parsing, which is typical in real LLM apps\n", "# use regular expression to extract only the SQL SELECT statement\n", "import re\n", "\n", "pattern = r'SELECT[\\s\\S]*?;'\n", "match = re.search(pattern, sql)\n", "if match:\n", " sql = match.group()\n", " print(sql)\n", "print(db._execute(sql))" ] }, { "cell_type": "code", "execution_count": 17, "id": "0bfa3d34", "metadata": {}, "outputs": [], "source": [ "# now let's try to fix the issue that the context (the previous question and answer) was not sent to LLM along with the new question\n", "# SQLDatabaseChain.from_llm has a parameter \"memory\" which can be set to a ConversationBufferMemory instance, which looks promising.\n", "from langchain.memory import ConversationBufferMemory\n", "\n", "memory = ConversationBufferMemory()\n", "db_chain_memory = SQLDatabaseChain.from_llm(llm, db, memory=memory, \n", " verbose=True, return_sql=True, \n", " prompt=PromptTemplate(input_variables=[\"input\", \"table_info\"], \n", " template=PROMPT_SUFFIX))" ] }, { "cell_type": "code", "execution_count": 18, "id": "d12b50e7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[32;1m\u001b[1;3m[chain/start]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain] Entering Chain run with input:\n", "\u001b[0m{\n", " \"query\": \"Which team is Klay Thompson in\",\n", " \"history\": \"\"\n", "}\n", "\u001b[32;1m\u001b[1;3m[chain/start]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain] Entering Chain run with input:\n", "\u001b[0m{\n", " \"input\": \"Which team is Klay Thompson in\\nSQLQuery:\",\n", " \"top_k\": \"5\",\n", " \"dialect\": \"sqlite\",\n", " \"table_info\": \"\\nCREATE TABLE nba_roster (\\n\\t\\\"Team\\\" TEXT, \\n\\t\\\"NAME\\\" TEXT, \\n\\t\\\"Jersey\\\" TEXT, \\n\\t\\\"POS\\\" TEXT, \\n\\t\\\"AGE\\\" INTEGER, \\n\\t\\\"HT\\\" TEXT, \\n\\t\\\"WT\\\" TEXT, \\n\\t\\\"COLLEGE\\\" TEXT, \\n\\t\\\"SALARY\\\" TEXT\\n)\",\n", " \"stop\": [\n", " \"\\nSQLResult:\"\n", " ],\n", " \"history\": \"\"\n", "}\n", "\u001b[32;1m\u001b[1;3m[llm/start]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain > 3:llm:Replicate] Entering LLM run with input:\n", "\u001b[0m{\n", " \"prompts\": [\n", " \"Only use the following tables:\\n\\nCREATE TABLE nba_roster (\\n\\t\\\"Team\\\" TEXT, \\n\\t\\\"NAME\\\" TEXT, \\n\\t\\\"Jersey\\\" TEXT, \\n\\t\\\"POS\\\" TEXT, \\n\\t\\\"AGE\\\" INTEGER, \\n\\t\\\"HT\\\" TEXT, \\n\\t\\\"WT\\\" TEXT, \\n\\t\\\"COLLEGE\\\" TEXT, \\n\\t\\\"SALARY\\\" TEXT\\n)\\n\\nQuestion: Which team is Klay Thompson in\\nSQLQuery:\"\n", " ]\n", "}\n", "\u001b[36;1m\u001b[1;3m[llm/end]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain > 3:llm:Replicate] [2.11s] Exiting LLM run with output:\n", "\u001b[0m{\n", " \"generations\": [\n", " [\n", " {\n", " \"text\": \" SELECT Team FROM nba_roster WHERE NAME = 'Klay Thompson';\",\n", " \"generation_info\": null\n", " }\n", " ]\n", " ],\n", " \"llm_output\": null,\n", " \"run\": null\n", "}\n", "\u001b[36;1m\u001b[1;3m[chain/end]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain] [2.11s] Exiting Chain run with output:\n", "\u001b[0m{\n", " \"text\": \" SELECT Team FROM nba_roster WHERE NAME = 'Klay Thompson';\"\n", "}\n", "\u001b[36;1m\u001b[1;3m[chain/end]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain] [2.11s] Exiting Chain run with output:\n", "\u001b[0m{\n", " \"result\": \"SELECT Team FROM nba_roster WHERE NAME = 'Klay Thompson';\"\n", "}\n", "SELECT Team FROM nba_roster WHERE NAME = 'Klay Thompson';\n" ] } ], "source": [ "# use the db_chain_memory to run the original question again\n", "question = \"Which team is Klay Thompson in\"\n", "sql = db_chain_memory.run(question)\n", "print(sql)" ] }, { "cell_type": "code", "execution_count": 19, "id": "15cb95bf-73fe-404f-9b30-79dd6a53e369", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[{'Team': 'Golden State Warriors'}]\n" ] } ], "source": [ "answer = db._execute(sql)\n", "print(answer)" ] }, { "cell_type": "code", "execution_count": 20, "id": "fee65d5f", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[32;1m\u001b[1;3m[chain/start]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain] Entering Chain run with input:\n", "\u001b[0m{\n", " \"query\": \"What's his salary\",\n", " \"history\": \"Human: Which team is Klay Thompson in\\nAI: SELECT Team FROM nba_roster WHERE NAME = 'Klay Thompson';\\nHuman: Which team is Klay Thompson in\\nAI: [{\\\"Team\\\": \\\"Golden State Warriors\\\"}]\"\n", "}\n", "\u001b[32;1m\u001b[1;3m[chain/start]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain] Entering Chain run with input:\n", "\u001b[0m{\n", " \"input\": \"What's his salary\\nSQLQuery:\",\n", " \"top_k\": \"5\",\n", " \"dialect\": \"sqlite\",\n", " \"table_info\": \"\\nCREATE TABLE nba_roster (\\n\\t\\\"Team\\\" TEXT, \\n\\t\\\"NAME\\\" TEXT, \\n\\t\\\"Jersey\\\" TEXT, \\n\\t\\\"POS\\\" TEXT, \\n\\t\\\"AGE\\\" INTEGER, \\n\\t\\\"HT\\\" TEXT, \\n\\t\\\"WT\\\" TEXT, \\n\\t\\\"COLLEGE\\\" TEXT, \\n\\t\\\"SALARY\\\" TEXT\\n)\",\n", " \"stop\": [\n", " \"\\nSQLResult:\"\n", " ],\n", " \"history\": \"Human: Which team is Klay Thompson in\\nAI: SELECT Team FROM nba_roster WHERE NAME = 'Klay Thompson';\\nHuman: Which team is Klay Thompson in\\nAI: [{\\\"Team\\\": \\\"Golden State Warriors\\\"}]\"\n", "}\n", "\u001b[32;1m\u001b[1;3m[llm/start]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain > 3:llm:Replicate] Entering LLM run with input:\n", "\u001b[0m{\n", " \"prompts\": [\n", " \"Only use the following tables:\\n\\nCREATE TABLE nba_roster (\\n\\t\\\"Team\\\" TEXT, \\n\\t\\\"NAME\\\" TEXT, \\n\\t\\\"Jersey\\\" TEXT, \\n\\t\\\"POS\\\" TEXT, \\n\\t\\\"AGE\\\" INTEGER, \\n\\t\\\"HT\\\" TEXT, \\n\\t\\\"WT\\\" TEXT, \\n\\t\\\"COLLEGE\\\" TEXT, \\n\\t\\\"SALARY\\\" TEXT\\n)\\n\\nQuestion: What's his salary\\nSQLQuery:\"\n", " ]\n", "}\n", "\u001b[36;1m\u001b[1;3m[llm/end]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain > 3:llm:Replicate] [2.75s] Exiting LLM run with output:\n", "\u001b[0m{\n", " \"generations\": [\n", " [\n", " {\n", " \"text\": \" Sure! Here is the SQL query based on the given question and table structure:\\n\\nSELECT SALARY FROM nba_roster WHERE NAME = 'Lebron James';\",\n", " \"generation_info\": null\n", " }\n", " ]\n", " ],\n", " \"llm_output\": null,\n", " \"run\": null\n", "}\n", "\u001b[36;1m\u001b[1;3m[chain/end]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain > 2:chain:LLMChain] [2.75s] Exiting Chain run with output:\n", "\u001b[0m{\n", " \"text\": \" Sure! Here is the SQL query based on the given question and table structure:\\n\\nSELECT SALARY FROM nba_roster WHERE NAME = 'Lebron James';\"\n", "}\n", "\u001b[36;1m\u001b[1;3m[chain/end]\u001b[0m \u001b[1m[1:chain:SQLDatabaseChain] [2.75s] Exiting Chain run with output:\n", "\u001b[0m{\n", " \"result\": \"Sure! Here is the SQL query based on the given question and table structure:\\n\\nSELECT SALARY FROM nba_roster WHERE NAME = 'Lebron James';\"\n", "}\n", "Sure! Here is the SQL query based on the given question and table structure:\n", "\n", "SELECT SALARY FROM nba_roster WHERE NAME = 'Lebron James';\n" ] } ], "source": [ "import json\n", "\n", "memory.save_context({\"input\": question},\n", " {\"output\": json.dumps(answer)})\n", "followup = \"What's his salary\"\n", "followup_answer = db_chain_memory.run(followup)\n", "print(followup_answer)\n", "\n", "# \"Entering Chain run with input\" does show \"history\" including our question: \"Human: Which team is Klay Thompson in\"\n", "# but it doesn't get passed to \"Entering LLM run with input\", so the llm output still is Lebron James." ] }, { "cell_type": "markdown", "id": "d23d47e9", "metadata": {}, "source": [ "Looks like we hit a [known issue](https://github.com/langchain-ai/langchain/issues/6918#issuecomment-1632932653) with adding memory to SQLDatabaseChain, even after the [PR](https://github.com/langchain-ai/langchain/pull/7546) that's supposed to fix it was merged. We'll leave the notebook using the experimental SQLDatabaseChain as is. To find a better solution to using SQLDatabaseChain, we can:\n", "1. Play with the prompt engineering further;\n", "2. Look into the SQLDatabaseChain's implementation and the use of ConversationBufferMemory to fix the memory issue;\n", "3. Wait till SQLDatabaseChain prompted from the langchain_experimental package to langchain." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.18" } }, "nbformat": 4, "nbformat_minor": 5 }