Minor Cleanup

This commit is contained in:
Sam Partee 2024-06-04 09:05:18 -07:00
parent 5a3c28ab5e
commit a5decd4483
23 changed files with 436 additions and 1235 deletions

3
.gitignore vendored
View file

@ -8,3 +8,6 @@ backend/app/alembic/versions/
backend/app/static/media/
.ruff_cache/
.pytest_cache/
examples/data
examples/set_secrets.sh
scratch

View file

@ -1,17 +0,0 @@
[pack]
name = "example"
description = "A simple actor that sends an email using the Gmail API."
version = "0.1.0"
author = "Sam Partee"
email = "sam@partee.io"
[depends]
[tools]
SendEmail = "gmailer.send_email@0.1.0"
ReadEmail = "gmailer.read_email@0.1.0"
PlotDataframe = "gmailer.plot_dataframe@0.1.0"
ReadSqlite = "read_sqlite.read_sqlite@0.1.0"
Summarize = "chat.summarize@0.1.0"
TranscribeText = "chat.transcribe_text@0.1.0"
VectorSearch = "search.vector_search@0.1.0"

View file

@ -1,14 +0,0 @@
[pack]
name = "example"
description = "A simple actor that sends an email using the Gmail API."
version = "0.1.0"
author = "Sam Partee"
email = "sam@partee.io"
[modules]
gmailer = "0.1.0"
chat = "0.1.0"
search = "0.1.0"
read_sqlite = "0.1.0"

View file

@ -1,120 +0,0 @@
from typing import (
IO,
Union,
List,
Dict,
Optional,
Any,
Type,
)
import io
import requests
from os import PathLike
import base64
from toolserve.sdk import Param, tool, get_secret
from toolserve.sdk.dataframe import get_df
from typing import List
import pandas as pd
import openai
@tool
async def summarize(
text: Param(str, "Text to summarize"),
#data_id: Param(int, "ID of the data to summarize"),
system_prompt: Param(str, "System prompt to use") = "Summarize the following text",
max_tokens: Param(int, "Maximum number of tokens to generate") = 1000,
) -> Param(str, "Summarized text"):
"""Summarize a piece of text using OpenAI Language models.
Args:
text (str): The text to summarize.
max_tokens (int): The maximum number of tokens to generate.
Returns:
str: The summarized text.
"""
#df = await get_df(data_id)
#text = df.to_json(orient='records')
api_key = get_secret("openai_api_key", None)
model = get_secret("openai_model_summarize", "gpt-4-turbo")
# Call the OpenAI model with the tools and messages
if isinstance(text, list):
text = "\n".join(text)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": text},
]
client = openai.Client(api_key=api_key)
completion = openai.chat.completions.create(
model=model,
messages=messages,
)
summary = completion.choices[0].message.content
return summary
@tool
async def transcribe_text(
audio_file: Param(str, "Audio file bytes"),
system_prompt: Param(str, "System prompt to use") = "Transcribe the following audio files",
) -> Param(str, "Transcribed text"):
"""Use OpenAI to translate audio to text using the Whisper model.
Args:
audio_file_bytes (str): The bytes of the audio file to transcribe.
system_prompt (str): The system prompt to use for guiding the transcription.
Returns:
str: The transcribed text.
"""
api_key = get_secret("openai_api_key", None)
model = get_secret("openai_model_whisper", "whisper-1")
if audio_file is None:
raise ValueError("No audio file provided")
# Decode the base64 audio file
audio_file_bytes = base64.b64decode(audio_file)
file = io.BytesIO(audio_file_bytes)
# Prepare the headers
headers = {
'Authorization': f'Bearer {api_key}',
}
# Prepare the files parameter
files = {
'file': ('audio.mp3', file, 'audio/mp3')
}
# Prepare the data parameter
data = {
'model': model,
'prompt': system_prompt,
'response_format': 'text'
}
# Send the request to the OpenAI Whisper API
response = requests.post(
'https://api.openai.com/v1/audio/transcriptions',
headers=headers,
files=files,
data=data
)
# Check if the request was successful
if response.status_code == 200:
# Return the plain text response directly
return response.text
else:
# Handle errors
raise Exception(f"Error: {response.status_code} - {response.text}")

View file

@ -1,65 +0,0 @@
import faiss
import numpy as np
from typing import List
from fastembed import TextEmbedding
from toolserve.sdk import Param, tool, get_secret
from toolserve.sdk.dataframe import get_df
@tool
async def vector_search(
data_id: Param(int, "The ID of the data source containing the documents"),
query: Param(str, "The text to find within the documents"),
column_name: Param(str, "The name of the column containing the documents"),
n_results: Param(int, "The number of top results to return") = 5
) -> Param(List[str], "The documents most similar to the query"):
"""Create a FAISS index from a list of documents and search for the query, returning the most similar documents.
Args:
query (str): The text query to search for. Should be written like a document.
column_name (str): The name of the column containing the documents.
n_results (int, optional): The number of top results to return. Defaults to 5.
Returns:
List[str]: The documents most similar to the query based on the search.
"""
# Get the data
df = await get_df(data_id)
docs = df[column_name].tolist()
# Initialize the embedding model
embedding_tool = TextEmbedding()
# Embed all documents
embeddings = []
for doc in docs:
# Get the generator from the embed method
doc_embedding_generator = embedding_tool.embed([doc])
# Convert the generator to a list and take the first element
doc_embedding = list(doc_embedding_generator)[0]
embeddings.append(doc_embedding)
# Convert list of embeddings to a numpy array and ensure type float32
embeddings = np.vstack(embeddings).astype('float32')
# Create a flat L2 index
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings) # Add embeddings to the index
# Embed the query
query_embedding_generator = embedding_tool.embed([query])
query_embedding = list(query_embedding_generator)[0]
query_embedding = np.array(query_embedding, dtype='float32').reshape(1, -1)
# Search the index
distances, indices = index.search(query_embedding, n_results)
# Fetch the documents corresponding to the top indices
top_docs = [docs[i] for i in indices.flatten().tolist()]
return top_docs

14
examples/pack.lock.toml Normal file
View file

@ -0,0 +1,14 @@
[pack]
name = "Routines for Jarvis"
description = "Jarvis Chatbot routines"
version = "0.1.0"
author = "Sam Partee"
email = "sam@partee.io"
[depends]
[tools]
Summarize = "llm.summarize@0.0.1"
Respond = "llm.respond@0.0.1"
SendEmail = "gmail.send_email@0.0.1"
ReadEmail = "gmail.read_email@0.0.1"

