multiple branch example
This commit is contained in:
parent
2e4542c260
commit
16c8156f98
5 changed files with 208 additions and 91 deletions
|
|
@ -13,4 +13,5 @@ ReadEmail = "gmailer.read_email@0.1.0"
|
|||
PlotDataframe = "gmailer.plot_dataframe@0.1.0"
|
||||
ReadSqlite = "read_sqlite.read_sqlite@0.1.0"
|
||||
Summarize = "chat.summarize@0.1.0"
|
||||
TranscribeText = "chat.transcribe_text@0.1.0"
|
||||
VectorSearch = "search.vector_search@0.1.0"
|
||||
|
|
|
|||
|
|
@ -1,10 +1,27 @@
|
|||
|
||||
|
||||
from typing import (
|
||||
IO,
|
||||
Union,
|
||||
List,
|
||||
Dict,
|
||||
Optional,
|
||||
Any,
|
||||
Type,
|
||||
)
|
||||
import io
|
||||
import requests
|
||||
from os import PathLike
|
||||
import base64
|
||||
|
||||
from toolserve.sdk import Param, tool, get_secret
|
||||
from toolserve.sdk.dataframe import get_df
|
||||
from typing import List
|
||||
import pandas as pd
|
||||
import openai
|
||||
|
||||
|
||||
|
||||
@tool
|
||||
async def summarize(
|
||||
text: Param(str, "Text to summarize"),
|
||||
|
|
@ -42,3 +59,62 @@ async def summarize(
|
|||
)
|
||||
summary = completion.choices[0].message.content
|
||||
return summary
|
||||
|
||||
|
||||
@tool
|
||||
async def transcribe_text(
|
||||
audio_file: Param(str, "Audio file bytes"),
|
||||
system_prompt: Param(str, "System prompt to use") = "Transcribe the following audio files",
|
||||
) -> Param(str, "Transcribed text"):
|
||||
"""Use OpenAI to translate audio to text using the Whisper model.
|
||||
|
||||
Args:
|
||||
audio_file_bytes (str): The bytes of the audio file to transcribe.
|
||||
system_prompt (str): The system prompt to use for guiding the transcription.
|
||||
|
||||
Returns:
|
||||
str: The transcribed text.
|
||||
"""
|
||||
api_key = get_secret("openai_api_key", None)
|
||||
model = get_secret("openai_model_whisper", "whisper-1")
|
||||
|
||||
if audio_file is None:
|
||||
raise ValueError("No audio file provided")
|
||||
|
||||
# Decode the base64 audio file
|
||||
audio_file_bytes = base64.b64decode(audio_file)
|
||||
file = io.BytesIO(audio_file_bytes)
|
||||
|
||||
# Prepare the headers
|
||||
headers = {
|
||||
'Authorization': f'Bearer {api_key}',
|
||||
}
|
||||
|
||||
# Prepare the files parameter
|
||||
files = {
|
||||
'file': ('audio.mp3', file, 'audio/mp3')
|
||||
}
|
||||
|
||||
# Prepare the data parameter
|
||||
data = {
|
||||
'model': model,
|
||||
'prompt': system_prompt,
|
||||
'response_format': 'text'
|
||||
}
|
||||
|
||||
# Send the request to the OpenAI Whisper API
|
||||
response = requests.post(
|
||||
'https://api.openai.com/v1/audio/transcriptions',
|
||||
headers=headers,
|
||||
files=files,
|
||||
data=data
|
||||
)
|
||||
|
||||
# Check if the request was successful
|
||||
if response.status_code == 200:
|
||||
# Return the plain text response directly
|
||||
return response.text
|
||||
else:
|
||||
# Handle errors
|
||||
raise Exception(f"Error: {response.status_code} - {response.text}")
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import httpx
|
|||
import json
|
||||
import time
|
||||
import openai
|
||||
import uuid
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
|
@ -41,6 +42,7 @@ def pydantic_to_openai_tool(model: Type[BaseModel]) -> str:
|
|||
class Edge(BaseModel):
|
||||
source: int = Field(..., description="The ID of the source node")
|
||||
target: int = Field(..., description="The ID of the target node")
|
||||
uuid: str = Field(default_factory=lambda: str(uuid.uuid4()), description="UUID for the data flow between nodes")
|
||||
|
||||
class ToolNode(BaseModel):
|
||||
node_id: int = Field(..., description="The ID of the node", ge=0)
|
||||
|
|
@ -50,7 +52,7 @@ class ToolNode(BaseModel):
|
|||
predict_args: bool = Field(True, description="Whether to predict the arguments for the tool")
|
||||
from_node: Optional[Dict[str, int]] = Field(None, description="The ID of the source node name of the argument to pass to the tool")
|
||||
args: Optional[Dict[str, Any]] = Field(None, description="The arguments to pass to the tool")
|
||||
|
||||
allow_extra: bool = Field(False, description="Whether to allow extra arguments to be passed to the tool")
|
||||
|
||||
class OutputType(Enum):
|
||||
DATA = "data"
|
||||
|
|
@ -64,10 +66,37 @@ class FlowSchema(BaseModel):
|
|||
edges: List[Edge] = Field([], description="The IDs of the adjacent nodes")
|
||||
output_type: OutputType = Field(OutputType.CHAT, description="The type of the output")
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
self.generate_uuids_for_edges()
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
use_enum_values = True
|
||||
|
||||
def generate_uuids_for_edges(self):
|
||||
edge_map = {}
|
||||
for edge in self.edges:
|
||||
edge_map[(edge.source, edge.target)] = edge.uuid
|
||||
for node in self.nodes:
|
||||
incoming_edges = [e.uuid for e in self.edges if e.target == node.node_id]
|
||||
outgoing_edges = [e.uuid for e in self.edges if e.source == node.node_id]
|
||||
if node.from_node:
|
||||
node.input_name = None
|
||||
node.output_name = None
|
||||
# Set the output of the source node and the input of the target node to None
|
||||
for edge in self.edges:
|
||||
if edge.target == node.node_id:
|
||||
source_node = next((n for n in self.nodes if n.node_id == edge.source), None)
|
||||
if source_node:
|
||||
source_node.output_name = None
|
||||
if edge.source == node.node_id:
|
||||
target_node = next((n for n in self.nodes if n.node_id == edge.target), None)
|
||||
if target_node:
|
||||
target_node.input_name = None
|
||||
else:
|
||||
node.input_name = incoming_edges[0] if incoming_edges else None
|
||||
node.output_name = outgoing_edges[0] if outgoing_edges else None
|
||||
|
||||
class ToolClient:
|
||||
|
||||
|
|
@ -259,20 +288,19 @@ class ToolRunner:
|
|||
if tool.predict_args:
|
||||
messages = self.__create_prompt(user_query, source, tool.output_name)
|
||||
tool_args = self.get_tool_args(tool.tool_name, messages, tool.output_name)
|
||||
elif tool.from_node:
|
||||
# todo change to list
|
||||
tool_args = kwargs.get("tool_args", {})
|
||||
else:
|
||||
tool_args = {}
|
||||
tool_args = kwargs.get("tool_args", {})
|
||||
|
||||
# TODO would something ever have an input_name and not need a data_id?
|
||||
if tool.input_name:
|
||||
tool_args["data_id"] = self._data_id
|
||||
|
||||
if tool.output_name:
|
||||
tool_args["output_name"] = tool.output_name
|
||||
|
||||
if tool.args:
|
||||
tool_args.update(tool.args)
|
||||
|
||||
|
||||
print("Calling tool with args:", tool_args)
|
||||
result = self._client.execute_tool(tool.tool_name, tool_args)
|
||||
return result
|
||||
|
|
@ -295,20 +323,19 @@ class ToolFlow:
|
|||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
prompt: str,
|
||||
base_url: str = "http://localhost:8000",
|
||||
model: str = "gpt-4-turbo",
|
||||
model_api_key: Optional[str] = None
|
||||
):
|
||||
self.name = name
|
||||
self.description = description
|
||||
self.prompt = prompt
|
||||
|
||||
self.runner = ToolRunner(base_url, model, model_api_key)
|
||||
self.model = model
|
||||
self.openai_client = openai.Client(api_key=model_api_key)
|
||||
|
||||
|
||||
def execute_flow(self, flow_schema: Dict[str, Any], user_query: str) -> Any:
|
||||
def execute_flow(self, flow_schema: Dict[str, Any], user_query: str, user_args: Dict[str, Any] = {}) -> Any:
|
||||
"""
|
||||
Executes the tool flow based on the provided schema. This method performs a breadth-first search (BFS)
|
||||
on the graph defined by the flow schema and executes each node according to the order determined by the BFS.
|
||||
|
|
@ -343,6 +370,8 @@ class ToolFlow:
|
|||
visited.add(node_id)
|
||||
|
||||
exec_start_time = time.time()
|
||||
|
||||
tool_args = {}
|
||||
# Execute the current node's operation using runner.run_tool
|
||||
current_tool = ToolNode(**current_node)
|
||||
if current_tool.from_node:
|
||||
|
|
@ -350,10 +379,9 @@ class ToolFlow:
|
|||
for arg_name, from_node_id in current_tool.from_node.items():
|
||||
from_node_result = results[from_node_id]["data"]["result"]
|
||||
tool_args[arg_name] = from_node_result
|
||||
|
||||
operation_result = self.runner.run_tool(current_tool, user_query, tool_args=tool_args)
|
||||
else:
|
||||
operation_result = self.runner.run_tool(current_tool, user_query)
|
||||
if current_tool.allow_extra:
|
||||
tool_args.update(user_args)
|
||||
operation_result = self.runner.run_tool(current_tool, user_query, tool_args=tool_args)
|
||||
|
||||
results[node_id] = operation_result
|
||||
exec_end_time = time.time()
|
||||
|
|
@ -388,48 +416,12 @@ class ToolFlow:
|
|||
return (data, results, sink_output_type, timings)
|
||||
|
||||
|
||||
|
||||
|
||||
plotting_flow = FlowSchema(
|
||||
nodes=[
|
||||
ToolNode(node_id=0, input_name="products", tool_name="query_sql", output_name="product_data"),
|
||||
ToolNode(node_id=1, input_name="product_data", tool_name="PlotDataframe", output_name=None),
|
||||
],
|
||||
edges=[
|
||||
Edge(source=0, target=1)
|
||||
],
|
||||
output_type=OutputType.ARTIFACT
|
||||
)
|
||||
|
||||
|
||||
email_flow_1 = FlowSchema(
|
||||
nodes=[
|
||||
ToolNode(node_id=0, input_name=None, tool_name="ReadEmail", output_name="email_data_1"),
|
||||
ToolNode(node_id=1, input_name="email_data_1", tool_name="Summarize", output_name=None),
|
||||
],
|
||||
edges=[
|
||||
Edge(source=0, target=1)
|
||||
],
|
||||
output_type=OutputType.CHAT
|
||||
)
|
||||
|
||||
email_flow = FlowSchema(
|
||||
nodes=[
|
||||
ToolNode(node_id=0, tool_name="ReadEmail"),
|
||||
ToolNode(node_id=1, tool_name="Summarize", from_node={"text": 2}, predict_args=False),
|
||||
],
|
||||
edges=[
|
||||
Edge(source=0, target=1)
|
||||
],
|
||||
output_type=OutputType.CHAT
|
||||
)
|
||||
|
||||
review_db = "/Users/spartee/Dropbox/Arcade/platform/toolserver/examples/data/food-reviews/database.sqlite"
|
||||
review_flow = FlowSchema(
|
||||
nodes=[
|
||||
ToolNode(node_id=0, tool_name="ReadSqlite", args={"table_name": "Reviews", "file_path": review_db, "output_name": "reviews"}, predict_args=False),
|
||||
ToolNode(node_id=1, input_name="reviews", tool_name="query_sql", output_name="review_data"),
|
||||
ToolNode(node_id=2, input_name="review_data", tool_name="search_text_columns"),
|
||||
ToolNode(node_id=0, tool_name="ReadSqlite", args={"table_name": "Reviews", "file_path": review_db}, predict_args=False),
|
||||
ToolNode(node_id=1, tool_name="query_sql"),
|
||||
ToolNode(node_id=2, tool_name="search_text_columns"),
|
||||
ToolNode(node_id=3, tool_name="Summarize", from_node={"text": 2}, predict_args=False),
|
||||
],
|
||||
edges=[
|
||||
|
|
@ -440,16 +432,42 @@ review_flow = FlowSchema(
|
|||
output_type=OutputType.CHAT
|
||||
)
|
||||
|
||||
plotting_flow = FlowSchema(
|
||||
nodes=[
|
||||
ToolNode(node_id=0, tool_name="ReadSqlite", args={"table_name": "Reviews", "file_path": review_db}, predict_args=False),
|
||||
ToolNode(node_id=1, tool_name="query_sql"),
|
||||
ToolNode(node_id=2, tool_name="PlotDataframe"),
|
||||
],
|
||||
edges=[
|
||||
Edge(source=0, target=1),
|
||||
Edge(source=1, target=2)
|
||||
],
|
||||
output_type=OutputType.ARTIFACT
|
||||
)
|
||||
|
||||
|
||||
email_flow = FlowSchema(
|
||||
nodes=[
|
||||
ToolNode(node_id=0, tool_name="ReadEmail"),
|
||||
ToolNode(node_id=1, tool_name="Summarize", from_node={"text": 0}, predict_args=False),
|
||||
],
|
||||
edges=[
|
||||
Edge(source=0, target=1)
|
||||
],
|
||||
output_type=OutputType.CHAT
|
||||
)
|
||||
|
||||
|
||||
|
||||
shopify_db = "/Users/spartee/Dropbox/Arcade/platform/toolserver/examples/data/olist.sqlite"
|
||||
customer_flow = FlowSchema(
|
||||
nodes=[
|
||||
ToolNode(node_id=0, tool_name="ReadSqlite", args={"table_name": "customers", "file_path": shopify_db, "output_name": "customers"}, predict_args=False),
|
||||
ToolNode(node_id=1, tool_name="ReadSqlite", args={"table_name": "orders", "file_path": shopify_db, "output_name": "all_customer_orders"}, predict_args=False),
|
||||
ToolNode(node_id=2, input_name="customers", tool_name="query_sql", output_name="customer_data"),
|
||||
ToolNode(node_id=3, input_name="all_customer_orders", tool_name="query_sql", output_name="customer_orders"),
|
||||
ToolNode(node_id=4, input_name="customer_data", tool_name="get"),
|
||||
ToolNode(node_id=5, input_name="customer_orders", tool_name="get"),
|
||||
ToolNode(node_id=0, tool_name="ReadSqlite", args={"table_name": "customers", "file_path": shopify_db}, predict_args=False),
|
||||
ToolNode(node_id=1, tool_name="ReadSqlite", args={"table_name": "orders", "file_path": shopify_db}, predict_args=False),
|
||||
ToolNode(node_id=2, tool_name="query_sql"),
|
||||
ToolNode(node_id=3, tool_name="query_sql"),
|
||||
ToolNode(node_id=4, tool_name="get"),
|
||||
ToolNode(node_id=5, tool_name="get"),
|
||||
ToolNode(node_id=6, tool_name="combine_results", from_node={"result_1": 4, "result_2": 5}, predict_args=False),
|
||||
ToolNode(node_id=7, tool_name="Summarize", from_node={"text": 6}, predict_args=False)
|
||||
],
|
||||
|
|
@ -466,6 +484,18 @@ customer_flow = FlowSchema(
|
|||
)
|
||||
|
||||
|
||||
audio_files = ["/Users/spartee/Desktop/notes.mp3"]
|
||||
notetaker = FlowSchema(
|
||||
nodes=[
|
||||
ToolNode(node_id=0, tool_name="TranscribeText", predict_args=False, allow_extra=True),
|
||||
ToolNode(node_id=1, tool_name="Summarize", from_node={"text": 0}, predict_args=False),
|
||||
],
|
||||
edges=[
|
||||
Edge(source=0, target=1)
|
||||
],
|
||||
output_type=OutputType.CHAT
|
||||
)
|
||||
|
||||
|
||||
def print_flow_as_yaml(data: Dict[str, Any]):
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import openai
|
|||
|
||||
oai_key = "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"
|
||||
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import subprocess
|
||||
|
|
@ -10,7 +10,6 @@ import sys
|
|||
import time
|
||||
import traceback
|
||||
import os
|
||||
|
||||
from typing import Dict, Any
|
||||
import networkx as nx
|
||||
import matplotlib.pyplot as plt
|
||||
|
|
@ -21,28 +20,10 @@ from streamlit_chat import message
|
|||
import streamlit.components.v1 as components
|
||||
from textwrap import dedent
|
||||
import plotly.express as px
|
||||
from agent import ToolFlow, email_flow, plotting_flow, review_flow, customer_flow
|
||||
from agent import ToolFlow, email_flow, plotting_flow, review_flow, customer_flow, notetaker
|
||||
|
||||
|
||||
|
||||
PROMPT = dedent("""Given a user query, construct a graph based representation of functions (nodes), and their data flow (edges) such that
|
||||
the graph can be executed to supply the user query enough information to answer their query.
|
||||
|
||||
You must construct the graph with the following constraints:
|
||||
- There can only be 1 source node and 1 sink node.
|
||||
- There should be no leaf nodes besides the sink node.
|
||||
- The source and sink can be the same node.
|
||||
|
||||
Only use the available nodes and their output types as edges. Create unique ids for each node starting from 0.
|
||||
|
||||
The available nodes are:
|
||||
{nodes}
|
||||
|
||||
The available input names for the source are:
|
||||
{sources}
|
||||
""")
|
||||
|
||||
|
||||
def plot_flow(data: Dict[str, Any]):
|
||||
"""
|
||||
Plot the flow of data using a directed graph.
|
||||
|
|
@ -86,7 +67,6 @@ def get_agent():
|
|||
AnalysisTool = ToolFlow(
|
||||
name="data_analysis",
|
||||
description="A tool flow for data analysis",
|
||||
prompt=PROMPT,
|
||||
model_api_key=oai_key
|
||||
)
|
||||
return AnalysisTool
|
||||
|
|
@ -95,7 +75,7 @@ def get_agent():
|
|||
# From here down is all the StreamLit UI.
|
||||
st.set_page_config(page_title="Arcade AI Demo", page_icon=":robot:", layout="wide")
|
||||
|
||||
dropdown_options = ["Gmailer", "PlotBot", "ReviewChat", "CustomerService"]
|
||||
dropdown_options = ["Gmailer", "PlotBot", "ReviewChat", "CustomerService", "Notetaker"]
|
||||
selected_option = st.sidebar.selectbox("Select an App:", dropdown_options)
|
||||
st.sidebar.write(f"Selected App: {selected_option}")
|
||||
|
||||
|
|
@ -112,6 +92,10 @@ if "past" not in st.session_state:
|
|||
st.session_state["past"] = []
|
||||
if "generated" not in st.session_state:
|
||||
st.session_state["generated"] = []
|
||||
if "input" not in st.session_state:
|
||||
st.session_state["input"] = ""
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
@ -121,9 +105,8 @@ st.subheader("Arcade AI Agent Demo")
|
|||
chat_container = st.container()
|
||||
input_container = st.container()
|
||||
|
||||
def submit():
|
||||
submit_text = st.session_state["input"]
|
||||
st.session_state["input"] = ""
|
||||
def submit(data=None):
|
||||
|
||||
with st.spinner(text="Wait for Agent..."):
|
||||
try:
|
||||
agent = get_agent()
|
||||
|
|
@ -137,11 +120,15 @@ def submit():
|
|||
json_flow = review_flow.dict()
|
||||
elif selected_option == "CustomerService":
|
||||
json_flow = customer_flow.dict()
|
||||
elif selected_option == "Notetaker":
|
||||
json_flow = notetaker.dict()
|
||||
else:
|
||||
st.error("Invalid option selected")
|
||||
return
|
||||
|
||||
print(json_flow)
|
||||
plot_flow(json_flow)
|
||||
submit_text = st.session_state.input
|
||||
st.session_state.input = ""
|
||||
res = agent.execute_flow(json_flow, submit_text)
|
||||
except Exception:
|
||||
st.error("Error executing the flow:")
|
||||
|
|
@ -150,12 +137,35 @@ def submit():
|
|||
st.session_state.past.append(submit_text)
|
||||
st.session_state.generated.append(res)
|
||||
|
||||
def get_text():
|
||||
input_text = st.text_input("You: ", key="input", on_change=submit)
|
||||
return input_text
|
||||
|
||||
def run_notetaker():
|
||||
with st.spinner(text="Wait for Agent..."):
|
||||
try:
|
||||
agent = get_agent()
|
||||
json_flow = notetaker.dict()
|
||||
plot_flow(json_flow)
|
||||
audio_file = st.session_state.audio_file
|
||||
if audio_file is None:
|
||||
st.error("No audio file uploaded")
|
||||
return
|
||||
|
||||
audio_file_byte_str = base64.b64encode(audio_file.getvalue()).decode("utf-8")
|
||||
res = agent.execute_flow(json_flow, "placeholder", user_args={"audio_file": audio_file_byte_str})
|
||||
except Exception:
|
||||
st.error("Error executing the flow:")
|
||||
st.error(traceback.format_exc())
|
||||
return
|
||||
st.session_state.past.append("Audio File")
|
||||
st.session_state.generated.append(res)
|
||||
|
||||
|
||||
|
||||
with input_container:
|
||||
user_input = get_text()
|
||||
if selected_option != "Notetaker":
|
||||
st.text_input("You: ", key="input", on_change=submit)
|
||||
else:
|
||||
st.file_uploader("Upload an audio file", type=["mp3", "wav", "ogg"], key="audio_file", on_change=run_notetaker)
|
||||
|
||||
|
||||
if st.session_state["generated"]:
|
||||
with chat_container:
|
||||
|
|
|
|||
|
|
@ -4,12 +4,12 @@ from toolserve.server.routes.tool import router as tool_router
|
|||
from toolserve.server.routes.data import router as data_router
|
||||
from toolserve.server.routes.artifact import router as artifact_router
|
||||
from toolserve.server.routes.log import router as log_router
|
||||
|
||||
from toolserve.server.routes.slack import router as slack_router
|
||||
|
||||
v1 = APIRouter(prefix=settings.API_V1_STR)
|
||||
v1.include_router(tool_router, prefix="/tools", tags=["Tool Catalog"])
|
||||
v1.include_router(data_router, prefix="/data", tags=["Data Management"])
|
||||
v1.include_router(artifact_router, prefix="/artifact", tags=["Artifact Management"])
|
||||
v1.include_router(log_router, prefix="/log", tags=["Tool Logging API"])
|
||||
|
||||
v1.include_router(slack_router, prefix="/slack", tags=["Slack"])
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue