Retryable errors (#20)
Co-authored-by: Sterling Dreyer <sdreyer@ucsc.edu>
This commit is contained in:
parent
acba912816
commit
ab703b75ef
8 changed files with 145 additions and 11 deletions
|
|
@ -86,7 +86,10 @@ class BaseActor(Actor):
|
|||
else ToolCallOutput(value=f"Tool {tool_name} called successfully")
|
||||
)
|
||||
else:
|
||||
# TODO flatten this to just ToolCallError
|
||||
output = ToolCallOutput(error=ToolCallError(message=response.msg))
|
||||
if response.code == 425:
|
||||
output.error.additional_prompt_content = response.additional_prompt_content
|
||||
|
||||
end_time = time.time() # End time in seconds
|
||||
duration_ms = (end_time - start_time) * 1000 # Convert to milliseconds
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import asyncio
|
||||
import os
|
||||
from typing import Optional
|
||||
from typing import Any, Optional
|
||||
|
||||
import typer
|
||||
from openai.resources.chat.completions import ChatCompletionChunk, Stream
|
||||
|
|
@ -192,7 +192,7 @@ def run(
|
|||
def chat(
|
||||
model: str = typer.Option("gpt-4o", "-m", help="The model to use for prediction."),
|
||||
stream: bool = typer.Option(
|
||||
False, "-s", "--stream", is_flag=True, help="Stream the tool output."
|
||||
True, "-s", "--stream", is_flag=True, help="Stream the tool output."
|
||||
),
|
||||
) -> None:
|
||||
"""
|
||||
|
|
@ -236,7 +236,8 @@ def chat(
|
|||
tool_choice="generate",
|
||||
user=user,
|
||||
)
|
||||
display_streamed_markdown(stream_response)
|
||||
role, message = display_streamed_markdown(stream_response)
|
||||
messages.append({"role": role, "content": message})
|
||||
else:
|
||||
response = client.complete(
|
||||
model=model,
|
||||
|
|
@ -379,21 +380,27 @@ def display_config_as_table(config: Config) -> None:
|
|||
console.print(table)
|
||||
|
||||
|
||||
def display_streamed_markdown(stream: Stream[ChatCompletionChunk]) -> None:
|
||||
def display_streamed_markdown(stream: Stream[ChatCompletionChunk]) -> tuple[str, dict[str, Any]]:
|
||||
"""
|
||||
Display the streamed markdown chunks as a single line.
|
||||
"""
|
||||
from rich.live import Live
|
||||
|
||||
full_message = ""
|
||||
role = ""
|
||||
with Live(console=console, refresh_per_second=10) as live:
|
||||
for chunk in stream:
|
||||
choice = chunk.choices[0]
|
||||
chunk_message = choice.delta.content
|
||||
if role == "":
|
||||
role = choice.delta.role
|
||||
if role == "assistant":
|
||||
console.print("\n[bold blue]Assistant:[/bold blue] ")
|
||||
if chunk_message:
|
||||
full_message += chunk_message
|
||||
markdown_chunk = Markdown(full_message)
|
||||
live.update(markdown_chunk)
|
||||
return role, full_message
|
||||
|
||||
|
||||
def create_cli_catalog(
|
||||
|
|
|
|||
|
|
@ -1,3 +1,6 @@
|
|||
from typing import Optional
|
||||
|
||||
|
||||
class ToolError(Exception):
|
||||
"""
|
||||
Base class for all errors related to tools.
|
||||
|
|
@ -18,7 +21,9 @@ class ToolDefinitionError(ToolError):
|
|||
|
||||
|
||||
class ToolRuntimeError(RuntimeError):
|
||||
pass
|
||||
def __init__(self, message: str):
|
||||
super().__init__(message)
|
||||
self.message = message
|
||||
|
||||
|
||||
class ToolExecutionError(ToolRuntimeError):
|
||||
|
|
@ -26,7 +31,25 @@ class ToolExecutionError(ToolRuntimeError):
|
|||
Raised when there is an error executing a tool.
|
||||
"""
|
||||
|
||||
pass
|
||||
def __init__(self, message: str, developer_message: Optional[str] = None):
|
||||
super().__init__(message)
|
||||
self.developer_message = developer_message
|
||||
|
||||
|
||||
class RetryableToolError(ToolExecutionError):
|
||||
"""
|
||||
Raised when a tool error is retryable.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
message: str,
|
||||
developer_message: Optional[str] = None,
|
||||
additional_prompt_content: Optional[str] = None,
|
||||
):
|
||||
super().__init__(message)
|
||||
self.developer_message = developer_message
|
||||
self.additional_prompt_content = additional_prompt_content
|
||||
|
||||
|
||||
class ToolSerializationError(ToolRuntimeError):
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from typing import Any, Callable
|
|||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from arcade.core.errors import (
|
||||
RetryableToolError,
|
||||
ToolExecutionError,
|
||||
ToolInputError,
|
||||
ToolOutputError,
|
||||
|
|
@ -50,6 +51,11 @@ class ToolExecutor:
|
|||
# return the output
|
||||
return tool_response.success(data=output)
|
||||
|
||||
except RetryableToolError as e:
|
||||
return tool_response.fail_retry(
|
||||
msg=str(e), additional_prompt_content=e.additional_prompt_content
|
||||
)
|
||||
|
||||
except ToolSerializationError as e:
|
||||
return tool_response.fail(msg=str(e))
|
||||
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ class ToolResponse(BaseModel, Generic[T]):
|
|||
|
||||
code: int = CustomResponseCode.HTTP_200.code
|
||||
msg: str = CustomResponseCode.HTTP_200.msg
|
||||
additional_prompt_content: str | None = None
|
||||
|
||||
#
|
||||
data: T | None = None
|
||||
|
|
@ -67,5 +68,21 @@ class ToolResponseFactory:
|
|||
data=data,
|
||||
)
|
||||
|
||||
def fail_retry(
|
||||
self,
|
||||
*,
|
||||
res: CustomResponseCode | CustomResponse = CustomResponseCode.HTTP_425,
|
||||
msg: str = CustomResponseCode.HTTP_425.msg,
|
||||
data: Any = None,
|
||||
additional_prompt_content: str | None = None,
|
||||
) -> ToolResponse:
|
||||
res = self.__response(
|
||||
res=res,
|
||||
msg=msg,
|
||||
data=data,
|
||||
)
|
||||
res.additional_prompt_content = additional_prompt_content
|
||||
return res
|
||||
|
||||
|
||||
tool_response = ToolResponseFactory()
|
||||
|
|
|
|||
|
|
@ -168,6 +168,7 @@ class ToolCallError(BaseModel):
|
|||
"""The user-facing error message."""
|
||||
developer_message: str | None = None
|
||||
"""The developer-facing error details."""
|
||||
additional_prompt_content: str | None = None
|
||||
|
||||
|
||||
class ToolCallOutput(BaseModel):
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ actor.register_tool(repo.count_stargazers)
|
|||
actor.register_tool(repo.search_issues)
|
||||
actor.register_tool(user.set_starred)
|
||||
actor.register_tool(chat.send_dm_to_user)
|
||||
actor.register_tool(chat.send_message_to_channel)
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
|
|
@ -47,6 +48,7 @@ async def postChat(request: ChatRequest, tool_choice: str = "execute"):
|
|||
"SetStarred",
|
||||
"SearchIssues",
|
||||
"SendDmToUser",
|
||||
"SendMessageToChannel",
|
||||
],
|
||||
tool_choice=tool_choice,
|
||||
user="sam",
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
import time
|
||||
from typing import Annotated
|
||||
from arcade.core.errors import ToolExecutionError, RetryableToolError
|
||||
from arcade.core.schema import ToolContext
|
||||
from arcade.sdk import tool
|
||||
from arcade.sdk.auth import SlackUser
|
||||
|
|
@ -22,9 +24,9 @@ def send_dm_to_user(
|
|||
|
||||
try:
|
||||
# Step 1: Retrieve the user's Slack ID based on their username
|
||||
response = slackClient.users_list()
|
||||
userListResponse = slackClient.users_list()
|
||||
user_id = None
|
||||
for user in response["members"]:
|
||||
for user in userListResponse["members"]:
|
||||
if user["name"].lower() == user_name.lower():
|
||||
user_id = user["id"]
|
||||
break
|
||||
|
|
@ -33,7 +35,14 @@ def send_dm_to_user(
|
|||
# does this end up as a developerMessage?
|
||||
# does it end up in the LLM context?
|
||||
# provide the dev an Error type that controls what ends up in the LLM context
|
||||
raise ValueError(f"User with username '{user_name}' not found.")
|
||||
|
||||
# TODO make the sleep configurable and sent to the engine
|
||||
time.sleep(0.5) # Wait for half a second
|
||||
raise RetryableToolError(
|
||||
"User not found",
|
||||
developer_message=f"User with username '{user_name}' not found.",
|
||||
additional_prompt_content=format_users(userListResponse),
|
||||
)
|
||||
|
||||
# Step 2: Retrieve the DM channel ID with the user
|
||||
im_response = slackClient.conversations_open(users=[user_id])
|
||||
|
|
@ -43,5 +52,71 @@ def send_dm_to_user(
|
|||
slackClient.chat_postMessage(channel=dm_channel_id, text=message)
|
||||
|
||||
except SlackApiError as e:
|
||||
# this should be caught also, not printed
|
||||
print(f"Error sending message: {e.response['error']}")
|
||||
raise ToolExecutionError(
|
||||
f"Error sending message: {e.response['error']}",
|
||||
developer_message="Error sending message",
|
||||
)
|
||||
|
||||
|
||||
def format_users(userListResponse: dict) -> str:
|
||||
csv_string = "All active Slack users:\n\nid,name,real_name\n"
|
||||
for user in userListResponse["members"]:
|
||||
if not user.get("deleted", False):
|
||||
user_id = user.get("id", "")
|
||||
name = user.get("name", "")
|
||||
real_name = user.get("profile", {}).get("real_name", "")
|
||||
csv_string += f"{user_id},{name},{real_name}\n"
|
||||
return csv_string.strip()
|
||||
|
||||
|
||||
@tool(
|
||||
requires_auth=SlackUser(
|
||||
scope=["chat:write", "channels:read", "groups:read"],
|
||||
)
|
||||
)
|
||||
def send_message_to_channel(
|
||||
context: ToolContext,
|
||||
channel_name: Annotated[
|
||||
str, "The Slack channel name where you want to send the message"
|
||||
],
|
||||
message: Annotated[str, "The message you want to send"],
|
||||
):
|
||||
"""Send a message to a channel in Slack."""
|
||||
|
||||
slackClient = WebClient(token=context.authorization.token)
|
||||
|
||||
try:
|
||||
# Step 1: Retrieve the list of channels
|
||||
channels_response = slackClient.conversations_list()
|
||||
channel_id = None
|
||||
for channel in channels_response["channels"]:
|
||||
if channel["name"].lower() == channel_name.lower():
|
||||
channel_id = channel["id"]
|
||||
break
|
||||
|
||||
if not channel_id:
|
||||
time.sleep(0.5) # Wait for half a second
|
||||
raise RetryableToolError(
|
||||
"Channel not found",
|
||||
developer_message=f"Channel with name '{channel_name}' not found.",
|
||||
additional_prompt_content=format_channels(channels_response),
|
||||
)
|
||||
|
||||
# Step 2: Send the message to the channel
|
||||
slackClient.chat_postMessage(channel=channel_id, text=message)
|
||||
|
||||
except SlackApiError as e:
|
||||
raise ToolExecutionError(
|
||||
f"Error sending message: {e.response['error']}",
|
||||
developer_message="Error sending message",
|
||||
)
|
||||
|
||||
|
||||
def format_channels(channels_response: dict) -> str:
|
||||
csv_string = "All active Slack channels:\n\nid,name\n"
|
||||
for channel in channels_response["channels"]:
|
||||
if not channel.get("is_archived", False):
|
||||
channel_id = channel.get("id", "")
|
||||
name = channel.get("name", "")
|
||||
csv_string += f"{channel_id},{name}\n"
|
||||
return csv_string.strip()
|
||||
|
|
|
|||
Loading…
Reference in a new issue