Retryable errors (#20)

Co-authored-by: Sterling Dreyer <sdreyer@ucsc.edu>
This commit is contained in:
Nate Barbettini 2024-08-22 16:17:15 -07:00 committed by GitHub
parent acba912816
commit ab703b75ef
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 145 additions and 11 deletions

View file

@ -86,7 +86,10 @@ class BaseActor(Actor):
else ToolCallOutput(value=f"Tool {tool_name} called successfully")
)
else:
# TODO flatten this to just ToolCallError
output = ToolCallOutput(error=ToolCallError(message=response.msg))
if response.code == 425:
output.error.additional_prompt_content = response.additional_prompt_content
end_time = time.time() # End time in seconds
duration_ms = (end_time - start_time) * 1000 # Convert to milliseconds

View file

@ -1,6 +1,6 @@
import asyncio
import os
from typing import Optional
from typing import Any, Optional
import typer
from openai.resources.chat.completions import ChatCompletionChunk, Stream
@ -192,7 +192,7 @@ def run(
def chat(
model: str = typer.Option("gpt-4o", "-m", help="The model to use for prediction."),
stream: bool = typer.Option(
False, "-s", "--stream", is_flag=True, help="Stream the tool output."
True, "-s", "--stream", is_flag=True, help="Stream the tool output."
),
) -> None:
"""
@ -236,7 +236,8 @@ def chat(
tool_choice="generate",
user=user,
)
display_streamed_markdown(stream_response)
role, message = display_streamed_markdown(stream_response)
messages.append({"role": role, "content": message})
else:
response = client.complete(
model=model,
@ -379,21 +380,27 @@ def display_config_as_table(config: Config) -> None:
console.print(table)
def display_streamed_markdown(stream: Stream[ChatCompletionChunk]) -> None:
def display_streamed_markdown(stream: Stream[ChatCompletionChunk]) -> tuple[str, dict[str, Any]]:
"""
Display the streamed markdown chunks as a single line.
"""
from rich.live import Live
full_message = ""
role = ""
with Live(console=console, refresh_per_second=10) as live:
for chunk in stream:
choice = chunk.choices[0]
chunk_message = choice.delta.content
if role == "":
role = choice.delta.role
if role == "assistant":
console.print("\n[bold blue]Assistant:[/bold blue] ")
if chunk_message:
full_message += chunk_message
markdown_chunk = Markdown(full_message)
live.update(markdown_chunk)
return role, full_message
def create_cli_catalog(

View file

@ -1,3 +1,6 @@
from typing import Optional
class ToolError(Exception):
"""
Base class for all errors related to tools.
@ -18,7 +21,9 @@ class ToolDefinitionError(ToolError):
class ToolRuntimeError(RuntimeError):
pass
def __init__(self, message: str):
super().__init__(message)
self.message = message
class ToolExecutionError(ToolRuntimeError):
@ -26,7 +31,25 @@ class ToolExecutionError(ToolRuntimeError):
Raised when there is an error executing a tool.
"""
pass
def __init__(self, message: str, developer_message: Optional[str] = None):
super().__init__(message)
self.developer_message = developer_message
class RetryableToolError(ToolExecutionError):
"""
Raised when a tool error is retryable.
"""
def __init__(
self,
message: str,
developer_message: Optional[str] = None,
additional_prompt_content: Optional[str] = None,
):
super().__init__(message)
self.developer_message = developer_message
self.additional_prompt_content = additional_prompt_content
class ToolSerializationError(ToolRuntimeError):

View file

@ -4,6 +4,7 @@ from typing import Any, Callable
from pydantic import BaseModel, ValidationError
from arcade.core.errors import (
RetryableToolError,
ToolExecutionError,
ToolInputError,
ToolOutputError,
@ -50,6 +51,11 @@ class ToolExecutor:
# return the output
return tool_response.success(data=output)
except RetryableToolError as e:
return tool_response.fail_retry(
msg=str(e), additional_prompt_content=e.additional_prompt_content
)
except ToolSerializationError as e:
return tool_response.fail(msg=str(e))

View file

@ -22,6 +22,7 @@ class ToolResponse(BaseModel, Generic[T]):
code: int = CustomResponseCode.HTTP_200.code
msg: str = CustomResponseCode.HTTP_200.msg
additional_prompt_content: str | None = None
#
data: T | None = None
@ -67,5 +68,21 @@ class ToolResponseFactory:
data=data,
)
def fail_retry(
self,
*,
res: CustomResponseCode | CustomResponse = CustomResponseCode.HTTP_425,
msg: str = CustomResponseCode.HTTP_425.msg,
data: Any = None,
additional_prompt_content: str | None = None,
) -> ToolResponse:
res = self.__response(
res=res,
msg=msg,
data=data,
)
res.additional_prompt_content = additional_prompt_content
return res
tool_response = ToolResponseFactory()

View file

@ -168,6 +168,7 @@ class ToolCallError(BaseModel):
"""The user-facing error message."""
developer_message: str | None = None
"""The developer-facing error details."""
additional_prompt_content: str | None = None
class ToolCallOutput(BaseModel):

View file

@ -23,6 +23,7 @@ actor.register_tool(repo.count_stargazers)
actor.register_tool(repo.search_issues)
actor.register_tool(user.set_starred)
actor.register_tool(chat.send_dm_to_user)
actor.register_tool(chat.send_message_to_channel)
class ChatRequest(BaseModel):
@ -47,6 +48,7 @@ async def postChat(request: ChatRequest, tool_choice: str = "execute"):
"SetStarred",
"SearchIssues",
"SendDmToUser",
"SendMessageToChannel",
],
tool_choice=tool_choice,
user="sam",

View file

@ -1,4 +1,6 @@
import time
from typing import Annotated
from arcade.core.errors import ToolExecutionError, RetryableToolError
from arcade.core.schema import ToolContext
from arcade.sdk import tool
from arcade.sdk.auth import SlackUser
@ -22,9 +24,9 @@ def send_dm_to_user(
try:
# Step 1: Retrieve the user's Slack ID based on their username
response = slackClient.users_list()
userListResponse = slackClient.users_list()
user_id = None
for user in response["members"]:
for user in userListResponse["members"]:
if user["name"].lower() == user_name.lower():
user_id = user["id"]
break
@ -33,7 +35,14 @@ def send_dm_to_user(
# does this end up as a developerMessage?
# does it end up in the LLM context?
# provide the dev an Error type that controls what ends up in the LLM context
raise ValueError(f"User with username '{user_name}' not found.")
# TODO make the sleep configurable and sent to the engine
time.sleep(0.5) # Wait for half a second
raise RetryableToolError(
"User not found",
developer_message=f"User with username '{user_name}' not found.",
additional_prompt_content=format_users(userListResponse),
)
# Step 2: Retrieve the DM channel ID with the user
im_response = slackClient.conversations_open(users=[user_id])
@ -43,5 +52,71 @@ def send_dm_to_user(
slackClient.chat_postMessage(channel=dm_channel_id, text=message)
except SlackApiError as e:
# this should be caught also, not printed
print(f"Error sending message: {e.response['error']}")
raise ToolExecutionError(
f"Error sending message: {e.response['error']}",
developer_message="Error sending message",
)
def format_users(userListResponse: dict) -> str:
csv_string = "All active Slack users:\n\nid,name,real_name\n"
for user in userListResponse["members"]:
if not user.get("deleted", False):
user_id = user.get("id", "")
name = user.get("name", "")
real_name = user.get("profile", {}).get("real_name", "")
csv_string += f"{user_id},{name},{real_name}\n"
return csv_string.strip()
@tool(
requires_auth=SlackUser(
scope=["chat:write", "channels:read", "groups:read"],
)
)
def send_message_to_channel(
context: ToolContext,
channel_name: Annotated[
str, "The Slack channel name where you want to send the message"
],
message: Annotated[str, "The message you want to send"],
):
"""Send a message to a channel in Slack."""
slackClient = WebClient(token=context.authorization.token)
try:
# Step 1: Retrieve the list of channels
channels_response = slackClient.conversations_list()
channel_id = None
for channel in channels_response["channels"]:
if channel["name"].lower() == channel_name.lower():
channel_id = channel["id"]
break
if not channel_id:
time.sleep(0.5) # Wait for half a second
raise RetryableToolError(
"Channel not found",
developer_message=f"Channel with name '{channel_name}' not found.",
additional_prompt_content=format_channels(channels_response),
)
# Step 2: Send the message to the channel
slackClient.chat_postMessage(channel=channel_id, text=message)
except SlackApiError as e:
raise ToolExecutionError(
f"Error sending message: {e.response['error']}",
developer_message="Error sending message",
)
def format_channels(channels_response: dict) -> str:
csv_string = "All active Slack channels:\n\nid,name\n"
for channel in channels_response["channels"]:
if not channel.get("is_archived", False):
channel_id = channel.get("id", "")
name = channel.get("name", "")
csv_string += f"{channel_id},{name}\n"
return csv_string.strip()