12
examples/pack.toml Normal file
View file

@ -0,0 +1,12 @@
[pack]
name = "Routines for Jarvis"
description = "Jarvis Chatbot routines"
version = "0.1.0"
author = "Sam Partee"
email = "sam@partee.io"
[modules]
gmail = "0.0.1"
llm = "0.0.1"

View file

@ -1,502 +0,0 @@
import httpx
import json
import time
import openai
import uuid
from typing import Any, Dict, List, Optional, Tuple, Union
from pydantic import BaseModel
from typing import List, Dict
from textwrap import dedent
from pydantic import BaseModel, Field
from enum import Enum
from typing import Type
from toolserve.utils.openai_tool import model_to_json_schema
from typing import Dict, Any, Optional
import json
from collections import deque
def pydantic_to_openai_tool(model: Type[BaseModel]) -> str:
"""
Convert a Pydantic model to an OpenAI tool schema.
Args:
model (Type[BaseModel]): The Pydantic model to convert.
Returns:
str: The OpenAI tool schema.
"""
schema = model_to_json_schema(model)
tool_schema = {
"type": "function",
"function": {
"name": model.__name__,
"description": model.__doc__ or "",
"parameters": schema
}
}
return json.dumps(tool_schema)
class Edge(BaseModel):
source: int = Field(..., description="The ID of the source node")
target: int = Field(..., description="The ID of the target node")
uuid: str = Field(default_factory=lambda: str(uuid.uuid4()), description="UUID for the data flow between nodes")
class ToolNode(BaseModel):
node_id: int = Field(..., description="The ID of the node", ge=0)
input_name: Optional[str] = Field(None, description="The name of the input data")
tool_name: str = Field(..., description="The name of the tool to execute")
output_name: Optional[str] = Field(None, description="The name of the output data")
predict_args: bool = Field(True, description="Whether to predict the arguments for the tool")
from_node: Optional[Dict[str, int]] = Field(None, description="The ID of the source node name of the argument to pass to the tool")
args: Optional[Dict[str, Any]] = Field(None, description="The arguments to pass to the tool")
allow_extra: bool = Field(False, description="Whether to allow extra arguments to be passed to the tool")
class OutputType(Enum):
DATA = "data"
CHAT = "chat"
ARTIFACT = "artifact"
class FlowSchema(BaseModel):
"""A graph based representation of functions (nodes), and their data flow (edges)"""
nodes: List[ToolNode] = Field(..., description="The nodes in the flow")
edges: List[Edge] = Field([], description="The IDs of the adjacent nodes")
output_type: OutputType = Field(OutputType.CHAT, description="The type of the output")
def __init__(self, **data):
super().__init__(**data)
self.generate_uuids_for_edges()
class Config:
arbitrary_types_allowed = True
use_enum_values = True
def generate_uuids_for_edges(self):
edge_map = {}
for edge in self.edges:
edge_map[(edge.source, edge.target)] = edge.uuid
for node in self.nodes:
incoming_edges = [e.uuid for e in self.edges if e.target == node.node_id]
outgoing_edges = [e.uuid for e in self.edges if e.source == node.node_id]
if node.from_node:
node.input_name = None
node.output_name = None
# Set the output of the source node and the input of the target node to None
for edge in self.edges:
if edge.target == node.node_id:
source_node = next((n for n in self.nodes if n.node_id == edge.source), None)
if source_node:
source_node.output_name = None
if edge.source == node.node_id:
target_node = next((n for n in self.nodes if n.node_id == edge.target), None)
if target_node:
target_node.input_name = None
else:
node.input_name = incoming_edges[0] if incoming_edges else None
node.output_name = outgoing_edges[0] if outgoing_edges else None
class ToolRunner:
tool_prompt = dedent("""
Given a user query and the schema of the fields in a dataframe, generate the arguments for a tool to execute.
YOU MUST CALL THE TOOL.
The schema of the fields in the dataframe is as follows:
{schema}
If needed, the data_id for the source is: {data_id}
If needed, the output_name should be: {output_name}
""")
def __init__(self, base_url: str, model: str, api_key: str):
"""
Initialize the ToolRunner with necessary configurations.
Args:
base_url (str): The base URL for the API calls.
model (str): The model identifier to be used for queries.
api_key (str): The API key for authentication.
"""
self.base_url = base_url
self.client = httpx.Client(timeout=3000)
self.model = model
self.openai_client = openai.Client(api_key=api_key)
self.tools, self.available_tools = self.__collect_tool_specs()
self._data_sources = self.__get_data_sources()
self._source = None
self._data_schema = None
self._data_id = None
def __collect_tool_specs(self) -> Tuple[Dict[str, str], Dict[str, str]]:
tools_list = self.call_api("GET", "/api/v1/tools/list").get("data", {})
all_tools = [tool["name"] for tool in tools_list]
routes = {tool["name"]: tool["endpoint"] for tool in tools_list}
tools = {}
for tool_name, endpoint in routes.items():
openai_spec = self.call_api("GET", "/api/v1/tools/oai_function", params={"tool_name": tool_name}).get("data", {})
tools[tool_name] = openai_spec
return tools, routes
def call_api(self, method: str, endpoint: str, params: dict = {}, data: dict = {}, json_data: dict = {}) -> Dict[str, Any]:
"""Call the Darkstar Toolserver API with the given parameters.
Args:
method (str): The HTTP method to use for the request.
endpoint (str): The endpoint to call.
params (dict): The query parameters for the request.
data (dict): The data to send in the request body.
json_data (dict): The JSON data to send in the request body.
Returns:
Dict[str, Any]: The response from the API.
"""
url = f"{self.base_url}{endpoint}"
response = self.client.request(method, url, params=params, json=json_data, data=data)
try:
response.raise_for_status()
except httpx.HTTPStatusError as e:
print(f"HTTP error: {e}")
result = response.json()
return result
def execute_tool(self, tool_name: str, tool_args: Optional[Dict[str, Any]]) -> Dict[str, Any]:
"""
Executes a tool using the Darkstar Toolserver API and an OpenAI model.
:param tool_name: The name of the tool to execute.
:return: The result of the tool
"""
endpoint = self.available_tools[tool_name]
result = self.call_api("POST", endpoint, json_data=tool_args)
return result
def set_source(self, source: str):
self._data_sources = self.__get_data_sources()
if not source:
return
retries = 3
data_id = None
while retries > 0:
try:
data_id = self._data_sources[source]
break
except KeyError:
retries -= 1
time.sleep(1)
self._data_sources = self.__get_data_sources()
if data_id is None:
raise ValueError(f"Data source '{source}' not found.")
# get the schema
schema = self.call_api("POST", "/tool/query/get_data_schema", json_data={"data_id": data_id})
self._source = source
self._data_schema = schema
self._data_id = data_id
def __get_data_sources(self) -> Dict[str, Dict[str, str]]:
response = self.call_api("POST", "/tool/query/list_data_sources")
sources = {}
for _id, source_data in response["data"]["result"].items():
sources[source_data["file_name"]] = _id
return sources
def __create_prompt(self, user_query: str, input_name: str, output_name: str) -> List[Dict[str, str]]:
schema = self._data_schema
data_id = "No input"
if input_name:
data_id = self._data_sources[input_name]
prompt = self.tool_prompt.format(schema=schema, data_id=data_id, output_name=output_name)
messages = [
{"role": "system", "content": prompt},
{"role": "user", "content": user_query}
]
return messages
def get_tool_args(self, tool_name: str, messages: List[Dict[str, str]], output_name: str) -> Dict[str, Any]:
"""
Retrieves the required arguments for a tool from the Darkstar Toolserver API and
uses them to call an OpenAI model with predefined tools and messages.
:param tool_name: The name of the tool to execute.
:param messages: A list of messages to provide to the model.
:return: The result of the OpenAI model call.
"""
func_spec = self.tools.get(tool_name, {})
if not func_spec:
raise ValueError(f"Tool '{tool_name}' not found in available tools.")
tool = json.loads(func_spec)
# Call the OpenAI model with the tools and messages
completion = self.openai_client.chat.completions.create(
model="gpt-4-turbo",
messages=messages,
tools=[tool],
tool_choice="required"
)
predicted_args = completion.choices[0].message.tool_calls[0].function.arguments
args = json.loads(predicted_args)
if "params" in args:
params = args.get("params", [])
if params:
if isinstance(params, dict):
args["params"] = list(params.values())
elif isinstance(params, str):
args["params"] = [params]
elif isinstance(params, list):
args["params"] = params
else:
raise ValueError(f"Invalid params type: {type(params)}")
if "output_name" in args and output_name != "None":
args["output_name"] = output_name
return args
def run_tool(self, tool: ToolNode, user_query: str, **kwargs) -> Any:
"""
Executes a tool using the Darkstar Toolserver API and an OpenAI model.
"""
source = None
if tool.input_name:
source = tool.input_name
self.set_source(source)
if tool.predict_args:
messages = self.__create_prompt(user_query, source, tool.output_name)
tool_args = self.get_tool_args(tool.tool_name, messages, tool.output_name)
else:
tool_args = kwargs.get("tool_args", {})
if tool.input_name:
tool_args["data_id"] = self._data_id
if tool.output_name:
tool_args["output_name"] = tool.output_name
if tool.args:
tool_args.update(tool.args)
print("Calling tool with args:", tool_args)
result = self.execute_tool(tool.tool_name, tool_args)
return result
def get_data_object(self, data_id: int) -> Dict[str, Any]:
"""
Retrieves a data object from the Darkstar Toolserver API.
:param data_id: The ID of the data object to retrieve.
:return: The data object.
"""
return self.call_api("GET", f"/api/v1/data/object/{data_id}")["data"]["json_blob"]
class ToolFlow:
def __init__(
self,
name: str,
description: str,
base_url: str = "http://localhost:8000",
model: str = "gpt-4-turbo",
model_api_key: Optional[str] = None
):
self.name = name
self.description = description
self.runner = ToolRunner(base_url, model, model_api_key)
self.model = model
self.openai_client = openai.Client(api_key=model_api_key)
def execute_flow(self, flow_schema: Dict[str, Any], user_query: str, user_args: Dict[str, Any] = {}) -> Any:
"""
Executes the tool flow based on the provided schema. This method performs a breadth-first search (BFS)
on the graph defined by the flow schema and executes each node according to the order determined by the BFS.
Args:
flow_schema (Dict[str, Any]): The schema representing the tool flow to be executed.
user_query (str): The user's query string that may influence tool execution.
Returns:
Any: The result of executing the tool flow.
"""
# Initialize a queue for BFS
# Queue up all nodes which don't have incoming edges
incoming_edges = {node['node_id']: 0 for node in flow_schema['nodes']}
for edge in flow_schema.get('edges', []):
incoming_edges[edge['target']] += 1
execution_queue = deque([node for node in flow_schema['nodes'] if incoming_edges[node['node_id']] == 0])
visited = set()
results = {}
timings = {}
flow_start_time = time.time()
while execution_queue:
current_node = execution_queue.popleft()
node_id = current_node['node_id']
if node_id in visited:
continue
visited.add(node_id)
exec_start_time = time.time()
tool_args = {}
# Execute the current node's operation using runner.run_tool
current_tool = ToolNode(**current_node)
if current_tool.from_node:
tool_args = {}
for arg_name, from_node_id in current_tool.from_node.items():
from_node_result = results[from_node_id]["data"]["result"]
tool_args[arg_name] = from_node_result
if current_tool.allow_extra:
tool_args.update(user_args)
operation_result = self.runner.run_tool(current_tool, user_query, tool_args=tool_args)
results[node_id] = operation_result
exec_end_time = time.time()
timings[current_tool.tool_name] = exec_end_time - exec_start_time
# Enqueue all adjacent nodes
for edge in flow_schema.get('edges', []):
if edge['source'] == node_id:
target_node_id = edge['target']
target_node = next(node for node in flow_schema['nodes'] if node['node_id'] == target_node_id)
if target_node_id not in visited:
execution_queue.append(target_node)
# Assuming the last node processed is the sink node
sink_node = flow_schema['nodes'][-1]
sink_tool_name = sink_node['tool_name']
sink_node_id = sink_node['node_id']
# TODO: Tools need to specify output type
#sink_output_type = self.tools[sink_tool_name][0]
sink_output_type = OutputType(flow_schema['output_type'])
flow_end_time = time.time()
timings['total'] = flow_end_time - flow_start_time
if sink_output_type == OutputType.DATA:
data = self.runner.get_data_object(self.runner._data_id)
elif sink_output_type == OutputType.CHAT:
data = results[sink_node_id]["data"]["result"]
else:
data = results[sink_node_id]
return (data, results, sink_output_type, timings)
review_db = "/Users/spartee/Dropbox/Arcade/platform/toolserver/examples/data/food-reviews/database.sqlite"
review_flow = FlowSchema(
nodes=[
ToolNode(node_id=0, tool_name="ReadSqlite", args={"table_name": "Reviews", "file_path": review_db}, predict_args=False),
ToolNode(node_id=1, tool_name="query_sql"),
ToolNode(node_id=2, tool_name="search_text_columns"),
ToolNode(node_id=3, tool_name="Summarize", from_node={"text": 2}, predict_args=False),
],
edges=[
Edge(source=0, target=1),
Edge(source=1, target=2),
Edge(source=2, target=3)
],
output_type=OutputType.CHAT
)
plotting_flow = FlowSchema(
nodes=[
ToolNode(node_id=0, tool_name="ReadSqlite", args={"table_name": "Reviews", "file_path": review_db}, predict_args=False),
ToolNode(node_id=1, tool_name="query_sql"),
ToolNode(node_id=2, tool_name="PlotDataframe"),
],
edges=[
Edge(source=0, target=1),
Edge(source=1, target=2)
],
output_type=OutputType.ARTIFACT
)
email_flow = FlowSchema(
nodes=[
ToolNode(node_id=0, tool_name="ReadEmail"),
ToolNode(node_id=1, tool_name="Summarize", from_node={"text": 0}, predict_args=False),
],
edges=[
Edge(source=0, target=1)
],
output_type=OutputType.CHAT
)
shopify_db = "/Users/spartee/Dropbox/Arcade/platform/toolserver/examples/data/olist.sqlite"
customer_flow = FlowSchema(
nodes=[
ToolNode(node_id=0, tool_name="ReadSqlite", args={"table_name": "customers", "file_path": shopify_db}, predict_args=False),
ToolNode(node_id=1, tool_name="ReadSqlite", args={"table_name": "orders", "file_path": shopify_db}, predict_args=False),
ToolNode(node_id=2, tool_name="query_sql"),
ToolNode(node_id=3, tool_name="query_sql"),
ToolNode(node_id=4, tool_name="get"),
ToolNode(node_id=5, tool_name="get"),
ToolNode(node_id=6, tool_name="combine_results", from_node={"result_1": 4, "result_2": 5}, predict_args=False),
ToolNode(node_id=7, tool_name="Summarize", from_node={"text": 6}, predict_args=False)
],
edges=[
Edge(source=0, target=2),
Edge(source=1, target=3),
Edge(source=2, target=4),
Edge(source=3, target=5),
Edge(source=4, target=6),
Edge(source=5, target=6),
Edge(source=6, target=7)
],
output_type=OutputType.CHAT
)
audio_files = ["/Users/spartee/Desktop/notes.mp3"]
notetaker = FlowSchema(
nodes=[
ToolNode(node_id=0, tool_name="TranscribeText", predict_args=False, allow_extra=True),
ToolNode(node_id=1, tool_name="Summarize", from_node={"text": 0}, predict_args=False),
],
edges=[
Edge(source=0, target=1)
],
output_type=OutputType.CHAT
)
def print_flow_as_yaml(data: Dict[str, Any]):
data_dict = data.dict(exclude_unset=True) if isinstance(data, BaseModel) else data
# Convert the dictionary to a YAML formatted string
yaml_str = yaml.dump(data_dict, sort_keys=False)
# Print the YAML string
print(yaml_str)
#flow_schema = tf.infer_flow("Plot the users' age distribution")
#from pprint import pprint
#flow = json.loads(flow_schema)
#pprint(flow)
#result = tf.execute_flow(flow, "Plot the users' age distribution")
#print(result)

