Sql chat example

This commit is contained in:
Sam Partee 2024-04-29 21:28:58 -07:00
parent 8b62119cb2
commit 41b783ef2e
4 changed files with 324 additions and 15 deletions

188
examples/sql-chat/agent.py Normal file
View file

@ -0,0 +1,188 @@
import httpx
import json
import openai
from typing import Any, Dict, List, Optional
class Toolchain:
available_tools = {
"query_sql": "/tool/query/query_sql",
"list_data_sources": "/tool/query/list_data_sources",
"get_data_schema": "/tool/query/get_data_schema"
}
def __init__(self, base_url: str, openai_api_key: str, model: str = "gpt-4-turbo"):
self.base_url = base_url
self.client = httpx.Client()
self.openai_client = openai.Client(api_key=openai_api_key)
self.model = model
self.tools = self.__collect_tool_specs()
def __collect_tool_specs(self) -> Dict[str, str]:
tools = {}
for tool_name, endpoint in self.available_tools.items():
openai_spec = self.call_api("GET", "/api/v1/tools/oai_function", params={"tool_name": tool_name}).get("data", {})
tools[tool_name] = openai_spec
return tools
def call_api(self, method: str, endpoint: str, params: dict = {}, data: dict = {}, json_data: dict = {}) -> Dict[str, Any]:
url = f"{self.base_url}{endpoint}"
response = self.client.request(method, url, params=params, json=json_data, data=data)
try:
response.raise_for_status()
except httpx.HTTPStatusError as e:
print(f"HTTP error: {e}")
result = response.json()
return result
def get_tool_args(self, tool_name: str, messages: List[Dict[str, str]]) -> Dict[str, Any]:
"""
Retrieves the required arguments for an tool from the Darkstar Toolserver API and
uses them to call an OpenAI model with predefined tools and messages.
:param tool_name: The name of the tool to execute.
:param messages: A list of messages to provide to the model.
:return: The result of the OpenAI model call.
"""
func_spec = self.tools.get(tool_name, {})
if not func_spec:
raise ValueError(f"Tool '{tool_name}' not found in available tools.")
tool = json.loads(func_spec)
# Call the OpenAI model with the tools and messages
completion = self.openai_client.chat.completions.create(
model="gpt-4-turbo",
messages=messages,
tools=[tool],
tool_choice="auto"
)
predicted_args = completion.choices[0].message.tool_calls[0].function.arguments
print(predicted_args)
print("-----")
return predicted_args
def execute_tool(self, tool_name: str, tool_args: Optional[Dict[str, Any]]) -> Dict[str, Any]:
"""
Executes an tool using the Darkstar Toolserver API and an OpenAI model.
:param tool_name: The name of the tool to execute.
:return: The result of the tool
"""
# Prepare the input message for the OpenAI model
endpoint = self.available_tools[tool_name]
result = self.call_api("POST", endpoint, json_data=tool_args)
return result
from pydantic import BaseModel
from typing import List, Dict
from textwrap import dedent
class Agent:
prompt = dedent("""Given a user query and a schema of a table, generate the SQL query to answer the user query.
The generated SQL query should only refer to columns in the table schema list below. The table schema is as follows:
{schema}
The data_id of this source is: {data_id}
""")
def __init__(self, toolchain: Toolchain):
self.toolchain = toolchain
self.data_sources = self.__get_data_sources()
self._source = None
self._data_schema = None
def set_source(self, source: str):
if source not in self.data_sources.keys():
raise ValueError(f"Data source '{source}' not found.")
else:
data_id = self.data_sources[source]
# get the schema
schema = self.toolchain.call_api("POST", "/tool/query/get_data_schema", json_data={"data_id": data_id})
self._source = source
self._data_schema = schema
def get_source(self) -> str:
return self._source
def __get_data_sources(self) -> Dict[str, Dict[str, str]]:
response = self.toolchain.call_api("POST", "/tool/query/list_data_sources")
sources = {}
for _id, source_data in response["data"]["result"].items():
sources[source_data["file_name"]] = _id
return sources
def query(self, user_query: str) -> str:
if not self._source:
raise ValueError("Data source not set. Please set a data source before querying.")
schema = self._data_schema
prompt = self.prompt.format(schema=schema, data_id=self.data_sources[self._source])
# Prepare the input message for the OpenAI model
messages = [
{"role": "system", "content": prompt},
{"role": "user", "content": user_query}]
tool_args = self.toolchain.get_tool_args("query_sql", messages)
args = json.loads(tool_args)
params = args.get("params", [])
if params:
if isinstance(params, dict):
args["params"] = list(params.values())
elif isinstance(params, str):
args["params"] = [params]
elif isinstance(params, list):
args["params"] = params
else:
raise ValueError(f"Invalid params type: {type(params)}")
response = self.toolchain.execute_tool("query_sql", args)
if response["code"] != 200:
raise ValueError(f"Error executing tool: {response['message']}")
data_id = response["data"]["result"]["data_id"]
# get the data
data_response = self.toolchain.call_api("GET", f"/api/v1/data/object/{data_id}")
if data_response["code"] != 200:
raise ValueError(f"Error retrieving data: {data_response['message']}")
data = data_response["data"]["json_blob"]
return data
from pydantic import BaseModel, Field
from enum import Enum
class ToolNode:
pass
class ToolFlow:
def __init__(
self,
name,
description,
sources,
):
pass
""" # Example usage:
oai_key = "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"
toolchain = Toolchain(base_url="http://localhost:8000", model="gpt-4-turbo", openai_api_key=oai_key)
agent = Agent(toolchain)
agent.set_source("users_db")
while True:
user_query = input("Enter a query: ")
result = agent.query(user_query)
print(result)
"""

