Sql chat example
This commit is contained in:
parent
8b62119cb2
commit
41b783ef2e
4 changed files with 324 additions and 15 deletions
188
examples/sql-chat/agent.py
Normal file
188
examples/sql-chat/agent.py
Normal file
|
|
@ -0,0 +1,188 @@
|
|||
import httpx
|
||||
import json
|
||||
import openai
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
class Toolchain:
|
||||
|
||||
available_tools = {
|
||||
"query_sql": "/tool/query/query_sql",
|
||||
"list_data_sources": "/tool/query/list_data_sources",
|
||||
"get_data_schema": "/tool/query/get_data_schema"
|
||||
}
|
||||
|
||||
def __init__(self, base_url: str, openai_api_key: str, model: str = "gpt-4-turbo"):
|
||||
self.base_url = base_url
|
||||
self.client = httpx.Client()
|
||||
self.openai_client = openai.Client(api_key=openai_api_key)
|
||||
self.model = model
|
||||
self.tools = self.__collect_tool_specs()
|
||||
|
||||
def __collect_tool_specs(self) -> Dict[str, str]:
|
||||
tools = {}
|
||||
for tool_name, endpoint in self.available_tools.items():
|
||||
openai_spec = self.call_api("GET", "/api/v1/tools/oai_function", params={"tool_name": tool_name}).get("data", {})
|
||||
tools[tool_name] = openai_spec
|
||||
return tools
|
||||
|
||||
def call_api(self, method: str, endpoint: str, params: dict = {}, data: dict = {}, json_data: dict = {}) -> Dict[str, Any]:
|
||||
url = f"{self.base_url}{endpoint}"
|
||||
response = self.client.request(method, url, params=params, json=json_data, data=data)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
print(f"HTTP error: {e}")
|
||||
result = response.json()
|
||||
return result
|
||||
|
||||
def get_tool_args(self, tool_name: str, messages: List[Dict[str, str]]) -> Dict[str, Any]:
|
||||
"""
|
||||
Retrieves the required arguments for an tool from the Darkstar Toolserver API and
|
||||
uses them to call an OpenAI model with predefined tools and messages.
|
||||
|
||||
:param tool_name: The name of the tool to execute.
|
||||
:param messages: A list of messages to provide to the model.
|
||||
:return: The result of the OpenAI model call.
|
||||
"""
|
||||
func_spec = self.tools.get(tool_name, {})
|
||||
if not func_spec:
|
||||
raise ValueError(f"Tool '{tool_name}' not found in available tools.")
|
||||
|
||||
tool = json.loads(func_spec)
|
||||
# Call the OpenAI model with the tools and messages
|
||||
completion = self.openai_client.chat.completions.create(
|
||||
model="gpt-4-turbo",
|
||||
messages=messages,
|
||||
tools=[tool],
|
||||
tool_choice="auto"
|
||||
)
|
||||
predicted_args = completion.choices[0].message.tool_calls[0].function.arguments
|
||||
print(predicted_args)
|
||||
print("-----")
|
||||
return predicted_args
|
||||
|
||||
def execute_tool(self, tool_name: str, tool_args: Optional[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""
|
||||
Executes an tool using the Darkstar Toolserver API and an OpenAI model.
|
||||
|
||||
:param tool_name: The name of the tool to execute.
|
||||
:return: The result of the tool
|
||||
"""
|
||||
|
||||
# Prepare the input message for the OpenAI model
|
||||
endpoint = self.available_tools[tool_name]
|
||||
result = self.call_api("POST", endpoint, json_data=tool_args)
|
||||
return result
|
||||
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Dict
|
||||
from textwrap import dedent
|
||||
|
||||
class Agent:
|
||||
|
||||
prompt = dedent("""Given a user query and a schema of a table, generate the SQL query to answer the user query.
|
||||
|
||||
|
||||
The generated SQL query should only refer to columns in the table schema list below. The table schema is as follows:
|
||||
{schema}
|
||||
|
||||
The data_id of this source is: {data_id}
|
||||
""")
|
||||
|
||||
|
||||
def __init__(self, toolchain: Toolchain):
|
||||
self.toolchain = toolchain
|
||||
self.data_sources = self.__get_data_sources()
|
||||
self._source = None
|
||||
self._data_schema = None
|
||||
|
||||
def set_source(self, source: str):
|
||||
if source not in self.data_sources.keys():
|
||||
raise ValueError(f"Data source '{source}' not found.")
|
||||
else:
|
||||
data_id = self.data_sources[source]
|
||||
# get the schema
|
||||
schema = self.toolchain.call_api("POST", "/tool/query/get_data_schema", json_data={"data_id": data_id})
|
||||
self._source = source
|
||||
self._data_schema = schema
|
||||
|
||||
def get_source(self) -> str:
|
||||
return self._source
|
||||
|
||||
def __get_data_sources(self) -> Dict[str, Dict[str, str]]:
|
||||
response = self.toolchain.call_api("POST", "/tool/query/list_data_sources")
|
||||
sources = {}
|
||||
for _id, source_data in response["data"]["result"].items():
|
||||
sources[source_data["file_name"]] = _id
|
||||
return sources
|
||||
|
||||
|
||||
def query(self, user_query: str) -> str:
|
||||
if not self._source:
|
||||
raise ValueError("Data source not set. Please set a data source before querying.")
|
||||
schema = self._data_schema
|
||||
prompt = self.prompt.format(schema=schema, data_id=self.data_sources[self._source])
|
||||
|
||||
# Prepare the input message for the OpenAI model
|
||||
messages = [
|
||||
{"role": "system", "content": prompt},
|
||||
{"role": "user", "content": user_query}]
|
||||
|
||||
tool_args = self.toolchain.get_tool_args("query_sql", messages)
|
||||
args = json.loads(tool_args)
|
||||
params = args.get("params", [])
|
||||
if params:
|
||||
if isinstance(params, dict):
|
||||
args["params"] = list(params.values())
|
||||
elif isinstance(params, str):
|
||||
args["params"] = [params]
|
||||
elif isinstance(params, list):
|
||||
args["params"] = params
|
||||
else:
|
||||
raise ValueError(f"Invalid params type: {type(params)}")
|
||||
|
||||
|
||||
response = self.toolchain.execute_tool("query_sql", args)
|
||||
if response["code"] != 200:
|
||||
raise ValueError(f"Error executing tool: {response['message']}")
|
||||
data_id = response["data"]["result"]["data_id"]
|
||||
|
||||
# get the data
|
||||
data_response = self.toolchain.call_api("GET", f"/api/v1/data/object/{data_id}")
|
||||
if data_response["code"] != 200:
|
||||
raise ValueError(f"Error retrieving data: {data_response['message']}")
|
||||
data = data_response["data"]["json_blob"]
|
||||
return data
|
||||
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from enum import Enum
|
||||
|
||||
class ToolNode:
|
||||
pass
|
||||
|
||||
class ToolFlow:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name,
|
||||
description,
|
||||
sources,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
|
||||
""" # Example usage:
|
||||
oai_key = "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"
|
||||
toolchain = Toolchain(base_url="http://localhost:8000", model="gpt-4-turbo", openai_api_key=oai_key)
|
||||
agent = Agent(toolchain)
|
||||
agent.set_source("users_db")
|
||||
|
||||
while True:
|
||||
user_query = input("Enter a query: ")
|
||||
result = agent.query(user_query)
|
||||
print(result)
|
||||
"""
|
||||
103
examples/sql-chat/main.py
Normal file
103
examples/sql-chat/main.py
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
import openai
|
||||
|
||||
oai_key = "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"
|
||||
|
||||
|
||||
import json
|
||||
import logging
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import os
|
||||
|
||||
import pandas as pd
|
||||
import streamlit as st
|
||||
from pydantic import BaseModel
|
||||
from streamlit_chat import message
|
||||
|
||||
|
||||
from agent import Agent, Toolchain
|
||||
|
||||
@st.cache_resource()
|
||||
def get_agent():
|
||||
toolchain = Toolchain(base_url="http://localhost:8000", model="gpt-4-turbo", openai_api_key=oai_key)
|
||||
agent = Agent(toolchain)
|
||||
agent.set_source("users_db")
|
||||
return agent
|
||||
|
||||
|
||||
# From here down is all the StreamLit UI.
|
||||
st.set_page_config(page_title="Data Chat", page_icon=":robot:", layout="wide")
|
||||
st.header("Arcade AI Demo")
|
||||
|
||||
|
||||
def initialize_logger():
|
||||
logger = logging.getLogger("root")
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
logger.setLevel(logging.INFO)
|
||||
logger.handlers = [handler]
|
||||
return True
|
||||
|
||||
if "logger" not in st.session_state:
|
||||
st.session_state["logger"] = initialize_logger()
|
||||
if "past" not in st.session_state:
|
||||
st.session_state["past"] = []
|
||||
if "generated" not in st.session_state:
|
||||
st.session_state["generated"] = []
|
||||
|
||||
|
||||
|
||||
st.subheader("Chat")
|
||||
|
||||
|
||||
chat_container = st.container()
|
||||
input_container = st.container()
|
||||
|
||||
def submit():
|
||||
submit_text = st.session_state["input"]
|
||||
st.session_state["input"] = ""
|
||||
with st.spinner(text="Wait for Agent..."):
|
||||
try:
|
||||
agent = get_agent()
|
||||
res = agent.query(submit_text)
|
||||
except Exception:
|
||||
res = traceback.format_exc()
|
||||
st.session_state.past.append(submit_text)
|
||||
st.session_state.generated.append(res)
|
||||
|
||||
def get_text():
|
||||
input_text = st.text_input("You: ", key="input", on_change=submit)
|
||||
return input_text
|
||||
|
||||
with input_container:
|
||||
user_input = get_text()
|
||||
|
||||
if st.session_state["generated"]:
|
||||
with chat_container:
|
||||
for i in range(
|
||||
len(st.session_state["generated"])
|
||||
): # range(len(st.session_state["generated"]) - 1, -1, -1):
|
||||
message(st.session_state["past"][i], is_user=True, key=str(i) + "_user")
|
||||
|
||||
res = st.session_state["generated"][i]
|
||||
|
||||
try:
|
||||
json_res = json.loads(res)["data"]
|
||||
print(json_res)
|
||||
except Exception:
|
||||
json_res = None
|
||||
|
||||
if json_res:
|
||||
try:
|
||||
res = pd.DataFrame(json_res)
|
||||
except Exception:
|
||||
res = json_res
|
||||
|
||||
if isinstance(res, str):
|
||||
st.write(res)
|
||||
elif isinstance(res, pd.DataFrame):
|
||||
st.dataframe(res)
|
||||
else:
|
||||
st.error("Returned result:")
|
||||
st.error(res)
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, Optional, Union, List
|
||||
import io
|
||||
|
||||
from toolserve.sdk.client import list_data, log
|
||||
|
|
@ -15,11 +15,16 @@ async def list_data_sources() -> Dict[str, Dict[str, str]]:
|
|||
Dict[str, str]: A dictionary mapping data source IDs to their details.
|
||||
"""
|
||||
data = await list_data()
|
||||
return {str(item["id"]): {
|
||||
"file_name": item["file_name"],
|
||||
"created_at": item["created_time"],
|
||||
"updated_at": item["updated_time"]
|
||||
} for item in data}
|
||||
partial = {}
|
||||
for item in data:
|
||||
details = {
|
||||
"file_name": item["file_name"],
|
||||
"created_at": item["created_time"]
|
||||
}
|
||||
if "updated_time" in item and item["updated_time"] is not None:
|
||||
details["updated_at"] = item["updated_time"]
|
||||
partial[str(item["id"])] = details
|
||||
return partial
|
||||
|
||||
@tool
|
||||
async def get_data_schema(
|
||||
|
|
@ -35,19 +40,21 @@ async def get_data_schema(
|
|||
"""
|
||||
# TODO read in only a few lines
|
||||
df = await get_df(data_id)
|
||||
return get_df_info(df)
|
||||
return get_df_info(df)["schema"]
|
||||
|
||||
|
||||
@tool
|
||||
async def query_sql(
|
||||
data_id: Param(int, "id of the data source"),
|
||||
sql: Param(str, "parameterized SQL query to execute"),
|
||||
params: Param(Optional[Dict[str, Any]], "parameters to pass to the SQL query") = None,
|
||||
) -> Param(str, "schema of the data source after executing the query"):
|
||||
params: Param(Optional[List[Union[str, int]]], "parameters to pass to the SQL query") = None,
|
||||
) -> Dict[str, Union[int, str]]:
|
||||
"""Query a data source using SQL
|
||||
|
||||
The SQL query should be parameterized with Python's named parameter syntax,
|
||||
e.g. `SELECT * FROM table WHERE column = :value`.
|
||||
The SQL query should be parameterized with DuckDB's syntax. For example, to query a
|
||||
DataFrame named `df` with a parameter `param`, the query should be `SELECT * FROM df WHERE column = ?`.
|
||||
|
||||
The list of params should be in order of the parameters in the SQL query.
|
||||
|
||||
After the query, a new data source at a new id will be created with the results and
|
||||
the schema of the data source will be returned.
|
||||
|
|
@ -83,7 +90,7 @@ async def query_sql(
|
|||
raise RuntimeError(f"Query execution failed: {str(e)}")
|
||||
|
||||
|
||||
def get_df_info(df: pd.DataFrame, data_id: Optional[int]=None) -> str:
|
||||
def get_df_info(df: pd.DataFrame, data_id: Optional[int]=None) -> Dict[str, Union[int, str]]:
|
||||
"""
|
||||
Generate a compact string representation of a DataFrame including the count of columns,
|
||||
rows, overall size, and details for each column such as name and datatype.
|
||||
|
|
@ -92,7 +99,7 @@ def get_df_info(df: pd.DataFrame, data_id: Optional[int]=None) -> str:
|
|||
df (pd.DataFrame): The Pandas DataFrame to describe.
|
||||
|
||||
Returns:
|
||||
str: A string that contains the compact representation of the DataFrame.
|
||||
Dict[str, Union[int, str]]: A dictionary containing the DataFrame details and data_id
|
||||
"""
|
||||
|
||||
# Create an output stream to collect strings
|
||||
|
|
@ -112,8 +119,19 @@ def get_df_info(df: pd.DataFrame, data_id: Optional[int]=None) -> str:
|
|||
output.write(f"Column: {column}\n")
|
||||
output.write(f"type: {df[column].dtype}\n")
|
||||
|
||||
# put top 5 rows in the output if there are more than 5 rows.
|
||||
if len(df.index) > 5:
|
||||
output.write("---\n")
|
||||
output.write("Top 5 rows:\n")
|
||||
output.write(df.head().to_string())
|
||||
|
||||
# Get the complete string from the output stream
|
||||
result = output.getvalue()
|
||||
output.close()
|
||||
|
||||
return result
|
||||
info = {
|
||||
"schema": result
|
||||
}
|
||||
if data_id:
|
||||
info["data_id"] = data_id
|
||||
return info
|
||||
|
|
@ -43,4 +43,4 @@ class GetDataDetails(DataSchemaBase):
|
|||
file_path: str = Field(..., title="File Path", description="Path of the file")
|
||||
|
||||
created_time: datetime = Field(..., title="Creation Time", description="Time when the Data entry was created")
|
||||
updated_time: datetime | None = Field(default=None, title="Updated Time", description="Time when the Data entry was last updated")
|
||||
updated_time: datetime | None = Field(default=datetime.now(), title="Updated Time", description="Time when the Data entry was last updated")
|
||||
|
|
|
|||
Loading…
Reference in a new issue