View file

@ -1,201 +0,0 @@
import openai
oai_key = "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"
import base64
import json
import logging
import subprocess
import sys
import time
import traceback
import os
from typing import Dict, Any
import networkx as nx
import matplotlib.pyplot as plt
import pandas as pd
import streamlit as st
from pydantic import BaseModel
from streamlit_chat import message
import streamlit.components.v1 as components
from textwrap import dedent
import plotly.express as px
from agent import ToolFlow, email_flow, plotting_flow, review_flow, customer_flow, notetaker
def plot_flow(data: Dict[str, Any]):
"""
Plot the flow of data using a directed graph.
Args:
data (Dict[str, Any]): A dictionary containing 'nodes' and optionally 'edges'.
"""
# Create a directed graph
G = nx.DiGraph()
# Add nodes
for node in data['nodes']:
G.add_node(node['node_id'], label=node['tool_name'])
# Add edges
if 'edges' in data:
for edge in data['edges']:
G.add_edge(edge['source'], edge['target'])
# Node labels with specific formatting
labels = {node['node_id']: f"{node['tool_name']}\n({node['input_name']} -> {node['output_name']})" for node in data['nodes']}
# Check if there are any nodes to determine a start node for bfs_layout
if G.nodes:
#start_node = next(iter(G.nodes)) # Get an arbitrary start node
#pos = nx.bfs_layout(G, start_node)
pos = nx.spring_layout(G)
else:
pos = {}
plt.figure(figsize=(7, 7))
nx.draw(G, pos, with_labels=False, node_size=3000, node_color='skyblue', font_size=9, font_weight='bold')
nx.draw_networkx_labels(G, pos, labels, font_size=8)
# Use Streamlit's function to display the plot
st.sidebar.pyplot(plt, use_container_width=True)
@st.cache_resource()
def get_agent():
AnalysisTool = ToolFlow(
name="data_analysis",
description="A tool flow for data analysis",
model_api_key=oai_key
)
return AnalysisTool
# From here down is all the StreamLit UI.
st.set_page_config(page_title="Arcade AI Demo", page_icon=":robot:", layout="wide")
dropdown_options = ["Gmailer", "PlotBot", "ReviewChat", "CustomerService", "Notetaker"]
selected_option = st.sidebar.selectbox("Select an App:", dropdown_options)
st.sidebar.write(f"Selected App: {selected_option}")
def initialize_logger():
logger = logging.getLogger("root")
handler = logging.StreamHandler(sys.stdout)
logger.setLevel(logging.INFO)
logger.handlers = [handler]
return True
if "logger" not in st.session_state:
st.session_state["logger"] = initialize_logger()
if "past" not in st.session_state:
st.session_state["past"] = []
if "generated" not in st.session_state:
st.session_state["generated"] = []
if "input" not in st.session_state:
st.session_state["input"] = ""
st.subheader("Arcade AI Agent Demo")
chat_container = st.container()
input_container = st.container()
def submit(data=None):
with st.spinner(text="Wait for Agent..."):
try:
agent = get_agent()
#flow = agent.infer_flow(submit_text)
#json_flow = json.loads(flow)
if selected_option == "Gmailer":
json_flow = email_flow.dict()
elif selected_option == "PlotBot":
json_flow = plotting_flow.dict()
elif selected_option == "ReviewChat":
json_flow = review_flow.dict()
elif selected_option == "CustomerService":
json_flow = customer_flow.dict()
elif selected_option == "Notetaker":
json_flow = notetaker.dict()
else:
st.error("Invalid option selected")
return
print(json_flow)
plot_flow(json_flow)
submit_text = st.session_state.input
st.session_state.input = ""
res = agent.execute_flow(json_flow, submit_text)
except Exception:
st.error("Error executing the flow:")
st.error(traceback.format_exc())
return
st.session_state.past.append(submit_text)
st.session_state.generated.append(res)
def run_notetaker():
with st.spinner(text="Wait for Agent..."):
try:
agent = get_agent()
json_flow = notetaker.dict()
plot_flow(json_flow)
audio_file = st.session_state.audio_file
if audio_file is None:
st.error("No audio file uploaded")
return
audio_file_byte_str = base64.b64encode(audio_file.getvalue()).decode("utf-8")
res = agent.execute_flow(json_flow, "placeholder", user_args={"audio_file": audio_file_byte_str})
except Exception:
st.error("Error executing the flow:")
st.error(traceback.format_exc())
return
st.session_state.past.append("Audio File")
st.session_state.generated.append(res)
with input_container:
if selected_option != "Notetaker":
st.text_input("You: ", key="input", on_change=submit)
else:
st.file_uploader("Upload an audio file", type=["mp3", "wav", "ogg"], key="audio_file", on_change=run_notetaker)
if st.session_state["generated"]:
with chat_container:
for i in range(
len(st.session_state["generated"])
): # range(len(st.session_state["generated"]) - 1, -1, -1):
message(st.session_state["past"][i], is_user=True, key=str(i) + "_user")
result = st.session_state["generated"][i]
result_tab, all_results_tab, times_tab = st.tabs(["Result", "All Results", "Execution Times"])
res, all_results, output_type, timings = result
with all_results_tab:
st.write(all_results)
with times_tab:
st.write(timings)
with result_tab:
output_type = output_type.value
if output_type == "artifact":
# plot the json returned in res
fig_json = res["data"]["result"]
# plot the json with ploylu atream lit
st.plotly_chart(json.loads(fig_json))
elif output_type == "chat":
st.write(res)
elif output_type == "data":
json_res = json.loads(res)["data"]
st.dataframe(json_res)
else:
st.error("Returned result:")
st.error(res)

