diff --git a/examples/sql-chat/agent.py b/examples/sql-chat/agent.py new file mode 100644 index 00000000..2ef8c53e --- /dev/null +++ b/examples/sql-chat/agent.py @@ -0,0 +1,188 @@ +import httpx +import json +import openai +from typing import Any, Dict, List, Optional + +class Toolchain: + + available_tools = { + "query_sql": "/tool/query/query_sql", + "list_data_sources": "/tool/query/list_data_sources", + "get_data_schema": "/tool/query/get_data_schema" + } + + def __init__(self, base_url: str, openai_api_key: str, model: str = "gpt-4-turbo"): + self.base_url = base_url + self.client = httpx.Client() + self.openai_client = openai.Client(api_key=openai_api_key) + self.model = model + self.tools = self.__collect_tool_specs() + + def __collect_tool_specs(self) -> Dict[str, str]: + tools = {} + for tool_name, endpoint in self.available_tools.items(): + openai_spec = self.call_api("GET", "/api/v1/tools/oai_function", params={"tool_name": tool_name}).get("data", {}) + tools[tool_name] = openai_spec + return tools + + def call_api(self, method: str, endpoint: str, params: dict = {}, data: dict = {}, json_data: dict = {}) -> Dict[str, Any]: + url = f"{self.base_url}{endpoint}" + response = self.client.request(method, url, params=params, json=json_data, data=data) + try: + response.raise_for_status() + except httpx.HTTPStatusError as e: + print(f"HTTP error: {e}") + result = response.json() + return result + + def get_tool_args(self, tool_name: str, messages: List[Dict[str, str]]) -> Dict[str, Any]: + """ + Retrieves the required arguments for an tool from the Darkstar Toolserver API and + uses them to call an OpenAI model with predefined tools and messages. + + :param tool_name: The name of the tool to execute. + :param messages: A list of messages to provide to the model. + :return: The result of the OpenAI model call. + """ + func_spec = self.tools.get(tool_name, {}) + if not func_spec: + raise ValueError(f"Tool '{tool_name}' not found in available tools.") + + tool = json.loads(func_spec) + # Call the OpenAI model with the tools and messages + completion = self.openai_client.chat.completions.create( + model="gpt-4-turbo", + messages=messages, + tools=[tool], + tool_choice="auto" + ) + predicted_args = completion.choices[0].message.tool_calls[0].function.arguments + print(predicted_args) + print("-----") + return predicted_args + + def execute_tool(self, tool_name: str, tool_args: Optional[Dict[str, Any]]) -> Dict[str, Any]: + """ + Executes an tool using the Darkstar Toolserver API and an OpenAI model. + + :param tool_name: The name of the tool to execute. + :return: The result of the tool + """ + + # Prepare the input message for the OpenAI model + endpoint = self.available_tools[tool_name] + result = self.call_api("POST", endpoint, json_data=tool_args) + return result + + +from pydantic import BaseModel +from typing import List, Dict +from textwrap import dedent + +class Agent: + + prompt = dedent("""Given a user query and a schema of a table, generate the SQL query to answer the user query. + + + The generated SQL query should only refer to columns in the table schema list below. The table schema is as follows: + {schema} + + The data_id of this source is: {data_id} + """) + + + def __init__(self, toolchain: Toolchain): + self.toolchain = toolchain + self.data_sources = self.__get_data_sources() + self._source = None + self._data_schema = None + + def set_source(self, source: str): + if source not in self.data_sources.keys(): + raise ValueError(f"Data source '{source}' not found.") + else: + data_id = self.data_sources[source] + # get the schema + schema = self.toolchain.call_api("POST", "/tool/query/get_data_schema", json_data={"data_id": data_id}) + self._source = source + self._data_schema = schema + + def get_source(self) -> str: + return self._source + + def __get_data_sources(self) -> Dict[str, Dict[str, str]]: + response = self.toolchain.call_api("POST", "/tool/query/list_data_sources") + sources = {} + for _id, source_data in response["data"]["result"].items(): + sources[source_data["file_name"]] = _id + return sources + + + def query(self, user_query: str) -> str: + if not self._source: + raise ValueError("Data source not set. Please set a data source before querying.") + schema = self._data_schema + prompt = self.prompt.format(schema=schema, data_id=self.data_sources[self._source]) + + # Prepare the input message for the OpenAI model + messages = [ + {"role": "system", "content": prompt}, + {"role": "user", "content": user_query}] + + tool_args = self.toolchain.get_tool_args("query_sql", messages) + args = json.loads(tool_args) + params = args.get("params", []) + if params: + if isinstance(params, dict): + args["params"] = list(params.values()) + elif isinstance(params, str): + args["params"] = [params] + elif isinstance(params, list): + args["params"] = params + else: + raise ValueError(f"Invalid params type: {type(params)}") + + + response = self.toolchain.execute_tool("query_sql", args) + if response["code"] != 200: + raise ValueError(f"Error executing tool: {response['message']}") + data_id = response["data"]["result"]["data_id"] + + # get the data + data_response = self.toolchain.call_api("GET", f"/api/v1/data/object/{data_id}") + if data_response["code"] != 200: + raise ValueError(f"Error retrieving data: {data_response['message']}") + data = data_response["data"]["json_blob"] + return data + + +from pydantic import BaseModel, Field +from enum import Enum + +class ToolNode: + pass + +class ToolFlow: + + def __init__( + self, + name, + description, + sources, + ): + pass + + + + +""" # Example usage: +oai_key = "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" +toolchain = Toolchain(base_url="http://localhost:8000", model="gpt-4-turbo", openai_api_key=oai_key) +agent = Agent(toolchain) +agent.set_source("users_db") + +while True: + user_query = input("Enter a query: ") + result = agent.query(user_query) + print(result) + """ \ No newline at end of file diff --git a/examples/sql-chat/main.py b/examples/sql-chat/main.py new file mode 100644 index 00000000..4e481e63 --- /dev/null +++ b/examples/sql-chat/main.py @@ -0,0 +1,103 @@ +import openai + +oai_key = "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" + + +import json +import logging +import subprocess +import sys +import time +import traceback +import os + +import pandas as pd +import streamlit as st +from pydantic import BaseModel +from streamlit_chat import message + + +from agent import Agent, Toolchain + +@st.cache_resource() +def get_agent(): + toolchain = Toolchain(base_url="http://localhost:8000", model="gpt-4-turbo", openai_api_key=oai_key) + agent = Agent(toolchain) + agent.set_source("users_db") + return agent + + +# From here down is all the StreamLit UI. +st.set_page_config(page_title="Data Chat", page_icon=":robot:", layout="wide") +st.header("Arcade AI Demo") + + +def initialize_logger(): + logger = logging.getLogger("root") + handler = logging.StreamHandler(sys.stdout) + logger.setLevel(logging.INFO) + logger.handlers = [handler] + return True + +if "logger" not in st.session_state: + st.session_state["logger"] = initialize_logger() +if "past" not in st.session_state: + st.session_state["past"] = [] +if "generated" not in st.session_state: + st.session_state["generated"] = [] + + + +st.subheader("Chat") + + +chat_container = st.container() +input_container = st.container() + +def submit(): + submit_text = st.session_state["input"] + st.session_state["input"] = "" + with st.spinner(text="Wait for Agent..."): + try: + agent = get_agent() + res = agent.query(submit_text) + except Exception: + res = traceback.format_exc() + st.session_state.past.append(submit_text) + st.session_state.generated.append(res) + +def get_text(): + input_text = st.text_input("You: ", key="input", on_change=submit) + return input_text + +with input_container: + user_input = get_text() + +if st.session_state["generated"]: + with chat_container: + for i in range( + len(st.session_state["generated"]) + ): # range(len(st.session_state["generated"]) - 1, -1, -1): + message(st.session_state["past"][i], is_user=True, key=str(i) + "_user") + + res = st.session_state["generated"][i] + + try: + json_res = json.loads(res)["data"] + print(json_res) + except Exception: + json_res = None + + if json_res: + try: + res = pd.DataFrame(json_res) + except Exception: + res = json_res + + if isinstance(res, str): + st.write(res) + elif isinstance(res, pd.DataFrame): + st.dataframe(res) + else: + st.error("Returned result:") + st.error(res) diff --git a/toolserve/toolserve/builtin/default/query.py b/toolserve/toolserve/builtin/default/query.py index 3b9bd1cf..8368aee1 100644 --- a/toolserve/toolserve/builtin/default/query.py +++ b/toolserve/toolserve/builtin/default/query.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union, List import io from toolserve.sdk.client import list_data, log @@ -15,11 +15,16 @@ async def list_data_sources() -> Dict[str, Dict[str, str]]: Dict[str, str]: A dictionary mapping data source IDs to their details. """ data = await list_data() - return {str(item["id"]): { - "file_name": item["file_name"], - "created_at": item["created_time"], - "updated_at": item["updated_time"] - } for item in data} + partial = {} + for item in data: + details = { + "file_name": item["file_name"], + "created_at": item["created_time"] + } + if "updated_time" in item and item["updated_time"] is not None: + details["updated_at"] = item["updated_time"] + partial[str(item["id"])] = details + return partial @tool async def get_data_schema( @@ -35,19 +40,21 @@ async def get_data_schema( """ # TODO read in only a few lines df = await get_df(data_id) - return get_df_info(df) + return get_df_info(df)["schema"] @tool async def query_sql( data_id: Param(int, "id of the data source"), sql: Param(str, "parameterized SQL query to execute"), - params: Param(Optional[Dict[str, Any]], "parameters to pass to the SQL query") = None, - ) -> Param(str, "schema of the data source after executing the query"): + params: Param(Optional[List[Union[str, int]]], "parameters to pass to the SQL query") = None, + ) -> Dict[str, Union[int, str]]: """Query a data source using SQL - The SQL query should be parameterized with Python's named parameter syntax, - e.g. `SELECT * FROM table WHERE column = :value`. + The SQL query should be parameterized with DuckDB's syntax. For example, to query a + DataFrame named `df` with a parameter `param`, the query should be `SELECT * FROM df WHERE column = ?`. + + The list of params should be in order of the parameters in the SQL query. After the query, a new data source at a new id will be created with the results and the schema of the data source will be returned. @@ -83,7 +90,7 @@ async def query_sql( raise RuntimeError(f"Query execution failed: {str(e)}") -def get_df_info(df: pd.DataFrame, data_id: Optional[int]=None) -> str: +def get_df_info(df: pd.DataFrame, data_id: Optional[int]=None) -> Dict[str, Union[int, str]]: """ Generate a compact string representation of a DataFrame including the count of columns, rows, overall size, and details for each column such as name and datatype. @@ -92,7 +99,7 @@ def get_df_info(df: pd.DataFrame, data_id: Optional[int]=None) -> str: df (pd.DataFrame): The Pandas DataFrame to describe. Returns: - str: A string that contains the compact representation of the DataFrame. + Dict[str, Union[int, str]]: A dictionary containing the DataFrame details and data_id """ # Create an output stream to collect strings @@ -112,8 +119,19 @@ def get_df_info(df: pd.DataFrame, data_id: Optional[int]=None) -> str: output.write(f"Column: {column}\n") output.write(f"type: {df[column].dtype}\n") + # put top 5 rows in the output if there are more than 5 rows. + if len(df.index) > 5: + output.write("---\n") + output.write("Top 5 rows:\n") + output.write(df.head().to_string()) + # Get the complete string from the output stream result = output.getvalue() output.close() - return result \ No newline at end of file + info = { + "schema": result + } + if data_id: + info["data_id"] = data_id + return info \ No newline at end of file diff --git a/toolserve/toolserve/server/schemas/data.py b/toolserve/toolserve/server/schemas/data.py index 3d5ecb5b..640dfb73 100644 --- a/toolserve/toolserve/server/schemas/data.py +++ b/toolserve/toolserve/server/schemas/data.py @@ -43,4 +43,4 @@ class GetDataDetails(DataSchemaBase): file_path: str = Field(..., title="File Path", description="Path of the file") created_time: datetime = Field(..., title="Creation Time", description="Time when the Data entry was created") - updated_time: datetime | None = Field(default=None, title="Updated Time", description="Time when the Data entry was last updated") + updated_time: datetime | None = Field(default=datetime.now(), title="Updated Time", description="Time when the Data entry was last updated")