103
examples/sql-chat/main.py Normal file
View file

@ -0,0 +1,103 @@
import openai
oai_key = "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"
import json
import logging
import subprocess
import sys
import time
import traceback
import os
import pandas as pd
import streamlit as st
from pydantic import BaseModel
from streamlit_chat import message
from agent import Agent, Toolchain
@st.cache_resource()
def get_agent():
toolchain = Toolchain(base_url="http://localhost:8000", model="gpt-4-turbo", openai_api_key=oai_key)
agent = Agent(toolchain)
agent.set_source("users_db")
return agent
# From here down is all the StreamLit UI.
st.set_page_config(page_title="Data Chat", page_icon=":robot:", layout="wide")
st.header("Arcade AI Demo")
def initialize_logger():
logger = logging.getLogger("root")
handler = logging.StreamHandler(sys.stdout)
logger.setLevel(logging.INFO)
logger.handlers = [handler]
return True
if "logger" not in st.session_state:
st.session_state["logger"] = initialize_logger()
if "past" not in st.session_state:
st.session_state["past"] = []
if "generated" not in st.session_state:
st.session_state["generated"] = []
st.subheader("Chat")
chat_container = st.container()
input_container = st.container()
def submit():
submit_text = st.session_state["input"]
st.session_state["input"] = ""
with st.spinner(text="Wait for Agent..."):
try:
agent = get_agent()
res = agent.query(submit_text)
except Exception:
res = traceback.format_exc()
st.session_state.past.append(submit_text)
st.session_state.generated.append(res)
def get_text():
input_text = st.text_input("You: ", key="input", on_change=submit)
return input_text
with input_container:
user_input = get_text()
if st.session_state["generated"]:
with chat_container:
for i in range(
len(st.session_state["generated"])
): # range(len(st.session_state["generated"]) - 1, -1, -1):
message(st.session_state["past"][i], is_user=True, key=str(i) + "_user")
res = st.session_state["generated"][i]
try:
json_res = json.loads(res)["data"]
print(json_res)
except Exception:
json_res = None
if json_res:
try:
res = pd.DataFrame(json_res)
except Exception:
res = json_res
if isinstance(res, str):
st.write(res)
elif isinstance(res, pd.DataFrame):
st.dataframe(res)
else:
st.error("Returned result:")
st.error(res)

