From d37303de6a54805c6551228175a62d3d72f4b522 Mon Sep 17 00:00:00 2001 From: Nate Barbettini Date: Tue, 27 Aug 2024 16:19:22 -0700 Subject: [PATCH] Clean up retryable errors (#21) Clean up logic of retries --- arcade/arcade/actor/core/base.py | 17 +--- arcade/arcade/cli/main.py | 40 +++++---- arcade/arcade/core/errors.py | 20 +++-- arcade/arcade/core/executor.py | 38 ++++---- arcade/arcade/core/output.py | 48 ++++++++++ arcade/arcade/core/response.py | 88 ------------------- arcade/arcade/core/schema.py | 5 ++ .../preview/invoke_tool_response.schema.jsonc | 13 +++ schemas/preview/tool_definition.schema.jsonc | 3 +- toolkits/slack/arcade_slack/tools/chat.py | 52 ++++++----- 10 files changed, 157 insertions(+), 167 deletions(-) create mode 100644 arcade/arcade/core/output.py delete mode 100644 arcade/arcade/core/response.py diff --git a/arcade/arcade/actor/core/base.py b/arcade/arcade/actor/core/base.py index 82e7fbf9..68c07be6 100644 --- a/arcade/arcade/actor/core/base.py +++ b/arcade/arcade/actor/core/base.py @@ -12,8 +12,6 @@ from arcade.actor.core.components import ( from arcade.core.catalog import ToolCatalog, Toolkit from arcade.core.executor import ToolExecutor from arcade.core.schema import ( - ToolCallError, - ToolCallOutput, ToolCallRequest, ToolCallResponse, ToolDefinition, @@ -71,7 +69,7 @@ class BaseActor(Actor): start_time = time.time() - response = await ToolExecutor.run( + output = await ToolExecutor.run( func=materialized_tool.tool, definition=materialized_tool.definition, input_model=materialized_tool.input_model, @@ -79,17 +77,6 @@ class BaseActor(Actor): context=tool_request.context, **tool_request.inputs or {}, ) - if response.code == 200 and response.data is not None: - output = ( - ToolCallOutput(value=response.data.result) - if hasattr(response.data, "result") and response.data.result - else ToolCallOutput(value=f"Tool {tool_name} called successfully") - ) - else: - # TODO flatten this to just ToolCallError - output = ToolCallOutput(error=ToolCallError(message=response.msg)) - if response.code == 425: - output.error.additional_prompt_content = response.additional_prompt_content end_time = time.time() # End time in seconds duration_ms = (end_time - start_time) * 1000 # Convert to milliseconds @@ -98,7 +85,7 @@ class BaseActor(Actor): invocation_id=tool_request.invocation_id, duration=duration_ms, finished_at=datetime.now().isoformat(), - success=response.code == 200, + success=not output.error, output=output, ) diff --git a/arcade/arcade/cli/main.py b/arcade/arcade/cli/main.py index 8ae8163e..ca9824df 100644 --- a/arcade/arcade/cli/main.py +++ b/arcade/arcade/cli/main.py @@ -15,7 +15,7 @@ from typer.models import Context from arcade.core.catalog import ToolCatalog from arcade.core.client import EngineClient from arcade.core.config import Config -from arcade.core.schema import ToolContext +from arcade.core.schema import ToolCallOutput, ToolContext from arcade.core.toolkit import Toolkit @@ -145,7 +145,7 @@ def run( console.print(f"Calling tool: {tool_name} with params: {parameters}", style="bold blue") # TODO async.gather instead of loop. - output = asyncio.run( + output: ToolCallOutput = asyncio.run( ToolExecutor.run( called_tool.tool, called_tool.definition, @@ -155,22 +155,20 @@ def run( **parameters, ) ) - if output.code != 200: - console.print(output.msg, style="bold red") - if output.data: - console.print(output.data.result, style="bold red") - typer.Exit(code=1) + if output.error: + console.print(output.error.message, style="bold red") + typer.Exit(code=1) else: messages += [ { "role": "assistant", # TODO: escape the output and ensure serialization works - "content": f"Results of Tool {tool_name}: {output.data.result!s}", # type: ignore[union-attr] + "content": f"Results of Tool {tool_name}: {output.value!s}", }, ] if choice == "execute": - console.print(output.data.result, style="green") # type: ignore[union-attr] + console.print(output.value, style="green") raise typer.Exit(0) else: if stream: @@ -206,9 +204,18 @@ def chat( client = EngineClient(base_url=config.engine_url) + if config.user and config.user.email: + user_email = config.user.email + user_attribution = f"({user_email})" + else: + console.print( + "❌ User email not found in configuration. Please run `arcade login`.", style="bold red" + ) + typer.Exit(code=1) + try: # start messages conversation - messages = [] + messages: list[dict[str, Any]] = [] chat_header = Text.assemble( "\n", @@ -220,12 +227,9 @@ def chat( ) console.print(chat_header) - user = config.user.email if config.user and config.user.email else None - user_attribution = f" ({user})" if user else "" - while True: user_input = console.input( - f"\n[magenta][bold]User[/bold]{user_attribution}:[/magenta] " + f"\n[magenta][bold]User[/bold] {user_attribution}:[/magenta] " ) messages.append({"role": "user", "content": user_input}) @@ -234,7 +238,7 @@ def chat( model=model, messages=messages, tool_choice="generate", - user=user, + user=user_email, ) role, message = display_streamed_markdown(stream_response) messages.append({"role": role, "content": message}) @@ -243,7 +247,7 @@ def chat( model=model, messages=messages, tool_choice="generate", - user=user, + user=user_email, ) message_content = response.choices[0].message.content or "" role = response.choices[0].message.role @@ -380,7 +384,7 @@ def display_config_as_table(config: Config) -> None: console.print(table) -def display_streamed_markdown(stream: Stream[ChatCompletionChunk]) -> tuple[str, dict[str, Any]]: +def display_streamed_markdown(stream: Stream[ChatCompletionChunk]) -> tuple[str, str]: """ Display the streamed markdown chunks as a single line. """ @@ -393,7 +397,7 @@ def display_streamed_markdown(stream: Stream[ChatCompletionChunk]) -> tuple[str, choice = chunk.choices[0] chunk_message = choice.delta.content if role == "": - role = choice.delta.role + role = choice.delta.role or "" if role == "assistant": console.print("\n[bold blue]Assistant:[/bold blue] ") if chunk_message: diff --git a/arcade/arcade/core/errors.py b/arcade/arcade/core/errors.py index 0f8f4808..0d9dee9e 100644 --- a/arcade/arcade/core/errors.py +++ b/arcade/arcade/core/errors.py @@ -21,9 +21,10 @@ class ToolDefinitionError(ToolError): class ToolRuntimeError(RuntimeError): - def __init__(self, message: str): + def __init__(self, message: str, developer_message: Optional[str] = None): super().__init__(message) self.message = message + self.developer_message = developer_message class ToolExecutionError(ToolRuntimeError): @@ -32,8 +33,7 @@ class ToolExecutionError(ToolRuntimeError): """ def __init__(self, message: str, developer_message: Optional[str] = None): - super().__init__(message) - self.developer_message = developer_message + super().__init__(message, developer_message) class RetryableToolError(ToolExecutionError): @@ -46,10 +46,11 @@ class RetryableToolError(ToolExecutionError): message: str, developer_message: Optional[str] = None, additional_prompt_content: Optional[str] = None, + retry_after_ms: Optional[int] = None, ): - super().__init__(message) - self.developer_message = developer_message + super().__init__(message, developer_message) self.additional_prompt_content = additional_prompt_content + self.retry_after_ms = retry_after_ms class ToolSerializationError(ToolRuntimeError): @@ -57,7 +58,8 @@ class ToolSerializationError(ToolRuntimeError): Raised when there is an error executing a tool. """ - pass + def __init__(self, message: str, developer_message: Optional[str] = None): + super().__init__(message, developer_message) class ToolInputError(ToolSerializationError): @@ -65,7 +67,8 @@ class ToolInputError(ToolSerializationError): Raised when there is an error in the input to a tool. """ - pass + def __init__(self, message: str, developer_message: Optional[str] = None): + super().__init__(message, developer_message) class ToolOutputError(ToolSerializationError): @@ -73,4 +76,5 @@ class ToolOutputError(ToolSerializationError): Raised when there is an error in the output of a tool. """ - pass + def __init__(self, message: str, developer_message: Optional[str] = None): + super().__init__(message, developer_message) diff --git a/arcade/arcade/core/executor.py b/arcade/arcade/core/executor.py index d89d058a..bf9d9882 100644 --- a/arcade/arcade/core/executor.py +++ b/arcade/arcade/core/executor.py @@ -5,13 +5,12 @@ from pydantic import BaseModel, ValidationError from arcade.core.errors import ( RetryableToolError, - ToolExecutionError, ToolInputError, ToolOutputError, - ToolSerializationError, + ToolRuntimeError, ) -from arcade.core.response import ToolResponse, tool_response -from arcade.core.schema import ToolContext, ToolDefinition +from arcade.core.output import output_factory +from arcade.core.schema import ToolCallOutput, ToolContext, ToolDefinition class ToolExecutor: @@ -24,7 +23,7 @@ class ToolExecutor: context: ToolContext, *args: Any, **kwargs: Any, - ) -> ToolResponse: + ) -> ToolCallOutput: """ Execute a callable function with validated inputs and outputs via Pydantic models. """ @@ -49,23 +48,30 @@ class ToolExecutor: output = await ToolExecutor._serialize_output(output_model, results) # return the output - return tool_response.success(data=output) + return output_factory.success(data=output) except RetryableToolError as e: - return tool_response.fail_retry( - msg=str(e), additional_prompt_content=e.additional_prompt_content + return output_factory.fail_retry( + message=e.message, + developer_message=e.developer_message, + additional_prompt_content=e.additional_prompt_content, + retry_after_ms=e.retry_after_ms, ) - except ToolSerializationError as e: - return tool_response.fail(msg=str(e)) + except ToolInputError as e: + return output_factory.fail(message=e.message, developer_message=e.developer_message) - except ToolExecutionError as e: - return tool_response.fail(msg=str(e)) + except ToolOutputError as e: + return output_factory.fail(message=e.message, developer_message=e.developer_message) + + except ToolRuntimeError as e: # Catch any remaining tool-related errors + return output_factory.fail( + message=f"Error in execution: {e.message}", developer_message=e.developer_message + ) # if we get here we're in trouble - # TODO: Debate if this is necessary except Exception as e: - return tool_response.fail(msg=str(e)) + return output_factory.fail(message="Error in execution", developer_message=str(e)) @staticmethod async def _serialize_input(input_model: type[BaseModel], **kwargs: Any) -> BaseModel: @@ -79,7 +85,7 @@ class ToolExecutor: inputs = input_model(**kwargs) except ValidationError as e: - raise ToolInputError from e + raise ToolInputError(message="Error in input", developer_message=str(e)) from e return inputs @@ -97,6 +103,6 @@ class ToolExecutor: output = output_model(**{"result": results}) except ValidationError as e: - raise ToolOutputError from e + raise ToolOutputError(message="Error in output", developer_message=str(e)) from e return output diff --git a/arcade/arcade/core/output.py b/arcade/arcade/core/output.py new file mode 100644 index 00000000..f4c7d61c --- /dev/null +++ b/arcade/arcade/core/output.py @@ -0,0 +1,48 @@ +from typing import TypeVar + +from arcade.core.schema import ToolCallError, ToolCallOutput + +T = TypeVar("T") + + +class ToolOutputFactory: + """ + Singleton pattern for unified return method from tools. + """ + + def success( + self, + *, + data: T | None = None, + ) -> ToolCallOutput: + value = data.result if data and hasattr(data, "result") and data.result else "" + + return ToolCallOutput(value=value) + + def fail(self, *, message: str, developer_message: str | None = None) -> ToolCallOutput: + return ToolCallOutput( + error=ToolCallError( + message=message, developer_message=developer_message, can_retry=False + ) + ) + + def fail_retry( + self, + *, + message: str, + developer_message: str | None = None, + additional_prompt_content: str | None = None, + retry_after_ms: int | None = None, + ) -> ToolCallOutput: + return ToolCallOutput( + error=ToolCallError( + message=message, + developer_message=developer_message, + can_retry=True, + additional_prompt_content=additional_prompt_content, + retry_after_ms=retry_after_ms, + ) + ) + + +output_factory = ToolOutputFactory() diff --git a/arcade/arcade/core/response.py b/arcade/arcade/core/response.py deleted file mode 100644 index 138a5956..00000000 --- a/arcade/arcade/core/response.py +++ /dev/null @@ -1,88 +0,0 @@ -from typing import Any, Generic, TypeVar - -from pydantic import BaseModel - -from arcade.core.response_code import ( - CustomResponse, - CustomResponseCode, -) - -_ExcludeData = set[int | str] | dict[int | str, Any] -T = TypeVar("T") - - -# TODO: Mapping of tool response actions to http codes? - - -class ToolResponse(BaseModel, Generic[T]): - """ - Generic unified return model for Tools - - """ - - code: int = CustomResponseCode.HTTP_200.code - msg: str = CustomResponseCode.HTTP_200.msg - additional_prompt_content: str | None = None - - # - data: T | None = None - - -class ToolResponseFactory: - """ - Singleton pattern for unified return method from tools. - """ - - @staticmethod - def __response( - *, - msg: str | None = None, - res: CustomResponseCode | CustomResponse = CustomResponseCode.HTTP_200, - data: T | None = None, - ) -> ToolResponse: - """ - General method for successful response - """ - if msg: - return ToolResponse(code=res.code, msg=msg, data=data) - return ToolResponse(code=res.code, msg=res.msg, data=data) - - def success( - self, - *, - res: CustomResponseCode | CustomResponse = CustomResponseCode.HTTP_200, - data: T | None = None, - ) -> ToolResponse: - return self.__response(res=res, data=data) - - def fail( - self, - *, - res: CustomResponseCode | CustomResponse = CustomResponseCode.HTTP_400, - msg: str = CustomResponseCode.HTTP_400.msg, - data: Any = None, - ) -> ToolResponse: - return self.__response( - res=res, - msg=msg, # TODO this needs to map to developer_message in output.error - data=data, - ) - - def fail_retry( - self, - *, - res: CustomResponseCode | CustomResponse = CustomResponseCode.HTTP_425, - msg: str = CustomResponseCode.HTTP_425.msg, - data: Any = None, - additional_prompt_content: str | None = None, - ) -> ToolResponse: - res = self.__response( - res=res, - msg=msg, - data=data, - ) - res.additional_prompt_content = additional_prompt_content - return res - - -tool_response = ToolResponseFactory() diff --git a/arcade/arcade/core/schema.py b/arcade/arcade/core/schema.py index 1c567cff..6699f0e6 100644 --- a/arcade/arcade/core/schema.py +++ b/arcade/arcade/core/schema.py @@ -168,7 +168,12 @@ class ToolCallError(BaseModel): """The user-facing error message.""" developer_message: str | None = None """The developer-facing error details.""" + can_retry: bool = False + """Whether the tool call can be retried.""" additional_prompt_content: str | None = None + """Additional content to be included in the retry prompt.""" + retry_after_ms: int | None = None + """The number of milliseconds (if any) to wait before retrying the tool call.""" class ToolCallOutput(BaseModel): diff --git a/schemas/preview/invoke_tool_response.schema.jsonc b/schemas/preview/invoke_tool_response.schema.jsonc index 1e752dcf..46c0bba5 100644 --- a/schemas/preview/invoke_tool_response.schema.jsonc +++ b/schemas/preview/invoke_tool_response.schema.jsonc @@ -54,6 +54,19 @@ "developer_message": { "type": "string", "description": "An internal message that will be logged but will not be shown to the user or the AI model" + }, + "can_retry": { + "type": "boolean", + "description": "Whether the tool call can be retried", + "default": false + }, + "additional_prompt_content": { + "type": "string", + "description": "Additional content to be included in the retry prompt" + }, + "retry_after_ms": { + "type": "integer", + "description": "The number of milliseconds (if any) to wait before retrying the tool call" } }, "required": ["message"], diff --git a/schemas/preview/tool_definition.schema.jsonc b/schemas/preview/tool_definition.schema.jsonc index bfd4e365..f34b52a6 100644 --- a/schemas/preview/tool_definition.schema.jsonc +++ b/schemas/preview/tool_definition.schema.jsonc @@ -122,8 +122,7 @@ "type": "object", "properties": { "provider": { - "type": "string", - "enum": ["oauth2", "github_app"] + "type": "string" }, "oauth2": { "type": "object", diff --git a/toolkits/slack/arcade_slack/tools/chat.py b/toolkits/slack/arcade_slack/tools/chat.py index c3cead33..9a94d819 100644 --- a/toolkits/slack/arcade_slack/tools/chat.py +++ b/toolkits/slack/arcade_slack/tools/chat.py @@ -1,4 +1,3 @@ -import time from typing import Annotated from arcade.core.errors import ToolExecutionError, RetryableToolError from arcade.core.schema import ToolContext @@ -10,7 +9,16 @@ from slack_sdk.errors import SlackApiError @tool( requires_auth=SlackUser( - scope=["chat:write", "im:write", "users.profile:read", "users:read"], + # TODO reduce this to chat:write, im:write, users.profile:read, users:read + # when incremental auth works + scope=[ + "chat:write", + "im:write", + "users.profile:read", + "users:read", + "channels:read", + "groups:read", + ], ) ) def send_dm_to_user( @@ -32,16 +40,11 @@ def send_dm_to_user( break if not user_id: - # does this end up as a developerMessage? - # does it end up in the LLM context? - # provide the dev an Error type that controls what ends up in the LLM context - - # TODO make the sleep configurable and sent to the engine - time.sleep(0.5) # Wait for half a second raise RetryableToolError( "User not found", developer_message=f"User with username '{user_name}' not found.", additional_prompt_content=format_users(userListResponse), + retry_after_ms=500, # Play nice with Slack API rate limits ) # Step 2: Retrieve the DM channel ID with the user @@ -52,26 +55,35 @@ def send_dm_to_user( slackClient.chat_postMessage(channel=dm_channel_id, text=message) except SlackApiError as e: + error_message = e.response["error"] if "error" in e.response else str(e) raise ToolExecutionError( - f"Error sending message: {e.response['error']}", - developer_message="Error sending message", + "Error sending message", + developer_message=f"Slack API Error: {error_message}", ) def format_users(userListResponse: dict) -> str: - csv_string = "All active Slack users:\n\nid,name,real_name\n" + csv_string = "All active Slack users:\n\nname,real_name\n" for user in userListResponse["members"]: if not user.get("deleted", False): - user_id = user.get("id", "") name = user.get("name", "") real_name = user.get("profile", {}).get("real_name", "") - csv_string += f"{user_id},{name},{real_name}\n" + csv_string += f"{name},{real_name}\n" return csv_string.strip() @tool( requires_auth=SlackUser( - scope=["chat:write", "channels:read", "groups:read"], + # TODO reduce this to chat:write, channels:read, groups:read + # when incremental auth works + scope=[ + "chat:write", + "im:write", + "users.profile:read", + "users:read", + "channels:read", + "groups:read", + ], ) ) def send_message_to_channel( @@ -95,28 +107,28 @@ def send_message_to_channel( break if not channel_id: - time.sleep(0.5) # Wait for half a second raise RetryableToolError( "Channel not found", developer_message=f"Channel with name '{channel_name}' not found.", additional_prompt_content=format_channels(channels_response), + retry_after_ms=500, # Play nice with Slack API rate limits ) # Step 2: Send the message to the channel slackClient.chat_postMessage(channel=channel_id, text=message) except SlackApiError as e: + error_message = e.response["error"] if "error" in e.response else str(e) raise ToolExecutionError( - f"Error sending message: {e.response['error']}", - developer_message="Error sending message", + "Error sending message", + developer_message=f"Slack API Error: {error_message}", ) def format_channels(channels_response: dict) -> str: - csv_string = "All active Slack channels:\n\nid,name\n" + csv_string = "All active Slack channels:\n\nname\n" for channel in channels_response["channels"]: if not channel.get("is_archived", False): - channel_id = channel.get("id", "") name = channel.get("name", "") - csv_string += f"{channel_id},{name}\n" + csv_string += f"{name}\n" return csv_string.strip()