Clean up retryable errors (#21)

Clean up logic of retries
This commit is contained in:
Nate Barbettini 2024-08-27 16:19:22 -07:00 committed by GitHub
parent e7ccbe0efa
commit d37303de6a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 157 additions and 167 deletions

View file

@ -12,8 +12,6 @@ from arcade.actor.core.components import (
from arcade.core.catalog import ToolCatalog, Toolkit
from arcade.core.executor import ToolExecutor
from arcade.core.schema import (
ToolCallError,
ToolCallOutput,
ToolCallRequest,
ToolCallResponse,
ToolDefinition,
@ -71,7 +69,7 @@ class BaseActor(Actor):
start_time = time.time()
response = await ToolExecutor.run(
output = await ToolExecutor.run(
func=materialized_tool.tool,
definition=materialized_tool.definition,
input_model=materialized_tool.input_model,
@ -79,17 +77,6 @@ class BaseActor(Actor):
context=tool_request.context,
**tool_request.inputs or {},
)
if response.code == 200 and response.data is not None:
output = (
ToolCallOutput(value=response.data.result)
if hasattr(response.data, "result") and response.data.result
else ToolCallOutput(value=f"Tool {tool_name} called successfully")
)
else:
# TODO flatten this to just ToolCallError
output = ToolCallOutput(error=ToolCallError(message=response.msg))
if response.code == 425:
output.error.additional_prompt_content = response.additional_prompt_content
end_time = time.time() # End time in seconds
duration_ms = (end_time - start_time) * 1000 # Convert to milliseconds
@ -98,7 +85,7 @@ class BaseActor(Actor):
invocation_id=tool_request.invocation_id,
duration=duration_ms,
finished_at=datetime.now().isoformat(),
success=response.code == 200,
success=not output.error,
output=output,
)

View file

@ -15,7 +15,7 @@ from typer.models import Context
from arcade.core.catalog import ToolCatalog
from arcade.core.client import EngineClient
from arcade.core.config import Config
from arcade.core.schema import ToolContext
from arcade.core.schema import ToolCallOutput, ToolContext
from arcade.core.toolkit import Toolkit
@ -145,7 +145,7 @@ def run(
console.print(f"Calling tool: {tool_name} with params: {parameters}", style="bold blue")
# TODO async.gather instead of loop.
output = asyncio.run(
output: ToolCallOutput = asyncio.run(
ToolExecutor.run(
called_tool.tool,
called_tool.definition,
@ -155,22 +155,20 @@ def run(
**parameters,
)
)
if output.code != 200:
console.print(output.msg, style="bold red")
if output.data:
console.print(output.data.result, style="bold red")
typer.Exit(code=1)
if output.error:
console.print(output.error.message, style="bold red")
typer.Exit(code=1)
else:
messages += [
{
"role": "assistant",
# TODO: escape the output and ensure serialization works
"content": f"Results of Tool {tool_name}: {output.data.result!s}", # type: ignore[union-attr]
"content": f"Results of Tool {tool_name}: {output.value!s}",
},
]
if choice == "execute":
console.print(output.data.result, style="green") # type: ignore[union-attr]
console.print(output.value, style="green")
raise typer.Exit(0)
else:
if stream:
@ -206,9 +204,18 @@ def chat(
client = EngineClient(base_url=config.engine_url)
if config.user and config.user.email:
user_email = config.user.email
user_attribution = f"({user_email})"
else:
console.print(
"❌ User email not found in configuration. Please run `arcade login`.", style="bold red"
)
typer.Exit(code=1)
try:
# start messages conversation
messages = []
messages: list[dict[str, Any]] = []
chat_header = Text.assemble(
"\n",
@ -220,12 +227,9 @@ def chat(
)
console.print(chat_header)
user = config.user.email if config.user and config.user.email else None
user_attribution = f" ({user})" if user else ""
while True:
user_input = console.input(
f"\n[magenta][bold]User[/bold]{user_attribution}:[/magenta] "
f"\n[magenta][bold]User[/bold] {user_attribution}:[/magenta] "
)
messages.append({"role": "user", "content": user_input})
@ -234,7 +238,7 @@ def chat(
model=model,
messages=messages,
tool_choice="generate",
user=user,
user=user_email,
)
role, message = display_streamed_markdown(stream_response)
messages.append({"role": role, "content": message})
@ -243,7 +247,7 @@ def chat(
model=model,
messages=messages,
tool_choice="generate",
user=user,
user=user_email,
)
message_content = response.choices[0].message.content or ""
role = response.choices[0].message.role
@ -380,7 +384,7 @@ def display_config_as_table(config: Config) -> None:
console.print(table)
def display_streamed_markdown(stream: Stream[ChatCompletionChunk]) -> tuple[str, dict[str, Any]]:
def display_streamed_markdown(stream: Stream[ChatCompletionChunk]) -> tuple[str, str]:
"""
Display the streamed markdown chunks as a single line.
"""
@ -393,7 +397,7 @@ def display_streamed_markdown(stream: Stream[ChatCompletionChunk]) -> tuple[str,
choice = chunk.choices[0]
chunk_message = choice.delta.content
if role == "":
role = choice.delta.role
role = choice.delta.role or ""
if role == "assistant":
console.print("\n[bold blue]Assistant:[/bold blue] ")
if chunk_message:

View file

@ -21,9 +21,10 @@ class ToolDefinitionError(ToolError):
class ToolRuntimeError(RuntimeError):
def __init__(self, message: str):
def __init__(self, message: str, developer_message: Optional[str] = None):
super().__init__(message)
self.message = message
self.developer_message = developer_message
class ToolExecutionError(ToolRuntimeError):
@ -32,8 +33,7 @@ class ToolExecutionError(ToolRuntimeError):
"""
def __init__(self, message: str, developer_message: Optional[str] = None):
super().__init__(message)
self.developer_message = developer_message
super().__init__(message, developer_message)
class RetryableToolError(ToolExecutionError):
@ -46,10 +46,11 @@ class RetryableToolError(ToolExecutionError):
message: str,
developer_message: Optional[str] = None,
additional_prompt_content: Optional[str] = None,
retry_after_ms: Optional[int] = None,
):
super().__init__(message)
self.developer_message = developer_message
super().__init__(message, developer_message)
self.additional_prompt_content = additional_prompt_content
self.retry_after_ms = retry_after_ms
class ToolSerializationError(ToolRuntimeError):
@ -57,7 +58,8 @@ class ToolSerializationError(ToolRuntimeError):
Raised when there is an error executing a tool.
"""
pass
def __init__(self, message: str, developer_message: Optional[str] = None):
super().__init__(message, developer_message)
class ToolInputError(ToolSerializationError):
@ -65,7 +67,8 @@ class ToolInputError(ToolSerializationError):
Raised when there is an error in the input to a tool.
"""
pass
def __init__(self, message: str, developer_message: Optional[str] = None):
super().__init__(message, developer_message)
class ToolOutputError(ToolSerializationError):
@ -73,4 +76,5 @@ class ToolOutputError(ToolSerializationError):
Raised when there is an error in the output of a tool.
"""
pass
def __init__(self, message: str, developer_message: Optional[str] = None):
super().__init__(message, developer_message)

View file

@ -5,13 +5,12 @@ from pydantic import BaseModel, ValidationError
from arcade.core.errors import (
RetryableToolError,
ToolExecutionError,
ToolInputError,
ToolOutputError,
ToolSerializationError,
ToolRuntimeError,
)
from arcade.core.response import ToolResponse, tool_response
from arcade.core.schema import ToolContext, ToolDefinition
from arcade.core.output import output_factory
from arcade.core.schema import ToolCallOutput, ToolContext, ToolDefinition
class ToolExecutor:
@ -24,7 +23,7 @@ class ToolExecutor:
context: ToolContext,
*args: Any,
**kwargs: Any,
) -> ToolResponse:
) -> ToolCallOutput:
"""
Execute a callable function with validated inputs and outputs via Pydantic models.
"""
@ -49,23 +48,30 @@ class ToolExecutor:
output = await ToolExecutor._serialize_output(output_model, results)
# return the output
return tool_response.success(data=output)
return output_factory.success(data=output)
except RetryableToolError as e:
return tool_response.fail_retry(
msg=str(e), additional_prompt_content=e.additional_prompt_content
return output_factory.fail_retry(
message=e.message,
developer_message=e.developer_message,
additional_prompt_content=e.additional_prompt_content,
retry_after_ms=e.retry_after_ms,
)
except ToolSerializationError as e:
return tool_response.fail(msg=str(e))
except ToolInputError as e:
return output_factory.fail(message=e.message, developer_message=e.developer_message)
except ToolExecutionError as e:
return tool_response.fail(msg=str(e))
except ToolOutputError as e:
return output_factory.fail(message=e.message, developer_message=e.developer_message)
except ToolRuntimeError as e: # Catch any remaining tool-related errors
return output_factory.fail(
message=f"Error in execution: {e.message}", developer_message=e.developer_message
)
# if we get here we're in trouble
# TODO: Debate if this is necessary
except Exception as e:
return tool_response.fail(msg=str(e))
return output_factory.fail(message="Error in execution", developer_message=str(e))
@staticmethod
async def _serialize_input(input_model: type[BaseModel], **kwargs: Any) -> BaseModel:
@ -79,7 +85,7 @@ class ToolExecutor:
inputs = input_model(**kwargs)
except ValidationError as e:
raise ToolInputError from e
raise ToolInputError(message="Error in input", developer_message=str(e)) from e
return inputs
@ -97,6 +103,6 @@ class ToolExecutor:
output = output_model(**{"result": results})
except ValidationError as e:
raise ToolOutputError from e
raise ToolOutputError(message="Error in output", developer_message=str(e)) from e
return output

View file

@ -0,0 +1,48 @@
from typing import TypeVar
from arcade.core.schema import ToolCallError, ToolCallOutput
T = TypeVar("T")
class ToolOutputFactory:
"""
Singleton pattern for unified return method from tools.
"""
def success(
self,
*,
data: T | None = None,
) -> ToolCallOutput:
value = data.result if data and hasattr(data, "result") and data.result else ""
return ToolCallOutput(value=value)
def fail(self, *, message: str, developer_message: str | None = None) -> ToolCallOutput:
return ToolCallOutput(
error=ToolCallError(
message=message, developer_message=developer_message, can_retry=False
)
)
def fail_retry(
self,
*,
message: str,
developer_message: str | None = None,
additional_prompt_content: str | None = None,
retry_after_ms: int | None = None,
) -> ToolCallOutput:
return ToolCallOutput(
error=ToolCallError(
message=message,
developer_message=developer_message,
can_retry=True,
additional_prompt_content=additional_prompt_content,
retry_after_ms=retry_after_ms,
)
)
output_factory = ToolOutputFactory()

View file

@ -1,88 +0,0 @@
from typing import Any, Generic, TypeVar
from pydantic import BaseModel
from arcade.core.response_code import (
CustomResponse,
CustomResponseCode,
)
_ExcludeData = set[int | str] | dict[int | str, Any]
T = TypeVar("T")
# TODO: Mapping of tool response actions to http codes?
class ToolResponse(BaseModel, Generic[T]):
"""
Generic unified return model for Tools
"""
code: int = CustomResponseCode.HTTP_200.code
msg: str = CustomResponseCode.HTTP_200.msg
additional_prompt_content: str | None = None
#
data: T | None = None
class ToolResponseFactory:
"""
Singleton pattern for unified return method from tools.
"""
@staticmethod
def __response(
*,
msg: str | None = None,
res: CustomResponseCode | CustomResponse = CustomResponseCode.HTTP_200,
data: T | None = None,
) -> ToolResponse:
"""
General method for successful response
"""
if msg:
return ToolResponse(code=res.code, msg=msg, data=data)
return ToolResponse(code=res.code, msg=res.msg, data=data)
def success(
self,
*,
res: CustomResponseCode | CustomResponse = CustomResponseCode.HTTP_200,
data: T | None = None,
) -> ToolResponse:
return self.__response(res=res, data=data)
def fail(
self,
*,
res: CustomResponseCode | CustomResponse = CustomResponseCode.HTTP_400,
msg: str = CustomResponseCode.HTTP_400.msg,
data: Any = None,
) -> ToolResponse:
return self.__response(
res=res,
msg=msg, # TODO this needs to map to developer_message in output.error
data=data,
)
def fail_retry(
self,
*,
res: CustomResponseCode | CustomResponse = CustomResponseCode.HTTP_425,
msg: str = CustomResponseCode.HTTP_425.msg,
data: Any = None,
additional_prompt_content: str | None = None,
) -> ToolResponse:
res = self.__response(
res=res,
msg=msg,
data=data,
)
res.additional_prompt_content = additional_prompt_content
return res
tool_response = ToolResponseFactory()

View file

@ -168,7 +168,12 @@ class ToolCallError(BaseModel):
"""The user-facing error message."""
developer_message: str | None = None
"""The developer-facing error details."""
can_retry: bool = False
"""Whether the tool call can be retried."""
additional_prompt_content: str | None = None
"""Additional content to be included in the retry prompt."""
retry_after_ms: int | None = None
"""The number of milliseconds (if any) to wait before retrying the tool call."""
class ToolCallOutput(BaseModel):

View file

@ -54,6 +54,19 @@
"developer_message": {
"type": "string",
"description": "An internal message that will be logged but will not be shown to the user or the AI model"
},
"can_retry": {
"type": "boolean",
"description": "Whether the tool call can be retried",
"default": false
},
"additional_prompt_content": {
"type": "string",
"description": "Additional content to be included in the retry prompt"
},
"retry_after_ms": {
"type": "integer",
"description": "The number of milliseconds (if any) to wait before retrying the tool call"
}
},
"required": ["message"],

View file

@ -122,8 +122,7 @@
"type": "object",
"properties": {
"provider": {
"type": "string",
"enum": ["oauth2", "github_app"]
"type": "string"
},
"oauth2": {
"type": "object",

View file

@ -1,4 +1,3 @@
import time
from typing import Annotated
from arcade.core.errors import ToolExecutionError, RetryableToolError
from arcade.core.schema import ToolContext
@ -10,7 +9,16 @@ from slack_sdk.errors import SlackApiError
@tool(
requires_auth=SlackUser(
scope=["chat:write", "im:write", "users.profile:read", "users:read"],
# TODO reduce this to chat:write, im:write, users.profile:read, users:read
# when incremental auth works
scope=[
"chat:write",
"im:write",
"users.profile:read",
"users:read",
"channels:read",
"groups:read",
],
)
)
def send_dm_to_user(
@ -32,16 +40,11 @@ def send_dm_to_user(
break
if not user_id:
# does this end up as a developerMessage?
# does it end up in the LLM context?
# provide the dev an Error type that controls what ends up in the LLM context
# TODO make the sleep configurable and sent to the engine
time.sleep(0.5) # Wait for half a second
raise RetryableToolError(
"User not found",
developer_message=f"User with username '{user_name}' not found.",
additional_prompt_content=format_users(userListResponse),
retry_after_ms=500, # Play nice with Slack API rate limits
)
# Step 2: Retrieve the DM channel ID with the user
@ -52,26 +55,35 @@ def send_dm_to_user(
slackClient.chat_postMessage(channel=dm_channel_id, text=message)
except SlackApiError as e:
error_message = e.response["error"] if "error" in e.response else str(e)
raise ToolExecutionError(
f"Error sending message: {e.response['error']}",
developer_message="Error sending message",
"Error sending message",
developer_message=f"Slack API Error: {error_message}",
)
def format_users(userListResponse: dict) -> str:
csv_string = "All active Slack users:\n\nid,name,real_name\n"
csv_string = "All active Slack users:\n\nname,real_name\n"
for user in userListResponse["members"]:
if not user.get("deleted", False):
user_id = user.get("id", "")
name = user.get("name", "")
real_name = user.get("profile", {}).get("real_name", "")
csv_string += f"{user_id},{name},{real_name}\n"
csv_string += f"{name},{real_name}\n"
return csv_string.strip()
@tool(
requires_auth=SlackUser(
scope=["chat:write", "channels:read", "groups:read"],
# TODO reduce this to chat:write, channels:read, groups:read
# when incremental auth works
scope=[
"chat:write",
"im:write",
"users.profile:read",
"users:read",
"channels:read",
"groups:read",
],
)
)
def send_message_to_channel(
@ -95,28 +107,28 @@ def send_message_to_channel(
break
if not channel_id:
time.sleep(0.5) # Wait for half a second
raise RetryableToolError(
"Channel not found",
developer_message=f"Channel with name '{channel_name}' not found.",
additional_prompt_content=format_channels(channels_response),
retry_after_ms=500, # Play nice with Slack API rate limits
)
# Step 2: Send the message to the channel
slackClient.chat_postMessage(channel=channel_id, text=message)
except SlackApiError as e:
error_message = e.response["error"] if "error" in e.response else str(e)
raise ToolExecutionError(
f"Error sending message: {e.response['error']}",
developer_message="Error sending message",
"Error sending message",
developer_message=f"Slack API Error: {error_message}",
)
def format_channels(channels_response: dict) -> str:
csv_string = "All active Slack channels:\n\nid,name\n"
csv_string = "All active Slack channels:\n\nname\n"
for channel in channels_response["channels"]:
if not channel.get("is_archived", False):
channel_id = channel.get("id", "")
name = channel.get("name", "")
csv_string += f"{channel_id},{name}\n"
csv_string += f"{name}\n"
return csv_string.strip()