View file

@ -1,4 +1,4 @@
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union, List
import io
from toolserve.sdk.client import list_data, log
@ -15,11 +15,16 @@ async def list_data_sources() -> Dict[str, Dict[str, str]]:
Dict[str, str]: A dictionary mapping data source IDs to their details.
"""
data = await list_data()
return {str(item["id"]): {
"file_name": item["file_name"],
"created_at": item["created_time"],
"updated_at": item["updated_time"]
} for item in data}
partial = {}
for item in data:
details = {
"file_name": item["file_name"],
"created_at": item["created_time"]
}
if "updated_time" in item and item["updated_time"] is not None:
details["updated_at"] = item["updated_time"]
partial[str(item["id"])] = details
return partial
@tool
async def get_data_schema(
@ -35,19 +40,21 @@ async def get_data_schema(
"""
# TODO read in only a few lines
df = await get_df(data_id)
return get_df_info(df)
return get_df_info(df)["schema"]
@tool
async def query_sql(
data_id: Param(int, "id of the data source"),
sql: Param(str, "parameterized SQL query to execute"),
params: Param(Optional[Dict[str, Any]], "parameters to pass to the SQL query") = None,
) -> Param(str, "schema of the data source after executing the query"):
params: Param(Optional[List[Union[str, int]]], "parameters to pass to the SQL query") = None,
) -> Dict[str, Union[int, str]]:
"""Query a data source using SQL
The SQL query should be parameterized with Python's named parameter syntax,
e.g. `SELECT * FROM table WHERE column = :value`.
The SQL query should be parameterized with DuckDB's syntax. For example, to query a
DataFrame named `df` with a parameter `param`, the query should be `SELECT * FROM df WHERE column = ?`.
The list of params should be in order of the parameters in the SQL query.
After the query, a new data source at a new id will be created with the results and
the schema of the data source will be returned.
@ -83,7 +90,7 @@ async def query_sql(
raise RuntimeError(f"Query execution failed: {str(e)}")
def get_df_info(df: pd.DataFrame, data_id: Optional[int]=None) -> str:
def get_df_info(df: pd.DataFrame, data_id: Optional[int]=None) -> Dict[str, Union[int, str]]:
"""
Generate a compact string representation of a DataFrame including the count of columns,
rows, overall size, and details for each column such as name and datatype.
@ -92,7 +99,7 @@ def get_df_info(df: pd.DataFrame, data_id: Optional[int]=None) -> str:
df (pd.DataFrame): The Pandas DataFrame to describe.
Returns:
str: A string that contains the compact representation of the DataFrame.
Dict[str, Union[int, str]]: A dictionary containing the DataFrame details and data_id
"""
# Create an output stream to collect strings
@ -112,8 +119,19 @@ def get_df_info(df: pd.DataFrame, data_id: Optional[int]=None) -> str:
output.write(f"Column: {column}\n")
output.write(f"type: {df[column].dtype}\n")
# put top 5 rows in the output if there are more than 5 rows.
if len(df.index) > 5:
output.write("---\n")
output.write("Top 5 rows:\n")
output.write(df.head().to_string())
# Get the complete string from the output stream
result = output.getvalue()
output.close()
return result
info = {
"schema": result
}
if data_id:
info["data_id"] = data_id
return info

View file

@ -43,4 +43,4 @@ class GetDataDetails(DataSchemaBase):
file_path: str = Field(..., title="File Path", description="Path of the file")
created_time: datetime = Field(..., title="Creation Time", description="Time when the Data entry was created")
updated_time: datetime | None = Field(default=None, title="Updated Time", description="Time when the Data entry was last updated")
updated_time: datetime | None = Field(default=datetime.now(), title="Updated Time", description="Time when the Data entry was last updated")