View file

@ -1,18 +1,19 @@
import smtplib
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
import imaplib
import re
import email
from email.header import decode_header
from pydantic import BaseModel
import smtplib
import imaplib
import pandas as pd
import plotly.express as px
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from email.header import decode_header
from pydantic import BaseModel
from bs4 import BeautifulSoup
import re
from toolserve.sdk import Param, tool, get_secret
from toolserve.sdk.dataframe import get_df, save_df
from toolserve.sdk.client import log
@tool
@ -24,6 +25,7 @@ async def send_email(
):
"""Send an email via gmail SMTP server"""
email_address = get_secret("gmail_email")
sender_password = get_secret("gmail_password")
server = get_secret("gmail_stmp_server", "smtp.gmail.com")
port = get_secret("gmail_smtp_port", 587)
@ -37,18 +39,16 @@ async def send_email(
server = smtplib.SMTP(server, port)
server.starttls()
server.login(sender_email, sender_password)
print("Logged in to SMTP server")
log(f"Logged in to SMTP server at {':'.join((server, port))}", "DEBUG")
server.send_message(message)
server.quit()
print(f"Email sent to {recipient_email}")
log(f"Email sent from {sender_email} to {recipient_email}", "INFO")
@tool
async def read_email(
#output_name: Param(str, "Name of the output data"),
n_emails: Param(int, "Number of emails to read") = 5,
) -> Param(str, "emails"):
"""Read emails from a Gmail account and extract plain text content, removing any HTML."""
@ -90,15 +90,13 @@ async def read_email(
body = msg.get_payload(decode=True).decode('utf-8')
email_details["body"] = clean_email_body(body)
except Exception as e:
print(f"Error reading email {email_id}: {e}")
log(f"Error reading email {email_id}: {e}", "ERROR")
continue
emails.append(email_details)
mail.close()
mail.logout()
#df = pd.DataFrame(emails)
#await save_df(df, output_name)
data = "\n".join([f"{email['from']} - {email['date']}\n{email['body']}\n" for email in emails])
return data
@ -118,50 +116,4 @@ def clean_email_body(body: str) -> str:
text = re.sub(r'[^.!?a-zA-Z0-9\s]', '', text) # Remove non-sentence characters
text = ' '.join(text.split()) # Remove extra whitespace
return text
@tool
async def plot_dataframe(
data_id: Param(int, "Data ID of the dataframe"),
x: Param(str, "Column to use as x-axis"),
y: Param(str, "Column to use as y-axis"),
kind: Param(str, "Type of plot") = "line",
title: Param(str, "Title of the plot") = "Plot",
xlabel: Param(str, "Label for x-axis") = "X",
ylabel: Param(str, "Label for y-axis") = "Y",
) -> Param(str, "JSON representation of the plot"):
"""
Asynchronously generates a plot from a dataframe using Plotly and returns the plot as a JSON string.
Args:
data_id (int): The ID of the dataframe to plot.
x (str): The column name to use as the x-axis.
y (str): The column name to use as the y-axis.
kind (str): The type of plot to generate (e.g., 'line', 'scatter', 'bar').
title (str): The title of the plot.
xlabel (str): The label for the x-axis.
ylabel (str): The label for the y-axis.
Returns:
str: The JSON representation of the plot.
"""
import plotly.express as px
df = await get_df(data_id)
if kind == 'line':
fig = px.line(df, x=x, y=y, title=title)
elif kind == 'scatter':
fig = px.scatter(df, x=x, y=y, title=title)
elif kind == 'bar':
fig = px.bar(df, x=x, y=y, title=title)
elif kind == "histogram":
fig = px.histogram(df, x=x, title=title)
else:
raise ValueError(f"Unsupported plot type: {kind}")
fig.update_layout(xaxis_title=xlabel, yaxis_title=ylabel)
return fig.to_json()
return text

View file

@ -0,0 +1,119 @@
import re
import email
import smtplib
import imaplib
import pandas as pd
import plotly.express as px
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from email.header import decode_header
from pydantic import BaseModel
from bs4 import BeautifulSoup
from toolserve.sdk import Param, tool, get_secret
from toolserve.sdk.client import log
@tool
async def send_email(
sender_email: Param(str, "Email address of the sender"),
recipient_email: Param(str, "Email address of the recipient"),
subject: Param(str, "Subject of the email"),
body: Param(str, "Body of the email"),
):
"""Send an email via gmail SMTP server"""
email_address = get_secret("gmail_email")
sender_password = get_secret("gmail_password")
server = get_secret("gmail_stmp_server", "smtp.gmail.com")
port = get_secret("gmail_smtp_port", 587)
message = MIMEMultipart()
message['From'] = sender_email
message['To'] = recipient_email
message['Subject'] = subject
message.attach(MIMEText(body, 'plain'))
server = smtplib.SMTP(server, port)
server.starttls()
server.login(sender_email, sender_password)
log(f"Logged in to SMTP server at {':'.join((server, port))}", "DEBUG")
server.send_message(message)
server.quit()
log(f"Email sent from {sender_email} to {recipient_email}", "INFO")
@tool
async def read_email(
n_emails: Param(int, "Number of emails to read") = 5,
) -> Param(str, "emails"):
"""Read emails from a Gmail account and extract plain text content, removing any HTML."""
email_address = get_secret("gmail_email")
password = get_secret("gmail_password")
server = get_secret("gmail_stmp_server", "smtp.gmail.com")
port = get_secret("gmail_smtp_port", 587)
# Connect to the Gmail IMAP server
mail = imaplib.IMAP4_SSL(server)
mail.login(email_address, password)
mail.select("inbox") # connect to inbox.
result, data = mail.search(None, "ALL")
email_ids = data[0].split()
email_ids.reverse() # Reverse to get the most recent emails first
emails = []
for email_id in email_ids[:n_emails]:
try:
result, data = mail.fetch(email_id, "(RFC822)")
raw_email = data[0][1]
msg = email.message_from_bytes(raw_email)
email_details = {
"from": msg["From"],
"to": msg["To"],
"date": msg["Date"]
}
if msg.is_multipart():
for part in msg.walk():
if part.get_content_type() == "text/plain":
body = part.get_payload(decode=True).decode('utf-8')
email_details["body"] = clean_email_body(body)
else:
body = msg.get_payload(decode=True).decode('utf-8')
email_details["body"] = clean_email_body(body)
except Exception as e:
log(f"Error reading email {email_id}: {e}", "ERROR")
continue
emails.append(email_details)
mail.close()
mail.logout()
data = "\n".join([f"{email['from']} - {email['date']}\n{email['body']}\n" for email in emails])
return data
def clean_email_body(body: str) -> str:
"""Remove HTML tags and non-sentence elements from email body text."""
# Remove HTML tags using BeautifulSoup
soup = BeautifulSoup(body, "html.parser")
text = soup.get_text(separator=' ')
# Remove any non-sentence elements (e.g., URLs, email addresses, etc.)
text = re.sub(r'\S*@\S*\s?', '', text) # Remove emails
text = re.sub(r'http\S+', '', text) # Remove URLs
text = re.sub(r'[^.!?a-zA-Z0-9\s]', '', text) # Remove non-sentence characters
text = ' '.join(text.split()) # Remove extra whitespace
return text

83
examples/tools/llm.py Normal file
View file

@ -0,0 +1,83 @@
from typing import (
IO,
Union,
List,
Dict,
Optional,
Any,
Type,
)
import io
import requests
from os import PathLike
import base64
from toolserve.sdk import Param, tool, get_secret
from typing import List
import pandas as pd
import openai
@tool
async def summarize(
text: Param(str, "Text to summarize"),
system_prompt: Param(str, "System prompt to use") = "Summarize the following text",
max_tokens: Param(int, "Maximum number of tokens to generate") = 1000,
) -> Param(str, "Summarized text"):
"""Summarize a piece of text using OpenAI Language models."""
api_key = get_secret("openai_api_key", None)
model = get_secret("openai_model_summarize", "gpt-4-turbo")
# Call the OpenAI model with the tools and messages
if isinstance(text, list):
text = "\n".join(text)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": text},
]
client = openai.AsyncClient(api_key=api_key)
completion = await openai.chat.completions.create(
model=model,
messages=messages,
)
summary = completion.choices[0].message.content
return summary
@tool
async def respond(
context: Param(str, "context of the conversation"),
system_prompt: Param(str, "System prompt to use") = "Given the following context, respond with a message in a friendly and helpful manner. Be informal and use a casual tone.",
max_tokens: Param(int, "Maximum number of tokens to generate") = 1000,
) -> Param(str, "The response to the context provided"):
"""Respond to a user given context using OpenAI Language models"""
api_key = get_secret("openai_api_key", None)
model = get_secret("openai_model_summarize", "gpt-4-turbo")
# Call the OpenAI model with the tools and messages
if isinstance(context, list):
context = "\n".join(context)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": context},
]
client = openai.AsyncClient(api_key=api_key)
completion = await openai.chat.completions.create(
model=model,
messages=messages,
)
response = completion.choices[0].message.content
return response

