Slack Tools (#162)
implements additional tools for Slack related to retrieving conversations metadata, list of members, history of messages, as well as sending messages to private/public channels and DMs / multi-person DMs. --------- Co-authored-by: Eric Gustin <eric@arcade-ai.com> Co-authored-by: Renato Byrro <rmbyrro@gmail.com>
This commit is contained in:
parent
130858a958
commit
66e54d7cde
18 changed files with 4488 additions and 310 deletions
13
toolkits/slack/arcade_slack/constants.py
Normal file
13
toolkits/slack/arcade_slack/constants.py
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
import os
|
||||
|
||||
from arcade_slack.custom_types import PositiveInt
|
||||
|
||||
MAX_PAGINATION_SIZE_LIMIT = 200
|
||||
|
||||
MAX_PAGINATION_TIMEOUT_SECONDS = PositiveInt(
|
||||
os.environ.get(
|
||||
"MAX_PAGINATION_TIMEOUT_SECONDS",
|
||||
os.environ.get("MAX_SLACK_PAGINATION_TIMEOUT_SECONDS", 30),
|
||||
),
|
||||
name="MAX_PAGINATION_TIMEOUT_SECONDS or MAX_SLACK_PAGINATION_TIMEOUT_SECONDS",
|
||||
)
|
||||
34
toolkits/slack/arcade_slack/critics.py
Normal file
34
toolkits/slack/arcade_slack/critics.py
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
from typing import Any
|
||||
|
||||
from arcade.sdk.eval import BinaryCritic
|
||||
|
||||
|
||||
class RelativeTimeBinaryCritic(BinaryCritic):
|
||||
def evaluate(self, expected: Any, actual: Any) -> dict[str, float | bool]:
|
||||
"""
|
||||
Evaluates whether the expected and actual relative time strings are equivalent after
|
||||
casting.
|
||||
|
||||
Args:
|
||||
expected: The expected value.
|
||||
actual: The actual value to compare, cast to the type of expected.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the match status and score.
|
||||
"""
|
||||
try:
|
||||
actual_casted = self.cast_actual(expected, actual)
|
||||
except TypeError:
|
||||
actual_casted = actual
|
||||
|
||||
expected_parts = tuple(map(int, expected.split(":")))
|
||||
actual_parts = tuple(map(int, actual_casted.split(":")))
|
||||
|
||||
if len(expected_parts) != 3 or len(actual_parts) != 3:
|
||||
return {"match": False, "score": 0.0}
|
||||
|
||||
exp_days, exp_hours, exp_minutes = expected_parts
|
||||
act_days, act_hours, act_minutes = actual_parts
|
||||
|
||||
match = exp_days == act_days and exp_hours == act_hours and exp_minutes == act_minutes
|
||||
return {"match": match, "score": self.weight if match else 0.0}
|
||||
26
toolkits/slack/arcade_slack/custom_types.py
Normal file
26
toolkits/slack/arcade_slack/custom_types.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
from typing import NewType, Optional, Union
|
||||
|
||||
|
||||
class PositiveInt(int):
|
||||
def __new__(cls, value: Union[str, int], name: str = "value") -> "PositiveInt":
|
||||
def validate(val: int) -> int:
|
||||
if val <= 0:
|
||||
raise ValueError(f"{name} must be positive, got {val}")
|
||||
return val
|
||||
|
||||
try:
|
||||
value = int(value)
|
||||
except ValueError:
|
||||
raise ValueError(f"{name} must be a valid integer, got {value!r}")
|
||||
|
||||
validated_value = validate(value)
|
||||
instance = super().__new__(cls, validated_value)
|
||||
return instance
|
||||
|
||||
|
||||
SlackOffsetSecondsFromUTC = NewType("SlackOffsetSecondsFromUTC", int) # observe it can be negative
|
||||
SlackPaginationNextCursor = Optional[str]
|
||||
SlackUserFieldId = NewType("SlackUserFieldId", str)
|
||||
SlackUserId = NewType("SlackUserId", str)
|
||||
SlackTeamId = NewType("SlackTeamId", str)
|
||||
SlackTimestampStr = NewType("SlackTimestampStr", str)
|
||||
14
toolkits/slack/arcade_slack/exceptions.py
Normal file
14
toolkits/slack/arcade_slack/exceptions.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
class SlackToolkitError(Exception):
|
||||
"""Base class for all Slack toolkit errors."""
|
||||
|
||||
|
||||
class PaginationTimeoutError(SlackToolkitError):
|
||||
"""Raised when a timeout occurs during pagination."""
|
||||
|
||||
def __init__(self, timeout_seconds: int):
|
||||
self.timeout_seconds = timeout_seconds
|
||||
super().__init__(f"The pagination process timed out after {timeout_seconds} seconds.")
|
||||
|
||||
|
||||
class ItemNotFoundError(SlackToolkitError):
|
||||
"""Raised when an item is not found."""
|
||||
208
toolkits/slack/arcade_slack/models.py
Normal file
208
toolkits/slack/arcade_slack/models.py
Normal file
|
|
@ -0,0 +1,208 @@
|
|||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from typing_extensions import Literal, NotRequired, TypedDict
|
||||
|
||||
from arcade_slack.custom_types import (
|
||||
SlackOffsetSecondsFromUTC,
|
||||
SlackPaginationNextCursor,
|
||||
SlackTeamId,
|
||||
SlackTimestampStr,
|
||||
SlackUserFieldId,
|
||||
SlackUserId,
|
||||
)
|
||||
|
||||
|
||||
class ConversationTypeSlackName(str, Enum):
|
||||
PUBLIC_CHANNEL = "public_channel" # Public channels are visible to all users in the workspace
|
||||
PRIVATE_CHANNEL = "private_channel" # Private channels are visible to only specific users
|
||||
MPIM = "mpim" # Multi-person direct message conversation
|
||||
IM = "im" # Two person direct message conversation
|
||||
|
||||
|
||||
class ConversationType(str, Enum):
|
||||
PUBLIC_CHANNEL = "public_channel"
|
||||
PRIVATE_CHANNEL = "private_channel"
|
||||
MULTI_PERSON_DIRECT_MESSAGE = "multi_person_direct_message"
|
||||
DIRECT_MESSAGE = "direct_message"
|
||||
|
||||
|
||||
"""
|
||||
About Slack dictionaries: Slack does not guarantee the presence of all fields for a given
|
||||
object. It will vary from endpoint to endpoint and even if the field is present, they say it may
|
||||
contain a None value or an empty string instead of the actual expected value.
|
||||
|
||||
See, for example, the 'Common Fields' section of the user type definition at:
|
||||
https://api.slack.com/types/user#fields (https://archive.is/RUZdL)
|
||||
|
||||
Because of that, our TypedDicts ended up having to be mostly total=False and most of the fields'
|
||||
type hints are Optional. Use Slack dictionary fields with caution. It's advisable to validate the
|
||||
value before using it and raise errors that are clear to understand, when appropriate.
|
||||
"""
|
||||
|
||||
|
||||
class SlackUserFieldData(TypedDict, total=False):
|
||||
"""Type definition for Slack user field data dictionary.
|
||||
|
||||
Slack type definition: https://api.slack.com/methods/users.profile.set#custom-profile
|
||||
"""
|
||||
|
||||
value: Optional[str]
|
||||
alt: Optional[bool]
|
||||
|
||||
|
||||
class SlackStatusEmojiDisplayInfo(TypedDict, total=False):
|
||||
"""Type definition for Slack status emoji display info dictionary."""
|
||||
|
||||
emoji_name: Optional[str]
|
||||
display_url: Optional[str]
|
||||
|
||||
|
||||
class SlackUserProfile(TypedDict, total=False):
|
||||
"""Type definition for Slack user profile dictionary.
|
||||
|
||||
Slack type definition: https://api.slack.com/types/user#profile (https://archive.is/RUZdL)
|
||||
"""
|
||||
|
||||
title: Optional[str]
|
||||
phone: Optional[str]
|
||||
skype: Optional[str]
|
||||
email: Optional[str]
|
||||
real_name: Optional[str]
|
||||
real_name_normalized: Optional[str]
|
||||
display_name: Optional[str]
|
||||
display_name_normalized: Optional[str]
|
||||
first_name: Optional[str]
|
||||
last_name: Optional[str]
|
||||
fields: Optional[list[dict[SlackUserFieldId, SlackUserFieldData]]]
|
||||
image_original: Optional[str]
|
||||
is_custom_image: Optional[bool]
|
||||
image_24: Optional[str]
|
||||
image_32: Optional[str]
|
||||
image_48: Optional[str]
|
||||
image_72: Optional[str]
|
||||
image_192: Optional[str]
|
||||
image_512: Optional[str]
|
||||
image_1024: Optional[str]
|
||||
status_emoji: Optional[str]
|
||||
status_emoji_display_info: Optional[list[SlackStatusEmojiDisplayInfo]]
|
||||
status_text: Optional[str]
|
||||
status_text_canonical: Optional[str]
|
||||
status_expiration: Optional[int]
|
||||
avatar_hash: Optional[str]
|
||||
start_date: Optional[str]
|
||||
pronouns: Optional[str]
|
||||
huddle_state: Optional[str]
|
||||
huddle_state_expiration: Optional[int]
|
||||
team: Optional[SlackTeamId]
|
||||
|
||||
|
||||
class SlackUser(TypedDict, total=False):
|
||||
"""Type definition for Slack user dictionary.
|
||||
|
||||
Slack type definition: https://api.slack.com/types/user (https://archive.is/RUZdL)
|
||||
"""
|
||||
|
||||
id: SlackUserId
|
||||
team_id: SlackTeamId
|
||||
name: Optional[str]
|
||||
deleted: Optional[bool]
|
||||
color: Optional[str]
|
||||
real_name: Optional[str]
|
||||
tz: Optional[str]
|
||||
tz_label: Optional[str]
|
||||
tz_offset: Optional[SlackOffsetSecondsFromUTC]
|
||||
profile: Optional[SlackUserProfile]
|
||||
is_admin: Optional[bool]
|
||||
is_owner: Optional[bool]
|
||||
is_primary_owner: Optional[bool]
|
||||
is_restricted: Optional[bool]
|
||||
is_ultra_restricted: Optional[bool]
|
||||
is_bot: Optional[bool]
|
||||
is_app_user: Optional[bool]
|
||||
is_email_confirmed: Optional[bool]
|
||||
who_can_share_contact_card: Optional[str]
|
||||
|
||||
|
||||
class SlackUserList(TypedDict, total=False):
|
||||
"""Type definition for the returned user list dictionary."""
|
||||
|
||||
members: list[SlackUser]
|
||||
|
||||
|
||||
class SlackConversationPurpose(TypedDict, total=False):
|
||||
"""Type definition for the Slack conversation purpose dictionary."""
|
||||
|
||||
value: Optional[str]
|
||||
|
||||
|
||||
class SlackConversation(TypedDict, total=False):
|
||||
"""Type definition for the Slack conversation dictionary."""
|
||||
|
||||
id: Optional[str]
|
||||
name: Optional[str]
|
||||
is_private: Optional[bool]
|
||||
is_archived: Optional[bool]
|
||||
is_member: Optional[bool]
|
||||
is_channel: Optional[bool]
|
||||
is_group: Optional[bool]
|
||||
is_im: Optional[bool]
|
||||
is_mpim: Optional[bool]
|
||||
purpose: Optional[SlackConversationPurpose]
|
||||
num_members: Optional[int]
|
||||
user: Optional[SlackUser]
|
||||
is_user_deleted: Optional[bool]
|
||||
|
||||
|
||||
class SlackMessage(TypedDict, total=True):
|
||||
"""Type definition for the Slack message dictionary."""
|
||||
|
||||
type: Literal["message"]
|
||||
user: SlackUser
|
||||
text: str
|
||||
ts: SlackTimestampStr # Slack timestamp as a string (e.g. "1234567890.123456")
|
||||
|
||||
|
||||
class Message(SlackMessage, total=False):
|
||||
"""Type definition for the message dictionary.
|
||||
|
||||
Having a human-readable datetime string is useful for LLMs when they need to display the
|
||||
date/time for the user. If not, they'll try to convert the unix timestamp to a human-readable
|
||||
date/time,which they don't usually do accurately.
|
||||
"""
|
||||
|
||||
datetime_timestamp: str # Human-readable datetime string (e.g. "2025-01-22 12:00:00")
|
||||
|
||||
|
||||
class ConversationMetadata(TypedDict, total=True):
|
||||
"""Type definition for the conversation metadata dictionary."""
|
||||
|
||||
id: Optional[str]
|
||||
name: Optional[str]
|
||||
conversation_type: Optional[str]
|
||||
is_private: Optional[bool]
|
||||
is_archived: Optional[bool]
|
||||
is_member: Optional[bool]
|
||||
purpose: Optional[str]
|
||||
num_members: NotRequired[Optional[int]]
|
||||
user: NotRequired[Optional[SlackUser]]
|
||||
is_user_deleted: NotRequired[Optional[bool]]
|
||||
|
||||
|
||||
class BasicUserInfo(TypedDict, total=False):
|
||||
"""Type definition for the returned basic user info dictionary."""
|
||||
|
||||
id: Optional[str]
|
||||
name: Optional[str]
|
||||
is_bot: Optional[bool]
|
||||
email: Optional[str]
|
||||
display_name: Optional[str]
|
||||
real_name: Optional[str]
|
||||
timezone: Optional[str]
|
||||
|
||||
|
||||
class SlackConversationsToolResponse(TypedDict, total=True):
|
||||
"""Type definition for the Slack conversations tool response dictionary."""
|
||||
|
||||
conversations: list[ConversationMetadata]
|
||||
next_cursor: SlackPaginationNextCursor | None
|
||||
|
|
@ -1,11 +1,27 @@
|
|||
from typing import Annotated
|
||||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated, Optional, cast
|
||||
|
||||
from arcade.sdk import ToolContext, tool
|
||||
from arcade.sdk.auth import Slack
|
||||
from arcade.sdk.errors import RetryableToolError
|
||||
from slack_sdk import WebClient
|
||||
from arcade.sdk.errors import RetryableToolError, ToolExecutionError
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from slack_sdk.web.async_client import AsyncWebClient
|
||||
|
||||
from arcade_slack.tools.utils import format_channels, format_users
|
||||
from arcade_slack.constants import MAX_PAGINATION_TIMEOUT_SECONDS
|
||||
from arcade_slack.exceptions import ItemNotFoundError
|
||||
from arcade_slack.models import ConversationType, SlackUserList
|
||||
from arcade_slack.tools.users import get_user_info_by_id
|
||||
from arcade_slack.utils import (
|
||||
async_paginate,
|
||||
convert_conversation_type_to_slack_name,
|
||||
convert_datetime_to_unix_timestamp,
|
||||
convert_relative_datetime_to_unix_timestamp,
|
||||
enrich_message_datetime,
|
||||
extract_conversation_metadata,
|
||||
format_conversations_as_csv,
|
||||
format_users,
|
||||
)
|
||||
|
||||
|
||||
@tool(
|
||||
|
|
@ -18,49 +34,59 @@ from arcade_slack.tools.utils import format_channels, format_users
|
|||
],
|
||||
)
|
||||
)
|
||||
def send_dm_to_user(
|
||||
async def send_dm_to_user(
|
||||
context: ToolContext,
|
||||
user_name: Annotated[
|
||||
str,
|
||||
"The Slack username of the person you want to message. "
|
||||
"Slack usernames are ALWAYS lowercase.",
|
||||
(
|
||||
"The Slack username of the person you want to message. "
|
||||
"Slack usernames are ALWAYS lowercase."
|
||||
),
|
||||
],
|
||||
message: Annotated[str, "The message you want to send"],
|
||||
) -> Annotated[dict, "The response from the Slack API"]:
|
||||
"""Send a direct message to a user in Slack."""
|
||||
slackClient = WebClient(
|
||||
token=context.authorization.token
|
||||
if context.authorization and context.authorization.token
|
||||
else ""
|
||||
|
||||
token = (
|
||||
context.authorization.token if context.authorization and context.authorization.token else ""
|
||||
)
|
||||
slackClient = AsyncWebClient(token=token)
|
||||
|
||||
# Step 1: Retrieve the user's Slack ID based on their username
|
||||
userListResponse = slackClient.users_list()
|
||||
user_id = None
|
||||
for user in userListResponse["members"]:
|
||||
if user["name"].lower() == user_name.lower():
|
||||
user_id = user["id"]
|
||||
break
|
||||
try:
|
||||
# Step 1: Retrieve the user's Slack ID based on their username
|
||||
user_list_response = await slackClient.users_list()
|
||||
user_id = None
|
||||
for user in user_list_response["members"]:
|
||||
response_user_name = (
|
||||
"" if not isinstance(user.get("name"), str) else user["name"].lower()
|
||||
)
|
||||
if response_user_name == user_name.lower():
|
||||
user_id = user["id"]
|
||||
break
|
||||
|
||||
if not user_id:
|
||||
raise RetryableToolError(
|
||||
"User not found",
|
||||
developer_message=f"User with username '{user_name}' not found.",
|
||||
additional_prompt_content=format_users(userListResponse),
|
||||
retry_after_ms=500, # Play nice with Slack API rate limits
|
||||
if not user_id:
|
||||
raise RetryableToolError(
|
||||
"User not found",
|
||||
developer_message=f"User with username '{user_name}' not found.",
|
||||
additional_prompt_content=format_users(cast(SlackUserList, user_list_response)),
|
||||
retry_after_ms=500, # Play nice with Slack API rate limits
|
||||
)
|
||||
|
||||
# Step 2: Retrieve the DM channel ID with the user
|
||||
im_response = await slackClient.conversations_open(users=[user_id])
|
||||
dm_channel_id = im_response["channel"]["id"]
|
||||
|
||||
# Step 3: Send the message as if it's from you (because we're using a user token)
|
||||
response = await slackClient.chat_postMessage(channel=dm_channel_id, text=message)
|
||||
|
||||
except SlackApiError as e:
|
||||
error_message = e.response["error"] if "error" in e.response else str(e)
|
||||
raise ToolExecutionError(
|
||||
"Error sending message",
|
||||
developer_message=f"Slack API Error: {error_message}",
|
||||
)
|
||||
|
||||
# Step 2: Retrieve the DM channel ID with the user
|
||||
im_response = slackClient.conversations_open(users=[user_id])
|
||||
dm_channel_id = im_response["channel"]["id"]
|
||||
|
||||
# Step 3: Send the message as if it's from you (because we're using a user token)
|
||||
response = slackClient.chat_postMessage(channel=dm_channel_id, text=message)
|
||||
response.validate()
|
||||
|
||||
if isinstance(response.data, dict):
|
||||
return response.data
|
||||
return {}
|
||||
else:
|
||||
return {"response": response.data}
|
||||
|
||||
|
||||
@tool(
|
||||
|
|
@ -72,43 +98,567 @@ def send_dm_to_user(
|
|||
],
|
||||
)
|
||||
)
|
||||
def send_message_to_channel(
|
||||
async def send_message_to_channel(
|
||||
context: ToolContext,
|
||||
channel_name: Annotated[
|
||||
str,
|
||||
"The Slack channel name where you want to send the message. "
|
||||
"Slack channel names are ALWAYS lowercase.",
|
||||
],
|
||||
channel_name: Annotated[str, "The Slack channel name where you want to send the message. "],
|
||||
message: Annotated[str, "The message you want to send"],
|
||||
) -> Annotated[dict, "The response from the Slack API"]:
|
||||
"""Send a message to a channel in Slack."""
|
||||
|
||||
slackClient = WebClient(
|
||||
token=context.authorization.token
|
||||
if context.authorization and context.authorization.token
|
||||
else ""
|
||||
)
|
||||
|
||||
# Step 1: Retrieve the list of channels
|
||||
channels_response = slackClient.conversations_list()
|
||||
channel_id = None
|
||||
for channel in channels_response["channels"]:
|
||||
if channel["name"].lower() == channel_name.lower():
|
||||
channel_id = channel["id"]
|
||||
break
|
||||
|
||||
if not channel_id:
|
||||
raise RetryableToolError(
|
||||
"Channel not found",
|
||||
developer_message=f"Channel with name '{channel_name}' not found.",
|
||||
additional_prompt_content=format_channels(channels_response),
|
||||
retry_after_ms=500, # Play nice with Slack API rate limits
|
||||
try:
|
||||
slackClient = AsyncWebClient(
|
||||
token=context.authorization.token
|
||||
if context.authorization and context.authorization.token
|
||||
else ""
|
||||
)
|
||||
|
||||
# Step 2: Send the message to the channel
|
||||
response = slackClient.chat_postMessage(channel=channel_id, text=message)
|
||||
response.validate()
|
||||
# Step 1: Retrieve the list of channels
|
||||
channels_response = await slackClient.conversations_list()
|
||||
channel_id = None
|
||||
for channel in channels_response["channels"]:
|
||||
response_channel_name = (
|
||||
"" if not isinstance(channel.get("name"), str) else channel["name"].lower()
|
||||
)
|
||||
if response_channel_name == channel_name.lower():
|
||||
channel_id = channel["id"]
|
||||
break
|
||||
|
||||
if isinstance(response.data, dict):
|
||||
return response.data
|
||||
return {}
|
||||
if not channel_id:
|
||||
raise RetryableToolError(
|
||||
"Channel not found",
|
||||
developer_message=f"Channel with name '{channel_name}' not found.",
|
||||
additional_prompt_content=format_conversations_as_csv({
|
||||
"channels": channels_response["channels"],
|
||||
}),
|
||||
retry_after_ms=500, # Play nice with Slack API rate limits
|
||||
)
|
||||
|
||||
# Step 2: Send the message to the channel
|
||||
response = await slackClient.chat_postMessage(channel=channel_id, text=message)
|
||||
|
||||
except SlackApiError as e:
|
||||
error_message = e.response["error"] if "error" in e.response else str(e)
|
||||
raise ToolExecutionError(
|
||||
"Error sending message",
|
||||
developer_message=f"Slack API Error: {error_message}",
|
||||
)
|
||||
else:
|
||||
return {"response": response.data}
|
||||
|
||||
|
||||
@tool(
|
||||
requires_auth=Slack(
|
||||
scopes=["channels:read", "groups:read", "im:read", "mpim:read"],
|
||||
)
|
||||
)
|
||||
async def get_members_in_conversation_by_id(
|
||||
context: ToolContext,
|
||||
conversation_id: Annotated[str, "The ID of the conversation to get members for"],
|
||||
limit: Annotated[Optional[int], "The maximum number of members to return."] = None,
|
||||
next_cursor: Annotated[Optional[str], "The cursor to use for pagination."] = None,
|
||||
) -> Annotated[dict, "Information about each member in the conversation"]:
|
||||
"""Get the members of a conversation in Slack by the conversation's ID."""
|
||||
token = (
|
||||
context.authorization.token if context.authorization and context.authorization.token else ""
|
||||
)
|
||||
slackClient = AsyncWebClient(token=token)
|
||||
|
||||
try:
|
||||
member_ids, next_cursor = await async_paginate(
|
||||
slackClient.conversations_members,
|
||||
"members",
|
||||
limit=limit,
|
||||
next_cursor=next_cursor,
|
||||
channel=conversation_id,
|
||||
)
|
||||
except SlackApiError as e:
|
||||
if e.response["error"] == "channel_not_found":
|
||||
conversations = await list_conversations_metadata(context)
|
||||
available_conversations = ", ".join(
|
||||
f"{conversation['id']} ({conversation['name']})"
|
||||
for conversation in conversations["conversations"]
|
||||
)
|
||||
|
||||
raise RetryableToolError(
|
||||
"Conversation not found",
|
||||
developer_message=f"Conversation with ID '{conversation_id}' not found.",
|
||||
additional_prompt_content=f"Available conversations: {available_conversations}",
|
||||
retry_after_ms=500,
|
||||
)
|
||||
|
||||
# Get the members' info
|
||||
# TODO: This will probably hit rate limits. We should probably call list_users() and
|
||||
# then filter the results instead.
|
||||
members = await asyncio.gather(*[
|
||||
get_user_info_by_id(context, member_id) for member_id in member_ids
|
||||
])
|
||||
|
||||
return {
|
||||
"members": [member for member in members if not member.get("is_bot")],
|
||||
"next_cursor": next_cursor,
|
||||
}
|
||||
|
||||
|
||||
@tool(
|
||||
requires_auth=Slack(
|
||||
scopes=["channels:read", "groups:read", "im:read", "mpim:read"],
|
||||
)
|
||||
)
|
||||
async def get_members_in_conversation_by_name(
|
||||
context: ToolContext,
|
||||
conversation_name: Annotated[str, "The name of the conversation to get members for"],
|
||||
limit: Annotated[Optional[int], "The maximum number of members to return."] = None,
|
||||
next_cursor: Annotated[Optional[str], "The cursor to use for pagination."] = None,
|
||||
) -> Annotated[dict, "The conversation members' IDs and Names"]:
|
||||
"""Get the members of a conversation in Slack by the conversation's name."""
|
||||
conversation_metadata = await get_conversation_metadata_by_name(
|
||||
context=context, conversation_name=conversation_name, next_cursor=next_cursor
|
||||
)
|
||||
|
||||
return await get_members_in_conversation_by_id( # type: ignore[no-any-return]
|
||||
context=context,
|
||||
conversation_id=conversation_metadata["id"],
|
||||
limit=limit,
|
||||
next_cursor=next_cursor,
|
||||
)
|
||||
|
||||
|
||||
# TODO: make the function accept a current unix timestamp argument to allow testing without
|
||||
# mocking. Have to wait until arcade.core.annotations.Inferrable is implemented, so that we
|
||||
# can avoid exposing this arg to the LLM.
|
||||
@tool(
|
||||
requires_auth=Slack(
|
||||
scopes=["channels:history", "groups:history", "im:history", "mpim:history"],
|
||||
)
|
||||
)
|
||||
async def get_messages_in_conversation_by_id(
|
||||
context: ToolContext,
|
||||
conversation_id: Annotated[str, "The ID of the conversation to get history for"],
|
||||
oldest_relative: Annotated[
|
||||
Optional[str],
|
||||
(
|
||||
"The oldest message to include in the results, specified as a time offset from the "
|
||||
"current time in the format 'DD:HH:MM'"
|
||||
),
|
||||
] = None,
|
||||
latest_relative: Annotated[
|
||||
Optional[str],
|
||||
(
|
||||
"The latest message to include in the results, specified as a time offset from the "
|
||||
"current time in the format 'DD:HH:MM'"
|
||||
),
|
||||
] = None,
|
||||
oldest_datetime: Annotated[
|
||||
Optional[str],
|
||||
(
|
||||
"The oldest message to include in the results, specified as a datetime object in the "
|
||||
"format 'YYYY-MM-DD HH:MM:SS'"
|
||||
),
|
||||
] = None,
|
||||
latest_datetime: Annotated[
|
||||
Optional[str],
|
||||
(
|
||||
"The latest message to include in the results, specified as a datetime object in the "
|
||||
"format 'YYYY-MM-DD HH:MM:SS'"
|
||||
),
|
||||
] = None,
|
||||
limit: Annotated[Optional[int], "The maximum number of messages to return."] = None,
|
||||
next_cursor: Annotated[Optional[str], "The cursor to use for pagination."] = None,
|
||||
) -> Annotated[
|
||||
dict,
|
||||
(
|
||||
"The messages in a conversation and next cursor for paginating results (when "
|
||||
"there are additional messages to retrieve)."
|
||||
),
|
||||
]:
|
||||
"""Get the messages in a conversation by the conversation's ID.
|
||||
|
||||
A conversation can be a channel, a DM, or a group DM.
|
||||
|
||||
To filter by an absolute datetime, use 'oldest_datetime' and/or 'latest_datetime'. If
|
||||
only 'oldest_datetime' is provided, it returns messages from the oldest_datetime to the
|
||||
current time. If only 'latest_datetime' is provided, it returns messages since the
|
||||
beginning of the conversation to the latest_datetime.
|
||||
|
||||
To filter by a relative datetime (e.g. 3 days ago, 1 hour ago, etc.), use
|
||||
'oldest_relative' and/or 'latest_relative'. If only 'oldest_relative' is provided, it returns
|
||||
messages from the oldest_relative to the current time. If only 'latest_relative' is provided,
|
||||
it returns messages from the current time to the latest_relative.
|
||||
|
||||
Do not provide both 'oldest_datetime' and 'oldest_relative' or both 'latest_datetime' and
|
||||
'latest_relative'.
|
||||
|
||||
Leave all arguments with the default None to get messages without date/time filtering"""
|
||||
error_message = None
|
||||
if oldest_datetime and oldest_relative:
|
||||
error_message = "Cannot specify both 'oldest_datetime' and 'oldest_relative'."
|
||||
|
||||
if latest_datetime and latest_relative:
|
||||
error_message = "Cannot specify both 'latest_datetime' and 'latest_relative'."
|
||||
|
||||
if error_message:
|
||||
raise ToolExecutionError(error_message, developer_message=error_message)
|
||||
|
||||
current_unix_timestamp = int(datetime.now(timezone.utc).timestamp())
|
||||
|
||||
if latest_relative:
|
||||
latest_timestamp = convert_relative_datetime_to_unix_timestamp(
|
||||
latest_relative, current_unix_timestamp
|
||||
)
|
||||
elif latest_datetime:
|
||||
latest_timestamp = convert_datetime_to_unix_timestamp(latest_datetime)
|
||||
else:
|
||||
latest_timestamp = None
|
||||
|
||||
if oldest_relative:
|
||||
oldest_timestamp = convert_relative_datetime_to_unix_timestamp(
|
||||
oldest_relative, current_unix_timestamp
|
||||
)
|
||||
elif oldest_datetime:
|
||||
oldest_timestamp = convert_datetime_to_unix_timestamp(oldest_datetime)
|
||||
else:
|
||||
oldest_timestamp = None
|
||||
|
||||
token = (
|
||||
context.authorization.token if context.authorization and context.authorization.token else ""
|
||||
)
|
||||
slackClient = AsyncWebClient(token=token)
|
||||
|
||||
datetime_args = {}
|
||||
if oldest_timestamp:
|
||||
datetime_args["oldest"] = oldest_timestamp
|
||||
if latest_timestamp:
|
||||
datetime_args["latest"] = latest_timestamp
|
||||
|
||||
response, next_cursor = await async_paginate(
|
||||
slackClient.conversations_history,
|
||||
"messages",
|
||||
limit=limit,
|
||||
next_cursor=next_cursor,
|
||||
channel=conversation_id,
|
||||
include_all_metadata=True,
|
||||
inclusive=True, # Include messages at the start and end of the time range
|
||||
**datetime_args,
|
||||
)
|
||||
|
||||
messages = [enrich_message_datetime(message) for message in response]
|
||||
|
||||
return {"messages": messages, "next_cursor": next_cursor}
|
||||
|
||||
|
||||
# TODO: make the function accept a current unix timestamp argument to allow testing without
|
||||
# mocking. Have to wait until arcade.core.annotations.Inferrable is implemented, so that we
|
||||
# can avoid exposing this arg to the LLM.
|
||||
@tool(
|
||||
requires_auth=Slack(
|
||||
scopes=["channels:history", "groups:history", "im:history", "mpim:history"],
|
||||
)
|
||||
)
|
||||
async def get_messages_in_channel_by_name(
|
||||
context: ToolContext,
|
||||
channel_name: Annotated[str, "The name of the channel"],
|
||||
oldest_relative: Annotated[
|
||||
Optional[str],
|
||||
(
|
||||
"The oldest message to include in the results, specified as a time offset from the "
|
||||
"current time in the format 'DD:HH:MM'"
|
||||
),
|
||||
] = None,
|
||||
latest_relative: Annotated[
|
||||
Optional[str],
|
||||
(
|
||||
"The latest message to include in the results, specified as a time offset from the "
|
||||
"current time in the format 'DD:HH:MM'"
|
||||
),
|
||||
] = None,
|
||||
oldest_datetime: Annotated[
|
||||
Optional[str],
|
||||
(
|
||||
"The oldest message to include in the results, specified as a datetime object in the "
|
||||
"format 'YYYY-MM-DD HH:MM:SS'"
|
||||
),
|
||||
] = None,
|
||||
latest_datetime: Annotated[
|
||||
Optional[str],
|
||||
(
|
||||
"The latest message to include in the results, specified as a datetime object in the "
|
||||
"format 'YYYY-MM-DD HH:MM:SS'"
|
||||
),
|
||||
] = None,
|
||||
limit: Annotated[Optional[int], "The maximum number of messages to return."] = None,
|
||||
next_cursor: Annotated[Optional[str], "The cursor to use for pagination."] = None,
|
||||
) -> Annotated[
|
||||
dict,
|
||||
(
|
||||
"The messages in a channel and next cursor for paginating results (when "
|
||||
"there are additional messages to retrieve)."
|
||||
),
|
||||
]:
|
||||
"""Get the messages in a channel by the channel's name.
|
||||
|
||||
To filter messages by an absolute datetime, use 'oldest_datetime' and/or 'latest_datetime'. If
|
||||
only 'oldest_datetime' is provided, it will return messages from the oldest_datetime to the
|
||||
current time. If only 'latest_datetime' is provided, it will return messages since the
|
||||
beginning of the channel to the latest_datetime.
|
||||
|
||||
To filter messages by a relative datetime (e.g. 3 days ago, 1 hour ago, etc.), use
|
||||
'oldest_relative' and/or 'latest_relative'. If only 'oldest_relative' is provided, it will
|
||||
return messages from the oldest_relative to the current time. If only 'latest_relative' is
|
||||
provided, it will return messages from the current time to the latest_relative.
|
||||
|
||||
Do not provide both 'oldest_datetime' and 'oldest_relative' or both 'latest_datetime' and
|
||||
'latest_relative'.
|
||||
|
||||
Leave all arguments with the default None to get messages without date/time filtering"""
|
||||
conversation_metadata = await get_conversation_metadata_by_name(
|
||||
context=context, conversation_name=channel_name
|
||||
)
|
||||
return await get_messages_in_conversation_by_id( # type: ignore[no-any-return]
|
||||
context=context,
|
||||
conversation_id=conversation_metadata["id"],
|
||||
oldest_relative=oldest_relative,
|
||||
latest_relative=latest_relative,
|
||||
oldest_datetime=oldest_datetime,
|
||||
latest_datetime=latest_datetime,
|
||||
limit=limit,
|
||||
next_cursor=next_cursor,
|
||||
)
|
||||
|
||||
|
||||
@tool(
|
||||
requires_auth=Slack(
|
||||
scopes=["channels:read", "groups:read", "im:read", "mpim:read"],
|
||||
)
|
||||
)
|
||||
async def get_conversation_metadata_by_id(
|
||||
context: ToolContext,
|
||||
conversation_id: Annotated[str, "The ID of the conversation to get metadata for"],
|
||||
) -> Annotated[dict, "The conversation metadata"]:
|
||||
"""Get the metadata of a conversation in Slack searching by its ID."""
|
||||
token = (
|
||||
context.authorization.token if context.authorization and context.authorization.token else ""
|
||||
)
|
||||
slackClient = AsyncWebClient(token=token)
|
||||
|
||||
try:
|
||||
response = await slackClient.conversations_info(
|
||||
channel=conversation_id,
|
||||
include_locale=True,
|
||||
include_num_members=True,
|
||||
)
|
||||
|
||||
except SlackApiError as e:
|
||||
if e.response.get("error") == "channel_not_found":
|
||||
conversations = await list_conversations_metadata(context)
|
||||
available_conversations = ", ".join(
|
||||
f"{conversation['id']} ({conversation['name']})"
|
||||
for conversation in conversations["conversations"]
|
||||
)
|
||||
|
||||
raise RetryableToolError(
|
||||
"Conversation not found",
|
||||
developer_message=f"Conversation with ID '{conversation_id}' not found.",
|
||||
additional_prompt_content=f"Available conversations: {available_conversations}",
|
||||
retry_after_ms=500,
|
||||
)
|
||||
|
||||
raise
|
||||
|
||||
return dict(**extract_conversation_metadata(response["channel"]))
|
||||
|
||||
|
||||
@tool(
|
||||
requires_auth=Slack(
|
||||
scopes=["channels:read", "groups:read", "im:read", "mpim:read"],
|
||||
)
|
||||
)
|
||||
async def get_conversation_metadata_by_name(
|
||||
context: ToolContext,
|
||||
conversation_name: Annotated[str, "The name of the conversation to get metadata for"],
|
||||
next_cursor: Annotated[
|
||||
Optional[str],
|
||||
"The cursor to use for pagination, if continuing from a previous search.",
|
||||
] = None,
|
||||
) -> Annotated[dict, "The conversation metadata"]:
|
||||
"""Get the metadata of a conversation in Slack searching by its name."""
|
||||
conversation_names: list[str] = []
|
||||
|
||||
async def find_conversation() -> dict:
|
||||
nonlocal conversation_names, conversation_name, next_cursor
|
||||
should_continue = True
|
||||
|
||||
while should_continue:
|
||||
response = await list_conversations_metadata(context, next_cursor=next_cursor)
|
||||
next_cursor = response.get("response_metadata", {}).get("next_cursor")
|
||||
|
||||
for conversation in response["conversations"]:
|
||||
response_conversation_name = (
|
||||
""
|
||||
if not isinstance(conversation.get("name"), str)
|
||||
else conversation["name"].lower()
|
||||
)
|
||||
if response_conversation_name == conversation_name.lower():
|
||||
return conversation # type: ignore[no-any-return]
|
||||
conversation_names.append(conversation["name"])
|
||||
|
||||
if not next_cursor:
|
||||
should_continue = False
|
||||
|
||||
raise ItemNotFoundError()
|
||||
|
||||
try:
|
||||
return await asyncio.wait_for(find_conversation(), timeout=MAX_PAGINATION_TIMEOUT_SECONDS)
|
||||
except ItemNotFoundError:
|
||||
raise RetryableToolError(
|
||||
"Conversation not found",
|
||||
developer_message=f"Conversation with name '{conversation_name}' not found.",
|
||||
additional_prompt_content=f"Available conversation names: {conversation_names}",
|
||||
retry_after_ms=500,
|
||||
)
|
||||
except TimeoutError:
|
||||
raise RetryableToolError(
|
||||
"Conversation not found, search timed out.",
|
||||
developer_message=(
|
||||
f"Conversation with name '{conversation_name}' not found. "
|
||||
f"Search timed out after {MAX_PAGINATION_TIMEOUT_SECONDS} seconds."
|
||||
),
|
||||
additional_prompt_content=(
|
||||
f"Other conversation names found are: {conversation_names}. "
|
||||
"The list is potentially non-exhaustive, since the search process timed out. "
|
||||
f"Use the '{list_conversations_metadata.__name__}' tool to get a comprehensive "
|
||||
"list of conversations."
|
||||
),
|
||||
retry_after_ms=500,
|
||||
)
|
||||
|
||||
|
||||
@tool(
|
||||
requires_auth=Slack(
|
||||
scopes=["channels:read", "groups:read", "im:read", "mpim:read"],
|
||||
)
|
||||
)
|
||||
async def list_conversations_metadata(
|
||||
context: ToolContext,
|
||||
conversation_types: Annotated[
|
||||
Optional[list[ConversationType]],
|
||||
"The type(s) of conversations to list. Defaults to all types.",
|
||||
] = None,
|
||||
limit: Annotated[Optional[int], "The maximum number of conversations to list."] = None,
|
||||
next_cursor: Annotated[Optional[str], "The cursor to use for pagination."] = None,
|
||||
) -> Annotated[
|
||||
dict,
|
||||
(
|
||||
"The conversations metadata list and a pagination 'next_cursor', if there are more "
|
||||
"conversations to retrieve."
|
||||
),
|
||||
]:
|
||||
"""
|
||||
List metadata for Slack conversations (channels and/or direct messages) that the user
|
||||
is a member of.
|
||||
"""
|
||||
if isinstance(conversation_types, ConversationType):
|
||||
conversation_types = [conversation_types]
|
||||
|
||||
conversation_types_filter = ",".join(
|
||||
convert_conversation_type_to_slack_name(conv_type).value
|
||||
for conv_type in conversation_types or ConversationType
|
||||
)
|
||||
|
||||
token = (
|
||||
context.authorization.token if context.authorization and context.authorization.token else ""
|
||||
)
|
||||
slackClient = AsyncWebClient(token=token)
|
||||
|
||||
results, next_cursor = await async_paginate(
|
||||
slackClient.conversations_list,
|
||||
"channels",
|
||||
limit=limit,
|
||||
next_cursor=next_cursor,
|
||||
types=conversation_types_filter,
|
||||
exclude_archived=True,
|
||||
)
|
||||
|
||||
return {
|
||||
"conversations": [
|
||||
dict(**extract_conversation_metadata(conversation))
|
||||
for conversation in results
|
||||
if conversation.get("is_im") or conversation.get("is_member")
|
||||
],
|
||||
"next_cursor": next_cursor,
|
||||
}
|
||||
|
||||
|
||||
@tool(
|
||||
requires_auth=Slack(
|
||||
scopes=["channels:read"],
|
||||
)
|
||||
)
|
||||
async def list_public_channels_metadata(
|
||||
context: ToolContext,
|
||||
limit: Annotated[Optional[int], "The maximum number of channels to list."] = None,
|
||||
) -> Annotated[dict, "The public channels"]:
|
||||
"""List metadata for public channels in Slack that the user is a member of."""
|
||||
|
||||
return await list_conversations_metadata( # type: ignore[no-any-return]
|
||||
context,
|
||||
conversation_types=[ConversationType.PUBLIC_CHANNEL],
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
|
||||
@tool(
|
||||
requires_auth=Slack(
|
||||
scopes=["groups:read"],
|
||||
)
|
||||
)
|
||||
async def list_private_channels_metadata(
|
||||
context: ToolContext,
|
||||
limit: Annotated[Optional[int], "The maximum number of channels to list."] = None,
|
||||
) -> Annotated[dict, "The private channels"]:
|
||||
"""List metadata for private channels in Slack that the user is a member of."""
|
||||
|
||||
return await list_conversations_metadata( # type: ignore[no-any-return]
|
||||
context,
|
||||
conversation_types=[ConversationType.PRIVATE_CHANNEL],
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
|
||||
@tool(
|
||||
requires_auth=Slack(
|
||||
scopes=["mpim:read"],
|
||||
)
|
||||
)
|
||||
async def list_group_direct_message_conversations_metadata(
|
||||
context: ToolContext,
|
||||
limit: Annotated[Optional[int], "The maximum number of conversations to list."] = None,
|
||||
) -> Annotated[dict, "The group direct message conversations metadata"]:
|
||||
"""List metadata for group direct message conversations that the user is a member of."""
|
||||
|
||||
return await list_conversations_metadata( # type: ignore[no-any-return]
|
||||
context,
|
||||
conversation_types=[ConversationType.MULTI_PERSON_DIRECT_MESSAGE],
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
|
||||
# Note: Bots are included in the results.
|
||||
# Note: Direct messages with no conversation history are included in the results.
|
||||
@tool(
|
||||
requires_auth=Slack(
|
||||
scopes=["im:read"],
|
||||
)
|
||||
)
|
||||
async def list_direct_message_conversations_metadata(
|
||||
context: ToolContext,
|
||||
limit: Annotated[Optional[int], "The maximum number of conversations to list."] = None,
|
||||
) -> Annotated[dict, "The direct message conversations metadata"]:
|
||||
"""List metadata for direct message conversations in Slack that the user is a member of."""
|
||||
|
||||
response = await list_conversations_metadata(
|
||||
context,
|
||||
conversation_types=[ConversationType.DIRECT_MESSAGE],
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
return response # type: ignore[no-any-return]
|
||||
|
|
|
|||
87
toolkits/slack/arcade_slack/tools/users.py
Normal file
87
toolkits/slack/arcade_slack/tools/users.py
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
from typing import Annotated, Any, Optional, cast
|
||||
|
||||
from arcade.sdk import ToolContext, tool
|
||||
from arcade.sdk.auth import Slack
|
||||
from arcade.sdk.errors import RetryableToolError
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from slack_sdk.web.async_client import AsyncWebClient
|
||||
|
||||
from arcade_slack.constants import MAX_PAGINATION_TIMEOUT_SECONDS
|
||||
from arcade_slack.models import SlackPaginationNextCursor, SlackUser
|
||||
from arcade_slack.utils import (
|
||||
async_paginate,
|
||||
extract_basic_user_info,
|
||||
is_user_a_bot,
|
||||
is_user_deleted,
|
||||
)
|
||||
|
||||
|
||||
@tool(
|
||||
requires_auth=Slack(
|
||||
scopes=["users:read", "users:read.email"],
|
||||
)
|
||||
)
|
||||
async def get_user_info_by_id(
|
||||
context: ToolContext,
|
||||
user_id: Annotated[str, "The ID of the user to get"],
|
||||
) -> Annotated[dict[str, Any], "The user's information"]:
|
||||
"""Get the information of a user in Slack."""
|
||||
|
||||
token = (
|
||||
context.authorization.token if context.authorization and context.authorization.token else ""
|
||||
)
|
||||
slackClient = AsyncWebClient(token=token)
|
||||
|
||||
try:
|
||||
response = await slackClient.users_info(user=user_id)
|
||||
except SlackApiError as e:
|
||||
if e.response.get("error") == "user_not_found":
|
||||
users = await list_users(context)
|
||||
available_users = ", ".join(f"{user['id']} ({user['name']})" for user in users["users"])
|
||||
|
||||
raise RetryableToolError(
|
||||
"User not found",
|
||||
developer_message=f"User with ID '{user_id}' not found.",
|
||||
additional_prompt_content=f"Available users: {available_users}",
|
||||
retry_after_ms=500,
|
||||
)
|
||||
|
||||
user_dict_raw: dict[str, Any] = response.get("user", {}) or {}
|
||||
user_dict = cast(SlackUser, user_dict_raw)
|
||||
user = SlackUser(**user_dict)
|
||||
return dict(**extract_basic_user_info(user))
|
||||
|
||||
|
||||
@tool(
|
||||
requires_auth=Slack(
|
||||
scopes=["users:read", "users:read.email"],
|
||||
)
|
||||
)
|
||||
async def list_users(
|
||||
context: ToolContext,
|
||||
exclude_bots: Annotated[Optional[bool], "Whether to exclude bots from the results"] = True,
|
||||
limit: Annotated[Optional[int], "The maximum number of users to return."] = None,
|
||||
next_cursor: Annotated[Optional[str], "The next cursor token to use for pagination."] = None,
|
||||
) -> Annotated[dict, "The users' info"]:
|
||||
"""List all users in the authenticated user's Slack team."""
|
||||
|
||||
token = (
|
||||
context.authorization.token if context.authorization and context.authorization.token else ""
|
||||
)
|
||||
slackClient = AsyncWebClient(token=token)
|
||||
|
||||
users, next_cursor = await async_paginate(
|
||||
func=slackClient.users_list,
|
||||
response_key="members",
|
||||
limit=limit,
|
||||
next_cursor=cast(SlackPaginationNextCursor, next_cursor),
|
||||
max_pagination_timeout_seconds=MAX_PAGINATION_TIMEOUT_SECONDS,
|
||||
)
|
||||
|
||||
users = [
|
||||
extract_basic_user_info(user)
|
||||
for user in users
|
||||
if not is_user_deleted(user) and (not exclude_bots or not is_user_a_bot(user))
|
||||
]
|
||||
|
||||
return {"users": users, "next_cursor": next_cursor}
|
||||
|
|
@ -1,20 +0,0 @@
|
|||
from slack_sdk.web import SlackResponse
|
||||
|
||||
|
||||
def format_channels(channels_response: SlackResponse) -> str:
|
||||
csv_string = "All active Slack channels:\n\nname\n"
|
||||
for channel in channels_response["channels"]:
|
||||
if not channel.get("is_archived", False):
|
||||
name = channel.get("name", "")
|
||||
csv_string += f"{name}\n"
|
||||
return csv_string.strip()
|
||||
|
||||
|
||||
def format_users(userListResponse: SlackResponse) -> str:
|
||||
csv_string = "All active Slack users:\n\nname,real_name\n"
|
||||
for user in userListResponse["members"]:
|
||||
if not user.get("deleted", False):
|
||||
name = user.get("name", "")
|
||||
real_name = user.get("profile", {}).get("real_name", "")
|
||||
csv_string += f"{name},{real_name}\n"
|
||||
return csv_string.strip()
|
||||
342
toolkits/slack/arcade_slack/utils.py
Normal file
342
toolkits/slack/arcade_slack/utils.py
Normal file
|
|
@ -0,0 +1,342 @@
|
|||
import asyncio
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
from arcade.sdk.errors import RetryableToolError
|
||||
|
||||
from arcade_slack.constants import MAX_PAGINATION_SIZE_LIMIT, MAX_PAGINATION_TIMEOUT_SECONDS
|
||||
from arcade_slack.custom_types import SlackPaginationNextCursor
|
||||
from arcade_slack.exceptions import PaginationTimeoutError
|
||||
from arcade_slack.models import (
|
||||
BasicUserInfo,
|
||||
ConversationMetadata,
|
||||
ConversationType,
|
||||
ConversationTypeSlackName,
|
||||
Message,
|
||||
SlackConversation,
|
||||
SlackConversationPurpose,
|
||||
SlackMessage,
|
||||
SlackUser,
|
||||
SlackUserList,
|
||||
)
|
||||
|
||||
|
||||
def format_users(user_list_response: SlackUserList) -> str:
|
||||
"""Format a list of Slack users into a CSV string.
|
||||
|
||||
Args:
|
||||
userListResponse: The response from the Slack API's users_list method.
|
||||
|
||||
Returns:
|
||||
A CSV string with two columns: the users' name and real name, each user in a new line.
|
||||
The first line is the header with column names 'name' and 'real_name'.
|
||||
"""
|
||||
csv_string = "name,real_name\n"
|
||||
for user in user_list_response["members"]:
|
||||
if not user.get("deleted", False):
|
||||
name = user.get("name", "")
|
||||
profile = user.get("profile", {})
|
||||
real_name = "" if not profile else profile.get("real_name", "")
|
||||
csv_string += f"{name},{real_name}\n"
|
||||
return csv_string.strip()
|
||||
|
||||
|
||||
def format_conversations_as_csv(conversations: dict) -> str:
|
||||
"""Format a list of Slack conversations into a CSV string.
|
||||
|
||||
Args:
|
||||
conversations: The response from the Slack API's conversations_list method.
|
||||
|
||||
Returns:
|
||||
A CSV string with the conversations' names.
|
||||
"""
|
||||
csv_string = "All active Slack conversations:\n\nname\n"
|
||||
for conversation in conversations["channels"]:
|
||||
if not conversation.get("is_archived", False):
|
||||
name = conversation.get("name", "")
|
||||
csv_string += f"{name}\n"
|
||||
return csv_string.strip()
|
||||
|
||||
|
||||
def remove_none_values(params: dict) -> dict:
|
||||
"""Remove key/value pairs from a dictionary where the value is None.
|
||||
|
||||
Args:
|
||||
params: The dictionary to remove None values from.
|
||||
|
||||
Returns:
|
||||
A dictionary with None values removed.
|
||||
"""
|
||||
return {k: v for k, v in params.items() if v is not None}
|
||||
|
||||
|
||||
def get_slack_conversation_type_as_str(channel: SlackConversation) -> str:
|
||||
"""Get the type of conversation from a Slack channel's dictionary.
|
||||
|
||||
Args:
|
||||
channel: The Slack channel's dictionary.
|
||||
|
||||
Returns:
|
||||
The type of conversation string in Slack naming standard.
|
||||
"""
|
||||
if channel.get("is_channel"):
|
||||
return ConversationTypeSlackName.PUBLIC_CHANNEL.value
|
||||
if channel.get("is_group"):
|
||||
return ConversationTypeSlackName.PRIVATE_CHANNEL.value
|
||||
if channel.get("is_im"):
|
||||
return ConversationTypeSlackName.IM.value
|
||||
if channel.get("is_mpim"):
|
||||
return ConversationTypeSlackName.MPIM.value
|
||||
raise ValueError(f"Invalid conversation type in channel {channel.get('name')}")
|
||||
|
||||
|
||||
def convert_conversation_type_to_slack_name(
|
||||
conversation_type: ConversationType,
|
||||
) -> ConversationTypeSlackName:
|
||||
"""Convert a conversation type to another using Slack naming standard.
|
||||
|
||||
Args:
|
||||
conversation_type: The conversation type enum value.
|
||||
|
||||
Returns:
|
||||
The corresponding conversation type enum value using Slack naming standard.
|
||||
"""
|
||||
mapping = {
|
||||
ConversationType.PUBLIC_CHANNEL: ConversationTypeSlackName.PUBLIC_CHANNEL,
|
||||
ConversationType.PRIVATE_CHANNEL: ConversationTypeSlackName.PRIVATE_CHANNEL,
|
||||
ConversationType.MULTI_PERSON_DIRECT_MESSAGE: ConversationTypeSlackName.MPIM,
|
||||
ConversationType.DIRECT_MESSAGE: ConversationTypeSlackName.IM,
|
||||
}
|
||||
return mapping[conversation_type]
|
||||
|
||||
|
||||
def extract_conversation_metadata(conversation: SlackConversation) -> ConversationMetadata:
|
||||
"""Extract conversation metadata from a Slack conversation object.
|
||||
|
||||
Args:
|
||||
conversation: The Slack conversation dictionary.
|
||||
|
||||
Returns:
|
||||
A dictionary with the conversation metadata.
|
||||
"""
|
||||
conversation_type = get_slack_conversation_type_as_str(conversation)
|
||||
|
||||
purpose: Optional[SlackConversationPurpose] = conversation.get("purpose")
|
||||
purpose_value = "" if not purpose else purpose.get("value", "")
|
||||
|
||||
metadata = ConversationMetadata(
|
||||
id=conversation.get("id"),
|
||||
name=conversation.get("name"),
|
||||
conversation_type=conversation_type,
|
||||
is_private=conversation.get("is_private", True),
|
||||
is_archived=conversation.get("is_archived", False),
|
||||
is_member=conversation.get("is_member", True),
|
||||
purpose=purpose_value,
|
||||
num_members=conversation.get("num_members", 0),
|
||||
)
|
||||
|
||||
if conversation_type == ConversationTypeSlackName.IM.value:
|
||||
metadata["num_members"] = 2
|
||||
metadata["user"] = conversation.get("user")
|
||||
metadata["is_user_deleted"] = conversation.get("is_user_deleted")
|
||||
elif conversation_type == ConversationTypeSlackName.MPIM.value:
|
||||
conversation_name = conversation.get("name", "")
|
||||
if conversation_name:
|
||||
metadata["num_members"] = len(conversation_name.split("--"))
|
||||
else:
|
||||
metadata["num_members"] = None
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
def extract_basic_user_info(user_info: SlackUser) -> BasicUserInfo:
|
||||
"""Extract a user's basic info from a Slack user dictionary.
|
||||
|
||||
Args:
|
||||
user_info: The Slack user dictionary.
|
||||
|
||||
Returns:
|
||||
A dictionary with the user's basic info.
|
||||
|
||||
See https://api.slack.com/types/user for the structure of the user object.
|
||||
"""
|
||||
profile = user_info.get("profile", {})
|
||||
display_name = None if not profile else profile.get("display_name")
|
||||
email = None if not profile else profile.get("email")
|
||||
return BasicUserInfo(
|
||||
id=user_info.get("id"),
|
||||
name=user_info.get("name"),
|
||||
is_bot=user_info.get("is_bot"),
|
||||
email=email,
|
||||
display_name=display_name,
|
||||
real_name=user_info.get("real_name"),
|
||||
timezone=user_info.get("tz"),
|
||||
)
|
||||
|
||||
|
||||
def is_user_a_bot(user: SlackUser) -> bool:
|
||||
"""Check if a Slack user represents a bot.
|
||||
|
||||
Args:
|
||||
user: The Slack user dictionary.
|
||||
|
||||
Returns:
|
||||
True if the user is a bot, False otherwise.
|
||||
|
||||
Bots are users with the "is_bot" flag set to true.
|
||||
USLACKBOT is the user object for the Slack bot itself and is a special case.
|
||||
|
||||
See https://api.slack.com/types/user for the structure of the user object.
|
||||
"""
|
||||
return user.get("is_bot") or user.get("id") == "USLACKBOT"
|
||||
|
||||
|
||||
def is_user_deleted(user: SlackUser) -> bool:
|
||||
"""Check if a Slack user represents a deleted user.
|
||||
|
||||
Args:
|
||||
user: The Slack user dictionary.
|
||||
|
||||
Returns:
|
||||
True if the user is deleted, False otherwise.
|
||||
|
||||
See https://api.slack.com/types/user for the structure of the user object.
|
||||
"""
|
||||
is_deleted = user.get("deleted")
|
||||
|
||||
return is_deleted if isinstance(is_deleted, bool) else False
|
||||
|
||||
|
||||
async def async_paginate(
|
||||
func: Callable,
|
||||
response_key: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
next_cursor: Optional[SlackPaginationNextCursor] = None,
|
||||
max_pagination_timeout_seconds: int = MAX_PAGINATION_TIMEOUT_SECONDS,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> tuple[list, Optional[SlackPaginationNextCursor]]:
|
||||
"""Paginate a Slack AsyncWebClient's method results.
|
||||
|
||||
The purpose is to abstract the pagination work and make it easier for the LLM to retrieve the
|
||||
amount of items requested by the user, regardless of limits imposed by the Slack API. We still
|
||||
return the next cursor, if needed to paginate further.
|
||||
|
||||
Args:
|
||||
func: The Slack AsyncWebClient's method to paginate.
|
||||
response_key: The key in the response dictionary to extract the items from (optional). If
|
||||
not provided, the entire response dictionary is used.
|
||||
limit: The maximum number of items to retrieve (defaults to Slack's suggested limit).
|
||||
next_cursor: The cursor to use for pagination (optional).
|
||||
*args: Positional arguments to pass to the Slack method.
|
||||
**kwargs: Keyword arguments to pass to the Slack method.
|
||||
|
||||
Returns:
|
||||
A tuple containing the list of items and the next cursor, if needed to paginate further.
|
||||
"""
|
||||
results: list[Any] = []
|
||||
|
||||
async def paginate_loop() -> list[Any]:
|
||||
nonlocal results, next_cursor
|
||||
should_continue = True
|
||||
|
||||
"""
|
||||
The slack_limit variable makes the Slack API return no more than the appropriate
|
||||
amount of items. The loop extends results with the items returned and continues
|
||||
iterating if it hasn't reached the limit, and Slack indicates there're more
|
||||
items to retrieve.
|
||||
"""
|
||||
|
||||
while should_continue:
|
||||
iteration_limit = limit - len(results) if limit else MAX_PAGINATION_SIZE_LIMIT
|
||||
slack_limit = min(iteration_limit, MAX_PAGINATION_SIZE_LIMIT)
|
||||
iteration_kwargs = {**kwargs, "limit": slack_limit, "cursor": next_cursor}
|
||||
response = await func(*args, **iteration_kwargs)
|
||||
|
||||
try:
|
||||
results.extend(dict(response.data) if not response_key else response[response_key])
|
||||
except KeyError:
|
||||
raise ValueError(f"Response key {response_key} not found in Slack response")
|
||||
|
||||
next_cursor = response.get("response_metadata", {}).get("next_cursor")
|
||||
|
||||
if (limit and len(results) >= limit) or not next_cursor:
|
||||
should_continue = False
|
||||
|
||||
return results
|
||||
|
||||
try:
|
||||
results = await asyncio.wait_for(paginate_loop(), timeout=max_pagination_timeout_seconds)
|
||||
except TimeoutError:
|
||||
raise PaginationTimeoutError(max_pagination_timeout_seconds)
|
||||
else:
|
||||
return results, next_cursor
|
||||
|
||||
|
||||
def enrich_message_datetime(message: SlackMessage) -> Message:
|
||||
"""Enrich message metadata with formatted datetime.
|
||||
|
||||
It helps LLMs when they need to display the date/time in human-readable format. Slack
|
||||
will only return a unix-formatted timestamp (it's not actually UTC Unix timestamp, but
|
||||
the Unix timestamp in the user's timezone - I know, odd, but it is what it is).
|
||||
|
||||
Args:
|
||||
message: The Slack message dictionary.
|
||||
|
||||
Returns:
|
||||
The enriched message dictionary.
|
||||
"""
|
||||
message = Message(**message)
|
||||
ts = message.get("ts")
|
||||
if isinstance(ts, str):
|
||||
try:
|
||||
timestamp = float(ts)
|
||||
message["datetime_timestamp"] = datetime.fromtimestamp(timestamp).strftime(
|
||||
"%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
return message
|
||||
|
||||
|
||||
def convert_datetime_to_unix_timestamp(datetime_str: str) -> int:
|
||||
"""Convert a datetime string to a unix timestamp.
|
||||
|
||||
Args:
|
||||
datetime_str: The datetime string ('YYYY-MM-DD HH:MM:SS') to convert to a unix timestamp.
|
||||
|
||||
Returns:
|
||||
The unix timestamp integer.
|
||||
"""
|
||||
try:
|
||||
dt = datetime.strptime(datetime_str, "%Y-%m-%d %H:%M:%S")
|
||||
return int(dt.timestamp())
|
||||
except ValueError:
|
||||
raise RetryableToolError(
|
||||
"Invalid datetime format",
|
||||
developer_message=f"The datetime '{datetime_str}' is invalid. "
|
||||
"Please provide a datetime string in the format 'YYYY-MM-DD HH:MM:SS'.",
|
||||
retry_after_ms=500,
|
||||
)
|
||||
|
||||
|
||||
def convert_relative_datetime_to_unix_timestamp(
|
||||
relative_datetime: str,
|
||||
current_unix_timestamp: Optional[int] = None,
|
||||
) -> int:
|
||||
"""Convert a relative datetime string in the format 'DD:HH:MM' to unix timestamp.
|
||||
|
||||
Args:
|
||||
relative_datetime: The relative datetime string ('DD:HH:MM') to convert to a unix timestamp.
|
||||
current_unix_timestamp: The current unix timestamp (optional). If not provided, the
|
||||
current unix timestamp from datetime.now is used.
|
||||
|
||||
Returns:
|
||||
The unix timestamp integer.
|
||||
"""
|
||||
if not current_unix_timestamp:
|
||||
current_unix_timestamp = int(datetime.now(timezone.utc).timestamp())
|
||||
|
||||
days, hours, minutes = map(int, relative_datetime.split(":"))
|
||||
seconds = days * 86400 + hours * 3600 + minutes * 60
|
||||
return int(current_unix_timestamp - seconds)
|
||||
8
toolkits/slack/conftest.py
Normal file
8
toolkits/slack/conftest.py
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
import pytest
|
||||
from arcade.sdk import ToolAuthorizationContext, ToolContext
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_context():
|
||||
mock_auth = ToolAuthorizationContext(token="fake-token") # noqa: S106
|
||||
return ToolContext(authorization=mock_auth)
|
||||
987
toolkits/slack/evals/eval_chat.py
Normal file
987
toolkits/slack/evals/eval_chat.py
Normal file
|
|
@ -0,0 +1,987 @@
|
|||
import json
|
||||
from datetime import timedelta
|
||||
|
||||
from arcade.sdk import ToolCatalog
|
||||
from arcade.sdk.eval import (
|
||||
BinaryCritic,
|
||||
DatetimeCritic,
|
||||
EvalRubric,
|
||||
EvalSuite,
|
||||
ExpectedToolCall,
|
||||
SimilarityCritic,
|
||||
tool_eval,
|
||||
)
|
||||
|
||||
import arcade_slack
|
||||
from arcade_slack.critics import RelativeTimeBinaryCritic
|
||||
from arcade_slack.tools.chat import (
|
||||
get_conversation_metadata_by_id,
|
||||
get_conversation_metadata_by_name,
|
||||
get_members_in_conversation_by_id,
|
||||
get_members_in_conversation_by_name,
|
||||
get_messages_in_channel_by_name,
|
||||
get_messages_in_conversation_by_id,
|
||||
list_conversations_metadata,
|
||||
list_direct_message_conversations_metadata,
|
||||
list_group_direct_message_conversations_metadata,
|
||||
list_private_channels_metadata,
|
||||
list_public_channels_metadata,
|
||||
send_dm_to_user,
|
||||
send_message_to_channel,
|
||||
)
|
||||
|
||||
# Evaluation rubric
|
||||
rubric = EvalRubric(
|
||||
fail_threshold=0.8,
|
||||
warn_threshold=0.9,
|
||||
)
|
||||
|
||||
|
||||
catalog = ToolCatalog()
|
||||
# Register the Slack tools
|
||||
catalog.add_module(arcade_slack)
|
||||
|
||||
|
||||
@tool_eval()
|
||||
def send_message_eval_suite() -> EvalSuite:
|
||||
"""Create an evaluation suite for Slack messaging tools."""
|
||||
suite = EvalSuite(
|
||||
name="Slack Messaging Tools Evaluation",
|
||||
system_message="You are an AI assistant that can send direct messages and post messages to channels in Slack using the provided tools.",
|
||||
catalog=catalog,
|
||||
rubric=rubric,
|
||||
)
|
||||
|
||||
# Send DM to User Scenarios
|
||||
suite.add_case(
|
||||
name="Send DM to user with clear username",
|
||||
user_message="Send a direct message to johndoe saying 'Hello, can we meet at 3 PM?'",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=send_dm_to_user,
|
||||
args={
|
||||
"user_name": "johndoe",
|
||||
"message": "Hello, can we meet at 3 PM?",
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="user_name", weight=0.5),
|
||||
SimilarityCritic(critic_field="message", weight=0.5, similarity_threshold=0.9),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Send DM with ambiguous username",
|
||||
user_message="ask him for an update on the project",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=send_dm_to_user,
|
||||
args={
|
||||
"user_name": "john",
|
||||
"message": "Hi John, could you please provide an update on the Acme project?",
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="user_name", weight=0.75),
|
||||
SimilarityCritic(critic_field="message", weight=0.25, similarity_threshold=0.6),
|
||||
],
|
||||
additional_messages=[
|
||||
{"role": "user", "content": "Message John about the Acme project deadline"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "Slack_ListUsers",
|
||||
"arguments": '{"exclude_bots":true}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"content": json.dumps({
|
||||
"users": [
|
||||
{
|
||||
"display_name": "john",
|
||||
"email": "john@randomtech.com",
|
||||
"id": "abc123",
|
||||
"is_bot": False,
|
||||
"name": "john",
|
||||
"real_name": "John Doe",
|
||||
"timezone": "America/Los_Angeles",
|
||||
},
|
||||
{
|
||||
"display_name": "jack",
|
||||
"email": "jack@randomtech.com",
|
||||
"id": "def456",
|
||||
"is_bot": False,
|
||||
"name": "jack",
|
||||
"real_name": "Jack Doe",
|
||||
"timezone": "America/Los_Angeles",
|
||||
},
|
||||
]
|
||||
}),
|
||||
"tool_call_id": "call_1",
|
||||
"name": "Slack_ListUsers",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "What would you like to include in the message to John about the Acme project deadline?",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Send DM with username in different format",
|
||||
user_message="yes, send it",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=send_dm_to_user,
|
||||
args={
|
||||
"user_name": "jane.doe",
|
||||
"message": "Hi Jane, I need to reschedule our meeting. When are you available?",
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="user_name", weight=0.75),
|
||||
SimilarityCritic(critic_field="message", weight=0.25, similarity_threshold=0.6),
|
||||
],
|
||||
additional_messages=[
|
||||
{"role": "user", "content": "Message Jane.Doe asking to reschedule our meeting"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "Slack_ListUsers",
|
||||
"arguments": '{"exclude_bots":true}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"content": json.dumps({
|
||||
"users": [
|
||||
{
|
||||
"display_name": "jane.doe",
|
||||
"email": "jane@randomtech.com",
|
||||
"id": "abc123",
|
||||
"is_bot": False,
|
||||
"name": "jane.doe",
|
||||
"real_name": "Jane Doe",
|
||||
"timezone": "America/Los_Angeles",
|
||||
},
|
||||
{
|
||||
"display_name": "jack",
|
||||
"email": "jack@randomtech.com",
|
||||
"id": "def456",
|
||||
"is_bot": False,
|
||||
"name": "jack",
|
||||
"real_name": "Jack Doe",
|
||||
"timezone": "America/Los_Angeles",
|
||||
},
|
||||
]
|
||||
}),
|
||||
"tool_call_id": "call_1",
|
||||
"name": "Slack_ListUsers",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I found a user with the name 'jane.doe'. Would you like to send a message to them?",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
# Send Message to Channel Scenarios
|
||||
suite.add_case(
|
||||
name="Send message to channel with clear name",
|
||||
user_message="Post 'The new feature is now live!' in the #announcements channel",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=send_message_to_channel,
|
||||
args={
|
||||
"channel_name": "announcements",
|
||||
"message": "The new feature is now live!",
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="channel_name", weight=0.5),
|
||||
SimilarityCritic(critic_field="message", weight=0.5),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Send message to channel with ambiguous name",
|
||||
user_message="Inform the team in the general channel about the upcoming maintenance",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=send_message_to_channel,
|
||||
args={
|
||||
"channel_name": "general",
|
||||
"message": "Attention team: There will be upcoming maintenance. Please save your work and expect some downtime.",
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
SimilarityCritic(critic_field="channel_name", weight=0.8),
|
||||
SimilarityCritic(critic_field="message", weight=0.2, similarity_threshold=0.6),
|
||||
],
|
||||
)
|
||||
|
||||
# Adversarial Scenarios
|
||||
suite.add_case(
|
||||
name="Ambiguous between DM and channel message",
|
||||
user_message="general",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=send_message_to_channel,
|
||||
args={
|
||||
"channel_name": "general",
|
||||
"message": "Great job on the presentation!",
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
SimilarityCritic(critic_field="channel_name", weight=0.4),
|
||||
SimilarityCritic(critic_field="message", weight=0.6),
|
||||
],
|
||||
additional_messages=[
|
||||
{"role": "user", "content": "Send 'Great job on the presentation!' to the team"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "Slack_ListPublicChannelsMetadata",
|
||||
"arguments": '{"limit":20}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"content": json.dumps({
|
||||
"conversations": [
|
||||
{
|
||||
"conversation_type": "public_channel",
|
||||
"id": "channel1",
|
||||
"is_archived": False,
|
||||
"is_member": True,
|
||||
"is_private": False,
|
||||
"name": "random",
|
||||
"num_members": 999,
|
||||
"purpose": "Random stuff",
|
||||
},
|
||||
{
|
||||
"conversation_type": "public_channel",
|
||||
"id": "channel2",
|
||||
"is_archived": False,
|
||||
"is_member": True,
|
||||
"is_private": False,
|
||||
"name": "general",
|
||||
"num_members": 999,
|
||||
"purpose": "Just a general channel",
|
||||
},
|
||||
],
|
||||
"next_cursor": "",
|
||||
}),
|
||||
"tool_call_id": "call_1",
|
||||
"name": "Slack_ListPublicChannelsMetadata",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": 'To send the message "Great job on the presentation!" to the team, please let me know which Slack channel you\'d like to use:\n\n1. #random\n2. #general\n\nPlease let me know your choice!',
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
# Multiple recipients in DM request
|
||||
suite.add_case(
|
||||
name="Multiple recipients in DM request",
|
||||
user_message="Send DMs to the users 'alice' and 'bob' about pushing the meeting tomorrow. I have to much work to do.",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=send_dm_to_user,
|
||||
args={
|
||||
"user_name": "alice",
|
||||
"message": "Hi Alice, about our meeting tomorrow, let's reschedule? I am swamped with work.",
|
||||
},
|
||||
),
|
||||
ExpectedToolCall(
|
||||
func=send_dm_to_user,
|
||||
args={
|
||||
"user_name": "bob",
|
||||
"message": "Hi Bob, about our meeting tomorrow, let's reschedule? I am swamped with work.",
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
SimilarityCritic(critic_field="user_name", weight=0.75),
|
||||
SimilarityCritic(critic_field="message", weight=0.25, similarity_threshold=0.5),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Channel name similar to username",
|
||||
user_message="Post 'sounds great!' in john-project channel",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=send_message_to_channel,
|
||||
args={
|
||||
"channel_name": "john-project",
|
||||
"message": "Sounds great!",
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="channel_name", weight=0.5),
|
||||
SimilarityCritic(critic_field="message", weight=0.5),
|
||||
],
|
||||
)
|
||||
|
||||
return suite
|
||||
|
||||
|
||||
@tool_eval()
|
||||
def list_conversations_eval_suite() -> EvalSuite:
|
||||
"""Create an evaluation suite for tools listing conversations."""
|
||||
suite = EvalSuite(
|
||||
name="Slack Messaging Tools Evaluation",
|
||||
system_message="You are an AI assistant that can interact with Slack to send messages and get information from conversations, users, etc.",
|
||||
catalog=catalog,
|
||||
rubric=rubric,
|
||||
)
|
||||
|
||||
cases = [
|
||||
(
|
||||
"List my conversations",
|
||||
"List all conversations I am a member of",
|
||||
list_conversations_metadata,
|
||||
),
|
||||
(
|
||||
"List public channels",
|
||||
"List all public channels",
|
||||
list_public_channels_metadata,
|
||||
),
|
||||
(
|
||||
"List private channels",
|
||||
"List all private channels",
|
||||
list_private_channels_metadata,
|
||||
),
|
||||
(
|
||||
"List group direct message channels",
|
||||
"List all group direct message channels",
|
||||
list_group_direct_message_conversations_metadata,
|
||||
),
|
||||
(
|
||||
"List individual direct message channels",
|
||||
"List all individual direct message channels",
|
||||
list_direct_message_conversations_metadata,
|
||||
),
|
||||
(
|
||||
"List direct message channels",
|
||||
"List all direct message channels",
|
||||
list_direct_message_conversations_metadata,
|
||||
),
|
||||
(
|
||||
"List public and private channels",
|
||||
"List public and private channels I am a member of",
|
||||
list_public_channels_metadata,
|
||||
list_private_channels_metadata,
|
||||
),
|
||||
(
|
||||
"List public channels and direct message conversations",
|
||||
"List public channels and direct message conversations I am a member of",
|
||||
list_public_channels_metadata,
|
||||
list_direct_message_conversations_metadata,
|
||||
),
|
||||
]
|
||||
|
||||
for name, user_message, *expect_called_tool_functions in cases:
|
||||
suite.add_case(
|
||||
name=name,
|
||||
user_message=user_message,
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=tool_function,
|
||||
args={},
|
||||
)
|
||||
for tool_function in expect_called_tool_functions
|
||||
],
|
||||
)
|
||||
|
||||
return suite
|
||||
|
||||
|
||||
@tool_eval()
|
||||
def get_conversations_metadata_eval_suite() -> EvalSuite:
|
||||
"""Create an evaluation suite for tools getting conversations metadata."""
|
||||
suite = EvalSuite(
|
||||
name="Slack Tools Evaluation",
|
||||
system_message="You are an AI assistant that can interact with Slack to send messages and get information from conversations, users, etc.",
|
||||
catalog=catalog,
|
||||
rubric=rubric,
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Get conversation metadata by name",
|
||||
user_message="Get the metadata of the #general channel",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_conversation_metadata_by_name,
|
||||
args={
|
||||
"conversation_name": "general",
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_name", weight=1.0),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Get conversation metadata by id",
|
||||
user_message="Get the metadata of the conversation with id '1234567890'",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_conversation_metadata_by_id,
|
||||
args={
|
||||
"conversation_id": "1234567890",
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_id", weight=1.0),
|
||||
],
|
||||
)
|
||||
|
||||
return suite
|
||||
|
||||
|
||||
@tool_eval()
|
||||
def get_conversations_members_eval_suite() -> EvalSuite:
|
||||
"""Create an evaluation suite for tools getting conversations members."""
|
||||
suite = EvalSuite(
|
||||
name="Slack Tools Evaluation",
|
||||
system_message="You are an AI assistant that can interact with Slack to send messages and get information from conversations, users, etc.",
|
||||
catalog=catalog,
|
||||
rubric=rubric,
|
||||
)
|
||||
|
||||
user_messages = [
|
||||
"Get the members of the #general channel",
|
||||
"Get the members of the general channel",
|
||||
"Get a list of people in the #general channel",
|
||||
"Get a list of people in the general channel",
|
||||
"Show me who's in the #general channel",
|
||||
"Show me who's in the general channel",
|
||||
"Who is in the #general channel?",
|
||||
"Who is in the general channel?",
|
||||
]
|
||||
|
||||
for user_message in user_messages:
|
||||
suite.add_case(
|
||||
name=f"Get conversation members by name: {user_message}",
|
||||
user_message=user_message,
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_members_in_conversation_by_name,
|
||||
args={
|
||||
"conversation_name": "general",
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_name", weight=1.0),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Get conversation members by id",
|
||||
user_message="Get the members of the conversation with id '1234567890'",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_members_in_conversation_by_id,
|
||||
args={
|
||||
"conversation_id": "1234567890",
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_id", weight=1.0),
|
||||
],
|
||||
)
|
||||
|
||||
return suite
|
||||
|
||||
|
||||
@tool_eval()
|
||||
def get_conversation_history_eval_suite() -> EvalSuite:
|
||||
"""Create an evaluation suite for tools getting conversations history."""
|
||||
suite = EvalSuite(
|
||||
name="Slack Chat Tools Evaluation",
|
||||
system_message="You are an AI assistant that can interact with Slack to send messages and get information from conversations, users, etc.",
|
||||
catalog=catalog,
|
||||
rubric=rubric,
|
||||
)
|
||||
|
||||
no_arguments_user_messages_by_conversation_name = [
|
||||
"Get the history of the #general channel",
|
||||
"Get the history of the general channel",
|
||||
"list the messages in the #general channel",
|
||||
"list the messages in the general channel",
|
||||
]
|
||||
|
||||
for user_message in no_arguments_user_messages_by_conversation_name:
|
||||
suite.add_case(
|
||||
name=f"Get conversation history by name: '{user_message}'",
|
||||
user_message=user_message,
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_messages_in_channel_by_name,
|
||||
args={
|
||||
"conversation_name": "general",
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_name", weight=1.0),
|
||||
],
|
||||
)
|
||||
|
||||
no_arguments_user_messages_by_conversation_id = [
|
||||
"Get the history of the conversation with id '1234567890'",
|
||||
"Get the history of the conversation with id '1234567890'",
|
||||
"list the messages in the conversation with id '1234567890'",
|
||||
"list the messages in the conversation with id '1234567890'",
|
||||
]
|
||||
|
||||
for user_message in no_arguments_user_messages_by_conversation_id:
|
||||
suite.add_case(
|
||||
name=f"Get conversation history by id: '{user_message}'",
|
||||
user_message=user_message,
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_messages_in_conversation_by_id,
|
||||
args={
|
||||
"conversation_id": "1234567890",
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_id", weight=1.0),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Get conversation history with limit by name",
|
||||
user_message="Get the last 10 messages in the #general channel",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_messages_in_channel_by_name,
|
||||
args={
|
||||
"conversation_name": "general",
|
||||
"limit": 10,
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_name", weight=1.0),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Get conversation history with limit by id",
|
||||
user_message="Get the last 25 messages in the conversation with id '1234567890'",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_messages_in_conversation_by_id,
|
||||
args={
|
||||
"conversation_id": "1234567890",
|
||||
"limit": 25,
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_id", weight=1.0),
|
||||
],
|
||||
)
|
||||
|
||||
# Relative time eval cases by id
|
||||
|
||||
suite.add_case(
|
||||
name="Get conversation history oldest relative by id (2 days ago)",
|
||||
user_message="Get the messages in the conversation with id '1234567890' starting 2 days ago",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_messages_in_conversation_by_id,
|
||||
args={
|
||||
"conversation_id": "1234567890",
|
||||
"oldest_relative": "02:00:00",
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_id", weight=0.5),
|
||||
RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.5),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Get conversation history oldest and latest relative by id",
|
||||
user_message="Get the messages in the conversation with id '1234567890' from 2 days ago to 3 hours ago",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_messages_in_conversation_by_id,
|
||||
args={
|
||||
"conversation_id": "1234567890",
|
||||
"oldest_relative": "02:00:00",
|
||||
"latest_relative": "00:03:00",
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_id", weight=1 / 3),
|
||||
RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=1 / 3),
|
||||
RelativeTimeBinaryCritic(critic_field="latest_relative", weight=1 / 3),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Get conversation history oldest relative by id (1 week ago)",
|
||||
user_message="Get the messages in the conversation with id '1234567890' starting 1 week ago",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_messages_in_conversation_by_id,
|
||||
args={
|
||||
"conversation_id": "1234567890",
|
||||
"oldest_relative": "07:00:00",
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_id", weight=0.5),
|
||||
RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.5),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Get conversation history oldest relative by id (yesterday)",
|
||||
user_message="Get the messages in the conversation with id '1234567890' from yesterday",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_messages_in_conversation_by_id,
|
||||
args={
|
||||
"conversation_id": "1234567890",
|
||||
"oldest_relative": "01:00:00",
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_id", weight=0.5),
|
||||
RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.5),
|
||||
],
|
||||
)
|
||||
|
||||
# Relative time eval cases by name
|
||||
|
||||
suite.add_case(
|
||||
name="Get conversation history oldest relative by name (2 days ago)",
|
||||
user_message="Get the messages in the #general channel starting 2 days ago",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_messages_in_channel_by_name,
|
||||
args={
|
||||
"conversation_name": "general",
|
||||
"oldest_relative": "02:00:00",
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_name", weight=0.5),
|
||||
RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.5),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Get conversation history oldest and latest relative by name",
|
||||
user_message="Get the messages in the #general channel from 2 days ago to 3 hours ago",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_messages_in_channel_by_name,
|
||||
args={
|
||||
"conversation_name": "general",
|
||||
"oldest_relative": "02:00:00",
|
||||
"latest_relative": "00:03:00",
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_name", weight=1 / 3),
|
||||
RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=1 / 3),
|
||||
RelativeTimeBinaryCritic(critic_field="latest_relative", weight=1 / 3),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Get conversation history oldest relative by name (yesterday)",
|
||||
user_message="Get the messages in the #general channel from yesterday",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_messages_in_channel_by_name,
|
||||
args={
|
||||
"conversation_name": "general",
|
||||
"oldest_relative": "01:00:00",
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_name", weight=0.5),
|
||||
RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.5),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Get conversation history oldest relative by name (last week)",
|
||||
user_message="Get the messages in the #general channel from last week",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_messages_in_channel_by_name,
|
||||
args={
|
||||
"conversation_name": "general",
|
||||
"oldest_relative": "07:00:00",
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_name", weight=0.5),
|
||||
RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.5),
|
||||
],
|
||||
)
|
||||
|
||||
# Absolute time eval cases by id
|
||||
|
||||
suite.add_case(
|
||||
name="Get conversation history oldest absolute by id (on a specific date)",
|
||||
user_message="Get the messages in the conversation with id '1234567890' from 2025-01-20",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_messages_in_conversation_by_id,
|
||||
args={
|
||||
"conversation_id": "1234567890",
|
||||
"oldest_datetime": "2025-01-20 00:00:00",
|
||||
"latest_datetime": "2025-01-20 23:59:59",
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_id", weight=1 / 3),
|
||||
DatetimeCritic(
|
||||
critic_field="oldest_datetime", weight=1 / 3, max_difference=timedelta(minutes=2)
|
||||
),
|
||||
DatetimeCritic(
|
||||
critic_field="latest_datetime", weight=1 / 3, max_difference=timedelta(minutes=2)
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Get conversation history oldest absolute by id (between a date range)",
|
||||
user_message="Get the messages in the conversation with id '1234567890' from 2025-01-20 to 2025-01-25",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_messages_in_conversation_by_id,
|
||||
args={
|
||||
"conversation_id": "1234567890",
|
||||
"oldest_datetime": "2025-01-20 00:00:00",
|
||||
"latest_datetime": "2025-01-25 23:59:59",
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_id", weight=1 / 3),
|
||||
DatetimeCritic(
|
||||
critic_field="oldest_datetime", weight=1 / 3, max_difference=timedelta(minutes=2)
|
||||
),
|
||||
DatetimeCritic(
|
||||
critic_field="latest_datetime", weight=1 / 3, max_difference=timedelta(minutes=2)
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Get conversation history oldest absolute by name (on a specific date)",
|
||||
user_message="Get the messages in the #general channel from 2025-01-20",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_messages_in_channel_by_name,
|
||||
args={
|
||||
"conversation_name": "general",
|
||||
"oldest_datetime": "2025-01-20 00:00:00",
|
||||
"latest_datetime": "2025-01-20 23:59:59",
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_name", weight=1 / 3),
|
||||
# We use a timedelta of 10 seconds because sometimes the LLM will select the limit
|
||||
# date at 23:59:59, other times it'll select the next day at 00:00:00.
|
||||
DatetimeCritic(
|
||||
critic_field="oldest_datetime", weight=1 / 3, max_difference=timedelta(seconds=10)
|
||||
),
|
||||
DatetimeCritic(
|
||||
critic_field="latest_datetime", weight=1 / 3, max_difference=timedelta(seconds=10)
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Get conversation history oldest absolute by name (between a date range)",
|
||||
user_message="Get the messages in the #general channel from 2025-01-20 to 2025-01-25",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_messages_in_channel_by_name,
|
||||
args={
|
||||
"conversation_name": "general",
|
||||
"oldest_datetime": "2025-01-20 00:00:00",
|
||||
"latest_datetime": "2025-01-25 23:59:59",
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_name", weight=1 / 3),
|
||||
# We use a timedelta of 10 seconds because sometimes the LLM will select the limit
|
||||
# date at 23:59:59, other times it'll select the next day at 00:00:00.
|
||||
DatetimeCritic(
|
||||
critic_field="oldest_datetime", weight=1 / 3, max_difference=timedelta(seconds=10)
|
||||
),
|
||||
DatetimeCritic(
|
||||
critic_field="latest_datetime", weight=1 / 3, max_difference=timedelta(seconds=10)
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# Eval case for pagination
|
||||
|
||||
suite.add_case(
|
||||
name="Get conversation history with pagination",
|
||||
user_message="get the next 5 messages",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_messages_in_channel_by_name,
|
||||
args={
|
||||
"conversation_name": "general",
|
||||
"limit": 5,
|
||||
"cursor": "dXNlcjpVsDjzOTZGVDlQRA==",
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_name", weight=1 / 3),
|
||||
BinaryCritic(critic_field="cursor", weight=1 / 3),
|
||||
BinaryCritic(critic_field="limit", weight=1 / 3),
|
||||
],
|
||||
additional_messages=[
|
||||
{"role": "user", "content": "Get the last 2 messages on the general channel"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "Slack_GetConversationHistoryByName",
|
||||
"arguments": json.dumps({
|
||||
"conversation_name": "general",
|
||||
"limit": 2,
|
||||
}),
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"content": json.dumps({
|
||||
"messages": [
|
||||
{
|
||||
"blocks": [
|
||||
{
|
||||
"block_id": "abc123",
|
||||
"elements": [
|
||||
{
|
||||
"elements": [
|
||||
{
|
||||
"text": "Almost there, Boss, need to get some evals in!",
|
||||
"type": "text",
|
||||
}
|
||||
],
|
||||
"type": "rich_text_section",
|
||||
}
|
||||
],
|
||||
"type": "rich_text",
|
||||
}
|
||||
],
|
||||
"client_msg_id": "msg_id_0",
|
||||
"datetime_timestamp": "2025-01-21 16:59:55",
|
||||
"team": "617263616465207465616D20697320617420626F7373206C6576656C",
|
||||
"text": "Almost there, Boss, need to get some evals in!",
|
||||
"ts": "1737507595.598529",
|
||||
"type": "message",
|
||||
"user": "77686F2069732074686520626F73733F",
|
||||
},
|
||||
{
|
||||
"blocks": [
|
||||
{
|
||||
"block_id": "xyz456",
|
||||
"elements": [
|
||||
{
|
||||
"elements": [
|
||||
{
|
||||
"text": "hey, are the Slack Tools ready yet?",
|
||||
"type": "text",
|
||||
}
|
||||
],
|
||||
"type": "rich_text_section",
|
||||
}
|
||||
],
|
||||
"type": "rich_text",
|
||||
}
|
||||
],
|
||||
"client_msg_id": "msg_id_1",
|
||||
"datetime_timestamp": "2025-01-21 16:57:35",
|
||||
"team": "617263616465207465616D20697320617420626F7373206C6576656C",
|
||||
"text": "hey, are the Slack Tools ready yet?",
|
||||
"ts": "1737507595.598529",
|
||||
"type": "message",
|
||||
"user": "73616D2069732074686520626F7373",
|
||||
},
|
||||
],
|
||||
"next_cursor": "dXNlcjpVsDjzOTZGVDlQRA==",
|
||||
}),
|
||||
"tool_call_id": "call_1",
|
||||
"name": "Slack_GetConversationHistoryByName",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": 'Here are the last 2 messages from the general channel:\n\n1. **User:** 77686F2069732074686520626F73733F \n **Message:** "Almost there, Boss, need to get some evals in!" \n **Timestamp:** 2025-01-21 16:59:55\n\n2. **User:** 73616D2069732074686520626F7373 \n **Message:** "hey, are the Slack Tools ready yet?" \n **Timestamp:** 2025-01-21 16:57:35',
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
return suite
|
||||
1020
toolkits/slack/evals/eval_slack_chat.py
Normal file
1020
toolkits/slack/evals/eval_slack_chat.py
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -1,191 +0,0 @@
|
|||
from arcade.sdk import ToolCatalog
|
||||
from arcade.sdk.eval import (
|
||||
BinaryCritic,
|
||||
EvalRubric,
|
||||
EvalSuite,
|
||||
ExpectedToolCall,
|
||||
SimilarityCritic,
|
||||
tool_eval,
|
||||
)
|
||||
|
||||
import arcade_slack
|
||||
from arcade_slack.tools.chat import send_dm_to_user, send_message_to_channel
|
||||
|
||||
# Evaluation rubric
|
||||
rubric = EvalRubric(
|
||||
fail_threshold=0.8,
|
||||
warn_threshold=0.9,
|
||||
)
|
||||
|
||||
|
||||
catalog = ToolCatalog()
|
||||
# Register the Slack tools
|
||||
catalog.add_module(arcade_slack)
|
||||
|
||||
|
||||
@tool_eval()
|
||||
def slack_eval_suite() -> EvalSuite:
|
||||
"""Create an evaluation suite for Slack messaging tools."""
|
||||
suite = EvalSuite(
|
||||
name="Slack Messaging Tools Evaluation",
|
||||
system_message="You are an AI assistant that can send direct messages and post messages to channels in Slack using the provided tools.",
|
||||
catalog=catalog,
|
||||
rubric=rubric,
|
||||
)
|
||||
|
||||
# Send DM to User Scenarios
|
||||
suite.add_case(
|
||||
name="Send DM to user with clear username",
|
||||
user_message="Send a direct message to johndoe saying 'Hello, can we meet at 3 PM?'",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=send_dm_to_user,
|
||||
args={
|
||||
"user_name": "johndoe",
|
||||
"message": "Hello, can we meet at 3 PM?",
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="user_name", weight=0.5),
|
||||
SimilarityCritic(critic_field="message", weight=0.5),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Send DM with ambiguous username",
|
||||
user_message="Message John about the project deadline",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=send_dm_to_user,
|
||||
args={
|
||||
"user_name": "john",
|
||||
"message": "Hi John, I wanted to check about the project deadline. Can you provide an update?",
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="user_name", weight=0.6),
|
||||
SimilarityCritic(critic_field="message", weight=0.4),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Send DM with username in different format",
|
||||
user_message="DM Jane.Doe to reschedule our meeting",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=send_dm_to_user,
|
||||
args={
|
||||
"user_name": "jane.doe",
|
||||
"message": "Hi Jane, I need to reschedule our meeting. When are you available?",
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="user_name", weight=0.5),
|
||||
SimilarityCritic(critic_field="message", weight=0.5),
|
||||
],
|
||||
)
|
||||
|
||||
# Send Message to Channel Scenarios
|
||||
suite.add_case(
|
||||
name="Send message to channel with clear name",
|
||||
user_message="Post 'The new feature is now live!' in the #announcements channel",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=send_message_to_channel,
|
||||
args={
|
||||
"channel_name": "announcements",
|
||||
"message": "The new feature is now live!",
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="channel_name", weight=0.5),
|
||||
SimilarityCritic(critic_field="message", weight=0.5),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Send message to channel with ambiguous name",
|
||||
user_message="Inform the engineering team about the upcoming maintenance in the general channel",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=send_message_to_channel,
|
||||
args={
|
||||
"channel_name": "engineering",
|
||||
"message": "Attention team: There will be upcoming maintenance. Please save your work and expect some downtime.",
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
SimilarityCritic(critic_field="channel_name", weight=0.4),
|
||||
SimilarityCritic(critic_field="message", weight=0.6),
|
||||
],
|
||||
)
|
||||
|
||||
# Adversarial Scenarios
|
||||
suite.add_case(
|
||||
name="Ambiguous between DM and channel message",
|
||||
user_message="Send 'Great job on the presentation!' to the team",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=send_message_to_channel,
|
||||
args={
|
||||
"channel_name": "general",
|
||||
"message": "Great job on the presentation!",
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
SimilarityCritic(critic_field="channel_name", weight=0.4),
|
||||
SimilarityCritic(critic_field="message", weight=0.6),
|
||||
],
|
||||
)
|
||||
|
||||
# Multiple recipients in DM request
|
||||
suite.add_case(
|
||||
name="Multiple recipients in DM request",
|
||||
user_message="Send a DM to Alice and Bob about pushing the meeting tomorrow. I have to much work to do.",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=send_dm_to_user,
|
||||
args={
|
||||
"user_name": "alice",
|
||||
"message": "Hi Alice, about our meeting tomorrow, let's reschedule? I am swamped with work.",
|
||||
},
|
||||
),
|
||||
ExpectedToolCall(
|
||||
func=send_dm_to_user,
|
||||
args={
|
||||
"user_name": "bob",
|
||||
"message": "Hi Bob, about our meeting tomorrow, let's reschedule? I am swamped with work.",
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
SimilarityCritic(critic_field="user_name", weight=0.7),
|
||||
SimilarityCritic(critic_field="message", weight=0.3, similarity_threshold=0.6),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Channel name similar to username",
|
||||
user_message="Post 'sounds great!' in john-project channel",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=send_message_to_channel,
|
||||
args={
|
||||
"channel_name": "john-project",
|
||||
"message": "Sounds great!",
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="channel_name", weight=0.5),
|
||||
SimilarityCritic(critic_field="message", weight=0.5),
|
||||
],
|
||||
)
|
||||
|
||||
return suite
|
||||
168
toolkits/slack/evals/eval_users.py
Normal file
168
toolkits/slack/evals/eval_users.py
Normal file
|
|
@ -0,0 +1,168 @@
|
|||
import json
|
||||
|
||||
from arcade.sdk import ToolCatalog
|
||||
from arcade.sdk.eval import (
|
||||
BinaryCritic,
|
||||
EvalRubric,
|
||||
EvalSuite,
|
||||
ExpectedToolCall,
|
||||
tool_eval,
|
||||
)
|
||||
|
||||
import arcade_slack
|
||||
from arcade_slack.tools.users import get_user_info_by_id, list_users
|
||||
|
||||
# Evaluation rubric
|
||||
rubric = EvalRubric(
|
||||
fail_threshold=0.8,
|
||||
warn_threshold=0.9,
|
||||
)
|
||||
|
||||
|
||||
catalog = ToolCatalog()
|
||||
# Register the Slack tools
|
||||
catalog.add_module(arcade_slack)
|
||||
|
||||
|
||||
@tool_eval()
|
||||
def get_user_info_by_id_eval_suite() -> EvalSuite:
|
||||
"""Create an evaluation suite for tools getting user info by id."""
|
||||
suite = EvalSuite(
|
||||
name="Slack Users Tools Evaluation",
|
||||
system_message="You are an AI assistant that can interact with Slack to get information about users.",
|
||||
catalog=catalog,
|
||||
rubric=rubric,
|
||||
)
|
||||
|
||||
expected_user_id = "U12345"
|
||||
|
||||
get_user_info_by_id_eval_cases = [
|
||||
(
|
||||
"get user info by id",
|
||||
f"What is the name of the user with id {expected_user_id}?",
|
||||
),
|
||||
(
|
||||
"get user info by id",
|
||||
f"get information about the user with id {expected_user_id}",
|
||||
),
|
||||
]
|
||||
|
||||
for name, user_message in get_user_info_by_id_eval_cases:
|
||||
suite.add_case(
|
||||
name=name,
|
||||
user_message=user_message,
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_user_info_by_id,
|
||||
args={"user_id": expected_user_id},
|
||||
)
|
||||
],
|
||||
critics=[BinaryCritic(critic_field="user_id", weight=1.0)],
|
||||
)
|
||||
|
||||
return suite
|
||||
|
||||
|
||||
@tool_eval()
|
||||
def list_users_eval_suite() -> EvalSuite:
|
||||
"""Create an evaluation suite for tools listing users."""
|
||||
suite = EvalSuite(
|
||||
name="Slack Users Tools Evaluation",
|
||||
system_message="You are an AI assistant that can interact with Slack to get information about users.",
|
||||
catalog=catalog,
|
||||
rubric=rubric,
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="list users",
|
||||
user_message="list all users on my slack workspace",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(func=list_users, args={}),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="list users without bots",
|
||||
user_message="list all users on my slack workspace, except bots",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(func=list_users, args={"exclude_bots": True}),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="exclude_bots", weight=1.0),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="list 10 users without bots",
|
||||
user_message="get a list of 10 users on my slack workspace, except bots",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(func=list_users, args={"exclude_bots": True, "limit": 10}),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="exclude_bots", weight=0.5),
|
||||
BinaryCritic(critic_field="limit", weight=0.5),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="test list users with pagination",
|
||||
user_message="get the next 5 users",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=list_users,
|
||||
args={"limit": 5, "next_cursor": "dXNlcjpVsDjzOTZGVDlQRA=="},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="limit", weight=0.5),
|
||||
BinaryCritic(critic_field="next_cursor", weight=0.5),
|
||||
],
|
||||
additional_messages=[
|
||||
{"role": "user", "content": "get a list of 2 users from my slack workspace"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {"name": "Slack_ListUsers", "arguments": '{"limit":2}'},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"content": json.dumps({
|
||||
"next_cursor": "dXNlcjpVsDjzOTZGVDlQRA==",
|
||||
"users": [
|
||||
{
|
||||
"display_name": "John Doe",
|
||||
"email": "john.doe@acme.com",
|
||||
"id": "U123",
|
||||
"is_bot": False,
|
||||
"name": "john.doe",
|
||||
"real_name": "John Doe",
|
||||
"timezone": "America/Los_Angeles",
|
||||
},
|
||||
{
|
||||
"display_name": "Jane Doe",
|
||||
"email": "jane.doe@acme.com",
|
||||
"id": "U124",
|
||||
"is_bot": False,
|
||||
"name": "jane.doe",
|
||||
"real_name": "Jane Doe",
|
||||
"timezone": "America/Los_Angeles",
|
||||
},
|
||||
],
|
||||
}),
|
||||
"tool_call_id": "call_1",
|
||||
"name": "Slack_ListUsers",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Here are two users from your Slack workspace:\n\n1. **John Doe**\n - Display Name: John Doe\n - Email: john.doe@acme.com\n - Timezone: America/Los_Angeles\n\n2. **Jane Doe**\n - Display Name: Jane Doe\n - Email: jane.doe@acme.com\n - Timezone: America/Los_Angeles\n\nIf you need more information or additional users, feel free to ask!",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
return suite
|
||||
|
|
@ -6,8 +6,11 @@ authors = ["Arcade AI <dev@arcade-ai.com>"]
|
|||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.10"
|
||||
aiodns = "^1.0" # required by slack - not picked by poetry due to slack requirements txt quirk :/
|
||||
aiohttp = ">=3.7.3,<4" # same as aiodns, above comment
|
||||
arcade-ai = "0.1.*"
|
||||
slack-sdk = "^3.31.0"
|
||||
typing-extensions = "^4.0.0" # some typing annotations aren't available in older versions of python
|
||||
|
||||
[tool.poetry.dev-dependencies]
|
||||
pytest = "^8.3.0"
|
||||
|
|
|
|||
|
|
@ -1,44 +1,694 @@
|
|||
from unittest.mock import MagicMock, patch
|
||||
import copy
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import Mock, call, patch
|
||||
|
||||
import pytest
|
||||
from arcade.sdk import ToolAuthorizationContext, ToolContext
|
||||
from arcade.sdk.errors import RetryableToolError, ToolExecutionError
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from slack_sdk.web.async_slack_response import AsyncSlackResponse
|
||||
|
||||
from arcade_slack.tools.chat import send_dm_to_user, send_message_to_channel
|
||||
from arcade_slack.constants import MAX_PAGINATION_SIZE_LIMIT
|
||||
from arcade_slack.models import ConversationType, ConversationTypeSlackName
|
||||
from arcade_slack.tools.chat import (
|
||||
get_conversation_metadata_by_id,
|
||||
get_conversation_metadata_by_name,
|
||||
get_members_in_conversation_by_id,
|
||||
get_members_in_conversation_by_name,
|
||||
get_messages_in_channel_by_name,
|
||||
get_messages_in_conversation_by_id,
|
||||
list_conversations_metadata,
|
||||
list_direct_message_conversations_metadata,
|
||||
list_group_direct_message_conversations_metadata,
|
||||
list_private_channels_metadata,
|
||||
list_public_channels_metadata,
|
||||
send_dm_to_user,
|
||||
send_message_to_channel,
|
||||
)
|
||||
from arcade_slack.utils import extract_basic_user_info, extract_conversation_metadata
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_context():
|
||||
mock_auth = ToolAuthorizationContext(token="fake-token") # noqa: S106
|
||||
return ToolContext(authorization=mock_auth)
|
||||
def mock_list_conversations_metadata(mocker):
|
||||
return mocker.patch("arcade_slack.tools.chat.list_conversations_metadata", autospec=True)
|
||||
|
||||
|
||||
def test_send_dm_to_user(mock_context):
|
||||
with patch("arcade_slack.tools.chat.WebClient") as MockWebClient:
|
||||
mock_client = MockWebClient.return_value
|
||||
mock_client.users_list.return_value = {"members": [{"name": "testuser", "id": "U12345"}]}
|
||||
mock_client.conversations_open.return_value = {"channel": {"id": "D12345"}}
|
||||
mock_client.chat_postMessage.return_value = MagicMock(data={"ok": True})
|
||||
|
||||
response = send_dm_to_user(mock_context, "testuser", "Hello!")
|
||||
|
||||
assert response["ok"] is True
|
||||
mock_client.users_list.assert_called_once()
|
||||
mock_client.conversations_open.assert_called_once_with(users=["U12345"])
|
||||
mock_client.chat_postMessage.assert_called_once_with(channel="D12345", text="Hello!")
|
||||
@pytest.fixture
|
||||
def mock_channel_info() -> dict:
|
||||
return {"name": "general", "id": "C12345", "is_member": True, "is_channel": True}
|
||||
|
||||
|
||||
def test_send_message_to_channel(mock_context):
|
||||
with patch("arcade_slack.tools.chat.WebClient") as MockWebClient:
|
||||
mock_client = MockWebClient.return_value
|
||||
mock_client.conversations_list.return_value = {
|
||||
"channels": [{"name": "general", "id": "C12345"}]
|
||||
}
|
||||
mock_client.chat_postMessage.return_value = MagicMock(data={"ok": True})
|
||||
@pytest.fixture
|
||||
def mock_slack_client(mocker):
|
||||
mock_client = mocker.patch("arcade_slack.tools.chat.AsyncWebClient", autospec=True)
|
||||
return mock_client.return_value
|
||||
|
||||
response = send_message_to_channel(mock_context, "general", "Hello, channel!")
|
||||
|
||||
assert response["ok"] is True
|
||||
mock_client.conversations_list.assert_called_once()
|
||||
mock_client.chat_postMessage.assert_called_once_with(
|
||||
channel="C12345", text="Hello, channel!"
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_dm_to_user(mock_context, mock_slack_client):
|
||||
mock_slack_client.users_list.return_value = {
|
||||
"ok": True,
|
||||
"members": [{"name": "testuser", "id": "U12345"}],
|
||||
}
|
||||
mock_slack_client.conversations_open.return_value = {
|
||||
"ok": True,
|
||||
"channel": {"id": "D12345"},
|
||||
}
|
||||
mock_slack_response = Mock(spec=AsyncSlackResponse)
|
||||
mock_slack_response.data = {"ok": True}
|
||||
mock_slack_client.chat_postMessage.return_value = mock_slack_response
|
||||
|
||||
response = await send_dm_to_user(mock_context, "testuser", "Hello!")
|
||||
|
||||
assert response["response"]["ok"] is True
|
||||
mock_slack_client.users_list.assert_called_once()
|
||||
mock_slack_client.conversations_open.assert_called_once_with(users=["U12345"])
|
||||
mock_slack_client.chat_postMessage.assert_called_once_with(channel="D12345", text="Hello!")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_dm_to_inexistent_user(mock_context, mock_slack_client):
|
||||
mock_slack_client.users_list.return_value = {
|
||||
"ok": True,
|
||||
"members": [{"name": "testuser", "id": "U12345"}],
|
||||
}
|
||||
|
||||
with pytest.raises(RetryableToolError):
|
||||
await send_dm_to_user(mock_context, "inexistent_user", "Hello!")
|
||||
|
||||
mock_slack_client.users_list.assert_called_once()
|
||||
mock_slack_client.conversations_open.assert_not_called()
|
||||
mock_slack_client.chat_postMessage.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_to_channel(mock_context, mock_slack_client):
|
||||
mock_slack_client.conversations_list.return_value = {
|
||||
"ok": True,
|
||||
"channels": [{"id": "C12345", "name": "general"}],
|
||||
}
|
||||
mock_slack_response = Mock(spec=AsyncSlackResponse)
|
||||
mock_slack_response.data = {"ok": True}
|
||||
mock_slack_client.chat_postMessage.return_value = mock_slack_response
|
||||
|
||||
response = await send_message_to_channel(mock_context, "general", "Hello, channel!")
|
||||
|
||||
assert response["response"]["ok"] is True
|
||||
mock_slack_client.conversations_list.assert_called_once()
|
||||
mock_slack_client.chat_postMessage.assert_called_once_with(
|
||||
channel="C12345", text="Hello, channel!"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_to_inexistent_channel(mock_context, mock_slack_client):
|
||||
mock_slack_client.conversations_list.return_value = {
|
||||
"ok": True,
|
||||
"channels": [],
|
||||
}
|
||||
|
||||
with pytest.raises(RetryableToolError):
|
||||
await send_message_to_channel(mock_context, "inexistent_channel", "Hello!")
|
||||
|
||||
mock_slack_client.conversations_list.assert_called_once()
|
||||
mock_slack_client.chat_postMessage.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_conversations_metadata_with_default_args(
|
||||
mock_context, mock_slack_client, mock_channel_info
|
||||
):
|
||||
mock_slack_client.conversations_list.return_value = {
|
||||
"ok": True,
|
||||
"channels": [mock_channel_info],
|
||||
}
|
||||
|
||||
response = await list_conversations_metadata(mock_context)
|
||||
|
||||
assert response["conversations"] == [extract_conversation_metadata(mock_channel_info)]
|
||||
assert response["next_cursor"] is None
|
||||
|
||||
mock_slack_client.conversations_list.assert_called_once_with(
|
||||
types=",".join([conv_type.value for conv_type in ConversationTypeSlackName]),
|
||||
exclude_archived=True,
|
||||
limit=MAX_PAGINATION_SIZE_LIMIT,
|
||||
cursor=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_conversations_metadata_filtering_single_conversation_type(
|
||||
mock_context, mock_slack_client, mock_channel_info
|
||||
):
|
||||
mock_slack_client.conversations_list.return_value = {
|
||||
"ok": True,
|
||||
"channels": [mock_channel_info],
|
||||
}
|
||||
|
||||
response = await list_conversations_metadata(
|
||||
mock_context, conversation_types=ConversationType.PUBLIC_CHANNEL
|
||||
)
|
||||
|
||||
assert response["conversations"] == [extract_conversation_metadata(mock_channel_info)]
|
||||
assert response["next_cursor"] is None
|
||||
|
||||
mock_slack_client.conversations_list.assert_called_once_with(
|
||||
types=ConversationTypeSlackName.PUBLIC_CHANNEL.value,
|
||||
exclude_archived=True,
|
||||
limit=MAX_PAGINATION_SIZE_LIMIT,
|
||||
cursor=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_conversations_metadata_filtering_multiple_conversation_types(
|
||||
mock_context, mock_slack_client, mock_channel_info
|
||||
):
|
||||
mock_slack_client.conversations_list.return_value = {
|
||||
"ok": True,
|
||||
"channels": [mock_channel_info],
|
||||
}
|
||||
|
||||
response = await list_conversations_metadata(
|
||||
mock_context,
|
||||
conversation_types=[
|
||||
ConversationTypeSlackName.PUBLIC_CHANNEL,
|
||||
ConversationTypeSlackName.PRIVATE_CHANNEL,
|
||||
],
|
||||
)
|
||||
|
||||
assert response["conversations"] == [extract_conversation_metadata(mock_channel_info)]
|
||||
assert response["next_cursor"] is None
|
||||
|
||||
mock_slack_client.conversations_list.assert_called_once_with(
|
||||
types=f"{ConversationTypeSlackName.PUBLIC_CHANNEL.value},{ConversationTypeSlackName.PRIVATE_CHANNEL.value}",
|
||||
exclude_archived=True,
|
||||
limit=MAX_PAGINATION_SIZE_LIMIT,
|
||||
cursor=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_conversations_metadata_with_custom_pagination_args(
|
||||
mock_context, mock_slack_client, mock_channel_info
|
||||
):
|
||||
mock_slack_client.conversations_list.return_value = {
|
||||
"ok": True,
|
||||
"channels": [mock_channel_info] * 3,
|
||||
"response_metadata": {"next_cursor": "456"},
|
||||
}
|
||||
|
||||
response = await list_conversations_metadata(mock_context, limit=3, next_cursor="123")
|
||||
|
||||
assert response["conversations"] == [
|
||||
extract_conversation_metadata(mock_channel_info) for _ in range(3)
|
||||
]
|
||||
assert response["next_cursor"] == "456"
|
||||
|
||||
mock_slack_client.conversations_list.assert_called_once_with(
|
||||
types=",".join([conv_type.value for conv_type in ConversationTypeSlackName]),
|
||||
exclude_archived=True,
|
||||
limit=3,
|
||||
cursor="123",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"faulty_slack_function_name, tool_function, tool_args",
|
||||
[
|
||||
("users_list", send_dm_to_user, ("testuser", "Hello!")),
|
||||
("conversations_list", send_message_to_channel, ("general", "Hello!")),
|
||||
],
|
||||
)
|
||||
async def test_tools_with_slack_error(
|
||||
mock_context, mock_slack_client, faulty_slack_function_name, tool_function, tool_args
|
||||
):
|
||||
getattr(mock_slack_client, faulty_slack_function_name).side_effect = SlackApiError(
|
||||
message="test_slack_error",
|
||||
response={"ok": False, "error": "test_slack_error"},
|
||||
)
|
||||
|
||||
with pytest.raises(ToolExecutionError) as e:
|
||||
await tool_function(mock_context, *tool_args)
|
||||
assert "test_slack_error" in str(e.value)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"tool_function, conversation_type",
|
||||
[
|
||||
(list_public_channels_metadata, ConversationType.PUBLIC_CHANNEL),
|
||||
(list_private_channels_metadata, ConversationType.PRIVATE_CHANNEL),
|
||||
(
|
||||
list_group_direct_message_conversations_metadata,
|
||||
ConversationType.MULTI_PERSON_DIRECT_MESSAGE,
|
||||
),
|
||||
(list_direct_message_conversations_metadata, ConversationType.DIRECT_MESSAGE),
|
||||
],
|
||||
)
|
||||
async def test_list_channels_metadata(
|
||||
mock_context,
|
||||
mock_list_conversations_metadata,
|
||||
tool_function,
|
||||
conversation_type,
|
||||
):
|
||||
response = await tool_function(mock_context, limit=3)
|
||||
|
||||
mock_list_conversations_metadata.assert_called_once_with(
|
||||
mock_context, conversation_types=[conversation_type], limit=3
|
||||
)
|
||||
|
||||
assert response == mock_list_conversations_metadata.return_value
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_conversation_metadata_by_id(mock_context, mock_slack_client, mock_channel_info):
|
||||
mock_slack_client.conversations_info.return_value = {
|
||||
"ok": True,
|
||||
"channel": mock_channel_info,
|
||||
}
|
||||
|
||||
response = await get_conversation_metadata_by_id(mock_context, "C12345")
|
||||
|
||||
assert response == extract_conversation_metadata(mock_channel_info)
|
||||
mock_slack_client.conversations_info.assert_called_once_with(
|
||||
channel="C12345",
|
||||
include_locale=True,
|
||||
include_num_members=True,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("arcade_slack.tools.chat.list_conversations_metadata")
|
||||
async def test_get_conversation_metadata_by_id_slack_api_error(
|
||||
mock_list_conversations_metadata, mock_context, mock_slack_client, mock_channel_info
|
||||
):
|
||||
mock_channel_info["name"] = "whatever_conversation_should_be_present_in_additional_prompt"
|
||||
mock_list_conversations_metadata.return_value = {
|
||||
"conversations": [extract_conversation_metadata(mock_channel_info)],
|
||||
"response_metadata": {"next_cursor": None},
|
||||
}
|
||||
mock_slack_client.conversations_info.side_effect = SlackApiError(
|
||||
message="channel_not_found",
|
||||
response={"ok": False, "error": "channel_not_found"},
|
||||
)
|
||||
|
||||
with pytest.raises(RetryableToolError) as e:
|
||||
await get_conversation_metadata_by_id(mock_context, "C12345")
|
||||
|
||||
assert (
|
||||
"whatever_conversation_should_be_present_in_additional_prompt"
|
||||
in e.additional_prompt_content
|
||||
)
|
||||
|
||||
mock_slack_client.conversations_info.assert_called_once_with(
|
||||
channel="C12345",
|
||||
include_locale=True,
|
||||
include_num_members=True,
|
||||
)
|
||||
mock_list_conversations_metadata.assert_called_once_with(mock_context)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_conversation_metadata_by_name(
|
||||
mock_context, mock_list_conversations_metadata, mock_channel_info
|
||||
):
|
||||
sample_conversation = extract_conversation_metadata(mock_channel_info)
|
||||
mock_list_conversations_metadata.return_value = {
|
||||
"conversations": [sample_conversation],
|
||||
"response_metadata": {"next_cursor": None},
|
||||
}
|
||||
|
||||
response = await get_conversation_metadata_by_name(mock_context, sample_conversation["name"])
|
||||
|
||||
assert response == sample_conversation
|
||||
mock_list_conversations_metadata.assert_called_once_with(mock_context, next_cursor=None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_conversation_metadata_by_name_triggering_pagination(
|
||||
mock_context, mock_list_conversations_metadata, mock_channel_info
|
||||
):
|
||||
target_conversation = extract_conversation_metadata(mock_channel_info)
|
||||
another_conversation = extract_conversation_metadata(mock_channel_info)
|
||||
another_conversation["name"] = "another_conversation"
|
||||
|
||||
mock_list_conversations_metadata.side_effect = [
|
||||
{
|
||||
"conversations": [another_conversation],
|
||||
"response_metadata": {"next_cursor": "123"},
|
||||
},
|
||||
{
|
||||
"conversations": [target_conversation],
|
||||
"response_metadata": {"next_cursor": None},
|
||||
},
|
||||
]
|
||||
|
||||
response = await get_conversation_metadata_by_name(mock_context, target_conversation["name"])
|
||||
|
||||
assert response == target_conversation
|
||||
assert mock_list_conversations_metadata.call_count == 2
|
||||
mock_list_conversations_metadata.assert_has_calls([
|
||||
call(mock_context, next_cursor=None),
|
||||
call(mock_context, next_cursor="123"),
|
||||
])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_conversation_metadata_by_name_not_found(
|
||||
mock_context, mock_list_conversations_metadata, mock_channel_info
|
||||
):
|
||||
first_conversation = extract_conversation_metadata(mock_channel_info)
|
||||
second_conversation = extract_conversation_metadata(mock_channel_info)
|
||||
second_conversation["name"] = "second_conversation"
|
||||
|
||||
mock_list_conversations_metadata.side_effect = [
|
||||
{
|
||||
"conversations": [second_conversation],
|
||||
"response_metadata": {"next_cursor": "123"},
|
||||
},
|
||||
{
|
||||
"conversations": [first_conversation],
|
||||
"response_metadata": {"next_cursor": None},
|
||||
},
|
||||
]
|
||||
|
||||
with pytest.raises(RetryableToolError):
|
||||
await get_conversation_metadata_by_name(mock_context, "inexistent_conversation")
|
||||
|
||||
assert mock_list_conversations_metadata.call_count == 2
|
||||
mock_list_conversations_metadata.assert_has_calls([
|
||||
call(mock_context, next_cursor=None),
|
||||
call(mock_context, next_cursor="123"),
|
||||
])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("arcade_slack.tools.chat.async_paginate")
|
||||
@patch("arcade_slack.tools.chat.get_user_info_by_id")
|
||||
async def test_get_members_from_conversation_id(
|
||||
mock_get_user_info_by_id, mock_async_paginate, mock_context, mock_slack_client
|
||||
):
|
||||
member1 = {"id": "U123", "name": "testuser123"}
|
||||
member1_info = extract_basic_user_info(member1)
|
||||
member2 = {"id": "U456", "name": "testuser456"}
|
||||
member2_info = extract_basic_user_info(member2)
|
||||
|
||||
mock_async_paginate.return_value = [member1["id"], member2["id"]], "token123"
|
||||
mock_get_user_info_by_id.side_effect = [member1_info, member2_info]
|
||||
|
||||
response = await get_members_in_conversation_by_id(
|
||||
mock_context, conversation_id="C12345", limit=2
|
||||
)
|
||||
|
||||
assert response == {
|
||||
"members": [member1_info, member2_info],
|
||||
"next_cursor": "token123",
|
||||
}
|
||||
mock_async_paginate.assert_called_once_with(
|
||||
mock_slack_client.conversations_members,
|
||||
"members",
|
||||
limit=2,
|
||||
next_cursor=None,
|
||||
channel="C12345",
|
||||
)
|
||||
mock_get_user_info_by_id.assert_has_calls([
|
||||
call(mock_context, member1["id"]),
|
||||
call(mock_context, member2["id"]),
|
||||
])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("arcade_slack.tools.chat.async_paginate")
|
||||
@patch("arcade_slack.tools.chat.get_user_info_by_id")
|
||||
@patch("arcade_slack.tools.chat.list_conversations_metadata")
|
||||
async def test_get_members_from_conversation_id_channel_not_found(
|
||||
mock_list_conversations_metadata,
|
||||
mock_get_user_info_by_id,
|
||||
mock_async_paginate,
|
||||
mock_context,
|
||||
mock_slack_client,
|
||||
mock_channel_info,
|
||||
):
|
||||
conversations = [extract_conversation_metadata(mock_channel_info)] * 2
|
||||
mock_list_conversations_metadata.return_value = {
|
||||
"conversations": conversations,
|
||||
"next_cursor": None,
|
||||
}
|
||||
|
||||
member1 = {"id": "U123", "name": "testuser123"}
|
||||
member1_info = extract_basic_user_info(member1)
|
||||
member2 = {"id": "U456", "name": "testuser456"}
|
||||
member2_info = extract_basic_user_info(member2)
|
||||
|
||||
mock_async_paginate.side_effect = SlackApiError(
|
||||
message="channel_not_found",
|
||||
response={"ok": False, "error": "channel_not_found"},
|
||||
)
|
||||
mock_get_user_info_by_id.side_effect = [member1_info, member2_info]
|
||||
|
||||
with pytest.raises(RetryableToolError):
|
||||
await get_members_in_conversation_by_id(mock_context, conversation_id="C12345", limit=2)
|
||||
|
||||
mock_async_paginate.assert_called_once_with(
|
||||
mock_slack_client.conversations_members,
|
||||
"members",
|
||||
limit=2,
|
||||
next_cursor=None,
|
||||
channel="C12345",
|
||||
)
|
||||
mock_get_user_info_by_id.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("arcade_slack.tools.chat.list_conversations_metadata")
|
||||
@patch("arcade_slack.tools.chat.get_members_in_conversation_by_id")
|
||||
async def test_get_members_in_conversation_by_name(
|
||||
mock_get_members_in_conversation_by_id,
|
||||
mock_list_conversations_metadata,
|
||||
mock_context,
|
||||
mock_channel_info,
|
||||
):
|
||||
mock_list_conversations_metadata.return_value = {
|
||||
"conversations": [extract_conversation_metadata(mock_channel_info)],
|
||||
"next_cursor": None,
|
||||
}
|
||||
|
||||
response = await get_members_in_conversation_by_name(
|
||||
mock_context, mock_channel_info["name"], limit=2
|
||||
)
|
||||
|
||||
assert response == mock_get_members_in_conversation_by_id.return_value
|
||||
mock_list_conversations_metadata.assert_called_once_with(mock_context, next_cursor=None)
|
||||
mock_get_members_in_conversation_by_id.assert_called_once_with(
|
||||
context=mock_context,
|
||||
conversation_id="C12345",
|
||||
limit=2,
|
||||
next_cursor=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("arcade_slack.tools.chat.list_conversations_metadata")
|
||||
@patch("arcade_slack.tools.chat.get_members_in_conversation_by_id")
|
||||
async def test_get_members_in_conversation_by_name_triggering_pagination(
|
||||
mock_get_members_in_conversation_by_id,
|
||||
mock_list_conversations_metadata,
|
||||
mock_context,
|
||||
mock_channel_info,
|
||||
):
|
||||
conversation1 = copy.deepcopy(mock_channel_info)
|
||||
conversation1["name"] = "conversation1"
|
||||
conversation2 = copy.deepcopy(mock_channel_info)
|
||||
conversation2["name"] = "conversation2"
|
||||
|
||||
mock_list_conversations_metadata.side_effect = [
|
||||
{
|
||||
"conversations": [extract_conversation_metadata(conversation1)],
|
||||
"response_metadata": {"next_cursor": "123"},
|
||||
},
|
||||
{
|
||||
"conversations": [extract_conversation_metadata(conversation2)],
|
||||
"response_metadata": {"next_cursor": None},
|
||||
},
|
||||
]
|
||||
|
||||
response = await get_members_in_conversation_by_name(
|
||||
mock_context, conversation2["name"], limit=2
|
||||
)
|
||||
|
||||
assert response == mock_get_members_in_conversation_by_id.return_value
|
||||
mock_list_conversations_metadata.assert_has_calls([
|
||||
call(mock_context, next_cursor=None),
|
||||
call(mock_context, next_cursor="123"),
|
||||
])
|
||||
mock_get_members_in_conversation_by_id.assert_called_once_with(
|
||||
context=mock_context,
|
||||
conversation_id="C12345",
|
||||
limit=2,
|
||||
next_cursor=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_conversation_history_by_id(mock_context, mock_slack_client):
|
||||
mock_slack_client.conversations_history.return_value = {
|
||||
"ok": True,
|
||||
"messages": [{"text": "Hello, world!"}],
|
||||
}
|
||||
|
||||
response = await get_messages_in_conversation_by_id(mock_context, "C12345", limit=1)
|
||||
|
||||
assert response == {"messages": [{"text": "Hello, world!"}], "next_cursor": None}
|
||||
mock_slack_client.conversations_history.assert_called_once_with(
|
||||
channel="C12345",
|
||||
include_all_metadata=True,
|
||||
inclusive=True,
|
||||
limit=1,
|
||||
cursor=None,
|
||||
)
|
||||
|
||||
|
||||
# TODO: pass a current unix timestamp to the tool, instead of mocking the datetime
|
||||
# conversion. Have to wait until arcade.core.annotations.Inferrable is implemented.
|
||||
@pytest.mark.asyncio
|
||||
@patch("arcade_slack.tools.chat.convert_relative_datetime_to_unix_timestamp")
|
||||
@patch("arcade_slack.tools.chat.datetime")
|
||||
async def test_get_conversation_history_by_id_with_relative_datetime_args(
|
||||
mock_datetime, mock_convert_relative_datetime_to_unix_timestamp, mock_context, mock_slack_client
|
||||
):
|
||||
mock_slack_client.conversations_history.return_value = {
|
||||
"ok": True,
|
||||
"messages": [{"text": "Hello, world!"}],
|
||||
}
|
||||
|
||||
expected_oldest_timestamp = 1716489600
|
||||
expected_latest_timestamp = 1716403200
|
||||
|
||||
# Ideally we'd pass the current unix timestamp to the function, instead of mocking, but
|
||||
# currently there's no way to have a tool argument that is not exposed to the LLM. We
|
||||
# should have that soon, though.
|
||||
mock_datetime.now.return_value = datetime(2025, 1, 1, 0, 0, 0, tzinfo=timezone.utc)
|
||||
expected_current_unix_timestamp = int(mock_datetime.now.return_value.timestamp())
|
||||
mock_convert_relative_datetime_to_unix_timestamp.side_effect = [
|
||||
expected_latest_timestamp,
|
||||
expected_oldest_timestamp,
|
||||
]
|
||||
|
||||
response = await get_messages_in_conversation_by_id(
|
||||
mock_context, "C12345", oldest_relative="02:00:00", latest_relative="01:00:00", limit=1
|
||||
)
|
||||
|
||||
assert response == {"messages": [{"text": "Hello, world!"}], "next_cursor": None}
|
||||
mock_convert_relative_datetime_to_unix_timestamp.assert_has_calls([
|
||||
call("01:00:00", expected_current_unix_timestamp),
|
||||
call("02:00:00", expected_current_unix_timestamp),
|
||||
])
|
||||
mock_slack_client.conversations_history.assert_called_once_with(
|
||||
channel="C12345",
|
||||
include_all_metadata=True,
|
||||
inclusive=True,
|
||||
limit=1,
|
||||
cursor=None,
|
||||
oldest=expected_oldest_timestamp,
|
||||
latest=expected_latest_timestamp,
|
||||
)
|
||||
|
||||
|
||||
# TODO: pass a current unix timestamp to the tool, instead of mocking the datetime
|
||||
# conversion. Have to wait until arcade.core.annotations.Inferrable is implemented.
|
||||
@pytest.mark.asyncio
|
||||
@patch("arcade_slack.tools.chat.convert_datetime_to_unix_timestamp")
|
||||
async def test_get_conversation_history_by_id_with_absolute_datetime_args(
|
||||
mock_convert_datetime_to_unix_timestamp, mock_context, mock_slack_client
|
||||
):
|
||||
mock_slack_client.conversations_history.return_value = {
|
||||
"ok": True,
|
||||
"messages": [{"text": "Hello, world!"}],
|
||||
}
|
||||
|
||||
expected_latest_timestamp = 1716403200
|
||||
expected_oldest_timestamp = 1716489600
|
||||
|
||||
# Ideally we'd pass the current unix timestamp to the function, instead of mocking, but
|
||||
# currently there's no way to have a tool argument that is not exposed to the LLM. We
|
||||
# should have that soon, though.
|
||||
mock_convert_datetime_to_unix_timestamp.side_effect = [
|
||||
expected_latest_timestamp,
|
||||
expected_oldest_timestamp,
|
||||
]
|
||||
|
||||
response = await get_messages_in_conversation_by_id(
|
||||
mock_context,
|
||||
"C12345",
|
||||
oldest_datetime="2025-01-01 00:00:00",
|
||||
latest_datetime="2025-01-02 00:00:00",
|
||||
limit=1,
|
||||
)
|
||||
|
||||
assert response == {"messages": [{"text": "Hello, world!"}], "next_cursor": None}
|
||||
mock_convert_datetime_to_unix_timestamp.assert_has_calls([
|
||||
call("2025-01-02 00:00:00"),
|
||||
call("2025-01-01 00:00:00"),
|
||||
])
|
||||
mock_slack_client.conversations_history.assert_called_once_with(
|
||||
channel="C12345",
|
||||
include_all_metadata=True,
|
||||
inclusive=True,
|
||||
limit=1,
|
||||
cursor=None,
|
||||
oldest=expected_oldest_timestamp,
|
||||
latest=expected_latest_timestamp,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_conversation_history_by_id_with_messed_oldest_args(
|
||||
mock_context, mock_slack_client
|
||||
):
|
||||
with pytest.raises(ToolExecutionError):
|
||||
await get_messages_in_conversation_by_id(
|
||||
mock_context,
|
||||
"C12345",
|
||||
oldest_datetime="2025-01-01 00:00:00",
|
||||
oldest_relative="01:00:00",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_conversation_history_by_id_with_messed_latest_args(
|
||||
mock_context, mock_slack_client
|
||||
):
|
||||
with pytest.raises(ToolExecutionError):
|
||||
await get_messages_in_conversation_by_id(
|
||||
mock_context,
|
||||
"C12345",
|
||||
latest_datetime="2025-01-01 00:00:00",
|
||||
latest_relative="01:00:00",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_conversation_history_by_name(mock_context, mock_slack_client):
|
||||
mock_slack_client.conversations_list.return_value = {
|
||||
"ok": True,
|
||||
"channels": [
|
||||
{
|
||||
"id": "C12345",
|
||||
"name": "general",
|
||||
"is_member": True,
|
||||
"is_channel": True,
|
||||
}
|
||||
],
|
||||
}
|
||||
mock_slack_client.conversations_history.return_value = {
|
||||
"ok": True,
|
||||
"messages": [{"text": "Hello, world!"}],
|
||||
}
|
||||
|
||||
response = await get_messages_in_channel_by_name(mock_context, "general", limit=1)
|
||||
|
||||
assert response == {"messages": [{"text": "Hello, world!"}], "next_cursor": None}
|
||||
mock_slack_client.conversations_history.assert_called_once_with(
|
||||
channel="C12345", include_all_metadata=True, inclusive=True, limit=1, cursor=None
|
||||
)
|
||||
|
|
|
|||
92
toolkits/slack/tests/test_users.py
Normal file
92
toolkits/slack/tests/test_users.py
Normal file
|
|
@ -0,0 +1,92 @@
|
|||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from arcade.sdk.errors import RetryableToolError
|
||||
from slack_sdk.errors import SlackApiError
|
||||
|
||||
from arcade_slack.tools.users import get_user_info_by_id, list_users
|
||||
from arcade_slack.utils import extract_basic_user_info
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_slack_client(mocker):
|
||||
mock_client = mocker.patch("arcade_slack.tools.users.AsyncWebClient", autospec=True)
|
||||
return mock_client.return_value
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_info_by_id_success(mock_context, mock_slack_client):
|
||||
# Mock the response from slackClient.users_info
|
||||
mock_user = {
|
||||
"id": "U12345",
|
||||
"name": "testuser",
|
||||
"real_name": "Test User",
|
||||
"profile": {"email": "testuser@example.com"},
|
||||
}
|
||||
mock_slack_client.users_info.return_value = {"ok": True, "user": mock_user}
|
||||
|
||||
# Call the function
|
||||
response = await get_user_info_by_id(mock_context, user_id="U12345")
|
||||
|
||||
# Verify that the correct Slack API method was called
|
||||
mock_slack_client.users_info.assert_called_once_with(user="U12345")
|
||||
|
||||
# Verify the response
|
||||
expected_response = extract_basic_user_info(mock_user)
|
||||
assert response == expected_response
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("arcade_slack.tools.users.list_users")
|
||||
async def test_get_user_info_by_id_user_not_found(mock_list_users, mock_context, mock_slack_client):
|
||||
error_response = {"ok": False, "error": "user_not_found"}
|
||||
mock_slack_client.users_info.side_effect = SlackApiError(
|
||||
message="User not found",
|
||||
response=error_response,
|
||||
)
|
||||
|
||||
existing_user = {"id": "U12345", "name": "testuser"}
|
||||
mock_list_users.return_value = {"users": [existing_user]}
|
||||
|
||||
with pytest.raises(RetryableToolError) as e:
|
||||
await get_user_info_by_id(mock_context, user_id="U99999")
|
||||
|
||||
assert existing_user["id"] in e.value.additional_prompt_content
|
||||
assert existing_user["name"] in e.value.additional_prompt_content
|
||||
|
||||
mock_slack_client.users_info.assert_called_once_with(user="U99999")
|
||||
mock_list_users.assert_called_once_with(mock_context)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_users_success(mock_context, mock_slack_client):
|
||||
mock_slack_client.users_list.return_value = {"ok": True, "members": [{"id": "U12345"}]}
|
||||
response = await list_users(mock_context)
|
||||
assert response == {
|
||||
"users": [extract_basic_user_info({"id": "U12345"})],
|
||||
"next_cursor": None,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_users_with_pagination_success(mock_context, mock_slack_client):
|
||||
mock_slack_client.users_list.side_effect = [
|
||||
{
|
||||
"ok": True,
|
||||
"members": [{"id": "U12345"}],
|
||||
"response_metadata": {"next_cursor": "cursor_xyz"},
|
||||
},
|
||||
{
|
||||
"ok": True,
|
||||
"members": [{"id": "U123456"}],
|
||||
"response_metadata": {"next_cursor": None},
|
||||
},
|
||||
]
|
||||
response = await list_users(mock_context, limit=3)
|
||||
assert response == {
|
||||
"users": [
|
||||
extract_basic_user_info({"id": "U12345"}),
|
||||
extract_basic_user_info({"id": "U123456"}),
|
||||
],
|
||||
"next_cursor": None,
|
||||
}
|
||||
187
toolkits/slack/tests/test_utils.py
Normal file
187
toolkits/slack/tests/test_utils.py
Normal file
|
|
@ -0,0 +1,187 @@
|
|||
import asyncio
|
||||
from unittest.mock import AsyncMock, call, patch
|
||||
|
||||
import pytest
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from slack_sdk.web.async_client import AsyncWebClient
|
||||
|
||||
from arcade_slack.exceptions import PaginationTimeoutError
|
||||
from arcade_slack.utils import async_paginate
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_paginate():
|
||||
mock_slack_client = AsyncMock()
|
||||
mock_slack_client.conversations_list.return_value = {
|
||||
"ok": True,
|
||||
"channels": [{"id": "123"}],
|
||||
"response_metadata": {"next_cursor": None},
|
||||
}
|
||||
|
||||
results, next_cursor = await async_paginate(
|
||||
func=mock_slack_client.conversations_list,
|
||||
response_key="channels",
|
||||
)
|
||||
|
||||
assert results == [{"id": "123"}]
|
||||
assert next_cursor is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_paginate_with_response_error():
|
||||
mock_slack_client = AsyncMock()
|
||||
mock_slack_client.conversations_list.side_effect = SlackApiError(
|
||||
message="slack_error",
|
||||
response={"ok": False, "error": "slack_error"},
|
||||
)
|
||||
|
||||
with pytest.raises(SlackApiError) as e:
|
||||
await async_paginate(
|
||||
func=mock_slack_client.conversations_list,
|
||||
response_key="channels",
|
||||
)
|
||||
assert str(e.value) == "slack_error"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_paginate_with_custom_pagination_args():
|
||||
mock_slack_client = AsyncMock()
|
||||
mock_slack_client.conversations_list.return_value = {
|
||||
"ok": True,
|
||||
"channels": [{"id": "123"}],
|
||||
"response_metadata": {"next_cursor": "456"},
|
||||
}
|
||||
|
||||
results, next_cursor = await async_paginate(
|
||||
func=mock_slack_client.conversations_list,
|
||||
response_key="channels",
|
||||
limit=1,
|
||||
next_cursor="123",
|
||||
hello="world",
|
||||
)
|
||||
|
||||
assert results == [{"id": "123"}]
|
||||
assert next_cursor == "456"
|
||||
|
||||
mock_slack_client.conversations_list.assert_called_once_with(
|
||||
hello="world",
|
||||
limit=1,
|
||||
cursor="123",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"test_limit, last_next_cursor, last_expected_limit",
|
||||
[(5, "cursor3", 1), (None, None, 2)],
|
||||
)
|
||||
async def test_async_paginate_large_limit(test_limit, last_next_cursor, last_expected_limit):
|
||||
mock_slack_client = AsyncMock(spec=AsyncWebClient)
|
||||
mock_slack_client.conversations_list.side_effect = [
|
||||
{
|
||||
"ok": True,
|
||||
"channels": [{"id": "channel1"}, {"id": "channel2"}],
|
||||
"response_metadata": {"next_cursor": "cursor1"},
|
||||
},
|
||||
{
|
||||
"ok": True,
|
||||
"channels": [{"id": "channel3"}, {"id": "channel4"}],
|
||||
"response_metadata": {"next_cursor": "cursor2"},
|
||||
},
|
||||
{
|
||||
"ok": True,
|
||||
"channels": [{"id": "channel5"}],
|
||||
"response_metadata": {"next_cursor": last_next_cursor},
|
||||
},
|
||||
]
|
||||
|
||||
with patch("arcade_slack.utils.MAX_PAGINATION_SIZE_LIMIT", 2):
|
||||
results, next_cursor = await async_paginate(
|
||||
func=mock_slack_client.conversations_list,
|
||||
response_key="channels",
|
||||
limit=test_limit,
|
||||
hello="world",
|
||||
)
|
||||
|
||||
assert results == [
|
||||
{"id": "channel1"},
|
||||
{"id": "channel2"},
|
||||
{"id": "channel3"},
|
||||
{"id": "channel4"},
|
||||
{"id": "channel5"},
|
||||
]
|
||||
assert next_cursor == last_next_cursor
|
||||
assert mock_slack_client.conversations_list.call_count == 3
|
||||
mock_slack_client.conversations_list.assert_has_calls([
|
||||
call(hello="world", limit=2, cursor=None),
|
||||
call(hello="world", limit=2, cursor="cursor1"),
|
||||
call(hello="world", limit=last_expected_limit, cursor="cursor2"),
|
||||
])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_paginate_large_limit_with_response_error():
|
||||
mock_slack_client = AsyncMock()
|
||||
mock_slack_client.conversations_list.side_effect = [
|
||||
{
|
||||
"ok": True,
|
||||
"channels": [{"id": "channel1"}, {"id": "channel2"}],
|
||||
"response_metadata": {"next_cursor": "cursor1"},
|
||||
},
|
||||
SlackApiError(message="slack_error", response={"ok": False, "error": "slack_error"}),
|
||||
{
|
||||
"ok": True,
|
||||
"channels": [{"id": "channel5"}],
|
||||
"response_metadata": {"next_cursor": "cursor3"},
|
||||
},
|
||||
]
|
||||
|
||||
with (
|
||||
patch("arcade_slack.utils.MAX_PAGINATION_SIZE_LIMIT", 2),
|
||||
pytest.raises(SlackApiError) as e,
|
||||
):
|
||||
await async_paginate(
|
||||
func=mock_slack_client.conversations_list,
|
||||
response_key="channels",
|
||||
limit=5,
|
||||
hello="world",
|
||||
)
|
||||
assert str(e.value) == "slack_error"
|
||||
|
||||
assert mock_slack_client.conversations_list.call_count == 2
|
||||
mock_slack_client.conversations_list.assert_has_calls([
|
||||
call(hello="world", limit=2, cursor=None),
|
||||
call(hello="world", limit=2, cursor="cursor1"),
|
||||
])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_paginate_with_timeout():
|
||||
# Mock Slack client
|
||||
mock_slack_client = AsyncMock()
|
||||
|
||||
# Simulate a network delay by making the mock function sleep
|
||||
async def mock_conversations_list(*args, **kwargs):
|
||||
await asyncio.sleep(1) # Sleep for 1 second to simulate delay
|
||||
return {
|
||||
"ok": True,
|
||||
"channels": [{"id": "123"}],
|
||||
"response_metadata": {"next_cursor": None},
|
||||
}
|
||||
|
||||
mock_slack_client.conversations_list.side_effect = mock_conversations_list
|
||||
|
||||
# Set a low timeout to trigger the timeout error quickly during the test
|
||||
max_pagination_timeout_seconds = 0.1 # 100 milliseconds
|
||||
|
||||
with pytest.raises(PaginationTimeoutError) as exc_info:
|
||||
await async_paginate(
|
||||
func=mock_slack_client.conversations_list,
|
||||
response_key="channels",
|
||||
max_pagination_timeout_seconds=max_pagination_timeout_seconds,
|
||||
)
|
||||
|
||||
assert (
|
||||
str(exc_info.value)
|
||||
== f"The pagination process timed out after {max_pagination_timeout_seconds} seconds."
|
||||
)
|
||||
Loading…
Reference in a new issue