From ab703b75ef80edd111bc0a60b84ecba715b69329 Mon Sep 17 00:00:00 2001 From: Nate Barbettini Date: Thu, 22 Aug 2024 16:17:15 -0700 Subject: [PATCH] Retryable errors (#20) Co-authored-by: Sterling Dreyer --- arcade/arcade/actor/core/base.py | 3 + arcade/arcade/cli/main.py | 15 +++- arcade/arcade/core/errors.py | 27 +++++- arcade/arcade/core/executor.py | 6 ++ arcade/arcade/core/response.py | 17 ++++ arcade/arcade/core/schema.py | 1 + .../fastapi/arcade_example_fastapi/main.py | 2 + toolkits/slack/arcade_slack/tools/chat.py | 85 +++++++++++++++++-- 8 files changed, 145 insertions(+), 11 deletions(-) diff --git a/arcade/arcade/actor/core/base.py b/arcade/arcade/actor/core/base.py index ce97841a..82e7fbf9 100644 --- a/arcade/arcade/actor/core/base.py +++ b/arcade/arcade/actor/core/base.py @@ -86,7 +86,10 @@ class BaseActor(Actor): else ToolCallOutput(value=f"Tool {tool_name} called successfully") ) else: + # TODO flatten this to just ToolCallError output = ToolCallOutput(error=ToolCallError(message=response.msg)) + if response.code == 425: + output.error.additional_prompt_content = response.additional_prompt_content end_time = time.time() # End time in seconds duration_ms = (end_time - start_time) * 1000 # Convert to milliseconds diff --git a/arcade/arcade/cli/main.py b/arcade/arcade/cli/main.py index 2c890d31..8ae8163e 100644 --- a/arcade/arcade/cli/main.py +++ b/arcade/arcade/cli/main.py @@ -1,6 +1,6 @@ import asyncio import os -from typing import Optional +from typing import Any, Optional import typer from openai.resources.chat.completions import ChatCompletionChunk, Stream @@ -192,7 +192,7 @@ def run( def chat( model: str = typer.Option("gpt-4o", "-m", help="The model to use for prediction."), stream: bool = typer.Option( - False, "-s", "--stream", is_flag=True, help="Stream the tool output." + True, "-s", "--stream", is_flag=True, help="Stream the tool output." ), ) -> None: """ @@ -236,7 +236,8 @@ def chat( tool_choice="generate", user=user, ) - display_streamed_markdown(stream_response) + role, message = display_streamed_markdown(stream_response) + messages.append({"role": role, "content": message}) else: response = client.complete( model=model, @@ -379,21 +380,27 @@ def display_config_as_table(config: Config) -> None: console.print(table) -def display_streamed_markdown(stream: Stream[ChatCompletionChunk]) -> None: +def display_streamed_markdown(stream: Stream[ChatCompletionChunk]) -> tuple[str, dict[str, Any]]: """ Display the streamed markdown chunks as a single line. """ from rich.live import Live full_message = "" + role = "" with Live(console=console, refresh_per_second=10) as live: for chunk in stream: choice = chunk.choices[0] chunk_message = choice.delta.content + if role == "": + role = choice.delta.role + if role == "assistant": + console.print("\n[bold blue]Assistant:[/bold blue] ") if chunk_message: full_message += chunk_message markdown_chunk = Markdown(full_message) live.update(markdown_chunk) + return role, full_message def create_cli_catalog( diff --git a/arcade/arcade/core/errors.py b/arcade/arcade/core/errors.py index 20f5677f..0f8f4808 100644 --- a/arcade/arcade/core/errors.py +++ b/arcade/arcade/core/errors.py @@ -1,3 +1,6 @@ +from typing import Optional + + class ToolError(Exception): """ Base class for all errors related to tools. @@ -18,7 +21,9 @@ class ToolDefinitionError(ToolError): class ToolRuntimeError(RuntimeError): - pass + def __init__(self, message: str): + super().__init__(message) + self.message = message class ToolExecutionError(ToolRuntimeError): @@ -26,7 +31,25 @@ class ToolExecutionError(ToolRuntimeError): Raised when there is an error executing a tool. """ - pass + def __init__(self, message: str, developer_message: Optional[str] = None): + super().__init__(message) + self.developer_message = developer_message + + +class RetryableToolError(ToolExecutionError): + """ + Raised when a tool error is retryable. + """ + + def __init__( + self, + message: str, + developer_message: Optional[str] = None, + additional_prompt_content: Optional[str] = None, + ): + super().__init__(message) + self.developer_message = developer_message + self.additional_prompt_content = additional_prompt_content class ToolSerializationError(ToolRuntimeError): diff --git a/arcade/arcade/core/executor.py b/arcade/arcade/core/executor.py index 069a5281..d89d058a 100644 --- a/arcade/arcade/core/executor.py +++ b/arcade/arcade/core/executor.py @@ -4,6 +4,7 @@ from typing import Any, Callable from pydantic import BaseModel, ValidationError from arcade.core.errors import ( + RetryableToolError, ToolExecutionError, ToolInputError, ToolOutputError, @@ -50,6 +51,11 @@ class ToolExecutor: # return the output return tool_response.success(data=output) + except RetryableToolError as e: + return tool_response.fail_retry( + msg=str(e), additional_prompt_content=e.additional_prompt_content + ) + except ToolSerializationError as e: return tool_response.fail(msg=str(e)) diff --git a/arcade/arcade/core/response.py b/arcade/arcade/core/response.py index 0615ccf7..138a5956 100644 --- a/arcade/arcade/core/response.py +++ b/arcade/arcade/core/response.py @@ -22,6 +22,7 @@ class ToolResponse(BaseModel, Generic[T]): code: int = CustomResponseCode.HTTP_200.code msg: str = CustomResponseCode.HTTP_200.msg + additional_prompt_content: str | None = None # data: T | None = None @@ -67,5 +68,21 @@ class ToolResponseFactory: data=data, ) + def fail_retry( + self, + *, + res: CustomResponseCode | CustomResponse = CustomResponseCode.HTTP_425, + msg: str = CustomResponseCode.HTTP_425.msg, + data: Any = None, + additional_prompt_content: str | None = None, + ) -> ToolResponse: + res = self.__response( + res=res, + msg=msg, + data=data, + ) + res.additional_prompt_content = additional_prompt_content + return res + tool_response = ToolResponseFactory() diff --git a/arcade/arcade/core/schema.py b/arcade/arcade/core/schema.py index 62135b50..1c567cff 100644 --- a/arcade/arcade/core/schema.py +++ b/arcade/arcade/core/schema.py @@ -168,6 +168,7 @@ class ToolCallError(BaseModel): """The user-facing error message.""" developer_message: str | None = None """The developer-facing error details.""" + additional_prompt_content: str | None = None class ToolCallOutput(BaseModel): diff --git a/examples/fastapi/arcade_example_fastapi/main.py b/examples/fastapi/arcade_example_fastapi/main.py index 313df338..610473c9 100644 --- a/examples/fastapi/arcade_example_fastapi/main.py +++ b/examples/fastapi/arcade_example_fastapi/main.py @@ -23,6 +23,7 @@ actor.register_tool(repo.count_stargazers) actor.register_tool(repo.search_issues) actor.register_tool(user.set_starred) actor.register_tool(chat.send_dm_to_user) +actor.register_tool(chat.send_message_to_channel) class ChatRequest(BaseModel): @@ -47,6 +48,7 @@ async def postChat(request: ChatRequest, tool_choice: str = "execute"): "SetStarred", "SearchIssues", "SendDmToUser", + "SendMessageToChannel", ], tool_choice=tool_choice, user="sam", diff --git a/toolkits/slack/arcade_slack/tools/chat.py b/toolkits/slack/arcade_slack/tools/chat.py index 42f7f222..c3cead33 100644 --- a/toolkits/slack/arcade_slack/tools/chat.py +++ b/toolkits/slack/arcade_slack/tools/chat.py @@ -1,4 +1,6 @@ +import time from typing import Annotated +from arcade.core.errors import ToolExecutionError, RetryableToolError from arcade.core.schema import ToolContext from arcade.sdk import tool from arcade.sdk.auth import SlackUser @@ -22,9 +24,9 @@ def send_dm_to_user( try: # Step 1: Retrieve the user's Slack ID based on their username - response = slackClient.users_list() + userListResponse = slackClient.users_list() user_id = None - for user in response["members"]: + for user in userListResponse["members"]: if user["name"].lower() == user_name.lower(): user_id = user["id"] break @@ -33,7 +35,14 @@ def send_dm_to_user( # does this end up as a developerMessage? # does it end up in the LLM context? # provide the dev an Error type that controls what ends up in the LLM context - raise ValueError(f"User with username '{user_name}' not found.") + + # TODO make the sleep configurable and sent to the engine + time.sleep(0.5) # Wait for half a second + raise RetryableToolError( + "User not found", + developer_message=f"User with username '{user_name}' not found.", + additional_prompt_content=format_users(userListResponse), + ) # Step 2: Retrieve the DM channel ID with the user im_response = slackClient.conversations_open(users=[user_id]) @@ -43,5 +52,71 @@ def send_dm_to_user( slackClient.chat_postMessage(channel=dm_channel_id, text=message) except SlackApiError as e: - # this should be caught also, not printed - print(f"Error sending message: {e.response['error']}") + raise ToolExecutionError( + f"Error sending message: {e.response['error']}", + developer_message="Error sending message", + ) + + +def format_users(userListResponse: dict) -> str: + csv_string = "All active Slack users:\n\nid,name,real_name\n" + for user in userListResponse["members"]: + if not user.get("deleted", False): + user_id = user.get("id", "") + name = user.get("name", "") + real_name = user.get("profile", {}).get("real_name", "") + csv_string += f"{user_id},{name},{real_name}\n" + return csv_string.strip() + + +@tool( + requires_auth=SlackUser( + scope=["chat:write", "channels:read", "groups:read"], + ) +) +def send_message_to_channel( + context: ToolContext, + channel_name: Annotated[ + str, "The Slack channel name where you want to send the message" + ], + message: Annotated[str, "The message you want to send"], +): + """Send a message to a channel in Slack.""" + + slackClient = WebClient(token=context.authorization.token) + + try: + # Step 1: Retrieve the list of channels + channels_response = slackClient.conversations_list() + channel_id = None + for channel in channels_response["channels"]: + if channel["name"].lower() == channel_name.lower(): + channel_id = channel["id"] + break + + if not channel_id: + time.sleep(0.5) # Wait for half a second + raise RetryableToolError( + "Channel not found", + developer_message=f"Channel with name '{channel_name}' not found.", + additional_prompt_content=format_channels(channels_response), + ) + + # Step 2: Send the message to the channel + slackClient.chat_postMessage(channel=channel_id, text=message) + + except SlackApiError as e: + raise ToolExecutionError( + f"Error sending message: {e.response['error']}", + developer_message="Error sending message", + ) + + +def format_channels(channels_response: dict) -> str: + csv_string = "All active Slack channels:\n\nid,name\n" + for channel in channels_response["channels"]: + if not channel.get("is_archived", False): + channel_id = channel.get("id", "") + name = channel.get("name", "") + csv_string += f"{channel_id},{name}\n" + return csv_string.strip()