35
examples/tools/search.py Normal file
View file

@ -0,0 +1,35 @@
import asyncio
from serpapi import GoogleSearch
from typing import List, Dict
import json
from toolserve.sdk import Param, tool, get_secret
async def google_search(
query: Param(str, "search query for google"),
num_results: Param(int, "number of results")
) -> Param(str, "Json blob of Search results"):
"""
Perform a Google search using SerpAPI and retrieve a specified number of results.
Args:
query (str): The search query.
num_results (int): The number of search results to retrieve.
Returns:
List[Dict[str, str]]: A list of dictionaries containing the link and text of each result.
"""
serpapi_key = get_secret("serp_api_key", None)
params = {
"engine": "google",
"q": query,
"num": num_results,
"api_key": serpapi_key
}
search = GoogleSearch(params)
results = search.get_dict()
json_results = json.dumps(results.get("organic_results"), indent=2)
return json_results

0
examples/tools/sql.py Normal file
View file

View file

@ -1,107 +0,0 @@
from typing import List, Dict, Any
from toolserve.sdk.dataframe import get_df, save_df
from toolserve.sdk.tool import tool, Param
@tool
async def get(
data_id: Param(int, "ID of the data")
) -> Param(str, "data"):
"""Get data by ID"""
df = await get_df(data_id)
return df.to_json(orient='records')
@tool
async def select_columns(
data_id: Param(int, "ID of the data"),
columns: Param(List[str], "Columns to select")
) -> Param(str, "data"):
"""Select columns from a DataFrame"""
df = await get_df(data_id)
df = df[columns]
return df.to_json(orient='records')
@tool
async def filter_rows(
data_id: Param(int, "ID of the data"),
column: Param(str, "Column to filter"),
value: Param(str, "Value to filter by")
) -> Param(str, "data"):
"""Filter rows in a DataFrame"""
df = await get_df(data_id)
df = df[df[column] == value]
return df.to_json(orient='records')
@tool
async def sort(
data_id: Param(int, "ID of the data"),
column: Param(str, "Column to sort by"),
ascending: Param(bool, "Sort ascending or descending") = True
) -> Param(str, "data"):
"""Sort a DataFrame by a column"""
df = await get_df(data_id)
df = df.sort_values(by=column, ascending=ascending)
return df.to_json(orient='records')
@tool
async def group_by(
data_id: Param(int, "ID of the data"),
columns: Param(List[str], "Columns to group by"),
aggregations: Param(Dict[str, str], "Aggregations to perform")
) -> Param(str, "data"):
"""Group by columns and perform aggregations"""
df = await get_df(data_id)
df = df.groupby(columns).agg(aggregations)
return df.to_json(orient='records')
@tool
async def join(
data_id1: Param(int, "ID of the first data"),
data_id2: Param(int, "ID of the second data"),
on: Param(str, "Column to join on"),
how: Param(str, "Type of join") = "inner"
) -> Param(str, "data"):
"""Join two DataFrames"""
df1 = await get_df(data_id1)
df2 = await get_df(data_id2)
df = df1.merge(df2, on=on, how=how)
return df.to_json(orient='records')
@tool
async def search_text_columns(
data_id: Param(int, "ID of the data"),
query: Param(str, "Text to search for"),
column: Param(str, "Column to search in"),
max_rows: Param(int, "Maximum number of rows to return") = 50
) -> Param(str, "data"):
"""Search text in columns
Search for a text query in a specific column of a DataFrame.
Args:
data_id (int): The ID of the data source to search in.
query (str): The text to search for.
column (str): The column to search in.
Returns:
str: The data source after filtering for the text query, limited to a maximum number of rows.
"""
df = await get_df(data_id)
# Ensure the column data is treated as string
df[column] = df[column].astype(str)
# Use regex=False to treat the query as a literal string, avoiding any regex special character issues
mask = df[column].str.contains(query, case=False, na=False, regex=False)
df = df[mask]
# Limit the number of rows returned
df = df.head(max_rows)
return df.to_json(orient='records')
@tool
def combine_results(
result_1: Param(str, "First result"),
result_2: Param(str, "Second result")
) -> Param(str, "data"):
"""Combine two results"""
return str(result_1) + str(result_2)

