multiple branch example

This commit is contained in:
Sam Partee 2024-05-14 21:47:10 -07:00
parent 2e4542c260
commit 16c8156f98
5 changed files with 208 additions and 91 deletions

View file

@ -13,4 +13,5 @@ ReadEmail = "gmailer.read_email@0.1.0"
PlotDataframe = "gmailer.plot_dataframe@0.1.0"
ReadSqlite = "read_sqlite.read_sqlite@0.1.0"
Summarize = "chat.summarize@0.1.0"
TranscribeText = "chat.transcribe_text@0.1.0"
VectorSearch = "search.vector_search@0.1.0"

View file

@ -1,10 +1,27 @@
from typing import (
IO,
Union,
List,
Dict,
Optional,
Any,
Type,
)
import io
import requests
from os import PathLike
import base64
from toolserve.sdk import Param, tool, get_secret
from toolserve.sdk.dataframe import get_df
from typing import List
import pandas as pd
import openai
@tool
async def summarize(
text: Param(str, "Text to summarize"),
@ -42,3 +59,62 @@ async def summarize(
)
summary = completion.choices[0].message.content
return summary
@tool
async def transcribe_text(
audio_file: Param(str, "Audio file bytes"),
system_prompt: Param(str, "System prompt to use") = "Transcribe the following audio files",
) -> Param(str, "Transcribed text"):
"""Use OpenAI to translate audio to text using the Whisper model.
Args:
audio_file_bytes (str): The bytes of the audio file to transcribe.
system_prompt (str): The system prompt to use for guiding the transcription.
Returns:
str: The transcribed text.
"""
api_key = get_secret("openai_api_key", None)
model = get_secret("openai_model_whisper", "whisper-1")
if audio_file is None:
raise ValueError("No audio file provided")
# Decode the base64 audio file
audio_file_bytes = base64.b64decode(audio_file)
file = io.BytesIO(audio_file_bytes)
# Prepare the headers
headers = {
'Authorization': f'Bearer {api_key}',
}
# Prepare the files parameter
files = {
'file': ('audio.mp3', file, 'audio/mp3')
}
# Prepare the data parameter
data = {
'model': model,
'prompt': system_prompt,
'response_format': 'text'
}
# Send the request to the OpenAI Whisper API
response = requests.post(
'https://api.openai.com/v1/audio/transcriptions',
headers=headers,
files=files,
data=data
)
# Check if the request was successful
if response.status_code == 200:
# Return the plain text response directly
return response.text
else:
# Handle errors
raise Exception(f"Error: {response.status_code} - {response.text}")

View file

@ -2,6 +2,7 @@ import httpx
import json
import time
import openai
import uuid
from typing import Any, Dict, List, Optional, Tuple, Union
from pydantic import BaseModel
@ -41,6 +42,7 @@ def pydantic_to_openai_tool(model: Type[BaseModel]) -> str:
class Edge(BaseModel):
source: int = Field(..., description="The ID of the source node")
target: int = Field(..., description="The ID of the target node")
uuid: str = Field(default_factory=lambda: str(uuid.uuid4()), description="UUID for the data flow between nodes")
class ToolNode(BaseModel):
node_id: int = Field(..., description="The ID of the node", ge=0)
@ -50,7 +52,7 @@ class ToolNode(BaseModel):
predict_args: bool = Field(True, description="Whether to predict the arguments for the tool")
from_node: Optional[Dict[str, int]] = Field(None, description="The ID of the source node name of the argument to pass to the tool")
args: Optional[Dict[str, Any]] = Field(None, description="The arguments to pass to the tool")
allow_extra: bool = Field(False, description="Whether to allow extra arguments to be passed to the tool")
class OutputType(Enum):
DATA = "data"
@ -64,10 +66,37 @@ class FlowSchema(BaseModel):
edges: List[Edge] = Field([], description="The IDs of the adjacent nodes")
output_type: OutputType = Field(OutputType.CHAT, description="The type of the output")
def __init__(self, **data):
super().__init__(**data)
self.generate_uuids_for_edges()
class Config:
arbitrary_types_allowed = True
use_enum_values = True
def generate_uuids_for_edges(self):
edge_map = {}
for edge in self.edges:
edge_map[(edge.source, edge.target)] = edge.uuid
for node in self.nodes:
incoming_edges = [e.uuid for e in self.edges if e.target == node.node_id]
outgoing_edges = [e.uuid for e in self.edges if e.source == node.node_id]
if node.from_node:
node.input_name = None
node.output_name = None
# Set the output of the source node and the input of the target node to None
for edge in self.edges:
if edge.target == node.node_id:
source_node = next((n for n in self.nodes if n.node_id == edge.source), None)
if source_node:
source_node.output_name = None
if edge.source == node.node_id:
target_node = next((n for n in self.nodes if n.node_id == edge.target), None)
if target_node:
target_node.input_name = None
else:
node.input_name = incoming_edges[0] if incoming_edges else None
node.output_name = outgoing_edges[0] if outgoing_edges else None
class ToolClient:
@ -259,20 +288,19 @@ class ToolRunner:
if tool.predict_args:
messages = self.__create_prompt(user_query, source, tool.output_name)
tool_args = self.get_tool_args(tool.tool_name, messages, tool.output_name)
elif tool.from_node:
# todo change to list
tool_args = kwargs.get("tool_args", {})
else:
tool_args = {}
tool_args = kwargs.get("tool_args", {})
# TODO would something ever have an input_name and not need a data_id?
if tool.input_name:
tool_args["data_id"] = self._data_id
if tool.output_name:
tool_args["output_name"] = tool.output_name
if tool.args:
tool_args.update(tool.args)
print("Calling tool with args:", tool_args)
result = self._client.execute_tool(tool.tool_name, tool_args)
return result
@ -295,20 +323,19 @@ class ToolFlow:
self,
name: str,
description: str,
prompt: str,
base_url: str = "http://localhost:8000",
model: str = "gpt-4-turbo",
model_api_key: Optional[str] = None
):
self.name = name
self.description = description
self.prompt = prompt
self.runner = ToolRunner(base_url, model, model_api_key)
self.model = model
self.openai_client = openai.Client(api_key=model_api_key)
def execute_flow(self, flow_schema: Dict[str, Any], user_query: str) -> Any:
def execute_flow(self, flow_schema: Dict[str, Any], user_query: str, user_args: Dict[str, Any] = {}) -> Any:
"""
Executes the tool flow based on the provided schema. This method performs a breadth-first search (BFS)
on the graph defined by the flow schema and executes each node according to the order determined by the BFS.
@ -343,6 +370,8 @@ class ToolFlow:
visited.add(node_id)
exec_start_time = time.time()
tool_args = {}
# Execute the current node's operation using runner.run_tool
current_tool = ToolNode(**current_node)
if current_tool.from_node:
@ -350,10 +379,9 @@ class ToolFlow:
for arg_name, from_node_id in current_tool.from_node.items():
from_node_result = results[from_node_id]["data"]["result"]
tool_args[arg_name] = from_node_result
operation_result = self.runner.run_tool(current_tool, user_query, tool_args=tool_args)
else:
operation_result = self.runner.run_tool(current_tool, user_query)
if current_tool.allow_extra:
tool_args.update(user_args)
operation_result = self.runner.run_tool(current_tool, user_query, tool_args=tool_args)
results[node_id] = operation_result
exec_end_time = time.time()
@ -388,48 +416,12 @@ class ToolFlow:
return (data, results, sink_output_type, timings)
plotting_flow = FlowSchema(
nodes=[
ToolNode(node_id=0, input_name="products", tool_name="query_sql", output_name="product_data"),
ToolNode(node_id=1, input_name="product_data", tool_name="PlotDataframe", output_name=None),
],
edges=[
Edge(source=0, target=1)
],
output_type=OutputType.ARTIFACT
)
email_flow_1 = FlowSchema(
nodes=[
ToolNode(node_id=0, input_name=None, tool_name="ReadEmail", output_name="email_data_1"),
ToolNode(node_id=1, input_name="email_data_1", tool_name="Summarize", output_name=None),
],
edges=[
Edge(source=0, target=1)
],
output_type=OutputType.CHAT
)
email_flow = FlowSchema(
nodes=[
ToolNode(node_id=0, tool_name="ReadEmail"),
ToolNode(node_id=1, tool_name="Summarize", from_node={"text": 2}, predict_args=False),
],
edges=[
Edge(source=0, target=1)
],
output_type=OutputType.CHAT
)
review_db = "/Users/spartee/Dropbox/Arcade/platform/toolserver/examples/data/food-reviews/database.sqlite"
review_flow = FlowSchema(
nodes=[
ToolNode(node_id=0, tool_name="ReadSqlite", args={"table_name": "Reviews", "file_path": review_db, "output_name": "reviews"}, predict_args=False),
ToolNode(node_id=1, input_name="reviews", tool_name="query_sql", output_name="review_data"),
ToolNode(node_id=2, input_name="review_data", tool_name="search_text_columns"),
ToolNode(node_id=0, tool_name="ReadSqlite", args={"table_name": "Reviews", "file_path": review_db}, predict_args=False),
ToolNode(node_id=1, tool_name="query_sql"),
ToolNode(node_id=2, tool_name="search_text_columns"),
ToolNode(node_id=3, tool_name="Summarize", from_node={"text": 2}, predict_args=False),
],
edges=[
@ -440,16 +432,42 @@ review_flow = FlowSchema(
output_type=OutputType.CHAT
)
plotting_flow = FlowSchema(
nodes=[
ToolNode(node_id=0, tool_name="ReadSqlite", args={"table_name": "Reviews", "file_path": review_db}, predict_args=False),
ToolNode(node_id=1, tool_name="query_sql"),
ToolNode(node_id=2, tool_name="PlotDataframe"),
],
edges=[
Edge(source=0, target=1),
Edge(source=1, target=2)
],
output_type=OutputType.ARTIFACT
)
email_flow = FlowSchema(
nodes=[
ToolNode(node_id=0, tool_name="ReadEmail"),
ToolNode(node_id=1, tool_name="Summarize", from_node={"text": 0}, predict_args=False),
],
edges=[
Edge(source=0, target=1)
],
output_type=OutputType.CHAT
)
shopify_db = "/Users/spartee/Dropbox/Arcade/platform/toolserver/examples/data/olist.sqlite"
customer_flow = FlowSchema(
nodes=[
ToolNode(node_id=0, tool_name="ReadSqlite", args={"table_name": "customers", "file_path": shopify_db, "output_name": "customers"}, predict_args=False),
ToolNode(node_id=1, tool_name="ReadSqlite", args={"table_name": "orders", "file_path": shopify_db, "output_name": "all_customer_orders"}, predict_args=False),
ToolNode(node_id=2, input_name="customers", tool_name="query_sql", output_name="customer_data"),
ToolNode(node_id=3, input_name="all_customer_orders", tool_name="query_sql", output_name="customer_orders"),
ToolNode(node_id=4, input_name="customer_data", tool_name="get"),
ToolNode(node_id=5, input_name="customer_orders", tool_name="get"),
ToolNode(node_id=0, tool_name="ReadSqlite", args={"table_name": "customers", "file_path": shopify_db}, predict_args=False),
ToolNode(node_id=1, tool_name="ReadSqlite", args={"table_name": "orders", "file_path": shopify_db}, predict_args=False),
ToolNode(node_id=2, tool_name="query_sql"),
ToolNode(node_id=3, tool_name="query_sql"),
ToolNode(node_id=4, tool_name="get"),
ToolNode(node_id=5, tool_name="get"),
ToolNode(node_id=6, tool_name="combine_results", from_node={"result_1": 4, "result_2": 5}, predict_args=False),
ToolNode(node_id=7, tool_name="Summarize", from_node={"text": 6}, predict_args=False)
],
@ -466,6 +484,18 @@ customer_flow = FlowSchema(
)
audio_files = ["/Users/spartee/Desktop/notes.mp3"]
notetaker = FlowSchema(
nodes=[
ToolNode(node_id=0, tool_name="TranscribeText", predict_args=False, allow_extra=True),
ToolNode(node_id=1, tool_name="Summarize", from_node={"text": 0}, predict_args=False),
],
edges=[
Edge(source=0, target=1)
],
output_type=OutputType.CHAT
)
def print_flow_as_yaml(data: Dict[str, Any]):

View file

@ -2,7 +2,7 @@ import openai
oai_key = "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"
import base64
import json
import logging
import subprocess
@ -10,7 +10,6 @@ import sys
import time
import traceback
import os
from typing import Dict, Any
import networkx as nx
import matplotlib.pyplot as plt
@ -21,28 +20,10 @@ from streamlit_chat import message
import streamlit.components.v1 as components
from textwrap import dedent
import plotly.express as px
from agent import ToolFlow, email_flow, plotting_flow, review_flow, customer_flow
from agent import ToolFlow, email_flow, plotting_flow, review_flow, customer_flow, notetaker
PROMPT = dedent("""Given a user query, construct a graph based representation of functions (nodes), and their data flow (edges) such that
the graph can be executed to supply the user query enough information to answer their query.
You must construct the graph with the following constraints:
- There can only be 1 source node and 1 sink node.
- There should be no leaf nodes besides the sink node.
- The source and sink can be the same node.
Only use the available nodes and their output types as edges. Create unique ids for each node starting from 0.
The available nodes are:
{nodes}
The available input names for the source are:
{sources}
""")
def plot_flow(data: Dict[str, Any]):
"""
Plot the flow of data using a directed graph.
@ -86,7 +67,6 @@ def get_agent():
AnalysisTool = ToolFlow(
name="data_analysis",
description="A tool flow for data analysis",
prompt=PROMPT,
model_api_key=oai_key
)
return AnalysisTool
@ -95,7 +75,7 @@ def get_agent():
# From here down is all the StreamLit UI.
st.set_page_config(page_title="Arcade AI Demo", page_icon=":robot:", layout="wide")
dropdown_options = ["Gmailer", "PlotBot", "ReviewChat", "CustomerService"]
dropdown_options = ["Gmailer", "PlotBot", "ReviewChat", "CustomerService", "Notetaker"]
selected_option = st.sidebar.selectbox("Select an App:", dropdown_options)
st.sidebar.write(f"Selected App: {selected_option}")
@ -112,6 +92,10 @@ if "past" not in st.session_state:
st.session_state["past"] = []
if "generated" not in st.session_state:
st.session_state["generated"] = []
if "input" not in st.session_state:
st.session_state["input"] = ""
@ -121,9 +105,8 @@ st.subheader("Arcade AI Agent Demo")
chat_container = st.container()
input_container = st.container()
def submit():
submit_text = st.session_state["input"]
st.session_state["input"] = ""
def submit(data=None):
with st.spinner(text="Wait for Agent..."):
try:
agent = get_agent()
@ -137,11 +120,15 @@ def submit():
json_flow = review_flow.dict()
elif selected_option == "CustomerService":
json_flow = customer_flow.dict()
elif selected_option == "Notetaker":
json_flow = notetaker.dict()
else:
st.error("Invalid option selected")
return
print(json_flow)
plot_flow(json_flow)
submit_text = st.session_state.input
st.session_state.input = ""
res = agent.execute_flow(json_flow, submit_text)
except Exception:
st.error("Error executing the flow:")
@ -150,12 +137,35 @@ def submit():
st.session_state.past.append(submit_text)
st.session_state.generated.append(res)
def get_text():
input_text = st.text_input("You: ", key="input", on_change=submit)
return input_text
def run_notetaker():
with st.spinner(text="Wait for Agent..."):
try:
agent = get_agent()
json_flow = notetaker.dict()
plot_flow(json_flow)
audio_file = st.session_state.audio_file
if audio_file is None:
st.error("No audio file uploaded")
return
audio_file_byte_str = base64.b64encode(audio_file.getvalue()).decode("utf-8")
res = agent.execute_flow(json_flow, "placeholder", user_args={"audio_file": audio_file_byte_str})
except Exception:
st.error("Error executing the flow:")
st.error(traceback.format_exc())
return
st.session_state.past.append("Audio File")
st.session_state.generated.append(res)
with input_container:
user_input = get_text()
if selected_option != "Notetaker":
st.text_input("You: ", key="input", on_change=submit)
else:
st.file_uploader("Upload an audio file", type=["mp3", "wav", "ogg"], key="audio_file", on_change=run_notetaker)
if st.session_state["generated"]:
with chat_container:

View file

@ -4,12 +4,12 @@ from toolserve.server.routes.tool import router as tool_router
from toolserve.server.routes.data import router as data_router
from toolserve.server.routes.artifact import router as artifact_router
from toolserve.server.routes.log import router as log_router
from toolserve.server.routes.slack import router as slack_router
v1 = APIRouter(prefix=settings.API_V1_STR)
v1.include_router(tool_router, prefix="/tools", tags=["Tool Catalog"])
v1.include_router(data_router, prefix="/data", tags=["Data Management"])
v1.include_router(artifact_router, prefix="/artifact", tags=["Artifact Management"])
v1.include_router(log_router, prefix="/log", tags=["Tool Logging API"])
v1.include_router(slack_router, prefix="/slack", tags=["Slack"])