View file

@ -1,143 +0,0 @@
from typing import Any, Dict, Optional, Union, List
import io
from toolserve.sdk.client import list_data, log
from toolserve.sdk.dataframe import get_df, save_df
from toolserve.sdk.tool import tool, Param
import duckdb
import pandas as pd
@tool
async def list_data_sources() -> Dict[str, Dict[str, str]]:
"""List all data sources.
Returns:
Dict[str, str]: A dictionary mapping data source IDs to their details.
"""
data = await list_data()
partial = {}
for item in data:
details = {
"file_name": item["file_name"],
"created_at": item["created_time"]
}
if "updated_time" in item and item["updated_time"] is not None:
details["updated_at"] = item["updated_time"]
partial[str(item["id"])] = details
return partial
@tool
async def get_data_schema(
data_id: Param(int, "id of the data source"),
) -> Param(str, "schema of the data source"):
"""Get the schema of the data source by id.
Args:
data_id (int): The id of the data source to get the schema of.
Returns:
str: The schema of the data source.
"""
# TODO read in only a few lines
df = await get_df(data_id)
return get_df_info(df)["schema"]
@tool
async def query_sql(
data_id: Param(int, "id of the data source"),
sql: Param(str, "parameterized SQL query to execute"),
output_name: Param(str, "name of the output data to save"),
params: Param(Optional[List[Union[str, int]]], "parameters to pass to the SQL query") = None,
) -> Dict[str, Union[int, str]]:
"""Query a data source using SQL
The SQL query should be parameterized with DuckDB's syntax. For example, to query a
DataFrame named `df` with a parameter `param`, the query should be `SELECT * FROM df WHERE ? = ?`.
The list of params should be in order of the parameters in the SQL query.
IMPORTANT: There should be no parameters in the query.
For example: `SELECT * FROM df WHERE name = ?` should be `SELECT * FROM df WHERE ? = ?`.
After the query, a new data source at a new id will be created with the results and
the schema of the data source will be returned.
Args:
data_id (int): The id of the data source to query.
sql (str): The parameterized SQL query to execute.
output_name (str): The name of the output data to save.
params (Optional[Dict[str, Any]]): Parameters to pass to the SQL query.
Returns:
str: The schema of the data source after executing the query.
"""
try:
# Retrieve the DataFrame and execute the SQL query using DuckDB
df = await get_df(data_id)
con = duckdb.connect(database=':memory:', read_only=False)
con.register('df_table', df)
if params:
result_df = con.execute(sql, params).fetchdf()
else:
result_df = con.execute(sql).fetchdf()
# Save the resulting DataFrame and create a new data source
result = await save_df(result_df, output_name)
result_id = result["id"]
# Retrieve and return the schema of the new data source
return get_df_info(result_df, data_id=result_id)
except Exception as e:
# Log the error and raise an exception
log_message = f"Failed to execute query: {str(e)}."
log_message += f" -- SQL: {sql}"
log_message += f" -- Parameters: {params}"
await log(log_message, level="ERROR")
raise RuntimeError(f"Query execution failed: {str(e)}")
def get_df_info(df: pd.DataFrame, data_id: Optional[int]=None) -> Dict[str, Union[int, str]]:
"""
Generate a compact string representation of a DataFrame including the count of columns,
rows, overall size, and details for each column such as name and datatype.
Parameters:
df (pd.DataFrame): The Pandas DataFrame to describe.
Returns:
Dict[str, Union[int, str]]: A dictionary containing the DataFrame details and data_id
"""
# Create an output stream to collect strings
output = io.StringIO()
# Write general information about the DataFrame
if data_id:
output.write(f"Result Data ID: {data_id}\n")
output.write("Table Name: df\n")
output.write(f"Columns: {len(df.columns)}\n")
output.write(f"Rows: {len(df.index)}\n")
output.write(f"Size: {df.memory_usage(deep=True).sum()} bytes\n")
# Iterate through each column to get details
for column in df.columns:
output.write("---\n")
output.write(f"Column: {column}\n")
output.write(f"type: {df[column].dtype}\n")
# put top 5 rows in the output if there are more than 5 rows.
if len(df.index) > 5:
output.write("---\n")
output.write("Top 5 rows:\n")
output.write(df.head().to_string())
# Get the complete string from the output stream
result = output.getvalue()
output.close()
info = {
"schema": result
}
if data_id:
info["data_id"] = data_id
return info

View file

@ -41,7 +41,7 @@ class ToolSchema(BaseModel):
class ToolCatalog:
def __init__(self, tools_dir: str = settings.TOOLS_DIR):
self.tools = self.read_tools(tools_dir)
self.tools.update(self.__get_builitin_tools())
#self.tools.update(self.__get_builitin_tools())
@staticmethod
def read_tools(directory: str) -> List[ToolSchema]:

View file

@ -50,7 +50,7 @@ class Settings(BaseSettings):
OPERA_LOG_ENCRYPT_SECRET_KEY: str # 密钥 os.urandom(32), 需使用 bytes.hex() 方法转换为 str
# FastAPI
API_V1_STR: str = '/api/v1'
API_V1_STR: str = '/v1'
API_ACTION_STR: str = '/tool'
TITLE: str = 'Arcade AI Toolserver'
VERSION: str = '0.1.0'

View file

@ -5,6 +5,7 @@ from toolserve.server.routes.data import router as data_router
from toolserve.server.routes.artifact import router as artifact_router
from toolserve.server.routes.log import router as log_router
from toolserve.server.routes.slack import router as slack_router
from toolserve.server.routes.chat import router as chat_router
v1 = APIRouter(prefix=settings.API_V1_STR)
v1.include_router(tool_router, prefix="/tools", tags=["Tool Catalog"])
@ -12,4 +13,4 @@ v1.include_router(data_router, prefix="/data", tags=["Data Management"])
v1.include_router(artifact_router, prefix="/artifact", tags=["Artifact Management"])
v1.include_router(log_router, prefix="/log", tags=["Tool Logging API"])
v1.include_router(slack_router, prefix="/slack", tags=["Slack"])
v1.include_router(chat_router, prefix="/chat", tags=["Chat"])

View file

@ -0,0 +1,124 @@
from typing import Annotated
from fastapi import APIRouter, Path, Query
from fastapi.responses import StreamingResponse
from toolserve.server.common.response import ResponseModel, response_base
from toolserve.server.common.serializers import select_as_dict
# to take out later
import openai
import json
from pydantic import BaseModel, Field
from typing import List, Optional, Union, Literal, Iterable
from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
from openai.types.chat.chat_completion_stream_options_param import ChatCompletionStreamOptionsParam
from openai.types.chat.chat_completion_tool_choice_option_param import ChatCompletionToolChoiceOptionParam
from openai.types.chat.chat_completion_function_call_option_param import ChatCompletionFunctionCallOptionParam
from openai.types import shared_params
from openai.types.chat_model import ChatModel
from fastapi import Request, HTTPException, status, Depends
from toolserve.server.core.depends import get_catalog
from toolserve.utils.openai_tool import schema_to_openai_tool
router = APIRouter()
class FunctionCall(BaseModel):
type: Literal["none", "auto", "function"]
function: Optional[ChatCompletionFunctionCallOptionParam]
class Function(BaseModel):
name: str
description: Optional[str]
parameters: Optional[shared_params.FunctionParameters]
class ResponseFormat(BaseModel):
type: Literal["text", "json_object"]
class CompletionCreateParamsBase(BaseModel):
messages: List[ChatCompletionMessageParam]
model: Union[str, ChatModel]
frequency_penalty: Optional[float] = None
#function_call: Optional[FunctionCall] = None
#functions: Optional[List[Function]] = None
logit_bias: Optional[dict[str, int]] = None
logprobs: Optional[bool] = None
max_tokens: Optional[int] = None
n: Optional[int] = None
presence_penalty: Optional[float] = None
response_format: Optional[ResponseFormat] = None
seed: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None
stream_options: Optional[ChatCompletionStreamOptionsParam] = None
temperature: Optional[float] = None
tool_choice: Optional[ChatCompletionToolChoiceOptionParam] = None
tools: Optional[Union[List[ChatCompletionToolParam], List[str]]] = None
top_logprobs: Optional[int] = None
top_p: Optional[float] = None
user: Optional[str] = None
class CompletionCreateParamsNonStreaming(CompletionCreateParamsBase):
stream: Literal[False]
class CompletionCreateParamsStreaming(CompletionCreateParamsBase):
stream: Literal[True]
CompletionCreateParams = Union[CompletionCreateParamsNonStreaming, CompletionCreateParamsStreaming]
def get_openai_key(request: Request) -> str:
"""
Extracts the API key from the Authorization header as a Bearer token.
Args:
request (Request): The request object from which the API key is extracted.
Returns:
str: The API key extracted from the Authorization header.
Raises:
HTTPException: If the Authorization header is missing or improperly formatted.
"""
auth_header = request.headers.get('Authorization')
if auth_header is None or not auth_header.startswith('Bearer '):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail='Authorization token is missing or improperly formatted'
)
api_key = auth_header.split(' ')[1]
return api_key
@router.post(
'/completions',
summary='Chat Completions Endpoints mimicking OpenAI'
)
async def create_chat_completion(
completion: CompletionCreateParams,
api_key: str = Depends(get_openai_key),
catalog=Depends(get_catalog)
):
"""
Create a chat completion
"""
try:
oai_client = openai.AsyncOpenAI(api_key=api_key)
if completion.tools:
if isinstance(completion.tools[0], str):
specs = []
for tool in completion.tools:
specs.append(json.loads(schema_to_openai_tool(catalog[tool])))
completion.tool_choice = "required"
completion.tools = specs
result = await oai_client.chat.completions.create(**completion.dict())
return result
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))

View file

@ -0,0 +1,27 @@
from openai import AsyncOpenAI
api_key = "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"
client = AsyncOpenAI(api_key=api_key, base_url="http://localhost:8000/v1")
# Using 'async' with 'await' for proper asynchronous call
async def get_chat_response():
response = await client.chat.completions.create(
messages=[
{"role": "system", "content": "You are a friendly assistant named Jarvis. Help with whatever you can."},
{"role": "user", "content": "Hey there! What's your name?"},
],
model="gpt-4-turbo",
tools=["ReadEmail"],
stream=False
)
return response
async def print_chat_responses():
response = await get_chat_response()
print(response.choices[0].message)
import asyncio
asyncio.run(print_chat_responses())