From ed89af4b4d6671a915de59d0b72a80283eb3ed89 Mon Sep 17 00:00:00 2001 From: Renato Byrro Date: Thu, 10 Jul 2025 13:50:28 -0300 Subject: [PATCH] Slack Toolkit Refactoring (#453) # Backwards-compatible refactoring of the Slack toolkit Several performance improvements, streamlined tool set, and easier to understand tool interfaces. All "old" tools were kept for backwards compatibility, with the same interfaces and response structure (but using the new and more performant tools under the hood). Full revision of unit-tests and evals. ## Streamlined tool set Multiple groups of tools were streamlined into a single one: * "get conversation metadata" from 5 tools to one; * "send message" from 2 tools to one; * "get users in conversation" from 3 tools to one; * "get messages" from 4 tools to one ## New capabilities * Messages retrieved are now populated with the users' names, apart from ID: makes it easier for LLMs to reference who sent a message, were mentioned, or reacted to a message * Retrieve users by username, email, and/or ID (before we only supported ID) * Retrieve multiple users in a single tool call ## Concurrency controls All operations issuing multiple requests concurrently now have a `Semaphore` to limit the concurrency level. The limit can be controlled through the `SLACK_MAX_CONCURRENT_REQUESTS` env var (defaults to 3). ## Networking performance improvement Various operations that used to make multiple API calls are now executed more efficiently: ### Find users by username * Before: a full scan of `users_list` was required (potentially multiple pages for large workspaces); * Now it stops as soon as we have all users needed (yes, it was dumb before) ### Get multiple users by their IDs * Before: for each user ID, we made one API call to the `users_info` endpoint * Now: we call `list_users` and scan the results to match the user IDs (an estimate of 99.5% of Slack workspaces have < 200 users; for large workspaces, we may need to paginate `list_users`) ### Get a conversation by its users * Before: * Call to `list_conversations` (potentially paginating) * For each conversation, one call to `conversations_members` (potentially paginating) * Then loop and find which conversation matches the users' IDs * Now: * A single call to `conversations_open` --- .../google/arcade_google/tools/file_picker.py | 2 +- toolkits/slack/arcade_slack/constants.py | 9 +- .../arcade_slack/conversation_retrieval.py | 74 + toolkits/slack/arcade_slack/custom_types.py | 8 +- toolkits/slack/arcade_slack/exceptions.py | 20 - .../slack/arcade_slack/message_retrieval.py | 76 + toolkits/slack/arcade_slack/models.py | 166 +- toolkits/slack/arcade_slack/tools/chat.py | 1249 ++++++++------- toolkits/slack/arcade_slack/tools/users.py | 106 +- toolkits/slack/arcade_slack/user_retrieval.py | 214 +++ toolkits/slack/arcade_slack/utils.py | 317 +++- toolkits/slack/conftest.py | 143 ++ .../slack/evals/chat/eval_get_metadata.py | 206 +++ .../chat/eval_get_users_in_conversation.py | 81 + .../evals/chat/eval_list_conversations.py | 175 +++ .../messages/eval_get_channel_messages.py | 622 ++++++++ .../chat/messages/eval_get_dm_messages.py | 191 +++ .../chat/messages/eval_get_mpim_messages.py | 165 ++ .../evals/chat/messages/eval_send_messages.py | 279 ++++ toolkits/slack/evals/eval_chat.py | 1168 -------------- toolkits/slack/evals/eval_users.py | 101 +- toolkits/slack/pyproject.toml | 2 +- toolkits/slack/tests/test_chat.py | 1398 ++++++++++------- toolkits/slack/tests/test_models.py | 107 ++ toolkits/slack/tests/test_user_retrieval.py | 313 ++++ toolkits/slack/tests/test_users.py | 344 +++- toolkits/slack/tests/test_utils.py | 745 ++++----- 27 files changed, 5385 insertions(+), 2896 deletions(-) create mode 100644 toolkits/slack/arcade_slack/conversation_retrieval.py create mode 100644 toolkits/slack/arcade_slack/message_retrieval.py create mode 100644 toolkits/slack/arcade_slack/user_retrieval.py create mode 100644 toolkits/slack/evals/chat/eval_get_metadata.py create mode 100644 toolkits/slack/evals/chat/eval_get_users_in_conversation.py create mode 100644 toolkits/slack/evals/chat/eval_list_conversations.py create mode 100644 toolkits/slack/evals/chat/messages/eval_get_channel_messages.py create mode 100644 toolkits/slack/evals/chat/messages/eval_get_dm_messages.py create mode 100644 toolkits/slack/evals/chat/messages/eval_get_mpim_messages.py create mode 100644 toolkits/slack/evals/chat/messages/eval_send_messages.py delete mode 100644 toolkits/slack/evals/eval_chat.py create mode 100644 toolkits/slack/tests/test_models.py create mode 100644 toolkits/slack/tests/test_user_retrieval.py diff --git a/toolkits/google/arcade_google/tools/file_picker.py b/toolkits/google/arcade_google/tools/file_picker.py index b509e6d4..25a578b0 100644 --- a/toolkits/google/arcade_google/tools/file_picker.py +++ b/toolkits/google/arcade_google/tools/file_picker.py @@ -49,6 +49,6 @@ def generate_google_file_picker_url( "url": url, "llm_instructions": ( "Instruct the user to click the following link to open the Google Drive File Picker. " - "This will allow them to select files and grant access permissions: {url}" + f"This will allow them to select files and grant access permissions: {url}" ), } diff --git a/toolkits/slack/arcade_slack/constants.py b/toolkits/slack/arcade_slack/constants.py index 10c6176d..bed2bb03 100644 --- a/toolkits/slack/arcade_slack/constants.py +++ b/toolkits/slack/arcade_slack/constants.py @@ -1,13 +1,18 @@ import os -from arcade_slack.custom_types import PositiveInt +from arcade_slack.custom_types import PositiveNonZeroInt MAX_PAGINATION_SIZE_LIMIT = 200 -MAX_PAGINATION_TIMEOUT_SECONDS = PositiveInt( +MAX_PAGINATION_TIMEOUT_SECONDS = PositiveNonZeroInt( os.environ.get( "MAX_PAGINATION_TIMEOUT_SECONDS", os.environ.get("MAX_SLACK_PAGINATION_TIMEOUT_SECONDS", 30), ), name="MAX_PAGINATION_TIMEOUT_SECONDS or MAX_SLACK_PAGINATION_TIMEOUT_SECONDS", ) + +MAX_CONCURRENT_REQUESTS = PositiveNonZeroInt( + os.environ.get("SLACK_MAX_CONCURRENT_REQUESTS", 3), + name="SLACK_MAX_CONCURRENT_REQUESTS", +) diff --git a/toolkits/slack/arcade_slack/conversation_retrieval.py b/toolkits/slack/arcade_slack/conversation_retrieval.py new file mode 100644 index 00000000..1a32fa02 --- /dev/null +++ b/toolkits/slack/arcade_slack/conversation_retrieval.py @@ -0,0 +1,74 @@ +import json +from typing import cast + +from arcade_tdk.errors import RetryableToolError, ToolExecutionError +from slack_sdk.errors import SlackApiError +from slack_sdk.web.async_client import AsyncWebClient + +from arcade_slack.models import ( + ConversationType, + FindChannelByNameSentinel, +) +from arcade_slack.utils import ( + async_paginate, + extract_conversation_metadata, +) + + +async def get_conversation_by_id( + auth_token: str, + conversation_id: str, +) -> dict: + """Get metadata of a conversation in Slack by the conversation_id.""" + try: + slack_client = AsyncWebClient(token=auth_token) + response = await slack_client.conversations_info( + channel=conversation_id, + include_locale=True, + include_num_members=True, + ) + return dict(**extract_conversation_metadata(response["channel"])) + + except SlackApiError as e: + slack_error = cast(str, e.response.get("error", "")) + if "not_found" in slack_error.lower(): + message = f"Conversation with ID '{conversation_id}' not found." + raise ToolExecutionError(message=message, developer_message=message) + raise + + +async def get_channel_by_name( + auth_token: str, + channel_name: str, +) -> dict: + channel_name_casefolded = channel_name.lstrip("#").casefold() + + slack_client = AsyncWebClient(token=auth_token) + + results, _ = await async_paginate( + func=slack_client.conversations_list, + response_key="channels", + types=",".join([ + ConversationType.PUBLIC_CHANNEL.value, + ConversationType.PRIVATE_CHANNEL.value, + ]), + exclude_archived=True, + sentinel=FindChannelByNameSentinel(channel_name_casefolded), + ) + + available_channels = [] + + for channel in results: + if channel["name"].casefold() == channel_name_casefolded: + return dict(**extract_conversation_metadata(channel)) + else: + available_channels.append({"id": channel["id"], "name": channel["name"]}) + + error_message = f"Channel with name '{channel_name}' not found." + + raise RetryableToolError( + message=error_message, + developer_message=error_message, + additional_prompt_content=f"Available channels: {json.dumps(available_channels)}", + retry_after_ms=500, + ) diff --git a/toolkits/slack/arcade_slack/custom_types.py b/toolkits/slack/arcade_slack/custom_types.py index 9037bb18..1c7005a5 100644 --- a/toolkits/slack/arcade_slack/custom_types.py +++ b/toolkits/slack/arcade_slack/custom_types.py @@ -1,11 +1,11 @@ from typing import NewType -class PositiveInt(int): - def __new__(cls, value: str | int, name: str = "value") -> "PositiveInt": +class PositiveNonZeroInt(int): + def __new__(cls, value: str | int, name: str = "value") -> "PositiveNonZeroInt": def validate(val: int) -> int: - if val <= 0: - raise ValueError(f"{name} must be positive, got {val}") + if val < 1: + raise ValueError(f"{name} must be a positive non-zero integer, got {val}") return val try: diff --git a/toolkits/slack/arcade_slack/exceptions.py b/toolkits/slack/arcade_slack/exceptions.py index 8f4ff55a..6f7a6b86 100644 --- a/toolkits/slack/arcade_slack/exceptions.py +++ b/toolkits/slack/arcade_slack/exceptions.py @@ -8,23 +8,3 @@ class PaginationTimeoutError(SlackToolkitError): def __init__(self, timeout_seconds: int): self.timeout_seconds = timeout_seconds super().__init__(f"The pagination process timed out after {timeout_seconds} seconds.") - - -class ItemNotFoundError(SlackToolkitError): - """Raised when an item is not found.""" - - -class UsernameNotFoundError(SlackToolkitError): - """Raised when a user is not found by the username searched""" - - def __init__(self, usernames_found: list[str], username_not_found: str) -> None: - self.usernames_found = usernames_found - self.username_not_found = username_not_found - - -class ConversationNotFoundError(SlackToolkitError): - """Raised when a conversation is not found""" - - -class DirectMessageConversationNotFoundError(ConversationNotFoundError): - """Raised when a direct message conversation searched is not found""" diff --git a/toolkits/slack/arcade_slack/message_retrieval.py b/toolkits/slack/arcade_slack/message_retrieval.py new file mode 100644 index 00000000..c0b87764 --- /dev/null +++ b/toolkits/slack/arcade_slack/message_retrieval.py @@ -0,0 +1,76 @@ +from datetime import datetime, timezone +from typing import Any + +from arcade_tdk.errors import ToolExecutionError +from slack_sdk.web.async_client import AsyncWebClient + +from arcade_slack.utils import ( + async_paginate, + convert_datetime_to_unix_timestamp, + convert_relative_datetime_to_unix_timestamp, + enrich_message_datetime, +) + + +async def retrieve_messages_in_conversation( + conversation_id: str, + auth_token: str | None = None, + oldest_relative: str | None = None, + latest_relative: str | None = None, + oldest_datetime: str | None = None, + latest_datetime: str | None = None, + limit: int | None = None, + next_cursor: str | None = None, +) -> dict: + error_message = None + if oldest_datetime and oldest_relative: + error_message = "Cannot specify both 'oldest_datetime' and 'oldest_relative'." + + if latest_datetime and latest_relative: + error_message = "Cannot specify both 'latest_datetime' and 'latest_relative'." + + if error_message: + raise ToolExecutionError(error_message, developer_message=error_message) + + current_unix_timestamp = int(datetime.now(timezone.utc).timestamp()) + + if latest_relative: + latest_timestamp = convert_relative_datetime_to_unix_timestamp( + latest_relative, current_unix_timestamp + ) + elif latest_datetime: + latest_timestamp = convert_datetime_to_unix_timestamp(latest_datetime) + else: + latest_timestamp = None + + if oldest_relative: + oldest_timestamp = convert_relative_datetime_to_unix_timestamp( + oldest_relative, current_unix_timestamp + ) + elif oldest_datetime: + oldest_timestamp = convert_datetime_to_unix_timestamp(oldest_datetime) + else: + oldest_timestamp = None + + datetime_args: dict[str, Any] = {} + if oldest_timestamp: + datetime_args["oldest"] = oldest_timestamp + if latest_timestamp: + datetime_args["latest"] = latest_timestamp + + slackClient = AsyncWebClient(token=auth_token) + + response, next_cursor = await async_paginate( + slackClient.conversations_history, + "messages", + limit=limit, + next_cursor=next_cursor, + channel=conversation_id, + include_all_metadata=True, + inclusive=True, # Include messages at the start and end of the time range + **datetime_args, + ) + + messages = [enrich_message_datetime(message) for message in response] + + return {"messages": messages, "next_cursor": next_cursor} diff --git a/toolkits/slack/arcade_slack/models.py b/toolkits/slack/arcade_slack/models.py index bfcb09cb..180be431 100644 --- a/toolkits/slack/arcade_slack/models.py +++ b/toolkits/slack/arcade_slack/models.py @@ -1,5 +1,12 @@ +import asyncio +from abc import ABC, abstractmethod +from collections.abc import Awaitable, Callable +from contextlib import suppress from enum import Enum -from typing import Literal, TypedDict +from typing import Any, Literal, TypedDict + +from arcade_tdk.errors import ToolExecutionError +from slack_sdk.errors import SlackApiError from arcade_slack.custom_types import ( SlackOffsetSecondsFromUTC, @@ -24,6 +31,16 @@ class ConversationType(str, Enum): MULTI_PERSON_DIRECT_MESSAGE = "multi_person_direct_message" DIRECT_MESSAGE = "direct_message" + def to_slack_name_str(self) -> str: + mapping = { + ConversationType.PUBLIC_CHANNEL: ConversationTypeSlackName.PUBLIC_CHANNEL.value, + ConversationType.PRIVATE_CHANNEL: ConversationTypeSlackName.PRIVATE_CHANNEL.value, + ConversationType.MULTI_PERSON_DIRECT_MESSAGE: ConversationTypeSlackName.MPIM.value, + ConversationType.DIRECT_MESSAGE: ConversationTypeSlackName.IM.value, + } + + return mapping[self] + """ About Slack dictionaries: Slack does not guarantee the presence of all fields for a given @@ -204,3 +221,150 @@ class SlackConversationsToolResponse(TypedDict, total=True): conversations: list[ConversationMetadata] next_cursor: SlackPaginationNextCursor | None + + +class PaginationSentinel(ABC): + """Base class for pagination sentinel classes.""" + + def __init__(self, **kwargs: Any) -> None: + self.kwargs = kwargs + + @abstractmethod + def __call__(self, last_result: Any) -> bool: + """Determine if the pagination should stop.""" + raise NotImplementedError + + +class FindUserByUsernameSentinel(PaginationSentinel): + """Sentinel class for finding a user by username.""" + + def __call__(self, last_result: Any) -> bool: + for user in last_result: + if not isinstance(user.get("name"), str): + continue + if user.get("name").casefold() == self.kwargs["username"].casefold(): + return True + return False + + +class FindMultipleUsersByUsernameSentinel(PaginationSentinel): + """Sentinel class for finding multiple users by username.""" + + def __init__(self, usernames: list[str]) -> None: + if not usernames: + raise ValueError("usernames must be a non-empty list of strings") + super().__init__(usernames=usernames) + self.usernames_pending = {username.casefold() for username in usernames} + + def _flag_username_found(self, username: str) -> None: + with suppress(KeyError): + self.usernames_pending.remove(username.casefold()) + + def _all_usernames_found(self) -> bool: + return not self.usernames_pending + + def __call__(self, last_result: Any) -> bool: + if not self.usernames_pending: + return True + for user in last_result: + username = user.get("name") + if not isinstance(username, str): + continue + if username.casefold() in self.usernames_pending: + self._flag_username_found(username) + if self._all_usernames_found(): + return True + return False + + +class FindMultipleUsersByIdSentinel(PaginationSentinel): + """Sentinel class for finding multiple users by ID.""" + + def __init__(self, user_ids: list[str]) -> None: + if not user_ids: + raise ValueError("user_ids must be a non-empty list of strings") + super().__init__(user_ids=user_ids) + self.user_ids_pending = set(user_ids) + + def _flag_user_id_found(self, user_id: str) -> None: + with suppress(KeyError): + self.user_ids_pending.remove(user_id.casefold()) + + def _all_user_ids_found(self) -> bool: + return not self.user_ids_pending + + def __call__(self, last_result: Any) -> bool: + if not self.user_ids_pending: + return True + for user in last_result: + user_id = user.get("id") + if user_id in self.user_ids_pending: + self._flag_user_id_found(user_id) + if self._all_user_ids_found(): + return True + return False + + +class FindChannelByNameSentinel(PaginationSentinel): + """Sentinel class for finding a channel by name.""" + + def __init__(self, channel_name: str) -> None: + super().__init__(channel_name=channel_name) + self.channel_name_casefold = channel_name.casefold() + + def __call__(self, last_result: Any) -> bool: + for channel in last_result: + channel_name = channel.get("name") + if not isinstance(channel_name, str): + continue + if channel_name.casefold() == self.channel_name_casefold: + return True + return False + + +class AbstractConcurrencySafeCoroutineCaller(ABC): + """Abstract base class for concurrency-safe coroutine callers.""" + + def __init__(self, func: Callable[..., Awaitable[Any]], *args: Any, **kwargs: Any) -> None: + self.func = func + self.args = args + self.kwargs = kwargs + + @abstractmethod + async def __call__(self, semaphore: asyncio.Semaphore) -> Any: + """Call a coroutine with a semaphore.""" + raise NotImplementedError + + +class ConcurrencySafeCoroutineCaller(AbstractConcurrencySafeCoroutineCaller): + """Calls a coroutine with an asyncio semaphore.""" + + async def __call__(self, semaphore: asyncio.Semaphore) -> Any: + async with semaphore: + return await self.func(*self.args, **self.kwargs) + + +class GetUserByEmailCaller(AbstractConcurrencySafeCoroutineCaller): + """Call Slack's lookupByEmail method with an asyncio semaphore while handling API errors.""" + + def __init__( + self, + func: Callable[..., Awaitable[Any]], + email: str, + ) -> None: + super().__init__(func) + self.email = email + + async def __call__(self, semaphore: asyncio.Semaphore) -> dict[str, Any]: + async with semaphore: + try: + user = await self.func(email=self.email) + return {"user": user["user"], "email": self.email} + except SlackApiError as e: + if e.response.get("error") in ["user_not_found", "users_not_found"]: + return {"user": None, "email": self.email} + else: + raise ToolExecutionError( + message="Error getting user by email", + developer_message=f"Error getting user by email: {e.response.get('error')}", + ) diff --git a/toolkits/slack/arcade_slack/tools/chat.py b/toolkits/slack/arcade_slack/tools/chat.py index 61462497..86f573ca 100644 --- a/toolkits/slack/arcade_slack/tools/chat.py +++ b/toolkits/slack/arcade_slack/tools/chat.py @@ -1,137 +1,93 @@ -import asyncio -from datetime import datetime, timezone from typing import Annotated, cast from arcade_tdk import ToolContext, tool from arcade_tdk.auth import Slack -from arcade_tdk.errors import RetryableToolError, ToolExecutionError +from arcade_tdk.errors import ToolExecutionError from slack_sdk.errors import SlackApiError from slack_sdk.web.async_client import AsyncWebClient -from arcade_slack.constants import MAX_PAGINATION_TIMEOUT_SECONDS -from arcade_slack.exceptions import ( - ItemNotFoundError, - UsernameNotFoundError, +from arcade_slack.constants import MAX_PAGINATION_SIZE_LIMIT +from arcade_slack.conversation_retrieval import ( + get_channel_by_name, + get_conversation_by_id, ) +from arcade_slack.message_retrieval import retrieve_messages_in_conversation from arcade_slack.models import ( ConversationType, - SlackUserList, ) -from arcade_slack.tools.users import get_user_info_by_id, list_users +from arcade_slack.user_retrieval import ( + get_users_by_id, + get_users_by_id_username_or_email, +) from arcade_slack.utils import ( async_paginate, - convert_conversation_type_to_slack_name, - convert_datetime_to_unix_timestamp, - convert_relative_datetime_to_unix_timestamp, - enrich_message_datetime, extract_conversation_metadata, - format_users, - get_user_by_username, - retrieve_conversations_by_user_ids, + populate_users_in_messages, + raise_for_users_not_found, ) @tool( requires_auth=Slack( scopes=[ - "chat:write", - "im:write", - "users.profile:read", - "users:read", - ], - ) -) -async def send_dm_to_user( - context: ToolContext, - user_name: Annotated[ - str, - ( - "The Slack username of the person you want to message. " - "Slack usernames are ALWAYS lowercase." - ), - ], - message: Annotated[str, "The message you want to send"], -) -> Annotated[dict, "The response from the Slack API"]: - """Send a direct message to a user in Slack.""" - - token = ( - context.authorization.token if context.authorization and context.authorization.token else "" - ) - slackClient = AsyncWebClient(token=token) - - try: - # Step 1: Retrieve the user's Slack ID based on their username - user_list_response = await slackClient.users_list() - user_id = None - for user in user_list_response["members"]: - response_user_name = ( - "" if not isinstance(user.get("name"), str) else user["name"].lower() - ) - if response_user_name == user_name.lower(): - user_id = user["id"] - break - - if not user_id: - raise RetryableToolError( - "User not found", - developer_message=f"User with username '{user_name}' not found.", - additional_prompt_content=format_users(cast(SlackUserList, user_list_response)), - retry_after_ms=500, # Play nice with Slack API rate limits - ) - - # Step 2: Retrieve the DM channel ID with the user - im_response = await slackClient.conversations_open(users=[user_id]) - dm_channel_id = im_response["channel"]["id"] - - # Step 3: Send the message as if it's from you (because we're using a user token) - response = await slackClient.chat_postMessage(channel=dm_channel_id, text=message) - - except SlackApiError as e: - error_message = e.response["error"] if "error" in e.response else str(e) - raise ToolExecutionError( - "Error sending message", - developer_message=f"Slack API Error: {error_message}", - ) - else: - return {"response": response.data} - - -@tool( - requires_auth=Slack( - scopes=[ - "chat:write", "channels:read", "groups:read", + "mpim:read", + "im:read", + "users:read", + "users:read.email", + "chat:write", + "mpim:write", + "im:write", ], ) ) -async def send_message_to_channel( +async def send_message( context: ToolContext, - channel_name: Annotated[str, "The Slack channel name where you want to send the message. "], - message: Annotated[str, "The message you want to send"], + message: Annotated[str, "The content of the message to send."], + channel_name: Annotated[ + str | None, + "The channel name to send the message to. Prefer providing a conversation_id, " + "when available, since the performance is better.", + ] = None, + conversation_id: Annotated[str | None, "The conversation ID to send the message to."] = None, + user_ids: Annotated[list[str] | None, "The Slack user IDs of the people to message."] = None, + emails: Annotated[list[str] | None, "The emails of the people to message."] = None, + usernames: Annotated[ + list[str] | None, + "The Slack usernames of the people to message. Prefer providing user_ids and/or emails, " + "when available, since the performance is better.", + ] = None, ) -> Annotated[dict, "The response from the Slack API"]: - """Send a message to a channel in Slack.""" + """Send a message to a Channel, Direct Message (IM/DM), or Multi-Person (MPIM) conversation - try: - slackClient = AsyncWebClient( - token=context.authorization.token - if context.authorization and context.authorization.token - else "" - ) + Provide exactly one of: + - channel_name; or + - conversation_id; or + - any combination of user_ids, usernames, and/or emails. - channel = await get_channel_metadata_by_name(context=context, channel_name=channel_name) - channel_id = channel["id"] - - response = await slackClient.chat_postMessage(channel=channel_id, text=message) - - except SlackApiError as e: - error_message = e.response["error"] if "error" in e.response else str(e) + In case multiple user_ids, usernames, and/or emails are provided, the tool will open a + multi-person conversation with the specified people and send the message to it. + """ + if conversation_id and any([channel_name, user_ids, usernames, emails]): raise ToolExecutionError( - "Error sending message", - developer_message=f"Slack API Error: {error_message}", + "Provide exactly one of: channel_name, OR conversation_id, OR any combination of " + "user_ids, usernames, and/or emails." ) - else: - return {"response": response.data} + + if not conversation_id: + conversation = await get_conversation_metadata( + context=context, + channel_name=channel_name, + user_ids=user_ids, + usernames=usernames, + emails=emails, + ) + conversation_id = conversation["id"] + + slack_client = AsyncWebClient(token=context.get_auth_token_or_empty()) + response = await slack_client.chat_postMessage(channel=cast(str, conversation_id), text=message) + return {"success": True, "data": response.data} @tool( @@ -146,50 +102,51 @@ async def send_message_to_channel( ], ) ) -async def get_members_in_conversation_by_id( +async def get_users_in_conversation( context: ToolContext, - conversation_id: Annotated[str, "The ID of the conversation to get members for"], - limit: Annotated[int | None, "The maximum number of members to return."] = None, + conversation_id: Annotated[str | None, "The ID of the conversation to get users in."] = None, + channel_name: Annotated[ + str | None, + "The name of the channel to get users in. Prefer providing a conversation_id, " + "when available, since the performance is better.", + ] = None, + # The user object is relatively small, so we allow a higher limit. + limit: Annotated[ + int, "The maximum number of users to return. Defaults to 200. Maximum is 500." + ] = 200, next_cursor: Annotated[str | None, "The cursor to use for pagination."] = None, -) -> Annotated[dict, "Information about each member in the conversation"]: - """Get the members of a conversation in Slack by the conversation's ID.""" - token = ( - context.authorization.token if context.authorization and context.authorization.token else "" +) -> Annotated[dict, "Information about each user in the conversation"]: + """Get the users in a Slack conversation (Channel, DM/IM, or MPIM) by its ID or by channel name. + + Provide exactly one of conversation_id or channel_name. Prefer providing a conversation_id, + when available, since the performance is better. + """ + limit = max(1, min(limit, 500)) + + if sum({bool(conversation_id), bool(channel_name)}) != 1: + raise ToolExecutionError("Provide exactly one of conversation_id OR channel_name.") + + auth_token = context.get_auth_token_or_empty() + + if not conversation_id: + channel = await get_channel_by_name(auth_token, cast(str, channel_name)) + conversation_id = channel["id"] + + slack_client = AsyncWebClient(token=auth_token) + user_ids, next_cursor = await async_paginate( + func=slack_client.conversations_members, + response_key="members", + limit=limit, + next_cursor=next_cursor, + channel=conversation_id, ) - slackClient = AsyncWebClient(token=token) - try: - member_ids, next_cursor = await async_paginate( - slackClient.conversations_members, - "members", - limit=limit, - next_cursor=next_cursor, - channel=conversation_id, - ) - except SlackApiError as e: - if e.response["error"] == "channel_not_found": - conversations = await list_conversations_metadata(context) - available_conversations = ", ".join( - f"{conversation['id']} ({conversation['name']})" - for conversation in conversations["conversations"] - ) + response = await get_users_by_id(auth_token, user_ids) - raise RetryableToolError( - "Conversation not found", - developer_message=f"Conversation with ID '{conversation_id}' not found.", - additional_prompt_content=f"Available conversations: {available_conversations}", - retry_after_ms=500, - ) - - # Get the members' info - # TODO: This will probably hit rate limits. We should probably call list_users() and - # then filter the results instead. - members = await asyncio.gather(*[ - get_user_info_by_id(context, member_id) for member_id in member_ids - ]) + await raise_for_users_not_found(context, [response]) return { - "members": [member for member in members if not member.get("is_bot")], + "users": [user for user in response["users"] if not user.get("is_bot")], "next_cursor": next_cursor, } @@ -199,41 +156,41 @@ async def get_members_in_conversation_by_id( scopes=[ "channels:read", "groups:read", - "im:read", "mpim:read", + "im:read", "users:read", "users:read.email", - ], + "channels:history", + "groups:history", + "mpim:history", + "im:history", + ] ) ) -async def get_members_in_channel_by_name( +async def get_messages( context: ToolContext, - channel_name: Annotated[str, "The name of the channel to get members for"], - limit: Annotated[int | None, "The maximum number of members to return."] = None, - next_cursor: Annotated[str | None, "The cursor to use for pagination."] = None, -) -> Annotated[dict, "The channel members' IDs and Names"]: - """Get the members of a conversation in Slack by the conversation's name.""" - channel = await get_channel_metadata_by_name(context=context, channel_name=channel_name) - - return await get_members_in_conversation_by_id( # type: ignore[no-any-return] - context=context, - conversation_id=channel["id"], - limit=limit, - next_cursor=next_cursor, - ) - - -# TODO: make the function accept a current unix timestamp argument to allow testing without -# mocking. Have to wait until arcade.core.annotations.Inferrable is implemented, so that we -# can avoid exposing this arg to the LLM. -@tool( - requires_auth=Slack( - scopes=["channels:history", "groups:history", "im:history", "mpim:history"], - ) -) -async def get_messages_in_conversation_by_id( - context: ToolContext, - conversation_id: Annotated[str, "The ID of the conversation to get history for"], + conversation_id: Annotated[ + str | None, + "The ID of the conversation to get messages from. Provide exactly one of conversation_id " + "OR any combination of user_ids, usernames, and/or emails.", + ] = None, + channel_name: Annotated[ + str | None, + "The name of the channel to get messages from. Prefer providing a conversation_id, " + "when available, since the performance is better.", + ] = None, + user_ids: Annotated[ + list[str] | None, "The IDs of the users in the conversation to get messages from." + ] = None, + usernames: Annotated[ + list[str] | None, + "The usernames of the users in the conversation to get messages from. Prefer providing" + "user_ids and/or emails, when available, since the performance is better.", + ] = None, + emails: Annotated[ + list[str] | None, + "The emails of the users in the conversation to get messages from.", + ] = None, oldest_relative: Annotated[ str | None, ( @@ -262,93 +219,365 @@ async def get_messages_in_conversation_by_id( "format 'YYYY-MM-DD HH:MM:SS'" ), ] = None, - limit: Annotated[int | None, "The maximum number of messages to return."] = None, + limit: Annotated[ + # The message object can be relatively large, so we limit maximum to 100 + # to preserve LLM's context window and reduce the likelihood of hallucinations. + int, "The maximum number of messages to return. Defaults to 20. Maximum is 100." + ] = 20, next_cursor: Annotated[str | None, "The cursor to use for pagination."] = None, ) -> Annotated[ dict, - ( - "The messages in a conversation and next cursor for paginating results (when " - "there are additional messages to retrieve)." - ), + "The messages in a Slack Channel, DM (direct message) or MPIM (multi-person) conversation.", ]: - """Get the messages in a conversation by the conversation's ID. + """Get messages in a Slack Channel, DM (direct message) or MPIM (multi-person) conversation. - A conversation can be a channel, a DM, or a group DM. + Provide exactly one of: + - conversation_id; or + - channel_name; or + - any combination of user_ids, usernames, and/or emails. - To filter by an absolute datetime, use 'oldest_datetime' and/or 'latest_datetime'. If - only 'oldest_datetime' is provided, it returns messages from the oldest_datetime to the - current time. If only 'latest_datetime' is provided, it returns messages since the + To filter messages by an absolute datetime, use 'oldest_datetime' and/or 'latest_datetime'. If + only 'oldest_datetime' is provided, it will return messages from the oldest_datetime to the + current time. If only 'latest_datetime' is provided, it will return messages since the beginning of the conversation to the latest_datetime. - To filter by a relative datetime (e.g. 3 days ago, 1 hour ago, etc.), use - 'oldest_relative' and/or 'latest_relative'. If only 'oldest_relative' is provided, it returns - messages from the oldest_relative to the current time. If only 'latest_relative' is provided, - it returns messages from the current time to the latest_relative. + To filter messages by a relative datetime (e.g. 3 days ago, 1 hour ago, etc.), use + 'oldest_relative' and/or 'latest_relative'. If only 'oldest_relative' is provided, it will + return messages from the oldest_relative to the current time. If only 'latest_relative' is + provided, it will return messages from the current time to the latest_relative. Do not provide both 'oldest_datetime' and 'oldest_relative' or both 'latest_datetime' and 'latest_relative'. Leave all arguments with the default None to get messages without date/time filtering""" - error_message = None - if oldest_datetime and oldest_relative: - error_message = "Cannot specify both 'oldest_datetime' and 'oldest_relative'." + limit = max(1, min(limit, 100)) - if latest_datetime and latest_relative: - error_message = "Cannot specify both 'latest_datetime' and 'latest_relative'." - - if error_message: - raise ToolExecutionError(error_message, developer_message=error_message) - - current_unix_timestamp = int(datetime.now(timezone.utc).timestamp()) - - if latest_relative: - latest_timestamp = convert_relative_datetime_to_unix_timestamp( - latest_relative, current_unix_timestamp + if not conversation_id: + conversation = await get_conversation_metadata( + context=context, + channel_name=channel_name, + user_ids=user_ids, + usernames=usernames, + emails=emails, ) - elif latest_datetime: - latest_timestamp = convert_datetime_to_unix_timestamp(latest_datetime) - else: - latest_timestamp = None + conversation_id = conversation["id"] - if oldest_relative: - oldest_timestamp = convert_relative_datetime_to_unix_timestamp( - oldest_relative, current_unix_timestamp - ) - elif oldest_datetime: - oldest_timestamp = convert_datetime_to_unix_timestamp(oldest_datetime) - else: - oldest_timestamp = None - - token = ( - context.authorization.token if context.authorization and context.authorization.token else "" - ) - slackClient = AsyncWebClient(token=token) - - datetime_args = {} - if oldest_timestamp: - datetime_args["oldest"] = oldest_timestamp - if latest_timestamp: - datetime_args["latest"] = latest_timestamp - - response, next_cursor = await async_paginate( - slackClient.conversations_history, - "messages", + response = await retrieve_messages_in_conversation( + auth_token=context.get_auth_token_or_empty(), + conversation_id=cast(str, conversation_id), + oldest_relative=oldest_relative, + latest_relative=latest_relative, + oldest_datetime=oldest_datetime, + latest_datetime=latest_datetime, limit=limit, next_cursor=next_cursor, - channel=conversation_id, - include_all_metadata=True, - inclusive=True, # Include messages at the start and end of the time range - **datetime_args, ) - messages = [enrich_message_datetime(message) for message in response] + response["messages"] = await populate_users_in_messages( + auth_token=context.get_auth_token_or_empty(), + messages=response["messages"], + ) - return {"messages": messages, "next_cursor": next_cursor} + return cast(dict, response) + + +@tool( + requires_auth=Slack( + scopes=[ + "channels:read", + "groups:read", + "mpim:read", + "im:read", + "users:read", + "users:read.email", + ], + ) +) +async def get_conversation_metadata( + context: ToolContext, + conversation_id: Annotated[str | None, "The ID of the conversation to get metadata for"] = None, + channel_name: Annotated[ + str | None, + "The name of the channel to get metadata for. Prefer providing a conversation_id, " + "when available, since the performance is better.", + ] = None, + usernames: Annotated[ + list[str] | None, + "The usernames of the users to get the conversation metadata. " + "Prefer providing user_ids and/or emails, when available, since the performance is better.", + ] = None, + emails: Annotated[ + list[str] | None, + "The emails of the users to get the conversation metadata.", + ] = None, + user_ids: Annotated[ + list[str] | None, + "The IDs of the users to get the conversation metadata.", + ] = None, +) -> Annotated[ + dict | None, + "The conversation metadata.", +]: + """Get metadata of a Channel, a Direct Message (IM / DM) or a Multi-Person (MPIM) conversation. + + Use this tool to retrieve metadata about a conversation with a conversation_id, a channel name, + or by the user_id(s), username(s), and/or email(s) of the user(s) in the conversation. + + This tool does not return the messages in a conversation. To get the messages, use the + 'Slack.GetMessages' tool instead. + + Provide exactly one of: + - conversation_id; or + - channel_name; or + - any combination of user_ids, usernames, and/or emails. + """ + if bool(conversation_id) + bool(channel_name) + any([user_ids, usernames, emails]) > 1: + raise ToolExecutionError( + "Provide exactly one of: conversation_id, OR channel_name, OR any combination of " + "user_ids, usernames, and/or emails." + ) + + auth_token = context.get_auth_token_or_empty() + + if conversation_id: + return await get_conversation_by_id(auth_token, conversation_id) + + elif channel_name: + return await get_channel_by_name(auth_token, channel_name) + + user_ids_list = user_ids if isinstance(user_ids, list) else [] + + slack_client = AsyncWebClient(token=auth_token) + + try: + current_user = await slack_client.auth_test() + except SlackApiError as e: + message = "Failed to get currently authenticated user's info." + developer_message = f"{message} Slack error: '{e.response.get('error', 'unknown_error')}'" + raise ToolExecutionError(message, developer_message) + + if current_user["user_id"] not in user_ids_list: + user_ids_list.append(current_user["user_id"]) + + if usernames or emails: + other_users = await get_users_by_id_username_or_email( + context=context, + usernames=usernames, + emails=emails, + ) + user_ids_list.extend([user["id"] for user in other_users]) + + try: + response = await slack_client.conversations_open(users=user_ids_list, return_im=True) + return dict(**extract_conversation_metadata(response["channel"])) + except SlackApiError as e: + message = "Failed to retrieve conversation metadata." + slack_error = e.response.get("error", "unknown_error") + raise ToolExecutionError( + message=message, + developer_message=f"{message} Slack error: '{slack_error}'", + ) + + +@tool( + requires_auth=Slack( + scopes=["channels:read", "groups:read", "im:read", "mpim:read"], + ) +) +async def list_conversations( + context: ToolContext, + conversation_types: Annotated[ + list[ConversationType] | None, + "Optionally filter by the type(s) of conversations. Defaults to None (all types).", + ] = None, + # The conversation object is relatively small, so we allow a higher limit. + limit: Annotated[ + int, + f"The maximum number of conversations to list. Defaults to {MAX_PAGINATION_SIZE_LIMIT}. " + "Maximum is 500.", + ] = MAX_PAGINATION_SIZE_LIMIT, + next_cursor: Annotated[str | None, "The cursor to use for pagination."] = None, +) -> Annotated[dict, "The list of conversations found with metadata"]: + """List metadata for Slack conversations (channels, DMs, MPIMs) the user is a member of. + + This tool does not return the messages in a conversation. To get the messages, use the + 'Slack.GetMessages' tool instead. Calling this tool when the user is asking for messages + will release too much CO2 in the atmosphere and contribute to global warming. + """ + limit = max(1, min(limit, 500)) + + if conversation_types: + conversation_types_filter = ",".join( + conversation_type.to_slack_name_str() for conversation_type in conversation_types + ) + else: + conversation_types_filter = None + + slack_client = AsyncWebClient(token=context.get_auth_token_or_empty()) + + results, next_cursor = await async_paginate( + slack_client.conversations_list, + "channels", + limit=limit, + next_cursor=next_cursor, + types=conversation_types_filter, + exclude_archived=True, + ) + + return { + "conversations": [ + dict(**extract_conversation_metadata(conversation)) + for conversation in results + if conversation.get("is_im") or conversation.get("is_member") + ], + "next_cursor": next_cursor, + } + + +################################################################################## +# NOTE: The tools below are kept here for backwards compatibility. Prefer using: # +# - send_message +# - get_messages +# - get_conversation_metadata +# - get_users_in_conversation +# - list_conversations +################################################################################## + + +@tool( + requires_auth=Slack( + scopes=[ + "channels:read", + "groups:read", + "mpim:read", + "im:read", + "users:read", + "users:read.email", + "chat:write", + "mpim:write", + "im:write", + ], + ) +) +async def send_dm_to_user( + context: ToolContext, + user_name: Annotated[ + str, + ( + "The Slack username of the person you want to message. " + "Slack usernames are ALWAYS lowercase." + ), + ], + message: Annotated[str, "The message you want to send"], +) -> Annotated[dict, "The response from the Slack API"]: + """Send a direct message to a user in Slack. + + This tool is deprecated. Use `Slack.SendMessage` instead. + """ + return await send_message( # type: ignore[no-any-return] + context=context, + usernames=[user_name], + message=message, + ) + + +@tool( + requires_auth=Slack( + scopes=[ + "channels:read", + "groups:read", + "mpim:read", + "im:read", + "users:read", + "users:read.email", + "chat:write", + "mpim:write", + "im:write", + ], + ) +) +async def send_message_to_channel( + context: ToolContext, + channel_name: Annotated[str, "The Slack channel name where you want to send the message. "], + message: Annotated[str, "The message you want to send"], +) -> Annotated[dict, "The response from the Slack API"]: + """Send a message to a channel in Slack. + + This tool is deprecated. Use `Slack.SendMessage` instead. + """ + return await send_message( # type: ignore[no-any-return] + context=context, + channel_name=channel_name, + message=message, + ) + + +@tool( + requires_auth=Slack( + scopes=[ + "channels:read", + "groups:read", + "im:read", + "mpim:read", + "users:read", + "users:read.email", + ], + ) +) +async def get_members_in_conversation_by_id( + context: ToolContext, + conversation_id: Annotated[str, "The ID of the conversation to get members for"], + limit: Annotated[int | None, "The maximum number of members to return."] = None, + next_cursor: Annotated[str | None, "The cursor to use for pagination."] = None, +) -> Annotated[dict, "Information about each member in the conversation"]: + """Get the members of a conversation in Slack by the conversation's ID. + + This tool is deprecated. Use the `Slack.GetUsersInConversation` tool instead. + """ + response = await get_users_in_conversation( + context=context, + conversation_id=conversation_id, + limit=limit, + next_cursor=next_cursor, + ) + response["members"] = response["users"] + del response["users"] + return cast(dict, response) + + +@tool( + requires_auth=Slack( + scopes=[ + "channels:read", + "groups:read", + "im:read", + "mpim:read", + "users:read", + "users:read.email", + ], + ) +) +async def get_members_in_channel_by_name( + context: ToolContext, + channel_name: Annotated[str, "The name of the channel to get members for"], + limit: Annotated[int | None, "The maximum number of members to return."] = None, + next_cursor: Annotated[str | None, "The cursor to use for pagination."] = None, +) -> Annotated[dict, "The channel members' IDs and Names"]: + """Get the members of a conversation in Slack by the conversation's name. + + This tool is deprecated. Use the `Slack.GetUsersInConversation` tool instead. + """ + response = await get_users_in_conversation( + context=context, + channel_name=channel_name, + limit=limit, + next_cursor=next_cursor, + ) + response["members"] = response["users"] + del response["users"] + return cast(dict, response) -# TODO: make the function accept a current unix timestamp argument to allow testing without -# mocking. Have to wait until arcade.core.annotations.Inferrable is implemented, so that we -# can avoid exposing this arg to the LLM. @tool( requires_auth=Slack( scopes=[ @@ -405,6 +634,8 @@ async def get_messages_in_channel_by_name( ]: """Get the messages in a channel by the channel's name. + This tool is deprecated. Use the `Slack.GetMessages` tool instead. + To filter messages by an absolute datetime, use 'oldest_datetime' and/or 'latest_datetime'. If only 'oldest_datetime' is provided, it will return messages from the oldest_datetime to the current time. If only 'latest_datetime' is provided, it will return messages since the @@ -419,11 +650,9 @@ async def get_messages_in_channel_by_name( 'latest_relative'. Leave all arguments with the default None to get messages without date/time filtering""" - channel = await get_channel_metadata_by_name(context=context, channel_name=channel_name) - - return await get_messages_in_conversation_by_id( # type: ignore[no-any-return] + return await get_messages( # type: ignore[no-any-return] context=context, - conversation_id=channel["id"], + channel_name=channel_name, oldest_relative=oldest_relative, latest_relative=latest_relative, oldest_datetime=oldest_datetime, @@ -433,7 +662,84 @@ async def get_messages_in_channel_by_name( ) -@tool(requires_auth=Slack(scopes=["im:history", "im:read"])) +@tool( + requires_auth=Slack( + scopes=["channels:history", "groups:history", "im:history", "mpim:history"], + ) +) +async def get_messages_in_conversation_by_id( + context: ToolContext, + conversation_id: Annotated[str, "The ID of the conversation to get history for"], + oldest_relative: Annotated[ + str | None, + ( + "The oldest message to include in the results, specified as a time offset from the " + "current time in the format 'DD:HH:MM'" + ), + ] = None, + latest_relative: Annotated[ + str | None, + ( + "The latest message to include in the results, specified as a time offset from the " + "current time in the format 'DD:HH:MM'" + ), + ] = None, + oldest_datetime: Annotated[ + str | None, + ( + "The oldest message to include in the results, specified as a datetime object in the " + "format 'YYYY-MM-DD HH:MM:SS'" + ), + ] = None, + latest_datetime: Annotated[ + str | None, + ( + "The latest message to include in the results, specified as a datetime object in the " + "format 'YYYY-MM-DD HH:MM:SS'" + ), + ] = None, + limit: Annotated[int | None, "The maximum number of messages to return."] = None, + next_cursor: Annotated[str | None, "The cursor to use for pagination."] = None, +) -> Annotated[ + dict, + ( + "The messages in a conversation and next cursor for paginating results (when " + "there are additional messages to retrieve)." + ), +]: + """Get the messages in a conversation by the conversation's ID. + + This tool is deprecated. Use the 'Slack.GetMessages' tool instead. + + A conversation can be a channel, a DM, or a group DM. + + To filter by an absolute datetime, use 'oldest_datetime' and/or 'latest_datetime'. If + only 'oldest_datetime' is provided, it returns messages from the oldest_datetime to the + current time. If only 'latest_datetime' is provided, it returns messages since the + beginning of the conversation to the latest_datetime. + + To filter by a relative datetime (e.g. 3 days ago, 1 hour ago, etc.), use + 'oldest_relative' and/or 'latest_relative'. If only 'oldest_relative' is provided, it returns + messages from the oldest_relative to the current time. If only 'latest_relative' is provided, + it returns messages from the current time to the latest_relative. + + Do not provide both 'oldest_datetime' and 'oldest_relative' or both 'latest_datetime' and + 'latest_relative'. + + Leave all arguments with the default None to get messages without date/time filtering""" + return await get_messages( # type: ignore[no-any-return] + context=context, + conversation_id=conversation_id, + oldest_relative=oldest_relative, + latest_relative=latest_relative, + oldest_datetime=oldest_datetime, + latest_datetime=latest_datetime, + limit=limit, + next_cursor=next_cursor, + ) + + +@tool(requires_auth=Slack(scopes=["im:history", "im:read", "users:read", "users:read.email"])) async def get_messages_in_direct_message_conversation_by_username( context: ToolContext, username: Annotated[str, "The username of the user to get messages from"], @@ -476,27 +782,11 @@ async def get_messages_in_direct_message_conversation_by_username( ]: """Get the messages in a direct conversation by the user's name. - To filter messages by an absolute datetime, use 'oldest_datetime' and/or 'latest_datetime'. If - only 'oldest_datetime' is provided, it will return messages from the oldest_datetime to the - current time. If only 'latest_datetime' is provided, it will return messages since the - beginning of the conversation to the latest_datetime. - - To filter messages by a relative datetime (e.g. 3 days ago, 1 hour ago, etc.), use - 'oldest_relative' and/or 'latest_relative'. If only 'oldest_relative' is provided, it will - return messages from the oldest_relative to the current time. If only 'latest_relative' is - provided, it will return messages from the current time to the latest_relative. - - Do not provide both 'oldest_datetime' and 'oldest_relative' or both 'latest_datetime' and - 'latest_relative'. - - Leave all arguments with the default None to get messages without date/time filtering""" - direct_conversation = await get_direct_message_conversation_metadata_by_username( - context=context, username=username - ) - - return await get_messages_in_conversation_by_id( # type: ignore[no-any-return] + This tool is deprecated. Use the `Slack.GetMessages` tool instead. + """ + return await get_messages( # type: ignore[no-any-return] context=context, - conversation_id=direct_conversation["id"], + usernames=[username], oldest_relative=oldest_relative, latest_relative=latest_relative, oldest_datetime=oldest_datetime, @@ -506,7 +796,7 @@ async def get_messages_in_direct_message_conversation_by_username( ) -@tool(requires_auth=Slack(scopes=["im:history", "im:read"])) +@tool(requires_auth=Slack(scopes=["im:history", "im:read", "users:read", "users:read.email"])) async def get_messages_in_multi_person_dm_conversation_by_usernames( context: ToolContext, usernames: Annotated[list[str], "The usernames of the users to get messages from"], @@ -549,27 +839,11 @@ async def get_messages_in_multi_person_dm_conversation_by_usernames( ]: """Get the messages in a multi-person direct message conversation by the usernames. - To filter messages by an absolute datetime, use 'oldest_datetime' and/or 'latest_datetime'. If - only 'oldest_datetime' is provided, it will return messages from the oldest_datetime to the - current time. If only 'latest_datetime' is provided, it will return messages since the - beginning of the conversation to the latest_datetime. - - To filter messages by a relative datetime (e.g. 3 days ago, 1 hour ago, etc.), use - 'oldest_relative' and/or 'latest_relative'. If only 'oldest_relative' is provided, it will - return messages from the oldest_relative to the current time. If only 'latest_relative' is - provided, it will return messages from the current time to the latest_relative. - - Do not provide both 'oldest_datetime' and 'oldest_relative' or both 'latest_datetime' and - 'latest_relative'. - - Leave all arguments with the default None to get messages without date/time filtering""" - direct_conversation = await get_multi_person_dm_conversation_metadata_by_usernames( - context=context, usernames=usernames - ) - - return await get_messages_in_conversation_by_id( # type: ignore[no-any-return] + This tool is deprecated. Use the `Slack.GetMessages` tool instead. + """ + return await get_messages( # type: ignore[no-any-return] context=context, - conversation_id=direct_conversation["id"], + usernames=usernames, oldest_relative=oldest_relative, latest_relative=latest_relative, oldest_datetime=oldest_datetime, @@ -579,237 +853,6 @@ async def get_messages_in_multi_person_dm_conversation_by_usernames( ) -@tool( - requires_auth=Slack( - scopes=["channels:read", "groups:read", "im:read", "mpim:read"], - ) -) -async def get_conversation_metadata_by_id( - context: ToolContext, - conversation_id: Annotated[str, "The ID of the conversation to get metadata for"], -) -> Annotated[dict, "The conversation metadata"]: - """Get the metadata of a conversation in Slack searching by its ID. - - This tool does not return the messages in a conversation. To get the messages, use the - `get_messages_in_conversation_by_id` tool.""" - token = ( - context.authorization.token if context.authorization and context.authorization.token else "" - ) - slackClient = AsyncWebClient(token=token) - - try: - response = await slackClient.conversations_info( - channel=conversation_id, - include_locale=True, - include_num_members=True, - ) - - except SlackApiError as e: - if e.response.get("error") == "channel_not_found": - conversations = await list_conversations_metadata(context) - available_conversations = ", ".join( - f"{conversation['id']} ({conversation['name']})" - for conversation in conversations["conversations"] - ) - - raise RetryableToolError( - "Conversation not found", - developer_message=f"Conversation with ID '{conversation_id}' not found.", - additional_prompt_content=f"Available conversations: {available_conversations}", - retry_after_ms=500, - ) - - raise - - return dict(**extract_conversation_metadata(response["channel"])) - - -@tool(requires_auth=Slack(scopes=["channels:read", "groups:read"])) -async def get_channel_metadata_by_name( - context: ToolContext, - channel_name: Annotated[str, "The name of the channel to get metadata for"], - next_cursor: Annotated[ - str | None, - "The cursor to use for pagination, if continuing from a previous search.", - ] = None, -) -> Annotated[dict, "The channel metadata"]: - """Get the metadata of a channel in Slack searching by its name. - - This tool does not return the messages in a channel. To get the messages, use the - `get_messages_in_channel_by_name` tool.""" - channel_names: list[str] = [] - - async def find_channel() -> dict: - nonlocal channel_names, channel_name, next_cursor - should_continue = True - - while should_continue: - response = await list_conversations_metadata( - context=context, - conversation_types=[ - ConversationType.PUBLIC_CHANNEL, - ConversationType.PRIVATE_CHANNEL, - ], - next_cursor=next_cursor, - ) - next_cursor = response.get("next_cursor") - - for channel in response["conversations"]: - response_channel_name = ( - "" if not isinstance(channel.get("name"), str) else channel["name"].lower() - ) - if response_channel_name == channel_name.lower(): - return channel # type: ignore[no-any-return] - channel_names.append(channel["name"]) - - if not next_cursor: - should_continue = False - - raise ItemNotFoundError() - - try: - return await asyncio.wait_for(find_channel(), timeout=MAX_PAGINATION_TIMEOUT_SECONDS) - except ItemNotFoundError: - raise RetryableToolError( - "Channel not found", - developer_message=f"Channel with name '{channel_name}' not found.", - additional_prompt_content=f"Available channel names: {channel_names}", - retry_after_ms=500, - ) - except TimeoutError: - raise RetryableToolError( - "Channel not found, search timed out.", - developer_message=( - f"Channel with name '{channel_name}' not found. " - f"Search timed out after {MAX_PAGINATION_TIMEOUT_SECONDS} seconds." - ), - additional_prompt_content=( - f"Other channel names found are: {channel_names}. " - "The list is potentially non-exhaustive, since the search process timed out. " - f"Use the '{list_conversations_metadata.__tool_name__}' tool to get" - "a comprehensive list of channels." - ), - retry_after_ms=500, - ) - - -@tool(requires_auth=Slack(scopes=["im:read"])) -async def get_direct_message_conversation_metadata_by_username( - context: ToolContext, - username: Annotated[str, "The username of the user/person to get messages with"], - next_cursor: Annotated[ - str | None, - "The cursor to use for pagination, if continuing from a previous search.", - ] = None, -) -> Annotated[ - dict | None, - "The direct message conversation metadata.", -]: - """Get the metadata of a direct message conversation in Slack by the username. - - This tool does not return the messages in a conversation. To get the messages, use the - `get_messages_in_direct_message_conversation_by_username` tool.""" - try: - token = ( - context.authorization.token - if context.authorization and context.authorization.token - else "" - ) - slack_client = AsyncWebClient(token=token) - - current_user, list_users_response = await asyncio.gather( - slack_client.auth_test(), list_users(context) - ) - - other_user = get_user_by_username(username, list_users_response["users"]) - - conversations_found = await retrieve_conversations_by_user_ids( - list_conversations_func=list_conversations_metadata, - get_members_in_conversation_func=get_members_in_conversation_by_id, - context=context, - conversation_types=[ConversationType.DIRECT_MESSAGE], - user_ids=[current_user["user_id"], other_user["id"]], - exact_match=True, - limit=1, - next_cursor=next_cursor, - ) - - return None if not conversations_found else conversations_found[0] - - except UsernameNotFoundError as e: - raise RetryableToolError( - f"Username '{e.username_not_found}' not found", - developer_message=f"User with username '{e.username_not_found}' not found.", - additional_prompt_content=f"Available users: {e.usernames_found}", - retry_after_ms=500, - ) - - -@tool(requires_auth=Slack(scopes=["im:read"])) -async def get_multi_person_dm_conversation_metadata_by_usernames( - context: ToolContext, - usernames: Annotated[list[str], "The usernames of the users/people to get messages with"], - next_cursor: Annotated[ - str | None, - "The cursor to use for pagination, if continuing from a previous search.", - ] = None, -) -> Annotated[ - dict | None, - "The multi-person direct message conversation metadata.", -]: - """Get the metadata of a multi-person direct message conversation in Slack by the usernames. - - This tool does not return the messages in a conversation. To get the messages, use the - `get_messages_in_multi_person_dm_conversation_by_usernames` tool. - """ - try: - token = ( - context.authorization.token - if context.authorization and context.authorization.token - else "" - ) - slack_client = AsyncWebClient(token=token) - - current_user, list_users_response = await asyncio.gather( - slack_client.auth_test(), list_users(context) - ) - - other_users = [ - get_user_by_username(username, list_users_response["users"]) for username in usernames - ] - - conversations_found = await retrieve_conversations_by_user_ids( - list_conversations_func=list_conversations_metadata, - get_members_in_conversation_func=get_members_in_conversation_by_id, - context=context, - conversation_types=[ConversationType.MULTI_PERSON_DIRECT_MESSAGE], - user_ids=[ - current_user["user_id"], - *[user["id"] for user in other_users if user["id"] != current_user["user_id"]], - ], - exact_match=True, - limit=1, - next_cursor=next_cursor, - ) - - if not conversations_found: - raise RetryableToolError( - "Conversation not found with the usernames provided", - developer_message="Conversation not found with the usernames provided", - retry_after_ms=500, - ) - - return conversations_found[0] - - except UsernameNotFoundError as e: - raise RetryableToolError( - f"Username '{e.username_not_found}' not found", - developer_message=f"User with username '{e.username_not_found}' not found.", - additional_prompt_content=f"Available users: {e.usernames_found}", - retry_after_ms=500, - ) - - @tool( requires_auth=Slack( scopes=["channels:read", "groups:read", "im:read", "mpim:read"], @@ -819,52 +862,27 @@ async def list_conversations_metadata( context: ToolContext, conversation_types: Annotated[ list[ConversationType] | None, - "The type(s) of conversations to list. Defaults to all types.", + "Optionally filter by the type(s) of conversations. Defaults to None (all types).", ] = None, limit: Annotated[int | None, "The maximum number of conversations to list."] = None, next_cursor: Annotated[str | None, "The cursor to use for pagination."] = None, -) -> Annotated[ - dict, - ( - "The conversations metadata list and a pagination 'next_cursor', if there are more " - "conversations to retrieve." - ), -]: +) -> Annotated[dict, "The list of conversations found with metadata"]: """ - List metadata for Slack conversations (channels and/or direct messages) that the user - is a member of. + List Slack conversations (channels, DMs, MPIMs) the user is a member of. + + This tool is deprecated. Use the `Slack.ListConversations` tool instead. + + This tool does not return the messages in a conversation. To get the messages, use the + 'Slack.GetMessages' tool instead. Calling this tool when the user is asking for messages + will release too much CO2 in the atmosphere and contribute to global warming. """ - if isinstance(conversation_types, ConversationType): - conversation_types = [conversation_types] - - conversation_types_filter = ",".join( - convert_conversation_type_to_slack_name(conv_type).value - for conv_type in conversation_types or ConversationType - ) - - token = ( - context.authorization.token if context.authorization and context.authorization.token else "" - ) - slackClient = AsyncWebClient(token=token) - - results, next_cursor = await async_paginate( - slackClient.conversations_list, - "channels", + return await list_conversations( # type: ignore[no-any-return] + context=context, + conversation_types=conversation_types, limit=limit, next_cursor=next_cursor, - types=conversation_types_filter, - exclude_archived=True, ) - return { - "conversations": [ - dict(**extract_conversation_metadata(conversation)) - for conversation in results - if conversation.get("is_im") or conversation.get("is_member") - ], - "next_cursor": next_cursor, - } - @tool( requires_auth=Slack( @@ -875,9 +893,11 @@ async def list_public_channels_metadata( context: ToolContext, limit: Annotated[int | None, "The maximum number of channels to list."] = None, ) -> Annotated[dict, "The public channels"]: - """List metadata for public channels in Slack that the user is a member of.""" + """List metadata for public channels in Slack that the user is a member of. - return await list_conversations_metadata( # type: ignore[no-any-return] + This tool is deprecated. Use the `Slack.ListConversations` tool instead. + """ + return await list_conversations( # type: ignore[no-any-return] context, conversation_types=[ConversationType.PUBLIC_CHANNEL], limit=limit, @@ -893,9 +913,11 @@ async def list_private_channels_metadata( context: ToolContext, limit: Annotated[int | None, "The maximum number of channels to list."] = None, ) -> Annotated[dict, "The private channels"]: - """List metadata for private channels in Slack that the user is a member of.""" + """List metadata for private channels in Slack that the user is a member of. - return await list_conversations_metadata( # type: ignore[no-any-return] + This tool is deprecated. Use the `Slack.ListConversations` tool instead. + """ + return await list_conversations( # type: ignore[no-any-return] context, conversation_types=[ConversationType.PRIVATE_CHANNEL], limit=limit, @@ -911,9 +933,11 @@ async def list_group_direct_message_conversations_metadata( context: ToolContext, limit: Annotated[int | None, "The maximum number of conversations to list."] = None, ) -> Annotated[dict, "The group direct message conversations metadata"]: - """List metadata for group direct message conversations that the user is a member of.""" + """List metadata for group direct message conversations that the user is a member of. - return await list_conversations_metadata( # type: ignore[no-any-return] + This tool is deprecated. Use the `Slack.ListConversations` tool instead. + """ + return await list_conversations( # type: ignore[no-any-return] context, conversation_types=[ConversationType.MULTI_PERSON_DIRECT_MESSAGE], limit=limit, @@ -931,12 +955,125 @@ async def list_direct_message_conversations_metadata( context: ToolContext, limit: Annotated[int | None, "The maximum number of conversations to list."] = None, ) -> Annotated[dict, "The direct message conversations metadata"]: - """List metadata for direct message conversations in Slack that the user is a member of.""" + """List metadata for direct message conversations in Slack that the user is a member of. - response = await list_conversations_metadata( + This tool is deprecated. Use the `Slack.ListConversations` tool instead. + """ + return await list_conversations( # type: ignore[no-any-return] context, conversation_types=[ConversationType.DIRECT_MESSAGE], limit=limit, ) - return response # type: ignore[no-any-return] + +@tool( + requires_auth=Slack( + scopes=[ + "channels:read", + "groups:read", + "mpim:read", + "im:read", + "users:read", + "users:read.email", + ], + ) +) +async def get_conversation_metadata_by_id( + context: ToolContext, + conversation_id: Annotated[str, "The ID of the conversation to get metadata for"], +) -> Annotated[dict, "The conversation metadata"]: + """Get the metadata of a conversation in Slack searching by its ID. + + This tool is deprecated. Use the `Slack.GetConversationMetadata` tool instead. + """ + return await get_conversation_metadata(context, conversation_id=conversation_id) # type: ignore[no-any-return] + + +@tool( + requires_auth=Slack( + scopes=[ + "channels:read", + "groups:read", + "mpim:read", + "im:read", + "users:read", + "users:read.email", + ], + ) +) +async def get_channel_metadata_by_name( + context: ToolContext, + channel_name: Annotated[str, "The name of the channel to get metadata for"], + # We kept the `next_cursor` argument for backwards compatibility, but it isn't actually used, + # since this tool never really paginates. + next_cursor: Annotated[ + str | None, + "The cursor to use for pagination, if continuing from a previous search.", + ] = None, +) -> Annotated[dict, "The channel metadata"]: + """Get the metadata of a channel in Slack searching by its name. + + This tool is deprecated. Use the `Slack.GetConversationMetadata` tool instead.""" + return await get_conversation_metadata(context, channel_name=channel_name) # type: ignore[no-any-return] + + +@tool( + requires_auth=Slack( + scopes=[ + "channels:read", + "groups:read", + "mpim:read", + "im:read", + "users:read", + "users:read.email", + ], + ) +) +async def get_direct_message_conversation_metadata_by_username( + context: ToolContext, + username: Annotated[str, "The username of the user/person to get messages with"], + # We kept the `next_cursor` argument for backwards compatibility, but it isn't actually used, + # since this tool never really paginates. + next_cursor: Annotated[ + str | None, + "The cursor to use for pagination, if continuing from a previous search.", + ] = None, +) -> Annotated[ + dict | None, + "The direct message conversation metadata.", +]: + """Get the metadata of a direct message conversation in Slack by the username. + + This tool is deprecated. Use the `Slack.GetConversationMetadata` tool instead.""" + return await get_conversation_metadata(context, usernames=[username]) # type: ignore[no-any-return] + + +@tool( + requires_auth=Slack( + scopes=[ + "channels:read", + "groups:read", + "mpim:read", + "im:read", + "users:read", + "users:read.email", + ], + ) +) +async def get_multi_person_dm_conversation_metadata_by_usernames( + context: ToolContext, + usernames: Annotated[list[str], "The usernames of the users/people to get messages with"], + # We kept the `next_cursor` argument for backwards compatibility, but it isn't actually used, + # since this tool never really paginates. + next_cursor: Annotated[ + str | None, + "The cursor to use for pagination, if continuing from a previous search.", + ] = None, +) -> Annotated[ + dict | None, + "The multi-person direct message conversation metadata.", +]: + """Get the metadata of a multi-person direct message conversation in Slack by the usernames. + + This tool is deprecated. Use the `Slack.GetConversationMetadata` tool instead.""" + return await get_conversation_metadata(context, usernames=usernames) # type: ignore[no-any-return] diff --git a/toolkits/slack/arcade_slack/tools/users.py b/toolkits/slack/arcade_slack/tools/users.py index ddd60d2e..4a7a673b 100644 --- a/toolkits/slack/arcade_slack/tools/users.py +++ b/toolkits/slack/arcade_slack/tools/users.py @@ -2,12 +2,13 @@ from typing import Annotated, Any, cast from arcade_tdk import ToolContext, tool from arcade_tdk.auth import Slack -from arcade_tdk.errors import RetryableToolError -from slack_sdk.errors import SlackApiError from slack_sdk.web.async_client import AsyncWebClient from arcade_slack.constants import MAX_PAGINATION_TIMEOUT_SECONDS -from arcade_slack.models import SlackPaginationNextCursor, SlackUser +from arcade_slack.models import ( + SlackPaginationNextCursor, +) +from arcade_slack.user_retrieval import get_users_by_id_username_or_email from arcade_slack.utils import ( async_paginate, extract_basic_user_info, @@ -16,62 +17,58 @@ from arcade_slack.utils import ( ) -@tool( - requires_auth=Slack( - scopes=["users:read", "users:read.email"], - ) -) -async def get_user_info_by_id( +@tool(requires_auth=Slack(scopes=["users:read", "users:read.email"])) +async def get_users_info( context: ToolContext, - user_id: Annotated[str, "The ID of the user to get"], -) -> Annotated[dict[str, Any], "The user's information"]: - """Get the information of a user in Slack.""" + user_ids: Annotated[list[str] | None, "The IDs of the users to get"] = None, + usernames: Annotated[ + list[str] | None, + "The usernames of the users to get. Prefer retrieving by user_ids and/or emails, " + "when available, since the performance is better.", + ] = None, + emails: Annotated[list[str] | None, "The emails of the users to get"] = None, +) -> Annotated[dict[str, Any], "The users' information"]: + """Get the information of one or more users in Slack by ID, username, and/or email. - token = ( - context.authorization.token if context.authorization and context.authorization.token else "" - ) - slackClient = AsyncWebClient(token=token) + Provide any combination of user_ids, usernames, and/or emails. If you need to retrieve + data about multiple users, DO NOT CALL THE TOOL MULTIPLE TIMES. Instead, call it once + with all the user_ids, usernames, and/or emails. IF YOU CALL THIS TOOL MULTIPLE TIMES + UNNECESSARILY, YOU WILL RELEASE MORE CO2 IN THE ATMOSPHERE AND CONTRIBUTE TO GLOBAL WARMING. - try: - response = await slackClient.users_info(user=user_id) - except SlackApiError as e: - if e.response.get("error") == "user_not_found": - users = await list_users(context) - available_users = ", ".join(f"{user['id']} ({user['name']})" for user in users["users"]) - - raise RetryableToolError( - "User not found", - developer_message=f"User with ID '{user_id}' not found.", - additional_prompt_content=f"Available users: {available_users}", - retry_after_ms=500, - ) - - user_dict_raw: dict[str, Any] = response.get("user", {}) or {} - user_dict = cast(SlackUser, user_dict_raw) - user = SlackUser(**user_dict) - return dict(**extract_basic_user_info(user)) + If you need to get metadata or messages of a conversation, use the + `Slack.GetConversationMetadata` or `Slack.GetMessages` tool instead. These + tools accept user_ids, usernames, and/or emails. Do not retrieve users' info first, + as it is inefficient, releases more CO2 in the atmosphere, and contributes to climate change. + """ + users = await get_users_by_id_username_or_email(context, user_ids, usernames, emails) + return {"users": users} -@tool( - requires_auth=Slack( - scopes=["users:read", "users:read.email"], - ) -) +@tool(requires_auth=Slack(scopes=["users:read", "users:read.email"])) async def list_users( context: ToolContext, - exclude_bots: Annotated[bool | None, "Whether to exclude bots from the results"] = True, - limit: Annotated[int | None, "The maximum number of users to return."] = None, + exclude_bots: Annotated[ + bool | None, "Whether to exclude bots from the results. Defaults to True." + ] = True, + limit: Annotated[ + int, + # The user object is relatively small, so we allow a higher limit than the default of 200. + "The maximum number of users to return. Defaults to 200. Maximum is 500.", + ] = 200, next_cursor: Annotated[str | None, "The next cursor token to use for pagination."] = None, ) -> Annotated[dict, "The users' info"]: - """List all users in the authenticated user's Slack team.""" + """List all users in the authenticated user's Slack team. - token = ( - context.authorization.token if context.authorization and context.authorization.token else "" - ) - slackClient = AsyncWebClient(token=token) + If you need to get metadata or messages of a conversation, use the + `Slack.GetConversationMetadata` tool or `Slack.GetMessages` tool instead. These + tools accept a user_id, username, and/or email. Do not use this tool to first retrieve user(s), + as it is inefficient and releases more CO2 in the atmosphere, contributing to climate change. + """ + limit = max(1, min(limit, 500)) + slack_client = AsyncWebClient(token=context.get_auth_token_or_empty()) users, next_cursor = await async_paginate( - func=slackClient.users_list, + func=slack_client.users_list, response_key="members", limit=limit, next_cursor=cast(SlackPaginationNextCursor, next_cursor), @@ -85,3 +82,18 @@ async def list_users( ] return {"users": users, "next_cursor": next_cursor} + + +# NOTE: This tool is kept here for backwards compatibility. +# Use the `Slack.GetUsersInfo` tool instead. +@tool(requires_auth=Slack(scopes=["users:read", "users:read.email"])) +async def get_user_info_by_id( + context: ToolContext, + user_id: Annotated[str, "The ID of the user to get"], +) -> Annotated[dict[str, Any], "The user's information"]: + """Get the information of a user in Slack. + + This tool is deprecated. Use the `Slack.GetUsersInfo` tool instead. + """ + users = await get_users_info(context, user_ids=[user_id]) + return cast(dict[str, Any], users["users"][0]) diff --git a/toolkits/slack/arcade_slack/user_retrieval.py b/toolkits/slack/arcade_slack/user_retrieval.py new file mode 100644 index 00000000..9dee548a --- /dev/null +++ b/toolkits/slack/arcade_slack/user_retrieval.py @@ -0,0 +1,214 @@ +import asyncio +from typing import Any, cast + +from arcade_tdk import ToolContext +from arcade_tdk.errors import ToolExecutionError +from slack_sdk.errors import SlackApiError +from slack_sdk.web.async_client import AsyncWebClient + +from arcade_slack.constants import MAX_CONCURRENT_REQUESTS, MAX_PAGINATION_TIMEOUT_SECONDS +from arcade_slack.models import ( + FindMultipleUsersByIdSentinel, + FindMultipleUsersByUsernameSentinel, + GetUserByEmailCaller, +) +from arcade_slack.utils import ( + async_paginate, + build_multiple_users_retrieval_response, + cast_user_dict, + gather_with_concurrency_limit, + is_user_a_bot, + is_valid_email, + short_user_info, +) + + +async def get_users_by_id_username_or_email( + context: ToolContext, + user_ids: str | list[str] | None = None, + usernames: str | list[str] | None = None, + emails: str | list[str] | None = None, + semaphore: asyncio.Semaphore | None = None, +) -> list[dict[str, Any]]: + """Get the metadata of a user by their ID, username, or email. + + Provide any combination of user_ids, usernames, and/or emails. Always prefer providing user_ids + and/or emails, when available, since the performance is better. + """ + if isinstance(user_ids, str): + user_ids = [user_ids] + if isinstance(usernames, str): + usernames = [usernames] + if isinstance(emails, str): + emails = [emails] + + if not any([user_ids, usernames, emails]): + raise ToolExecutionError("At least one of user_ids, usernames, or emails must be provided") + + if not semaphore: + semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS) + + user_retrieval_calls = [] + + auth_token = context.get_auth_token_or_empty() + + if user_ids: + user_retrieval_calls.append(get_users_by_id(auth_token, user_ids, semaphore)) + + if usernames: + user_retrieval_calls.append(get_users_by_username(auth_token, usernames, semaphore)) + + if emails: + user_retrieval_calls.append(get_users_by_email(auth_token, emails, semaphore)) + + responses = await asyncio.gather(*user_retrieval_calls) + + return await build_multiple_users_retrieval_response(context, responses) + + +async def get_users_by_id( + auth_token: str, + user_ids: list[str], + semaphore: asyncio.Semaphore | None = None, +) -> dict[str, list]: + user_ids = list(set(user_ids)) + + if len(user_ids) == 1: + user = await get_single_user_by_id(auth_token, user_ids[0]) + if not user: + return {"users": [], "not_found": user_ids} + else: + return {"users": [user], "not_found": []} + + if not semaphore: + semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS) + + async with semaphore: + slack_client = AsyncWebClient(token=auth_token) + response, _ = await async_paginate( + func=slack_client.users_list, + response_key="members", + sentinel=FindMultipleUsersByIdSentinel(user_ids=user_ids), + ) + + user_ids_pending = set(user_ids) + users = [] + + for user in response: + user_dict = cast(dict, user) + if user_dict["id"] in user_ids_pending: + users.append(cast_user_dict(user_dict)) + user_ids_pending.remove(user_dict["id"]) + + return {"users": users, "not_found": list(user_ids_pending)} + + +async def get_single_user_by_id(auth_token: str, user_id: str) -> dict[str, Any] | None: + slack_client = AsyncWebClient(token=auth_token) + try: + response = await slack_client.users_info(user=user_id) + if not response.get("ok"): + return None + return cast_user_dict(response["user"]) + except SlackApiError as e: + if "not_found" in e.response.get("error", ""): + return None + else: + message = f"There was an error getting the user with ID {user_id}." + slack_error_message = e.response.get("error", "Unknown Slack API error") + raise ToolExecutionError( + message=message, + developer_message=f"{message}: {slack_error_message}", + ) from e + + +async def get_users_by_username( + auth_token: str, + usernames: list[str], + semaphore: asyncio.Semaphore | None = None, +) -> dict[str, list[dict]]: + usernames = list(set(usernames)) + + if not semaphore: + semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS) + + slack_client = AsyncWebClient(token=auth_token) + + async with semaphore: + users, _ = await async_paginate( + func=slack_client.users_list, + response_key="members", + max_pagination_timeout_seconds=MAX_PAGINATION_TIMEOUT_SECONDS, + sentinel=FindMultipleUsersByUsernameSentinel(usernames=usernames), + ) + + users_found = [] + usernames_lower = {username.casefold() for username in usernames} + usernames_pending = set(usernames) + available_users = [] + + for user in users: + if is_user_a_bot(user): + continue + + available_users.append(short_user_info(user)) + + if not isinstance(user.get("name"), str): + continue + + username_lower = user["name"].casefold() + + if username_lower in usernames_lower: + users_found.append(cast_user_dict(user)) + # Username/handle is unique in Slack, we can ignore it after finding a match + for pending_username in usernames_pending: + if pending_username.casefold() == username_lower: + usernames_pending.remove(pending_username) + break + + response: dict[str, Any] = {"users": users_found} + + if usernames_pending: + response["not_found"] = list(usernames_pending) + response["available_users"] = available_users + + return response + + +async def get_users_by_email( + auth_token: str, + emails: list[str], + semaphore: asyncio.Semaphore | None = None, +) -> dict[str, list[dict]]: + emails = list(set(emails)) + + for email in emails: + if not is_valid_email(email): + raise ToolExecutionError(f"Invalid email address: {email}") + + if not semaphore: + semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS) + + slack_client = AsyncWebClient(token=auth_token) + callers = [GetUserByEmailCaller(slack_client.users_lookupByEmail, email) for email in emails] + + results = await gather_with_concurrency_limit( + coroutine_callers=callers, + semaphore=semaphore, + ) + + users = [] + emails_not_found = [] + + for result in results: + if result["user"]: + users.append(cast_user_dict(result["user"])) + else: + emails_not_found.append(result["email"]) + + response: dict[str, Any] = {"users": users} + + if emails_not_found: + response["not_found"] = emails_not_found + + return response diff --git a/toolkits/slack/arcade_slack/utils.py b/toolkits/slack/arcade_slack/utils.py index 22a0f431..e4dce2e0 100644 --- a/toolkits/slack/arcade_slack/utils.py +++ b/toolkits/slack/arcade_slack/utils.py @@ -1,23 +1,28 @@ import asyncio -from collections.abc import Callable +import json +import re +from collections.abc import Callable, Sequence from datetime import datetime, timezone -from typing import Any +from typing import Any, cast from arcade_tdk import ToolContext from arcade_tdk.errors import RetryableToolError -from arcade_slack.constants import MAX_PAGINATION_SIZE_LIMIT, MAX_PAGINATION_TIMEOUT_SECONDS -from arcade_slack.custom_types import SlackPaginationNextCursor -from arcade_slack.exceptions import ( - PaginationTimeoutError, - UsernameNotFoundError, +from arcade_slack.constants import ( + MAX_CONCURRENT_REQUESTS, + MAX_PAGINATION_SIZE_LIMIT, + MAX_PAGINATION_TIMEOUT_SECONDS, ) +from arcade_slack.custom_types import SlackPaginationNextCursor +from arcade_slack.exceptions import PaginationTimeoutError from arcade_slack.models import ( + AbstractConcurrencySafeCoroutineCaller, BasicUserInfo, ConversationMetadata, ConversationType, ConversationTypeSlackName, Message, + PaginationSentinel, SlackConversation, SlackConversationPurpose, SlackMessage, @@ -75,7 +80,7 @@ def remove_none_values(params: dict) -> dict: return {k: v for k, v in params.items() if v is not None} -def get_slack_conversation_type_as_str(channel: SlackConversation) -> str: +def get_slack_conversation_type_as_str(channel: SlackConversation) -> str | None: """Get the type of conversation from a Slack channel's dictionary. Args: @@ -92,19 +97,7 @@ def get_slack_conversation_type_as_str(channel: SlackConversation) -> str: return ConversationTypeSlackName.IM.value if channel.get("is_mpim"): return ConversationTypeSlackName.MPIM.value - raise ValueError(f"Invalid conversation type in channel {channel.get('name')}") - - -def get_user_by_username(username: str, users_list: list[dict]) -> SlackUser: - usernames_found = [] - for user in users_list: - if isinstance(user.get("name"), str): - usernames_found.append(user["name"]) - username_found = user.get("name") or "" - if username.lower() == username_found.lower(): - return SlackUser(**user) - - raise UsernameNotFoundError(usernames_found=usernames_found, username_not_found=username) + raise ValueError(f"Invalid conversation type in channel: {json.dumps(channel)}") def convert_conversation_type_to_slack_name( @@ -191,62 +184,6 @@ def extract_basic_user_info(user_info: SlackUser) -> BasicUserInfo: ) -async def associate_members_of_multiple_conversations( - get_members_in_conversation_func: Callable, - conversations: list[dict], - context: ToolContext, -) -> list[dict]: - """Associate members to each conversation, returning the updated list.""" - return await asyncio.gather(*[ # type: ignore[no-any-return] - associate_members_of_conversation(get_members_in_conversation_func, context, conv) - for conv in conversations - ]) - - -async def associate_members_of_conversation( - get_members_in_conversation_func: Callable, - context: ToolContext, - conversation: dict, -) -> dict: - response = await get_members_in_conversation_func(context, conversation["id"]) - conversation["members"] = response["members"] - return conversation - - -async def retrieve_conversations_by_user_ids( - list_conversations_func: Callable, - get_members_in_conversation_func: Callable, - context: ToolContext, - conversation_types: list[ConversationType], - user_ids: list[str], - exact_match: bool = False, - limit: int | None = None, - next_cursor: str | None = None, -) -> list[dict]: - """ - Retrieve conversations filtered by the given user IDs. Includes pagination support - and optionally limits the number of returned conversations. - """ - conversations_found: list[dict] = [] - - response = await list_conversations_func( - context=context, - conversation_types=conversation_types, - next_cursor=next_cursor, - ) - - # Associate members to each conversation - conversations_with_members = await associate_members_of_multiple_conversations( - get_members_in_conversation_func, response["conversations"], context - ) - - conversations_found.extend( - filter_conversations_by_user_ids(conversations_with_members, user_ids, exact_match) - ) - - return conversations_found[:limit] - - def filter_conversations_by_user_ids( conversations: list[dict], user_ids: list[str], @@ -317,6 +254,7 @@ async def async_paginate( limit: int | None = None, next_cursor: SlackPaginationNextCursor | None = None, max_pagination_timeout_seconds: int = MAX_PAGINATION_TIMEOUT_SECONDS, + sentinel: PaginationSentinel | None = None, *args: Any, **kwargs: Any, ) -> tuple[list, SlackPaginationNextCursor | None]: @@ -332,6 +270,10 @@ async def async_paginate( not provided, the entire response dictionary is used. limit: The maximum number of items to retrieve (defaults to Slack's suggested limit). next_cursor: The cursor to use for pagination (optional). + max_pagination_timeout_seconds: The maximum timeout for the pagination loop (defaults to + MAX_PAGINATION_TIMEOUT_SECONDS). + sentinel: Control whether the pagination should continue after each iteration (optional). + If provided, the pagination will stop when the sentinel function returns True. *args: Positional arguments to pass to the Slack method. **kwargs: Keyword arguments to pass to the Slack method. @@ -358,13 +300,18 @@ async def async_paginate( response = await func(*args, **iteration_kwargs) try: - results.extend(dict(response.data) if not response_key else response[response_key]) + result = dict(response.data) if not response_key else response[response_key] + results.extend(result) except KeyError: raise ValueError(f"Response key {response_key} not found in Slack response") next_cursor = response.get("response_metadata", {}).get("next_cursor") - if (limit and len(results) >= limit) or not next_cursor: + if ( + (sentinel and sentinel(last_result=result)) + or (limit and len(results) >= limit) + or not next_cursor + ): should_continue = False return results @@ -445,3 +392,215 @@ def convert_relative_datetime_to_unix_timestamp( days, hours, minutes = map(int, relative_datetime.split(":")) seconds = days * 86400 + hours * 3600 + minutes * 60 return int(current_unix_timestamp - seconds) + + +def short_user_info(user: dict) -> dict[str, str | None]: + data = {"id": user.get("id")} + if user.get("name"): + data["name"] = user["name"] + if isinstance(user.get("profile"), dict) and user["profile"].get("email"): + data["email"] = user["profile"]["email"] + elif user.get("email"): + data["email"] = user["email"] + return data + + +def short_human_users_info(users: list[dict]) -> list[dict[str, str | None]]: + return [short_user_info(user) for user in users if not user.get("is_bot")] + + +def is_valid_email(email: str) -> bool: + """Validate an email address using regex. + + Args: + email: The email address to validate. + + Returns: + True if the email is valid, False otherwise. + """ + email_pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$" + return bool(re.match(email_pattern, email)) + + +async def build_multiple_users_retrieval_response( + context: ToolContext, + users_responses: list[dict[str, Any]], +) -> list[dict[str, Any]]: + """Builds response list for the get_multiple_users_by_usernames_or_emails function.""" + await raise_for_users_not_found(context, users_responses) + + users = [] + + for users_response in users_responses: + users.extend(users_response["users"]) + + return cast(list[dict[str, Any]], users) + + +async def raise_for_users_not_found( + context: ToolContext, users_responses: list[dict[str, Any]] +) -> None: + """Raise an error if any user was not found in the responses.""" + users_not_found, available_users = collect_users_not_found_in_responses(users_responses) + + if users_not_found: + not_found_message = ", ".join(users_not_found) + s = "" if len(users_not_found) == 1 else "s" + message = f"User{s} not found: {not_found_message}" + available_users_prompt = await get_available_users_prompt(context, available_users) + + raise RetryableToolError( + message=message, + developer_message=message, + additional_prompt_content=available_users_prompt, + retry_after_ms=500, + ) + + +def collect_users_not_found_in_responses( + responses: list[dict[str, Any]], +) -> tuple[list[str], list[dict[str, Any]]]: + users_not_found = [] + available_users = [] + + for response in responses: + if response.get("not_found"): + users_not_found.extend(response["not_found"]) + if response.get("available_users"): + available_users = response["available_users"] + + return users_not_found, available_users + + +async def get_available_users_prompt( + context: ToolContext, + available_users: list[dict] | None = None, + limit: int = 100, +) -> str: + try: + from arcade_slack.tools.users import list_users # Avoid circular import + + if isinstance(available_users, list) and available_users: + available_users = [ + user for user in available_users if not is_user_a_bot(SlackUser(**user)) + ] + available_users_str = json.dumps(short_human_users_info(available_users)) + next_cursor = None + potentially_more_users = True + else: + users = await list_users(context, limit=limit, exclude_bots=True) + next_cursor = users["next_cursor"] + available_users_str = json.dumps(short_human_users_info(users["users"])) + potentially_more_users = bool(next_cursor) + + if not potentially_more_users: + return f"The users available are: {available_users_str}" + else: + msg = ( + f"Some of the available users are: {available_users_str}. Potentially more users " + f"can be retrieved by calling the 'Slack.{list_users.__tool_name__}' tool" + ) + if next_cursor: + msg += f" using the next cursor: '{next_cursor}' to continue pagination." + return msg + except Exception as e: + return ( + "The tool tried to retrieve a list of available users, but failed with error: " + f"{type(e).__name__}: {e!s}. Use the 'Slack.{list_users.__tool_name__}' tool " + "to get a list of users." + ) + + +async def gather_with_concurrency_limit( + coroutine_callers: Sequence[AbstractConcurrencySafeCoroutineCaller], + semaphore: asyncio.Semaphore | None = None, + max_concurrent_requests: int = MAX_CONCURRENT_REQUESTS, +) -> list[Any]: + if not semaphore: + semaphore = asyncio.Semaphore(max_concurrent_requests) + + return await asyncio.gather(*[caller(semaphore) for caller in coroutine_callers]) # type: ignore[no-any-return] + + +def cast_user_dict(user: dict[str, Any]) -> dict[str, Any]: + slack_user = SlackUser(**cast(dict, user)) + return dict(**extract_basic_user_info(slack_user)) + + +async def populate_users_in_messages(auth_token: str, messages: list[dict]) -> list[dict]: + if not messages: + return messages + + users = await get_users_from_messages(auth_token, messages) + users_by_id = {user["id"]: {"id": user["id"], "name": user["name"]} for user in users} + + for message in messages: + if message.get("type") != "message": + continue + + # Message author + message["user"] = users_by_id.get( + message.get("user"), {"id": message["user"], "name": None} + ) + + # User mentions in the message text + text_mentions = re.findall(r"<@([A-Z0-9]+)>", message.get("text", "")) + for user_id in text_mentions: + if user_id in users_by_id: + user = users_by_id.get(user_id, {"id": user_id, "name": None}) + name = user.get("name") + message["text"] = message["text"].replace( + f"<@{user_id}>", f"<@{name} (id:{user_id})>" if name else f"<@{user_id}>" + ) + + # User mentions in reactions + reactions = message.get("reactions") + if isinstance(reactions, list): + for reaction in reactions: + reaction_users = [] + for user_id in reaction.get("users", []): + reaction_users.append(users_by_id.get(user_id, {"id": user_id, "name": None})) + reaction["users"] = reaction_users + + return messages + + +async def get_users_from_messages(auth_token: str, messages: list[dict]) -> list[dict[str, Any]]: + if not messages: + return [] + + from arcade_slack.user_retrieval import get_users_by_id # Avoid circular import + + user_ids = get_user_ids_from_messages(messages) + response = await get_users_by_id(auth_token, user_ids) + print("\n\n\nresponse:", response, "\n\n\n") + return response["users"] + + +def get_user_ids_from_messages(messages: list[dict]) -> list[str]: + if not messages: + return [] + + user_ids = [] + + for message in messages: + if message.get("type") != "message": + continue + + # Message author + user = message.get("user") + if isinstance(user, str) and user: + user_ids.append(user) + + # User mentions in the message text + text = message.get("text") + if isinstance(text, str) and text: + user_ids.extend(re.findall(r"<@([A-Z0-9]+)>", text)) + + # User mentions in reactions + reactions = message.get("reactions") + if isinstance(reactions, list): + for reaction in reactions: + user_ids.extend(reaction.get("users", [])) + + return user_ids diff --git a/toolkits/slack/conftest.py b/toolkits/slack/conftest.py index f1534f3d..d8f95066 100644 --- a/toolkits/slack/conftest.py +++ b/toolkits/slack/conftest.py @@ -1,3 +1,7 @@ +import random +import string +from collections.abc import Callable + import pytest from arcade_tdk import ToolAuthorizationContext, ToolContext @@ -18,3 +22,142 @@ def mock_chat_slack_client(mocker): def mock_users_slack_client(mocker): mock_client = mocker.patch("arcade_slack.tools.users.AsyncWebClient", autospec=True) return mock_client.return_value + + +@pytest.fixture +def mock_user_retrieval_slack_client(mocker): + mock_client = mocker.patch("arcade_slack.user_retrieval.AsyncWebClient", autospec=True) + return mock_client.return_value + + +@pytest.fixture +def mock_conversation_retrieval_slack_client(mocker): + mock_client = mocker.patch("arcade_slack.conversation_retrieval.AsyncWebClient", autospec=True) + return mock_client.return_value + + +@pytest.fixture +def mock_message_retrieval_slack_client(mocker): + mock_client = mocker.patch("arcade_slack.message_retrieval.AsyncWebClient", autospec=True) + return mock_client.return_value + + +@pytest.fixture +def random_str_factory(): + def random_str_factory(length: int = 10): + return "".join(random.choices(string.ascii_letters + string.digits, k=length)) # noqa: S311 + + return random_str_factory + + +@pytest.fixture +def random_ts_factory(): + def random_ts_factory(): + return f"{random.uniform(1735689600.000000, 1751327999.999999)}" # noqa: S311 + + return random_ts_factory + + +@pytest.fixture +def dummy_channel_factory(random_str_factory: Callable[[int], str]): + def dummy_channel_factory( + id_: str | None = None, + name: str | None = None, + is_member: bool = True, + is_private: bool = False, + is_archived: bool = False, + is_channel: bool = False, + is_im: bool = False, + is_mpim: bool = False, + num_members: int | None = None, + user: str | None = None, + is_user_deleted: bool = False, + ): + channel = { + "id": id_ or f"channel_id_{random_str_factory()}", + "is_member": is_member, + "is_private": is_private, + "is_archived": is_archived, + } + + if name or is_channel or is_mpim: + channel["name"] = name or f"channel_name_{random_str_factory()}" + + if is_channel: + channel["is_channel"] = True + if is_im: + channel["is_im"] = True + if is_mpim: + channel["is_group"] = True + if num_members: + channel["num_members"] = num_members + if user or is_im: + channel["user"] = user or f"user_id_{random_str_factory()}" + if is_user_deleted: + channel["is_user_deleted"] = is_user_deleted + + return channel + + return dummy_channel_factory + + +@pytest.fixture +def dummy_user_factory(random_str_factory: Callable[[int], str]): + def dummy_user_factory( + id_: str | None = None, + name: str | None = None, + email: str | None = None, + is_bot: bool = False, + ): + return { + "id": id_ or random_str_factory(), + "name": name or random_str_factory(), + "profile": { + "email": email or f"{random_str_factory()}@{random_str_factory()}.com", + }, + "is_bot": is_bot, + } + + return dummy_user_factory + + +@pytest.fixture +def dummy_reaction_factory(random_str_factory): + def reaction_factory( + name: str | None = None, + user_ids: list[str] | None = None, + count: int | None = None, + ): + count = count or random.randint(1, 10) # noqa: S311 + if user_ids: + count = len(user_ids) + return { + "count": count, + "name": name or random_str_factory(), + "users": user_ids or [random_str_factory() for _ in range(count)], + } + + return reaction_factory + + +@pytest.fixture +def dummy_message_factory(random_str_factory, random_ts_factory): + def message_factory( + user_id: str | None = None, + text: str | None = None, + reactions: list[dict] | None = None, + type_: str = "message", + ts: float | None = None, + ): + message = { + "user": user_id or random_str_factory(), + "text": text or random_str_factory(), + "type": type_, + "ts": ts or random_ts_factory(), + } + + if reactions: + message["reactions"] = reactions + return message + + return message_factory diff --git a/toolkits/slack/evals/chat/eval_get_metadata.py b/toolkits/slack/evals/chat/eval_get_metadata.py new file mode 100644 index 00000000..33661274 --- /dev/null +++ b/toolkits/slack/evals/chat/eval_get_metadata.py @@ -0,0 +1,206 @@ +from arcade_evals import ( + BinaryCritic, + EvalRubric, + EvalSuite, + ExpectedToolCall, + tool_eval, +) +from arcade_tdk import ToolCatalog + +import arcade_slack +from arcade_slack.tools.chat import get_conversation_metadata + +# Evaluation rubric +rubric = EvalRubric( + fail_threshold=0.8, + warn_threshold=0.9, +) + + +catalog = ToolCatalog() +# Register the Slack tools +catalog.add_module(arcade_slack) + + +@tool_eval() +def get_conversations_metadata_eval_suite() -> EvalSuite: + """Create an evaluation suite for tools getting conversations metadata.""" + suite = EvalSuite( + name="Slack Tools Evaluation", + system_message="You are an AI assistant that can interact with Slack to get information from conversations, users, etc.", + catalog=catalog, + rubric=rubric, + ) + + suite.add_case( + name="Get channel metadata by name", + user_message="Get the metadata of the #general channel", + expected_tool_calls=[ + ExpectedToolCall( + func=get_conversation_metadata, + args={ + "conversation_id": None, + "channel_name": "general", + "user_ids": None, + "usernames": None, + "emails": None, + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=1 / 5), + BinaryCritic(critic_field="channel_name", weight=1 / 5), + BinaryCritic(critic_field="user_ids", weight=1 / 5), + BinaryCritic(critic_field="usernames", weight=1 / 5), + BinaryCritic(critic_field="emails", weight=1 / 5), + ], + ) + + suite.add_case( + name="Get conversation metadata by id", + user_message="Get the metadata of the conversation with id '1234567890'", + expected_tool_calls=[ + ExpectedToolCall( + func=get_conversation_metadata, + args={ + "conversation_id": "1234567890", + "channel_name": None, + "user_ids": None, + "usernames": None, + "emails": None, + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=1 / 5), + BinaryCritic(critic_field="channel_name", weight=1 / 5), + BinaryCritic(critic_field="user_ids", weight=1 / 5), + BinaryCritic(critic_field="usernames", weight=1 / 5), + BinaryCritic(critic_field="emails", weight=1 / 5), + ], + ) + + suite.add_case( + name="Get conversation metadata by username mentioning DM", + user_message="get the metadata of the DM with janedoe", + expected_tool_calls=[ + ExpectedToolCall( + func=get_conversation_metadata, + args={ + "conversation_id": None, + "channel_name": None, + "user_ids": None, + "usernames": ["janedoe"], + "emails": None, + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=1 / 5), + BinaryCritic(critic_field="channel_name", weight=1 / 5), + BinaryCritic(critic_field="user_ids", weight=1 / 5), + BinaryCritic(critic_field="usernames", weight=1 / 5), + BinaryCritic(critic_field="emails", weight=1 / 5), + ], + ) + + suite.add_case( + name="Get conversation metadata by username mentioning IM", + user_message="get metadata about my IM with janedoe", + expected_tool_calls=[ + ExpectedToolCall( + func=get_conversation_metadata, + args={ + "conversation_id": None, + "channel_name": None, + "user_ids": None, + "usernames": ["janedoe"], + "emails": None, + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=1 / 5), + BinaryCritic(critic_field="channel_name", weight=1 / 5), + BinaryCritic(critic_field="user_ids", weight=1 / 5), + BinaryCritic(critic_field="usernames", weight=1 / 5), + BinaryCritic(critic_field="emails", weight=1 / 5), + ], + ) + + suite.add_case( + name="Get conversation metadata by email mentioning DM", + user_message="get the metadata of the DM with jane.doe@acme.com", + expected_tool_calls=[ + ExpectedToolCall( + func=get_conversation_metadata, + args={ + "conversation_id": None, + "channel_name": None, + "user_ids": None, + "usernames": None, + "emails": ["jane.doe@acme.com"], + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=1 / 5), + BinaryCritic(critic_field="channel_name", weight=1 / 5), + BinaryCritic(critic_field="user_ids", weight=1 / 5), + BinaryCritic(critic_field="usernames", weight=1 / 5), + BinaryCritic(critic_field="emails", weight=1 / 5), + ], + ) + + suite.add_case( + name="Get conversation metadata by email mentioning IM", + user_message="get the metadata of the IM with jane.doe@acme.com", + expected_tool_calls=[ + ExpectedToolCall( + func=get_conversation_metadata, + args={ + "conversation_id": None, + "channel_name": None, + "user_ids": None, + "usernames": None, + "emails": ["jane.doe@acme.com"], + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=1 / 5), + BinaryCritic(critic_field="channel_name", weight=1 / 5), + BinaryCritic(critic_field="user_ids", weight=1 / 5), + BinaryCritic(critic_field="usernames", weight=1 / 5), + BinaryCritic(critic_field="emails", weight=1 / 5), + ], + ) + + suite.add_case( + name="Get conversation metadata by mixed user ID, email, and username", + user_message=( + "get the metadata of the multi-person conversation I have with these users together: " + "janedoe, john@acme.com, and U0123456789" + ), + expected_tool_calls=[ + ExpectedToolCall( + func=get_conversation_metadata, + args={ + "conversation_id": None, + "channel_name": None, + "user_ids": ["U0123456789"], + "usernames": ["janedoe"], + "emails": ["john@acme.com"], + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=1 / 5), + BinaryCritic(critic_field="channel_name", weight=1 / 5), + BinaryCritic(critic_field="user_ids", weight=1 / 5), + BinaryCritic(critic_field="usernames", weight=1 / 5), + BinaryCritic(critic_field="emails", weight=1 / 5), + ], + ) + + return suite diff --git a/toolkits/slack/evals/chat/eval_get_users_in_conversation.py b/toolkits/slack/evals/chat/eval_get_users_in_conversation.py new file mode 100644 index 00000000..434454d2 --- /dev/null +++ b/toolkits/slack/evals/chat/eval_get_users_in_conversation.py @@ -0,0 +1,81 @@ +from arcade_evals import ( + BinaryCritic, + EvalRubric, + EvalSuite, + ExpectedToolCall, + tool_eval, +) +from arcade_tdk import ToolCatalog + +import arcade_slack +from arcade_slack.tools.chat import get_users_in_conversation + +# Evaluation rubric +rubric = EvalRubric( + fail_threshold=0.8, + warn_threshold=0.9, +) + + +catalog = ToolCatalog() +# Register the Slack tools +catalog.add_module(arcade_slack) + + +@tool_eval() +def get_users_in_conversation_eval_suite() -> EvalSuite: + """Create an evaluation suite for tools getting conversations members.""" + suite = EvalSuite( + name="Slack Tools Evaluation", + system_message="You are an AI assistant that can interact with Slack to send messages and get information from conversations, users, etc.", + catalog=catalog, + rubric=rubric, + ) + + user_messages = [ + "Get the members of the #general channel", + "Get the users in the #general channel", + "Get a list of people in the #general channel", + "Get a list of people in the general channel", + "Show me who's in the #general channel", + "Who is in the general channel?", + ] + + for user_message in user_messages: + suite.add_case( + name=f"Get users in channel by channel name: {user_message}", + user_message=user_message, + expected_tool_calls=[ + ExpectedToolCall( + func=get_users_in_conversation, + args={ + "conversation_id": None, + "channel_name": "general", + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.4), + BinaryCritic(critic_field="channel_name", weight=0.6), + ], + ) + + suite.add_case( + name="Get users in conversation by conversation id", + user_message="Get the users in the conversation with id '1234567890'", + expected_tool_calls=[ + ExpectedToolCall( + func=get_users_in_conversation, + args={ + "conversation_id": "1234567890", + "channel_name": None, + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.6), + BinaryCritic(critic_field="channel_name", weight=0.4), + ], + ) + + return suite diff --git a/toolkits/slack/evals/chat/eval_list_conversations.py b/toolkits/slack/evals/chat/eval_list_conversations.py new file mode 100644 index 00000000..fcb80c66 --- /dev/null +++ b/toolkits/slack/evals/chat/eval_list_conversations.py @@ -0,0 +1,175 @@ +from arcade_evals import ( + BinaryCritic, + EvalRubric, + EvalSuite, + ExpectedToolCall, + tool_eval, +) +from arcade_tdk import ToolCatalog + +import arcade_slack +from arcade_slack.models import ConversationType +from arcade_slack.tools.chat import list_conversations + +# Evaluation rubric +rubric = EvalRubric( + fail_threshold=0.8, + warn_threshold=0.9, +) + + +catalog = ToolCatalog() +# Register the Slack tools +catalog.add_module(arcade_slack) + + +@tool_eval() +def list_conversations_eval_suite() -> EvalSuite: + """Create an evaluation suite for tools listing conversations.""" + suite = EvalSuite( + name="Slack Messaging Tools Evaluation", + system_message="You are an AI assistant that can interact with Slack to send messages and get information from conversations, users, etc.", + catalog=catalog, + rubric=rubric, + ) + + suite.add_case( + name="List all conversations I am a member of", + user_message="List all conversations I am a member of", + expected_tool_calls=[ + ExpectedToolCall( + func=list_conversations, + args={ + "conversation_types": None, + }, + ) + ], + critics=[ + BinaryCritic(critic_field="conversation_types", weight=1.0), + ], + ) + + suite.add_case( + name="List 10 conversations I am a member of", + user_message="List 10 conversations I am a member of", + expected_tool_calls=[ + ExpectedToolCall( + func=list_conversations, + args={ + "conversation_types": None, + "limit": 10, + }, + ) + ], + critics=[ + BinaryCritic(critic_field="conversation_types", weight=0.5), + BinaryCritic(critic_field="limit", weight=0.5), + ], + ) + + suite.add_case( + name="List all public channels", + user_message="List all public channels", + expected_tool_calls=[ + ExpectedToolCall( + func=list_conversations, + args={ + "conversation_types": [ConversationType.PUBLIC_CHANNEL.value], + }, + ) + ], + critics=[ + BinaryCritic(critic_field="conversation_types", weight=1.0), + ], + ) + + suite.add_case( + name="List all private channels", + user_message="List all private channels", + expected_tool_calls=[ + ExpectedToolCall( + func=list_conversations, + args={ + "conversation_types": [ConversationType.PRIVATE_CHANNEL.value], + }, + ) + ], + critics=[ + BinaryCritic(critic_field="conversation_types", weight=1.0), + ], + ) + + suite.add_case( + name="List all public and private channels", + user_message="List all public and private channels", + expected_tool_calls=[ + ExpectedToolCall( + func=list_conversations, + args={ + "conversation_types": [ + ConversationType.PUBLIC_CHANNEL.value, + ConversationType.PRIVATE_CHANNEL.value, + ], + }, + ) + ], + critics=[ + BinaryCritic(critic_field="conversation_types", weight=1.0), + ], + ) + + suite.add_case( + name="List direct message channels", + user_message="List direct message channels", + expected_tool_calls=[ + ExpectedToolCall( + func=list_conversations, + args={ + "conversation_types": [ + ConversationType.DIRECT_MESSAGE.value, + ], + }, + ) + ], + critics=[ + BinaryCritic(critic_field="conversation_types", weight=1.0), + ], + ) + + suite.add_case( + name="List group direct message channels", + user_message="List group direct message channels", + expected_tool_calls=[ + ExpectedToolCall( + func=list_conversations, + args={ + "conversation_types": [ + ConversationType.MULTI_PERSON_DIRECT_MESSAGE.value, + ], + }, + ) + ], + critics=[ + BinaryCritic(critic_field="conversation_types", weight=1.0), + ], + ) + + suite.add_case( + name="List my multi-person conversations", + user_message="List my multi-person conversations", + expected_tool_calls=[ + ExpectedToolCall( + func=list_conversations, + args={ + "conversation_types": [ + ConversationType.MULTI_PERSON_DIRECT_MESSAGE.value, + ], + }, + ) + ], + critics=[ + BinaryCritic(critic_field="conversation_types", weight=1.0), + ], + ) + + return suite diff --git a/toolkits/slack/evals/chat/messages/eval_get_channel_messages.py b/toolkits/slack/evals/chat/messages/eval_get_channel_messages.py new file mode 100644 index 00000000..ffa6b036 --- /dev/null +++ b/toolkits/slack/evals/chat/messages/eval_get_channel_messages.py @@ -0,0 +1,622 @@ +import json +from datetime import timedelta + +from arcade_evals import ( + BinaryCritic, + DatetimeCritic, + EvalRubric, + EvalSuite, + ExpectedToolCall, + tool_eval, +) +from arcade_tdk import ToolCatalog + +import arcade_slack +from arcade_slack.critics import RelativeTimeBinaryCritic +from arcade_slack.tools.chat import get_messages + +# Evaluation rubric +rubric = EvalRubric( + fail_threshold=0.8, + warn_threshold=0.9, +) + + +catalog = ToolCatalog() +# Register the Slack tools +catalog.add_module(arcade_slack) + + +@tool_eval() +def get_messages_in_channel_eval_suite() -> EvalSuite: + """Create an evaluation suite for tools getting messages in channels.""" + suite = EvalSuite( + name="Slack Chat Tools Evaluation", + system_message="You are an AI assistant that can interact with Slack to send messages and get information from conversations, users, etc.", + catalog=catalog, + rubric=rubric, + ) + + no_arguments_user_messages_by_channel_name = [ + "what are the latest messages in the #general channel", + "show me the messages in the general channel", + "list the messages in the #general channel", + "list the messages in the general channel", + ] + + for i, user_message in enumerate(no_arguments_user_messages_by_channel_name): + suite.add_case( + name=f"Get messages in conversation by name {i}: '{user_message}'", + user_message=user_message, + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": None, + "channel_name": "general", + "user_ids": None, + "usernames": None, + "emails": None, + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.1), + BinaryCritic(critic_field="channel_name", weight=0.6), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.1), + ], + ) + + no_arguments_user_messages_by_conversation_id = [ + "Get the history of the conversation with id '1234567890'", + "Get the history of the conversation with id '1234567890'", + "list the messages in the conversation with id '1234567890'", + "list the messages in the conversation with id '1234567890'", + ] + + for user_message in no_arguments_user_messages_by_conversation_id: + suite.add_case( + name=f"Get conversation history by id: '{user_message}'", + user_message=user_message, + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": "1234567890", + "channel_name": None, + "user_ids": None, + "usernames": None, + "emails": None, + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.6), + BinaryCritic(critic_field="channel_name", weight=0.1), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.1), + ], + ) + + suite.add_case( + name="Get conversation history with limit by name", + user_message="Get the last 10 messages in the #general channel", + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": None, + "channel_name": "general", + "user_ids": None, + "usernames": None, + "emails": None, + "limit": 10, + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.1), + BinaryCritic(critic_field="channel_name", weight=0.3), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.1), + BinaryCritic(critic_field="limit", weight=0.3), + ], + ) + + suite.add_case( + name="Get conversation history with limit by id", + user_message="Get the last 25 messages in the conversation with id '1234567890'", + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": "1234567890", + "channel_name": None, + "user_ids": None, + "usernames": None, + "emails": None, + "limit": 25, + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.3), + BinaryCritic(critic_field="channel_name", weight=0.1), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.1), + BinaryCritic(critic_field="limit", weight=0.3), + ], + ) + + # Relative time eval cases by id + + suite.add_case( + name="Get conversation history oldest relative by id (2 days ago)", + user_message="Get the messages in the conversation with id '1234567890' starting 2 days ago", + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": "1234567890", + "channel_name": None, + "user_ids": None, + "usernames": None, + "emails": None, + "oldest_relative": "02:00:00", + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.3), + BinaryCritic(critic_field="channel_name", weight=0.1), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.1), + RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.3), + ], + ) + + suite.add_case( + name="Get conversation history oldest and latest relative by id", + user_message="Get the messages in the conversation with id '1234567890' from 2 days ago to 3 hours ago", + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": "1234567890", + "channel_name": None, + "user_ids": None, + "usernames": None, + "emails": None, + "oldest_relative": "02:00:00", + "latest_relative": "00:03:00", + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.2), + BinaryCritic(critic_field="channel_name", weight=0.1), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.1), + RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.2), + RelativeTimeBinaryCritic(critic_field="latest_relative", weight=0.2), + ], + ) + + suite.add_case( + name="Get conversation history oldest relative by id (1 week ago)", + user_message="Get the messages in the conversation with id '1234567890' starting 1 week ago", + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": "1234567890", + "channel_name": None, + "user_ids": None, + "usernames": None, + "emails": None, + "oldest_relative": "07:00:00", + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.3), + BinaryCritic(critic_field="channel_name", weight=0.1), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.1), + RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.3), + ], + ) + + suite.add_case( + name="Get conversation history oldest relative by id (yesterday)", + user_message="Get the messages in the conversation with id '1234567890' from yesterday", + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": "1234567890", + "channel_name": None, + "user_ids": None, + "usernames": None, + "emails": None, + "oldest_relative": "01:00:00", + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.3), + BinaryCritic(critic_field="channel_name", weight=0.1), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.1), + RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.3), + ], + ) + + # Relative time eval cases by name + + suite.add_case( + name="Get conversation history oldest relative by name (2 days ago)", + user_message="Get the messages in the #general channel starting 2 days ago", + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": None, + "channel_name": "general", + "user_ids": None, + "usernames": None, + "emails": None, + "oldest_relative": "02:00:00", + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.1), + BinaryCritic(critic_field="channel_name", weight=0.3), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.1), + RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.3), + ], + ) + + suite.add_case( + name="Get conversation history oldest and latest relative by name", + user_message="Get the messages in the #general channel from 2 days ago to 3 hours ago", + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": None, + "channel_name": "general", + "user_ids": None, + "usernames": None, + "emails": None, + "oldest_relative": "02:00:00", + "latest_relative": "00:03:00", + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.1), + BinaryCritic(critic_field="channel_name", weight=0.2), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.1), + RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.2), + RelativeTimeBinaryCritic(critic_field="latest_relative", weight=0.2), + ], + ) + + suite.add_case( + name="Get conversation history oldest relative by name (yesterday)", + user_message="Get the messages in the #general channel from yesterday", + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": None, + "channel_name": "general", + "user_ids": None, + "usernames": None, + "emails": None, + "oldest_relative": "01:00:00", + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.1), + BinaryCritic(critic_field="channel_name", weight=0.3), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.1), + RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.3), + ], + ) + + suite.add_case( + name="Get conversation history oldest relative by name (last week)", + user_message="Get the messages in the #general channel from last week", + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": None, + "channel_name": "general", + "user_ids": None, + "usernames": None, + "emails": None, + "oldest_relative": "07:00:00", + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.1), + BinaryCritic(critic_field="channel_name", weight=0.3), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.1), + RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.3), + ], + ) + + # Absolute time eval cases by id + + suite.add_case( + name="Get conversation history oldest absolute by id (on a specific date)", + user_message="Get the messages in the conversation with id '1234567890' from 2025-01-20", + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": "1234567890", + "channel_name": None, + "user_ids": None, + "usernames": None, + "emails": None, + "oldest_datetime": "2025-01-20 00:00:00", + "latest_datetime": "2025-01-20 23:59:59", + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.2), + BinaryCritic(critic_field="channel_name", weight=0.1), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.1), + DatetimeCritic( + critic_field="oldest_datetime", weight=0.2, max_difference=timedelta(minutes=2) + ), + DatetimeCritic( + critic_field="latest_datetime", weight=0.2, max_difference=timedelta(minutes=2) + ), + ], + ) + + suite.add_case( + name="Get conversation history oldest absolute by id (between a date range)", + user_message="Get the messages in the conversation with id '1234567890' from 2025-01-20 to 2025-01-25", + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": "1234567890", + "channel_name": None, + "user_ids": None, + "usernames": None, + "emails": None, + "oldest_datetime": "2025-01-20 00:00:00", + "latest_datetime": "2025-01-25 23:59:59", + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.2), + BinaryCritic(critic_field="channel_name", weight=0.1), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.1), + DatetimeCritic( + critic_field="oldest_datetime", weight=0.2, max_difference=timedelta(minutes=2) + ), + DatetimeCritic( + critic_field="latest_datetime", weight=0.2, max_difference=timedelta(minutes=2) + ), + ], + ) + + suite.add_case( + name="Get conversation history oldest absolute by name (on a specific date)", + user_message="Get the messages in the #general channel from 2025-01-20", + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": None, + "channel_name": "general", + "user_ids": None, + "usernames": None, + "emails": None, + "oldest_datetime": "2025-01-20 00:00:00", + "latest_datetime": "2025-01-20 23:59:59", + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.1), + BinaryCritic(critic_field="channel_name", weight=0.2), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.1), + # We use a timedelta of 10 seconds because sometimes the LLM will select the limit + # date at 23:59:59, other times it'll select the next day at 00:00:00. + DatetimeCritic( + critic_field="oldest_datetime", weight=0.2, max_difference=timedelta(seconds=10) + ), + DatetimeCritic( + critic_field="latest_datetime", weight=0.2, max_difference=timedelta(seconds=10) + ), + ], + ) + + suite.add_case( + name="Get conversation history oldest absolute by name (between a date range)", + user_message="Get the messages in the #general channel from 2025-01-20 to 2025-01-25", + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": None, + "channel_name": "general", + "user_ids": None, + "usernames": None, + "emails": None, + "oldest_datetime": "2025-01-20 00:00:00", + "latest_datetime": "2025-01-25 23:59:59", + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.1), + BinaryCritic(critic_field="channel_name", weight=0.2), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.1), + # We use a timedelta of 10 seconds because sometimes the LLM will select the limit + # date at 23:59:59, other times it'll select the next day at 00:00:00. + DatetimeCritic( + critic_field="oldest_datetime", weight=0.2, max_difference=timedelta(seconds=10) + ), + DatetimeCritic( + critic_field="latest_datetime", weight=0.2, max_difference=timedelta(seconds=10) + ), + ], + ) + + # Eval case for pagination + + suite.add_case( + name="Get conversation history with pagination", + user_message="get the next 5 messages", + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": None, + "channel_name": "general", + "user_ids": None, + "usernames": None, + "emails": None, + "limit": 5, + "next_cursor": "dXNlcjpVsDjzOTZGVDlQRA==", + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.1), + BinaryCritic(critic_field="channel_name", weight=0.2), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.1), + BinaryCritic(critic_field="next_cursor", weight=0.2), + BinaryCritic(critic_field="limit", weight=0.2), + ], + additional_messages=[ + {"role": "user", "content": "Get the last 2 messages on the general channel"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": { + "name": "Slack_GetConversationHistoryByName", + "arguments": json.dumps({ + "conversation_name": "general", + "limit": 2, + }), + }, + } + ], + }, + { + "role": "tool", + "content": json.dumps({ + "messages": [ + { + "blocks": [ + { + "block_id": "abc123", + "elements": [ + { + "elements": [ + { + "text": "Almost there, Boss, need to get some evals in!", + "type": "text", + } + ], + "type": "rich_text_section", + } + ], + "type": "rich_text", + } + ], + "client_msg_id": "msg_id_0", + "datetime_timestamp": "2025-01-21 16:59:55", + "team": "617263616465207465616D20697320617420626F7373206C6576656C", + "text": "Almost there, Boss, need to get some evals in!", + "ts": "1737507595.598529", + "type": "message", + "user": "77686F2069732074686520626F73733F", + }, + { + "blocks": [ + { + "block_id": "xyz456", + "elements": [ + { + "elements": [ + { + "text": "hey, are the Slack Tools ready yet?", + "type": "text", + } + ], + "type": "rich_text_section", + } + ], + "type": "rich_text", + } + ], + "client_msg_id": "msg_id_1", + "datetime_timestamp": "2025-01-21 16:57:35", + "team": "617263616465207465616D20697320617420626F7373206C6576656C", + "text": "hey, are the Slack Tools ready yet?", + "ts": "1737507595.598529", + "type": "message", + "user": "73616D2069732074686520626F7373", + }, + ], + "next_cursor": "dXNlcjpVsDjzOTZGVDlQRA==", + }), + "tool_call_id": "call_1", + "name": "Slack_GetConversationHistoryByName", + }, + { + "role": "assistant", + "content": 'Here are the last 2 messages from the general channel:\n\n1. **User:** 77686F2069732074686520626F73733F \n **Message:** "Almost there, Boss, need to get some evals in!" \n **Timestamp:** 2025-01-21 16:59:55\n\n2. **User:** 73616D2069732074686520626F7373 \n **Message:** "hey, are the Slack Tools ready yet?" \n **Timestamp:** 2025-01-21 16:57:35', + }, + ], + ) + + return suite diff --git a/toolkits/slack/evals/chat/messages/eval_get_dm_messages.py b/toolkits/slack/evals/chat/messages/eval_get_dm_messages.py new file mode 100644 index 00000000..88156592 --- /dev/null +++ b/toolkits/slack/evals/chat/messages/eval_get_dm_messages.py @@ -0,0 +1,191 @@ +from datetime import timedelta + +from arcade_evals import ( + BinaryCritic, + DatetimeCritic, + EvalRubric, + EvalSuite, + ExpectedToolCall, + tool_eval, +) +from arcade_tdk import ToolCatalog + +import arcade_slack +from arcade_slack.critics import RelativeTimeBinaryCritic +from arcade_slack.tools.chat import get_messages + +# Evaluation rubric +rubric = EvalRubric( + fail_threshold=0.8, + warn_threshold=0.9, +) + + +catalog = ToolCatalog() +# Register the Slack tools +catalog.add_module(arcade_slack) + + +@tool_eval() +def get_messages_in_direct_message_eval_suite() -> EvalSuite: + """Create an evaluation suite for tools getting messages in direct messages.""" + suite = EvalSuite( + name="Slack Chat Tools Evaluation", + system_message="You are an AI assistant that can interact with Slack to send messages and get information from conversations, users, etc.", + catalog=catalog, + rubric=rubric, + ) + + no_arguments_user_messages_by_username = [ + "what are the latest messages I exchanged with janedoe", + "show my messages with janedoe on Slack", + "get the messages I exchanged with janedoe", + "get the message history with janedoe", + ] + + for i, user_message in enumerate(no_arguments_user_messages_by_username): + suite.add_case( + name=f"{user_message} [{i}]", + user_message=user_message, + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": None, + "channel_name": None, + "user_ids": None, + "usernames": ["janedoe"], + "emails": None, + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.1), + BinaryCritic(critic_field="channel_name", weight=0.1), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.6), + BinaryCritic(critic_field="emails", weight=0.1), + ], + ) + + no_arguments_user_messages_by_email = [ + "what are the latest messages I exchanged with jane.doe@acme.com", + "show my messages with jane.doe@acme.com on Slack", + ] + + for i, user_message in enumerate(no_arguments_user_messages_by_email): + suite.add_case( + name=f"{user_message} [{i}]", + user_message=user_message, + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": None, + "channel_name": None, + "user_ids": None, + "usernames": None, + "emails": ["jane.doe@acme.com"], + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.1), + BinaryCritic(critic_field="channel_name", weight=0.1), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.6), + ], + ) + + suite.add_case( + name="get messages in direct conversation by username (on a specific date)", + user_message="get the messages I exchanged with janedoe on 2025-01-31", + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": None, + "channel_name": None, + "user_ids": None, + "usernames": ["janedoe"], + "emails": None, + "oldest_datetime": "2025-01-31 00:00:00", + "latest_datetime": "2025-01-31 23:59:59", + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.1), + BinaryCritic(critic_field="channel_name", weight=0.1), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.2), + BinaryCritic(critic_field="emails", weight=0.1), + DatetimeCritic( + critic_field="oldest_datetime", weight=0.2, max_difference=timedelta(minutes=2) + ), + DatetimeCritic( + critic_field="latest_datetime", weight=0.2, max_difference=timedelta(minutes=2) + ), + ], + ) + + suite.add_case( + name="get messages in direct conversation by email (on a specific date)", + user_message="get the messages I exchanged with jane.doe@acme.com on 2025-01-31", + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": None, + "channel_name": None, + "user_ids": None, + "usernames": None, + "emails": ["jane.doe@acme.com"], + "oldest_datetime": "2025-01-31 00:00:00", + "latest_datetime": "2025-01-31 23:59:59", + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.1), + BinaryCritic(critic_field="channel_name", weight=0.1), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.2), + DatetimeCritic( + critic_field="oldest_datetime", weight=0.2, max_difference=timedelta(minutes=2) + ), + DatetimeCritic( + critic_field="latest_datetime", weight=0.2, max_difference=timedelta(minutes=2) + ), + ], + ) + + suite.add_case( + name="Get conversation history oldest relative by username (2 days ago)", + user_message="Get the messages I exchanged with janedoe starting 2 days ago", + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": None, + "channel_name": None, + "user_ids": None, + "usernames": ["janedoe"], + "emails": None, + "oldest_relative": "02:00:00", + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.1), + BinaryCritic(critic_field="channel_name", weight=0.1), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.3), + BinaryCritic(critic_field="emails", weight=0.1), + RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.3), + ], + ) + + return suite diff --git a/toolkits/slack/evals/chat/messages/eval_get_mpim_messages.py b/toolkits/slack/evals/chat/messages/eval_get_mpim_messages.py new file mode 100644 index 00000000..028e4ad1 --- /dev/null +++ b/toolkits/slack/evals/chat/messages/eval_get_mpim_messages.py @@ -0,0 +1,165 @@ +from datetime import timedelta + +from arcade_evals import ( + BinaryCritic, + DatetimeCritic, + EvalRubric, + EvalSuite, + ExpectedToolCall, + tool_eval, +) +from arcade_tdk import ToolCatalog + +import arcade_slack +from arcade_slack.critics import RelativeTimeBinaryCritic +from arcade_slack.tools.chat import get_messages + +# Evaluation rubric +rubric = EvalRubric( + fail_threshold=0.8, + warn_threshold=0.9, +) + + +catalog = ToolCatalog() +# Register the Slack tools +catalog.add_module(arcade_slack) + + +@tool_eval() +def get_messages_in_multi_person_direct_message_eval_suite() -> EvalSuite: + """Create an evaluation suite for tools getting messages in multi-person direct messages.""" + suite = EvalSuite( + name="Slack Chat Tools Evaluation", + system_message="You are an AI assistant that can interact with Slack to send messages and get information from conversations, users, etc.", + catalog=catalog, + rubric=rubric, + ) + + no_arguments_user_messages_by_username = [ + ( + "what are the latest messages I exchanged in the MPIM " + "with the usernames john, ryan, and jennifer" + ), + ("show the messages in the MPIM with the usernames john, ryan, and jennifer on Slack"), + ("list the messages I exchanged in the MPIM with the usernames john, ryan, and jennifer"), + ("list the message history in the MPIM with the usernames john, ryan, and jennifer"), + ] + + for i, user_message in enumerate(no_arguments_user_messages_by_username): + suite.add_case( + name=f"{user_message} [{i}]", + user_message=user_message, + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": None, + "channel_name": None, + "user_ids": None, + "usernames": ["john", "ryan", "jennifer"], + "emails": None, + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.1), + BinaryCritic(critic_field="channel_name", weight=0.1), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.6), + BinaryCritic(critic_field="emails", weight=0.1), + ], + ) + + suite.add_case( + name="get messages in multi person direct conversation with mixed usernames and emails", + user_message=( + "get the messages I exchanged in the mpim with " + "the usernames john, ryan, and jennifer@acme.com" + ), + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": None, + "channel_name": None, + "user_ids": None, + "usernames": ["john", "ryan"], + "emails": ["jennifer@acme.com"], + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.1), + BinaryCritic(critic_field="channel_name", weight=0.1), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.35), + BinaryCritic(critic_field="emails", weight=0.35), + ], + ) + + suite.add_case( + name="get messages in direct conversation by username (on a specific date)", + user_message=( + "get the messages I exchanged in the mpim with " + "the usernames john, ryan, and jennifer on 2025-01-31" + ), + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": None, + "channel_name": None, + "user_ids": None, + "usernames": ["john", "ryan", "jennifer"], + "emails": None, + "oldest_datetime": "2025-01-31 00:00:00", + "latest_datetime": "2025-01-31 23:59:59", + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.1), + BinaryCritic(critic_field="channel_name", weight=0.1), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.2), + BinaryCritic(critic_field="emails", weight=0.1), + DatetimeCritic( + critic_field="oldest_datetime", weight=0.2, max_difference=timedelta(minutes=2) + ), + DatetimeCritic( + critic_field="latest_datetime", weight=0.2, max_difference=timedelta(minutes=2) + ), + ], + ) + + suite.add_case( + name="Get conversation history oldest relative by username (2 days ago)", + user_message=( + "Get the messages I exchanged in the MPIM with " + "the usernames john, ryan, and jennifer starting 2 days ago" + ), + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": None, + "channel_name": None, + "user_ids": None, + "usernames": ["john", "ryan", "jennifer"], + "emails": None, + "oldest_relative": "02:00:00", + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.1), + BinaryCritic(critic_field="channel_name", weight=0.1), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.3), + BinaryCritic(critic_field="emails", weight=0.1), + RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.3), + ], + ) + + return suite diff --git a/toolkits/slack/evals/chat/messages/eval_send_messages.py b/toolkits/slack/evals/chat/messages/eval_send_messages.py new file mode 100644 index 00000000..70a40688 --- /dev/null +++ b/toolkits/slack/evals/chat/messages/eval_send_messages.py @@ -0,0 +1,279 @@ +import json + +from arcade_evals import ( + BinaryCritic, + EvalRubric, + EvalSuite, + ExpectedToolCall, + SimilarityCritic, + tool_eval, +) +from arcade_tdk import ToolCatalog + +import arcade_slack +from arcade_slack.tools.chat import send_message + +# Evaluation rubric +rubric = EvalRubric( + fail_threshold=0.8, + warn_threshold=0.9, +) + + +catalog = ToolCatalog() +# Register the Slack tools +catalog.add_module(arcade_slack) + + +@tool_eval() +def send_message_eval_suite() -> EvalSuite: + """Create an evaluation suite for Slack messaging tools.""" + suite = EvalSuite( + name="Slack Messaging Tools Evaluation", + system_message="You are an AI assistant that can send direct messages and post messages to channels in Slack using the provided tools.", + catalog=catalog, + rubric=rubric, + ) + + # Send DM to User Scenarios + suite.add_case( + name="Send DM to user with clear username", + user_message="Send a direct message to johndoe saying 'Hello, can we meet at 3 PM?'", + expected_tool_calls=[ + ExpectedToolCall( + func=send_message, + args={ + "message": "Hello, can we meet at 3 PM?", + "conversation_id": None, + "channel_name": None, + "user_ids": None, + "emails": None, + "usernames": ["johndoe"], + }, + ) + ], + critics=[ + SimilarityCritic(critic_field="message", weight=0.3, similarity_threshold=0.9), + BinaryCritic(critic_field="conversation_id", weight=0.1), + BinaryCritic(critic_field="channel_name", weight=0.1), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.3), + ], + ) + + suite.add_case( + name="Send message to channel with clear name", + user_message="Post 'The new feature is now live!' in the #announcements channel", + expected_tool_calls=[ + ExpectedToolCall( + func=send_message, + args={ + "message": "The new feature is now live!", + "conversation_id": None, + "channel_name": "announcements", + "user_ids": None, + "emails": None, + "usernames": None, + }, + ) + ], + critics=[ + SimilarityCritic(critic_field="message", weight=0.3, similarity_threshold=0.9), + BinaryCritic(critic_field="conversation_id", weight=0.1), + BinaryCritic(critic_field="channel_name", weight=0.3), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.1), + ], + ) + + suite.add_case( + name="Send message to channel with ambiguous name", + user_message="Inform the team in the general channel about the upcoming maintenance", + expected_tool_calls=[ + ExpectedToolCall( + func=send_message, + args={ + "message": "Attention team: There will be upcoming maintenance. Please save your work and expect some downtime.", + "conversation_id": None, + "channel_name": "general", + "user_ids": None, + "emails": None, + "usernames": None, + }, + ) + ], + critics=[ + SimilarityCritic(critic_field="message", weight=0.3, similarity_threshold=0.9), + BinaryCritic(critic_field="conversation_id", weight=0.1), + BinaryCritic(critic_field="channel_name", weight=0.3), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.1), + ], + ) + + # Adversarial Scenarios + suite.add_case( + name="Ambiguous between DM and channel message", + user_message="general", + expected_tool_calls=[ + ExpectedToolCall( + func=send_message, + args={ + "message": "Great job on the presentation!", + "conversation_id": None, + "channel_name": "general", + "user_ids": None, + "emails": None, + "usernames": None, + }, + ) + ], + critics=[ + SimilarityCritic(critic_field="message", weight=0.3, similarity_threshold=0.9), + BinaryCritic(critic_field="conversation_id", weight=0.1), + BinaryCritic(critic_field="channel_name", weight=0.3), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.1), + ], + additional_messages=[ + {"role": "user", "content": "Send 'Great job on the presentation!' to the team"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": { + "name": "Slack_ListConversationsMetadata", + "arguments": '{"limit":20, "conversation_types":["public_channel"]}', + }, + } + ], + }, + { + "role": "tool", + "content": json.dumps({ + "conversations": [ + { + "conversation_type": "public_channel", + "id": "channel1", + "is_archived": False, + "is_member": True, + "is_private": False, + "name": "random", + "num_members": 999, + "purpose": "Random stuff", + }, + { + "conversation_type": "public_channel", + "id": "channel2", + "is_archived": False, + "is_member": True, + "is_private": False, + "name": "general", + "num_members": 999, + "purpose": "Just a general channel", + }, + ], + "next_cursor": "", + }), + "tool_call_id": "call_1", + "name": "Slack_ListPublicChannelsMetadata", + }, + { + "role": "assistant", + "content": 'To send the message "Great job on the presentation!" to the team, please let me know which Slack channel you\'d like to use:\n\n1. #random\n2. #general\n\nPlease let me know your choice!', + }, + ], + ) + + suite.add_case( + name="Multiple recipients in DM request", + user_message="Send DMs to the users 'alice' and 'bob' about rescheduling our meeting tomorrow. I have too much work to do.", + expected_tool_calls=[ + ExpectedToolCall( + func=send_message, + args={ + "conversation_id": None, + "channel_name": None, + "user_ids": None, + "emails": None, + "usernames": ["alice"], + }, + ), + ExpectedToolCall( + func=send_message, + args={ + "conversation_id": None, + "channel_name": None, + "user_ids": None, + "emails": None, + "usernames": ["bob"], + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.1), + BinaryCritic(critic_field="channel_name", weight=0.1), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.6), + ], + ) + + suite.add_case( + name="Multiple recipients in MPIM request", + user_message="Send a message to the users 'alice' and 'bob' about rescheduling our meeting tomorrow. I have too much work to do.", + expected_tool_calls=[ + ExpectedToolCall( + func=send_message, + args={ + "conversation_id": None, + "channel_name": None, + "user_ids": None, + "emails": None, + "usernames": ["alice", "bob"], + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.1), + BinaryCritic(critic_field="channel_name", weight=0.1), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.6), + ], + ) + + suite.add_case( + name="Channel name similar to username", + user_message="Post 'sounds great!' in john-project channel", + expected_tool_calls=[ + ExpectedToolCall( + func=send_message, + args={ + "conversation_id": None, + "channel_name": "john-project", + "user_ids": None, + "emails": None, + "usernames": None, + "message": "Sounds great!", + }, + ) + ], + critics=[ + SimilarityCritic(critic_field="message", weight=0.3, similarity_threshold=0.9), + BinaryCritic(critic_field="conversation_id", weight=0.1), + BinaryCritic(critic_field="channel_name", weight=0.3), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.1), + ], + ) + + return suite diff --git a/toolkits/slack/evals/eval_chat.py b/toolkits/slack/evals/eval_chat.py deleted file mode 100644 index 122aca95..00000000 --- a/toolkits/slack/evals/eval_chat.py +++ /dev/null @@ -1,1168 +0,0 @@ -import json -from datetime import timedelta - -from arcade_evals import ( - BinaryCritic, - DatetimeCritic, - EvalRubric, - EvalSuite, - ExpectedToolCall, - SimilarityCritic, - tool_eval, -) -from arcade_tdk import ToolCatalog - -import arcade_slack -from arcade_slack.critics import RelativeTimeBinaryCritic -from arcade_slack.tools.chat import ( - get_channel_metadata_by_name, - get_conversation_metadata_by_id, - get_direct_message_conversation_metadata_by_username, - get_members_in_channel_by_name, - get_members_in_conversation_by_id, - get_messages_in_channel_by_name, - get_messages_in_conversation_by_id, - get_messages_in_direct_message_conversation_by_username, - get_messages_in_multi_person_dm_conversation_by_usernames, - list_conversations_metadata, - list_direct_message_conversations_metadata, - list_group_direct_message_conversations_metadata, - list_private_channels_metadata, - list_public_channels_metadata, - send_dm_to_user, - send_message_to_channel, -) - -# Evaluation rubric -rubric = EvalRubric( - fail_threshold=0.8, - warn_threshold=0.9, -) - - -catalog = ToolCatalog() -# Register the Slack tools -catalog.add_module(arcade_slack) - - -@tool_eval() -def send_message_eval_suite() -> EvalSuite: - """Create an evaluation suite for Slack messaging tools.""" - suite = EvalSuite( - name="Slack Messaging Tools Evaluation", - system_message="You are an AI assistant that can send direct messages and post messages to channels in Slack using the provided tools.", - catalog=catalog, - rubric=rubric, - ) - - # Send DM to User Scenarios - suite.add_case( - name="Send DM to user with clear username", - user_message="Send a direct message to johndoe saying 'Hello, can we meet at 3 PM?'", - expected_tool_calls=[ - ExpectedToolCall( - func=send_dm_to_user, - args={ - "user_name": "johndoe", - "message": "Hello, can we meet at 3 PM?", - }, - ) - ], - critics=[ - BinaryCritic(critic_field="user_name", weight=0.5), - SimilarityCritic(critic_field="message", weight=0.5, similarity_threshold=0.9), - ], - ) - - suite.add_case( - name="Send DM with ambiguous username", - user_message="ask him for an update on the project", - expected_tool_calls=[ - ExpectedToolCall( - func=send_dm_to_user, - args={ - "user_name": "john", - "message": "Hi John, could you please provide an update on the Acme project?", - }, - ) - ], - critics=[ - BinaryCritic(critic_field="user_name", weight=0.75), - SimilarityCritic(critic_field="message", weight=0.25, similarity_threshold=0.6), - ], - additional_messages=[ - {"role": "user", "content": "Message John about the Acme project deadline"}, - { - "role": "assistant", - "content": "", - "tool_calls": [ - { - "id": "call_1", - "type": "function", - "function": { - "name": "Slack_ListUsers", - "arguments": '{"exclude_bots":true}', - }, - } - ], - }, - { - "role": "tool", - "content": json.dumps({ - "users": [ - { - "display_name": "john", - "email": "john@randomtech.com", - "id": "abc123", - "is_bot": False, - "name": "john", - "real_name": "John Doe", - "timezone": "America/Los_Angeles", - }, - { - "display_name": "jack", - "email": "jack@randomtech.com", - "id": "def456", - "is_bot": False, - "name": "jack", - "real_name": "Jack Doe", - "timezone": "America/Los_Angeles", - }, - ] - }), - "tool_call_id": "call_1", - "name": "Slack_ListUsers", - }, - { - "role": "assistant", - "content": "What would you like to include in the message to John about the Acme project deadline?", - }, - ], - ) - - suite.add_case( - name="Send DM with username in different format", - user_message="yes, send it", - expected_tool_calls=[ - ExpectedToolCall( - func=send_dm_to_user, - args={ - "user_name": "jane.doe", - "message": "Hi Jane, I need to reschedule our meeting. When are you available?", - }, - ) - ], - critics=[ - BinaryCritic(critic_field="user_name", weight=0.75), - SimilarityCritic(critic_field="message", weight=0.25, similarity_threshold=0.6), - ], - additional_messages=[ - {"role": "user", "content": "Message Jane.Doe asking to reschedule our meeting"}, - { - "role": "assistant", - "content": "", - "tool_calls": [ - { - "id": "call_1", - "type": "function", - "function": { - "name": "Slack_ListUsers", - "arguments": '{"exclude_bots":true}', - }, - } - ], - }, - { - "role": "tool", - "content": json.dumps({ - "users": [ - { - "display_name": "jane.doe", - "email": "jane@randomtech.com", - "id": "abc123", - "is_bot": False, - "name": "jane.doe", - "real_name": "Jane Doe", - "timezone": "America/Los_Angeles", - }, - { - "display_name": "jack", - "email": "jack@randomtech.com", - "id": "def456", - "is_bot": False, - "name": "jack", - "real_name": "Jack Doe", - "timezone": "America/Los_Angeles", - }, - ] - }), - "tool_call_id": "call_1", - "name": "Slack_ListUsers", - }, - { - "role": "assistant", - "content": "I found a user with the name 'jane.doe'. Would you like to send a message to them?", - }, - ], - ) - - # Send Message to Channel Scenarios - suite.add_case( - name="Send message to channel with clear name", - user_message="Post 'The new feature is now live!' in the #announcements channel", - expected_tool_calls=[ - ExpectedToolCall( - func=send_message_to_channel, - args={ - "channel_name": "announcements", - "message": "The new feature is now live!", - }, - ) - ], - critics=[ - BinaryCritic(critic_field="channel_name", weight=0.5), - SimilarityCritic(critic_field="message", weight=0.5), - ], - ) - - suite.add_case( - name="Send message to channel with ambiguous name", - user_message="Inform the team in the general channel about the upcoming maintenance", - expected_tool_calls=[ - ExpectedToolCall( - func=send_message_to_channel, - args={ - "channel_name": "general", - "message": "Attention team: There will be upcoming maintenance. Please save your work and expect some downtime.", - }, - ) - ], - critics=[ - SimilarityCritic(critic_field="channel_name", weight=0.8), - SimilarityCritic(critic_field="message", weight=0.2, similarity_threshold=0.6), - ], - ) - - # Adversarial Scenarios - suite.add_case( - name="Ambiguous between DM and channel message", - user_message="general", - expected_tool_calls=[ - ExpectedToolCall( - func=send_message_to_channel, - args={ - "channel_name": "general", - "message": "Great job on the presentation!", - }, - ) - ], - critics=[ - SimilarityCritic(critic_field="channel_name", weight=0.4), - SimilarityCritic(critic_field="message", weight=0.6), - ], - additional_messages=[ - {"role": "user", "content": "Send 'Great job on the presentation!' to the team"}, - { - "role": "assistant", - "content": "", - "tool_calls": [ - { - "id": "call_1", - "type": "function", - "function": { - "name": "Slack_ListPublicChannelsMetadata", - "arguments": '{"limit":20}', - }, - } - ], - }, - { - "role": "tool", - "content": json.dumps({ - "conversations": [ - { - "conversation_type": "public_channel", - "id": "channel1", - "is_archived": False, - "is_member": True, - "is_private": False, - "name": "random", - "num_members": 999, - "purpose": "Random stuff", - }, - { - "conversation_type": "public_channel", - "id": "channel2", - "is_archived": False, - "is_member": True, - "is_private": False, - "name": "general", - "num_members": 999, - "purpose": "Just a general channel", - }, - ], - "next_cursor": "", - }), - "tool_call_id": "call_1", - "name": "Slack_ListPublicChannelsMetadata", - }, - { - "role": "assistant", - "content": 'To send the message "Great job on the presentation!" to the team, please let me know which Slack channel you\'d like to use:\n\n1. #random\n2. #general\n\nPlease let me know your choice!', - }, - ], - ) - - # Multiple recipients in DM request - suite.add_case( - name="Multiple recipients in DM request", - user_message="Send DMs to the users 'alice' and 'bob' about pushing the meeting tomorrow. I have to much work to do.", - expected_tool_calls=[ - ExpectedToolCall( - func=send_dm_to_user, - args={ - "user_name": "alice", - "message": "Hi Alice, about our meeting tomorrow, let's reschedule? I am swamped with work.", - }, - ), - ExpectedToolCall( - func=send_dm_to_user, - args={ - "user_name": "bob", - "message": "Hi Bob, about our meeting tomorrow, let's reschedule? I am swamped with work.", - }, - ), - ], - critics=[ - SimilarityCritic(critic_field="user_name", weight=0.75), - SimilarityCritic(critic_field="message", weight=0.25, similarity_threshold=0.5), - ], - ) - - suite.add_case( - name="Channel name similar to username", - user_message="Post 'sounds great!' in john-project channel", - expected_tool_calls=[ - ExpectedToolCall( - func=send_message_to_channel, - args={ - "channel_name": "john-project", - "message": "Sounds great!", - }, - ) - ], - critics=[ - BinaryCritic(critic_field="channel_name", weight=0.5), - SimilarityCritic(critic_field="message", weight=0.5), - ], - ) - - return suite - - -@tool_eval() -def list_conversations_eval_suite() -> EvalSuite: - """Create an evaluation suite for tools listing conversations.""" - suite = EvalSuite( - name="Slack Messaging Tools Evaluation", - system_message="You are an AI assistant that can interact with Slack to send messages and get information from conversations, users, etc.", - catalog=catalog, - rubric=rubric, - ) - - cases = [ - ( - "List my conversations", - "List all conversations I am a member of", - list_conversations_metadata, - ), - ( - "List public channels", - "List all public channels", - list_public_channels_metadata, - ), - ( - "List private channels", - "List all private channels", - list_private_channels_metadata, - ), - ( - "List group direct message channels", - "List all group direct message channels", - list_group_direct_message_conversations_metadata, - ), - ( - "List individual direct message channels", - "List all individual direct message channels", - list_direct_message_conversations_metadata, - ), - ( - "List direct message channels", - "List all direct message channels", - list_direct_message_conversations_metadata, - ), - ( - "List public and private channels", - "List public and private channels I am a member of", - list_public_channels_metadata, - list_private_channels_metadata, - ), - ( - "List public channels and direct message conversations", - "List public channels and direct message conversations I am a member of", - list_public_channels_metadata, - list_direct_message_conversations_metadata, - ), - ] - - for name, user_message, *expect_called_tool_functions in cases: - suite.add_case( - name=name, - user_message=user_message, - expected_tool_calls=[ - ExpectedToolCall( - func=tool_function, - args={}, - ) - for tool_function in expect_called_tool_functions - ], - ) - - return suite - - -@tool_eval() -def get_conversations_metadata_eval_suite() -> EvalSuite: - """Create an evaluation suite for tools getting conversations metadata.""" - suite = EvalSuite( - name="Slack Tools Evaluation", - system_message="You are an AI assistant that can interact with Slack to get information from conversations, users, etc.", - catalog=catalog, - rubric=rubric, - ) - - suite.add_case( - name="Get channel metadata by name", - user_message="Get the metadata of the #general channel", - expected_tool_calls=[ - ExpectedToolCall( - func=get_channel_metadata_by_name, - args={ - "channel_name": "general", - }, - ), - ], - critics=[ - BinaryCritic(critic_field="conversation_name", weight=1.0), - ], - ) - - suite.add_case( - name="Get conversation metadata by id", - user_message="Get the metadata of the conversation with id '1234567890'", - expected_tool_calls=[ - ExpectedToolCall( - func=get_conversation_metadata_by_id, - args={ - "conversation_id": "1234567890", - }, - ), - ], - critics=[ - BinaryCritic(critic_field="conversation_id", weight=1.0), - ], - ) - - get_metadata_by_username_user_messages = [ - "get the metadata of the direct message conversation with the user 'jane.doe'" - "get data about my private conversation with the user 'jane.doe'", - "get data about my IM conversation with the 'jane.doe'", - ] - - for i, user_message in enumerate(get_metadata_by_username_user_messages): - suite.add_case( - name=f"Get direct message conversation metadata by username {i}", - user_message=user_message, - expected_tool_calls=[ - ExpectedToolCall( - func=get_direct_message_conversation_metadata_by_username, - args={ - "username": "jane.doe", - }, - ), - ], - ) - - return suite - - -@tool_eval() -def get_conversations_members_eval_suite() -> EvalSuite: - """Create an evaluation suite for tools getting conversations members.""" - suite = EvalSuite( - name="Slack Tools Evaluation", - system_message="You are an AI assistant that can interact with Slack to send messages and get information from conversations, users, etc.", - catalog=catalog, - rubric=rubric, - ) - - user_messages = [ - "Get the members of the #general channel", - "Get the members of the general channel", - "Get a list of people in the #general channel", - "Get a list of people in the general channel", - "Show me who's in the #general channel", - "Show me who's in the general channel", - "Who is in the #general channel?", - "Who is in the general channel?", - ] - - for user_message in user_messages: - suite.add_case( - name=f"Get channel members by name: {user_message}", - user_message=user_message, - expected_tool_calls=[ - ExpectedToolCall( - func=get_members_in_channel_by_name, - args={ - "channel_name": "general", - }, - ), - ], - critics=[ - BinaryCritic(critic_field="conversation_name", weight=1.0), - ], - ) - - suite.add_case( - name="Get conversation members by id", - user_message="Get the members of the conversation with id '1234567890'", - expected_tool_calls=[ - ExpectedToolCall( - func=get_members_in_conversation_by_id, - args={ - "conversation_id": "1234567890", - }, - ), - ], - critics=[ - BinaryCritic(critic_field="conversation_id", weight=1.0), - ], - ) - - return suite - - -@tool_eval() -def get_messages_in_channel_eval_suite() -> EvalSuite: - """Create an evaluation suite for tools getting messages in channels.""" - suite = EvalSuite( - name="Slack Chat Tools Evaluation", - system_message="You are an AI assistant that can interact with Slack to send messages and get information from conversations, users, etc.", - catalog=catalog, - rubric=rubric, - ) - - no_arguments_user_messages_by_channel_name = [ - "what are the latest messages in the #general channel", - "show me the messages in the general channel", - "list the messages in the #general channel", - "list the messages in the general channel", - ] - - for i, user_message in enumerate(no_arguments_user_messages_by_channel_name): - suite.add_case( - name=f"Get messages in conversation by name {i}: '{user_message}'", - user_message=user_message, - expected_tool_calls=[ - ExpectedToolCall( - func=get_messages_in_channel_by_name, - args={ - "conversation_name": "general", - }, - ), - ], - critics=[ - BinaryCritic(critic_field="conversation_name", weight=1.0), - ], - ) - - no_arguments_user_messages_by_conversation_id = [ - "Get the history of the conversation with id '1234567890'", - "Get the history of the conversation with id '1234567890'", - "list the messages in the conversation with id '1234567890'", - "list the messages in the conversation with id '1234567890'", - ] - - for user_message in no_arguments_user_messages_by_conversation_id: - suite.add_case( - name=f"Get conversation history by id: '{user_message}'", - user_message=user_message, - expected_tool_calls=[ - ExpectedToolCall( - func=get_messages_in_conversation_by_id, - args={ - "conversation_id": "1234567890", - }, - ), - ], - critics=[ - BinaryCritic(critic_field="conversation_id", weight=1.0), - ], - ) - - suite.add_case( - name="Get conversation history with limit by name", - user_message="Get the last 10 messages in the #general channel", - expected_tool_calls=[ - ExpectedToolCall( - func=get_messages_in_channel_by_name, - args={ - "conversation_name": "general", - "limit": 10, - }, - ), - ], - critics=[ - BinaryCritic(critic_field="conversation_name", weight=1.0), - ], - ) - - suite.add_case( - name="Get conversation history with limit by id", - user_message="Get the last 25 messages in the conversation with id '1234567890'", - expected_tool_calls=[ - ExpectedToolCall( - func=get_messages_in_conversation_by_id, - args={ - "conversation_id": "1234567890", - "limit": 25, - }, - ), - ], - critics=[ - BinaryCritic(critic_field="conversation_id", weight=1.0), - ], - ) - - # Relative time eval cases by id - - suite.add_case( - name="Get conversation history oldest relative by id (2 days ago)", - user_message="Get the messages in the conversation with id '1234567890' starting 2 days ago", - expected_tool_calls=[ - ExpectedToolCall( - func=get_messages_in_conversation_by_id, - args={ - "conversation_id": "1234567890", - "oldest_relative": "02:00:00", - }, - ), - ], - critics=[ - BinaryCritic(critic_field="conversation_id", weight=0.5), - RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.5), - ], - ) - - suite.add_case( - name="Get conversation history oldest and latest relative by id", - user_message="Get the messages in the conversation with id '1234567890' from 2 days ago to 3 hours ago", - expected_tool_calls=[ - ExpectedToolCall( - func=get_messages_in_conversation_by_id, - args={ - "conversation_id": "1234567890", - "oldest_relative": "02:00:00", - "latest_relative": "00:03:00", - }, - ), - ], - critics=[ - BinaryCritic(critic_field="conversation_id", weight=1 / 3), - RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=1 / 3), - RelativeTimeBinaryCritic(critic_field="latest_relative", weight=1 / 3), - ], - ) - - suite.add_case( - name="Get conversation history oldest relative by id (1 week ago)", - user_message="Get the messages in the conversation with id '1234567890' starting 1 week ago", - expected_tool_calls=[ - ExpectedToolCall( - func=get_messages_in_conversation_by_id, - args={ - "conversation_id": "1234567890", - "oldest_relative": "07:00:00", - }, - ), - ], - critics=[ - BinaryCritic(critic_field="conversation_id", weight=0.5), - RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.5), - ], - ) - - suite.add_case( - name="Get conversation history oldest relative by id (yesterday)", - user_message="Get the messages in the conversation with id '1234567890' from yesterday", - expected_tool_calls=[ - ExpectedToolCall( - func=get_messages_in_conversation_by_id, - args={ - "conversation_id": "1234567890", - "oldest_relative": "01:00:00", - }, - ), - ], - critics=[ - BinaryCritic(critic_field="conversation_id", weight=0.5), - RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.5), - ], - ) - - # Relative time eval cases by name - - suite.add_case( - name="Get conversation history oldest relative by name (2 days ago)", - user_message="Get the messages in the #general channel starting 2 days ago", - expected_tool_calls=[ - ExpectedToolCall( - func=get_messages_in_channel_by_name, - args={ - "conversation_name": "general", - "oldest_relative": "02:00:00", - }, - ), - ], - critics=[ - BinaryCritic(critic_field="conversation_name", weight=0.5), - RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.5), - ], - ) - - suite.add_case( - name="Get conversation history oldest and latest relative by name", - user_message="Get the messages in the #general channel from 2 days ago to 3 hours ago", - expected_tool_calls=[ - ExpectedToolCall( - func=get_messages_in_channel_by_name, - args={ - "conversation_name": "general", - "oldest_relative": "02:00:00", - "latest_relative": "00:03:00", - }, - ), - ], - critics=[ - BinaryCritic(critic_field="conversation_name", weight=1 / 3), - RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=1 / 3), - RelativeTimeBinaryCritic(critic_field="latest_relative", weight=1 / 3), - ], - ) - - suite.add_case( - name="Get conversation history oldest relative by name (yesterday)", - user_message="Get the messages in the #general channel from yesterday", - expected_tool_calls=[ - ExpectedToolCall( - func=get_messages_in_channel_by_name, - args={ - "conversation_name": "general", - "oldest_relative": "01:00:00", - }, - ), - ], - critics=[ - BinaryCritic(critic_field="conversation_name", weight=0.5), - RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.5), - ], - ) - - suite.add_case( - name="Get conversation history oldest relative by name (last week)", - user_message="Get the messages in the #general channel from last week", - expected_tool_calls=[ - ExpectedToolCall( - func=get_messages_in_channel_by_name, - args={ - "conversation_name": "general", - "oldest_relative": "07:00:00", - }, - ), - ], - critics=[ - BinaryCritic(critic_field="conversation_name", weight=0.5), - RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.5), - ], - ) - - # Absolute time eval cases by id - - suite.add_case( - name="Get conversation history oldest absolute by id (on a specific date)", - user_message="Get the messages in the conversation with id '1234567890' from 2025-01-20", - expected_tool_calls=[ - ExpectedToolCall( - func=get_messages_in_conversation_by_id, - args={ - "conversation_id": "1234567890", - "oldest_datetime": "2025-01-20 00:00:00", - "latest_datetime": "2025-01-20 23:59:59", - }, - ), - ], - critics=[ - BinaryCritic(critic_field="conversation_id", weight=1 / 3), - DatetimeCritic( - critic_field="oldest_datetime", weight=1 / 3, max_difference=timedelta(minutes=2) - ), - DatetimeCritic( - critic_field="latest_datetime", weight=1 / 3, max_difference=timedelta(minutes=2) - ), - ], - ) - - suite.add_case( - name="Get conversation history oldest absolute by id (between a date range)", - user_message="Get the messages in the conversation with id '1234567890' from 2025-01-20 to 2025-01-25", - expected_tool_calls=[ - ExpectedToolCall( - func=get_messages_in_conversation_by_id, - args={ - "conversation_id": "1234567890", - "oldest_datetime": "2025-01-20 00:00:00", - "latest_datetime": "2025-01-25 23:59:59", - }, - ), - ], - critics=[ - BinaryCritic(critic_field="conversation_id", weight=1 / 3), - DatetimeCritic( - critic_field="oldest_datetime", weight=1 / 3, max_difference=timedelta(minutes=2) - ), - DatetimeCritic( - critic_field="latest_datetime", weight=1 / 3, max_difference=timedelta(minutes=2) - ), - ], - ) - - suite.add_case( - name="Get conversation history oldest absolute by name (on a specific date)", - user_message="Get the messages in the #general channel from 2025-01-20", - expected_tool_calls=[ - ExpectedToolCall( - func=get_messages_in_channel_by_name, - args={ - "conversation_name": "general", - "oldest_datetime": "2025-01-20 00:00:00", - "latest_datetime": "2025-01-20 23:59:59", - }, - ), - ], - critics=[ - BinaryCritic(critic_field="conversation_name", weight=1 / 3), - # We use a timedelta of 10 seconds because sometimes the LLM will select the limit - # date at 23:59:59, other times it'll select the next day at 00:00:00. - DatetimeCritic( - critic_field="oldest_datetime", weight=1 / 3, max_difference=timedelta(seconds=10) - ), - DatetimeCritic( - critic_field="latest_datetime", weight=1 / 3, max_difference=timedelta(seconds=10) - ), - ], - ) - - suite.add_case( - name="Get conversation history oldest absolute by name (between a date range)", - user_message="Get the messages in the #general channel from 2025-01-20 to 2025-01-25", - expected_tool_calls=[ - ExpectedToolCall( - func=get_messages_in_channel_by_name, - args={ - "conversation_name": "general", - "oldest_datetime": "2025-01-20 00:00:00", - "latest_datetime": "2025-01-25 23:59:59", - }, - ), - ], - critics=[ - BinaryCritic(critic_field="conversation_name", weight=1 / 3), - # We use a timedelta of 10 seconds because sometimes the LLM will select the limit - # date at 23:59:59, other times it'll select the next day at 00:00:00. - DatetimeCritic( - critic_field="oldest_datetime", weight=1 / 3, max_difference=timedelta(seconds=10) - ), - DatetimeCritic( - critic_field="latest_datetime", weight=1 / 3, max_difference=timedelta(seconds=10) - ), - ], - ) - - # Eval case for pagination - - suite.add_case( - name="Get conversation history with pagination", - user_message="get the next 5 messages", - expected_tool_calls=[ - ExpectedToolCall( - func=get_messages_in_channel_by_name, - args={ - "conversation_name": "general", - "limit": 5, - "cursor": "dXNlcjpVsDjzOTZGVDlQRA==", - }, - ), - ], - critics=[ - BinaryCritic(critic_field="conversation_name", weight=1 / 3), - BinaryCritic(critic_field="cursor", weight=1 / 3), - BinaryCritic(critic_field="limit", weight=1 / 3), - ], - additional_messages=[ - {"role": "user", "content": "Get the last 2 messages on the general channel"}, - { - "role": "assistant", - "content": "", - "tool_calls": [ - { - "id": "call_1", - "type": "function", - "function": { - "name": "Slack_GetConversationHistoryByName", - "arguments": json.dumps({ - "conversation_name": "general", - "limit": 2, - }), - }, - } - ], - }, - { - "role": "tool", - "content": json.dumps({ - "messages": [ - { - "blocks": [ - { - "block_id": "abc123", - "elements": [ - { - "elements": [ - { - "text": "Almost there, Boss, need to get some evals in!", - "type": "text", - } - ], - "type": "rich_text_section", - } - ], - "type": "rich_text", - } - ], - "client_msg_id": "msg_id_0", - "datetime_timestamp": "2025-01-21 16:59:55", - "team": "617263616465207465616D20697320617420626F7373206C6576656C", - "text": "Almost there, Boss, need to get some evals in!", - "ts": "1737507595.598529", - "type": "message", - "user": "77686F2069732074686520626F73733F", - }, - { - "blocks": [ - { - "block_id": "xyz456", - "elements": [ - { - "elements": [ - { - "text": "hey, are the Slack Tools ready yet?", - "type": "text", - } - ], - "type": "rich_text_section", - } - ], - "type": "rich_text", - } - ], - "client_msg_id": "msg_id_1", - "datetime_timestamp": "2025-01-21 16:57:35", - "team": "617263616465207465616D20697320617420626F7373206C6576656C", - "text": "hey, are the Slack Tools ready yet?", - "ts": "1737507595.598529", - "type": "message", - "user": "73616D2069732074686520626F7373", - }, - ], - "next_cursor": "dXNlcjpVsDjzOTZGVDlQRA==", - }), - "tool_call_id": "call_1", - "name": "Slack_GetConversationHistoryByName", - }, - { - "role": "assistant", - "content": 'Here are the last 2 messages from the general channel:\n\n1. **User:** 77686F2069732074686520626F73733F \n **Message:** "Almost there, Boss, need to get some evals in!" \n **Timestamp:** 2025-01-21 16:59:55\n\n2. **User:** 73616D2069732074686520626F7373 \n **Message:** "hey, are the Slack Tools ready yet?" \n **Timestamp:** 2025-01-21 16:57:35', - }, - ], - ) - - return suite - - -@tool_eval() -def get_messages_in_direct_message_eval_suite() -> EvalSuite: - """Create an evaluation suite for tools getting messages in direct messages.""" - suite = EvalSuite( - name="Slack Chat Tools Evaluation", - system_message="You are an AI assistant that can interact with Slack to send messages and get information from conversations, users, etc.", - catalog=catalog, - rubric=rubric, - ) - - no_arguments_user_messages_by_username = [ - "what are the latest messages I exchanged with jane.doe", - "show my messages with jane.doe on Slack", - "list the messages I exchanged with jane.doe", - "list the message history with jane.doe", - ] - - for i, user_message in enumerate(no_arguments_user_messages_by_username): - suite.add_case( - name=f"{user_message} [{i}]", - user_message=user_message, - expected_tool_calls=[ - ExpectedToolCall( - func=get_messages_in_direct_message_conversation_by_username, - args={ - "username": "jane.doe", - }, - ), - ], - critics=[ - BinaryCritic(critic_field="username", weight=1.0), - ], - ) - - suite.add_case( - name="get messages in direct conversation by username (on a specific date)", - user_message="get the messages I exchanged with jane.doe on 2025-01-31", - expected_tool_calls=[ - ExpectedToolCall( - func=get_messages_in_direct_message_conversation_by_username, - args={ - "username": "jane.doe", - "oldest_datetime": "2025-01-31 00:00:00", - "latest_datetime": "2025-01-31 23:59:59", - }, - ), - ], - critics=[ - BinaryCritic(critic_field="username", weight=1 / 3), - DatetimeCritic( - critic_field="oldest_datetime", weight=1 / 3, max_difference=timedelta(minutes=2) - ), - DatetimeCritic( - critic_field="latest_datetime", weight=1 / 3, max_difference=timedelta(minutes=2) - ), - ], - ) - - suite.add_case( - name="Get conversation history oldest relative by username (2 days ago)", - user_message="Get the messages I exchanged with jane.doe starting 2 days ago", - expected_tool_calls=[ - ExpectedToolCall( - func=get_messages_in_direct_message_conversation_by_username, - args={ - "username": "jane.doe", - "oldest_relative": "02:00:00", - }, - ), - ], - critics=[ - BinaryCritic(critic_field="username", weight=0.5), - RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.5), - ], - ) - - return suite - - -@tool_eval() -def get_messages_in_multi_person_direct_message_eval_suite() -> EvalSuite: - """Create an evaluation suite for tools getting messages in multi-person direct messages.""" - suite = EvalSuite( - name="Slack Chat Tools Evaluation", - system_message="You are an AI assistant that can interact with Slack to send messages and get information from conversations, users, etc.", - catalog=catalog, - rubric=rubric, - ) - - no_arguments_user_messages_by_username = [ - "what are the latest messages I exchanged together with the usernames john, ryan, and jennifer", - "show the messages in the multi person dm with the usernames john, ryan, and jennifer on Slack", - "list the messages I exchanged together with the usernames john, ryan, and jennifer", - "list the message history together with the usernames john, ryan, and jennifer", - ] - - for i, user_message in enumerate(no_arguments_user_messages_by_username): - suite.add_case( - name=f"{user_message} [{i}]", - user_message=user_message, - expected_tool_calls=[ - ExpectedToolCall( - func=get_messages_in_multi_person_dm_conversation_by_usernames, - args={ - "usernames": ["john", "ryan", "jennifer"], - }, - ), - ], - critics=[ - BinaryCritic(critic_field="usernames", weight=1.0), - ], - ) - - suite.add_case( - name="get messages in direct conversation by username (on a specific date)", - user_message="get the messages I exchanged together with the usernames john, ryan, and jennifer on 2025-01-31", - expected_tool_calls=[ - ExpectedToolCall( - func=get_messages_in_multi_person_dm_conversation_by_usernames, - args={ - "usernames": ["john", "ryan", "jennifer"], - "oldest_datetime": "2025-01-31 00:00:00", - "latest_datetime": "2025-01-31 23:59:59", - }, - ), - ], - critics=[ - BinaryCritic(critic_field="usernames", weight=1 / 3), - DatetimeCritic( - critic_field="oldest_datetime", weight=1 / 3, max_difference=timedelta(minutes=2) - ), - DatetimeCritic( - critic_field="latest_datetime", weight=1 / 3, max_difference=timedelta(minutes=2) - ), - ], - ) - - suite.add_case( - name="Get conversation history oldest relative by username (2 days ago)", - user_message="Get the messages I exchanged together with the usernames john, ryan, and jennifer starting 2 days ago", - expected_tool_calls=[ - ExpectedToolCall( - func=get_messages_in_multi_person_dm_conversation_by_usernames, - args={ - "usernames": ["john", "ryan", "jennifer"], - "oldest_relative": "02:00:00", - }, - ), - ], - critics=[ - BinaryCritic(critic_field="usernames", weight=0.5), - RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.5), - ], - ) - - return suite diff --git a/toolkits/slack/evals/eval_users.py b/toolkits/slack/evals/eval_users.py index b1f3caa1..9b724145 100644 --- a/toolkits/slack/evals/eval_users.py +++ b/toolkits/slack/evals/eval_users.py @@ -10,7 +10,10 @@ from arcade_evals import ( from arcade_tdk import ToolCatalog import arcade_slack -from arcade_slack.tools.users import get_user_info_by_id, list_users +from arcade_slack.tools.users import ( + get_users_info, + list_users, +) # Evaluation rubric rubric = EvalRubric( @@ -25,7 +28,7 @@ catalog.add_module(arcade_slack) @tool_eval() -def get_user_info_by_id_eval_suite() -> EvalSuite: +def get_user_info_eval_suite() -> EvalSuite: """Create an evaluation suite for tools getting user info by id.""" suite = EvalSuite( name="Slack Users Tools Evaluation", @@ -53,13 +56,101 @@ def get_user_info_by_id_eval_suite() -> EvalSuite: user_message=user_message, expected_tool_calls=[ ExpectedToolCall( - func=get_user_info_by_id, - args={"user_id": expected_user_id}, + func=get_users_info, + args={ + "user_ids": [expected_user_id], + "usernames": None, + "emails": None, + }, ) ], - critics=[BinaryCritic(critic_field="user_id", weight=1.0)], + critics=[ + BinaryCritic(critic_field="user_ids", weight=1 / 3), + BinaryCritic(critic_field="usernames", weight=1 / 3), + BinaryCritic(critic_field="emails", weight=1 / 3), + ], ) + suite.add_case( + name="get user by username", + user_message="get the user 'johndoe'", + expected_tool_calls=[ + ExpectedToolCall( + func=get_users_info, + args={ + "usernames": ["johndoe"], + "user_ids": None, + "emails": None, + }, + ), + ], + critics=[ + BinaryCritic(critic_field="usernames", weight=1 / 3), + BinaryCritic(critic_field="user_ids", weight=1 / 3), + BinaryCritic(critic_field="emails", weight=1 / 3), + ], + ) + + suite.add_case( + name="get user by email", + user_message="get the user 'john.doe@acme.com'", + expected_tool_calls=[ + ExpectedToolCall( + func=get_users_info, + args={ + "usernames": None, + "user_ids": None, + "emails": ["john.doe@acme.com"], + }, + ), + ], + critics=[ + BinaryCritic(critic_field="usernames", weight=1 / 3), + BinaryCritic(critic_field="user_ids", weight=1 / 3), + BinaryCritic(critic_field="emails", weight=1 / 3), + ], + ) + + suite.add_case( + name="get multiple users by username", + user_message="get the users with the usernames 'johndoe' and 'foobar'", + expected_tool_calls=[ + ExpectedToolCall( + func=get_users_info, + args={ + "usernames": ["johndoe", "foobar"], + "user_ids": None, + "emails": None, + }, + ), + ], + critics=[ + BinaryCritic(critic_field="usernames", weight=1 / 3), + BinaryCritic(critic_field="user_ids", weight=1 / 3), + BinaryCritic(critic_field="emails", weight=1 / 3), + ], + ) + + suite.add_case( + name="get multiple users by email", + user_message="get the users with the emails 'john.doe@acme.com' and 'jane.doe@acme.com'", + expected_tool_calls=[ + ExpectedToolCall( + func=get_users_info, + args={ + "usernames": None, + "user_ids": None, + "emails": ["john.doe@acme.com", "jane.doe@acme.com"], + }, + ), + ], + critics=[ + BinaryCritic(critic_field="usernames", weight=1 / 3), + BinaryCritic(critic_field="user_ids", weight=1 / 3), + BinaryCritic(critic_field="emails", weight=1 / 3), + ], + ) + return suite diff --git a/toolkits/slack/pyproject.toml b/toolkits/slack/pyproject.toml index e2231396..9a630b44 100644 --- a/toolkits/slack/pyproject.toml +++ b/toolkits/slack/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "arcade_slack" -version = "0.4.6" +version = "0.5.0" description = "Arcade.dev LLM tools for Slack" requires-python = ">=3.10" dependencies = [ "aiodns>=1.0,<2.0.0", "typing; python_version < '3.7'", "aiohttp>=3.7.3,<4.0.0", "arcade-tdk>=2.0.0,<3.0.0", "slack-sdk>=3.31.0,<4.0.0",] diff --git a/toolkits/slack/tests/test_chat.py b/toolkits/slack/tests/test_chat.py index f10d00d2..05558d3a 100644 --- a/toolkits/slack/tests/test_chat.py +++ b/toolkits/slack/tests/test_chat.py @@ -1,4 +1,4 @@ -import copy +import json from datetime import datetime, timezone from unittest.mock import Mock, call, patch @@ -10,30 +10,18 @@ from slack_sdk.web.async_slack_response import AsyncSlackResponse from arcade_slack.constants import MAX_PAGINATION_SIZE_LIMIT from arcade_slack.models import ConversationType, ConversationTypeSlackName from arcade_slack.tools.chat import ( - get_channel_metadata_by_name, - get_conversation_metadata_by_id, - get_direct_message_conversation_metadata_by_username, - get_members_in_channel_by_name, - get_members_in_conversation_by_id, - get_messages_in_channel_by_name, - get_messages_in_conversation_by_id, - get_messages_in_direct_message_conversation_by_username, - get_messages_in_multi_person_dm_conversation_by_usernames, - get_multi_person_dm_conversation_metadata_by_usernames, - list_conversations_metadata, - list_direct_message_conversations_metadata, - list_group_direct_message_conversations_metadata, - list_private_channels_metadata, - list_public_channels_metadata, - send_dm_to_user, - send_message_to_channel, + get_conversation_metadata, + get_messages, + get_users_in_conversation, + list_conversations, + send_message, ) -from arcade_slack.utils import extract_basic_user_info, extract_conversation_metadata +from arcade_slack.utils import cast_user_dict, extract_conversation_metadata @pytest.fixture -def mock_list_conversations_metadata(mocker): - return mocker.patch("arcade_slack.tools.chat.list_conversations_metadata", autospec=True) +def mock_list_conversations(mocker): + return mocker.patch("arcade_slack.tools.chat.list_conversations", autospec=True) @pytest.fixture @@ -42,72 +30,147 @@ def mock_channel_info() -> dict: @pytest.mark.asyncio -async def test_send_dm_to_user(mock_context, mock_chat_slack_client): - mock_chat_slack_client.users_list.return_value = { - "ok": True, - "members": [{"name": "testuser", "id": "U12345"}], - } +async def test_send_message_to_conversation_id( + mock_context, + mock_chat_slack_client, +): + mock_slack_response = Mock(spec=AsyncSlackResponse) + mock_slack_response.data = {"ok": True} + mock_chat_slack_client.chat_postMessage.return_value = mock_slack_response + + response = await send_message(mock_context, conversation_id="abc123", message="Hello!") + + assert response["success"] is True + assert response["data"]["ok"] is True + mock_chat_slack_client.chat_postMessage.assert_called_once_with(channel="abc123", text="Hello!") + + +@pytest.mark.asyncio +async def test_send_message_to_username( + mock_context, + mock_chat_slack_client, + mock_user_retrieval_slack_client, +): + mock_chat_slack_client.auth_test.return_value = {"ok": True, "user_id": "current_user_id"} + mock_user_retrieval_slack_client.users_list.side_effect = [ + { + "ok": True, + "members": [{"name": "foo", "id": "bar"}], + "response_metadata": {"next_cursor": "123"}, + }, + { + "ok": True, + "members": [{"name": "foobar", "id": "foobar_user_id"}], + }, + ] mock_chat_slack_client.conversations_open.return_value = { "ok": True, - "channel": {"id": "D12345"}, + "channel": { + "id": "conversation_id", + "is_im": True, + }, } mock_slack_response = Mock(spec=AsyncSlackResponse) mock_slack_response.data = {"ok": True} mock_chat_slack_client.chat_postMessage.return_value = mock_slack_response - response = await send_dm_to_user(mock_context, "testuser", "Hello!") + response = await send_message( + context=mock_context, + usernames=["foobar"], + message="Hello, world!", + ) - assert response["response"]["ok"] is True - mock_chat_slack_client.users_list.assert_called_once() - mock_chat_slack_client.conversations_open.assert_called_once_with(users=["U12345"]) - mock_chat_slack_client.chat_postMessage.assert_called_once_with(channel="D12345", text="Hello!") + assert response["success"] is True + assert response["data"]["ok"] is True + + mock_chat_slack_client.auth_test.assert_called_once() + assert mock_user_retrieval_slack_client.users_list.call_count == 2 + mock_chat_slack_client.conversations_open.assert_called_once_with( + users=[ + "current_user_id", + "foobar_user_id", + ], + return_im=True, + ) + mock_chat_slack_client.chat_postMessage.assert_called_once_with( + channel="conversation_id", + text="Hello, world!", + ) @pytest.mark.asyncio -async def test_send_dm_to_inexistent_user(mock_context, mock_chat_slack_client): - mock_chat_slack_client.users_list.return_value = { +async def test_send_dm_to_inexistent_user( + mock_context, + mock_chat_slack_client, + mock_user_retrieval_slack_client, +): + mock_chat_slack_client.auth_test.return_value = {"ok": True, "user_id": "current_user_id"} + mock_user_retrieval_slack_client.users_list.return_value = { "ok": True, - "members": [{"name": "testuser", "id": "U12345"}], + "members": [{"name": "foo", "id": "bar"}], } - with pytest.raises(RetryableToolError): - await send_dm_to_user(mock_context, "inexistent_user", "Hello!") + with pytest.raises(RetryableToolError) as error: + await send_message(mock_context, usernames=["inexistent_user"], message="Hello!") - mock_chat_slack_client.users_list.assert_called_once() + assert "inexistent_user" in error.value.message + assert "foo" in error.value.additional_prompt_content + assert "bar" in error.value.additional_prompt_content + mock_user_retrieval_slack_client.users_list.assert_called_once() mock_chat_slack_client.conversations_open.assert_not_called() mock_chat_slack_client.chat_postMessage.assert_not_called() @pytest.mark.asyncio -async def test_send_message_to_channel(mock_context, mock_chat_slack_client): - mock_chat_slack_client.conversations_list.return_value = { +async def test_send_message_to_channel_success( + mock_context, + mock_chat_slack_client, + mock_conversation_retrieval_slack_client, +): + mock_conversation_retrieval_slack_client.conversations_list.return_value = { "ok": True, - "channels": [{"id": "C12345", "name": "general", "is_member": True, "is_group": True}], + "channels": [{"id": "channel_id", "name": "general", "is_member": True, "is_group": True}], } mock_slack_response = Mock(spec=AsyncSlackResponse) mock_slack_response.data = {"ok": True} mock_chat_slack_client.chat_postMessage.return_value = mock_slack_response - response = await send_message_to_channel(mock_context, "general", "Hello, channel!") + response = await send_message(mock_context, channel_name="general", message="Hello, channel!") - assert response["response"]["ok"] is True - mock_chat_slack_client.conversations_list.assert_called_once() + assert response["success"] is True + assert response["data"]["ok"] is True + mock_conversation_retrieval_slack_client.conversations_list.assert_called_once() mock_chat_slack_client.chat_postMessage.assert_called_once_with( - channel="C12345", text="Hello, channel!" + channel="channel_id", text="Hello, channel!" ) @pytest.mark.asyncio -async def test_send_message_to_inexistent_channel(mock_context, mock_chat_slack_client): - mock_chat_slack_client.conversations_list.return_value = { +async def test_send_message_to_inexistent_channel( + mock_context, + mock_chat_slack_client, + mock_conversation_retrieval_slack_client, +): + mock_conversation_retrieval_slack_client.conversations_list.return_value = { "ok": True, - "channels": [], + "channels": [ + { + "id": "another_channel_id", + "name": "another_channel", + "is_member": True, + "is_group": True, + } + ], } - with pytest.raises(RetryableToolError): - await send_message_to_channel(mock_context, "inexistent_channel", "Hello!") + with pytest.raises(RetryableToolError) as error: + await send_message(mock_context, channel_name="inexistent_channel", message="Hello!") - mock_chat_slack_client.conversations_list.assert_called_once() + assert "inexistent_channel" in error.value.message + assert "another_channel" in error.value.additional_prompt_content + assert "another_channel_id" in error.value.additional_prompt_content + + mock_conversation_retrieval_slack_client.conversations_list.assert_called_once() mock_chat_slack_client.chat_postMessage.assert_not_called() @@ -120,19 +183,44 @@ async def test_list_conversations_metadata_with_default_args( "channels": [mock_channel_info], } - response = await list_conversations_metadata(mock_context) + response = await list_conversations(mock_context) assert response["conversations"] == [extract_conversation_metadata(mock_channel_info)] assert response["next_cursor"] is None mock_chat_slack_client.conversations_list.assert_called_once_with( - types=",".join([conv_type.value for conv_type in ConversationTypeSlackName]), + types=None, exclude_archived=True, limit=MAX_PAGINATION_SIZE_LIMIT, cursor=None, ) +@pytest.mark.asyncio +async def test_list_conversations_metadata_with_more_pages( + mock_context, mock_chat_slack_client, dummy_channel_factory, random_str_factory +): + channel1 = dummy_channel_factory(is_channel=True) + channel2 = dummy_channel_factory(is_im=True) + channel3 = dummy_channel_factory(is_mpim=True) + next_cursor = random_str_factory() + + mock_chat_slack_client.conversations_list.return_value = { + "ok": True, + "channels": [channel1, channel2, channel3], + "response_metadata": {"next_cursor": next_cursor}, + } + + response = await list_conversations(mock_context, limit=3) + + assert response["conversations"] == [ + extract_conversation_metadata(channel1), + extract_conversation_metadata(channel2), + extract_conversation_metadata(channel3), + ] + assert response["next_cursor"] == next_cursor + + @pytest.mark.asyncio async def test_list_conversations_metadata_filtering_single_conversation_type( mock_context, mock_chat_slack_client, mock_channel_info @@ -142,8 +230,8 @@ async def test_list_conversations_metadata_filtering_single_conversation_type( "channels": [mock_channel_info], } - response = await list_conversations_metadata( - mock_context, conversation_types=ConversationType.PUBLIC_CHANNEL + response = await list_conversations( + mock_context, conversation_types=[ConversationType.PUBLIC_CHANNEL] ) assert response["conversations"] == [extract_conversation_metadata(mock_channel_info)] @@ -166,11 +254,11 @@ async def test_list_conversations_metadata_filtering_multiple_conversation_types "channels": [mock_channel_info], } - response = await list_conversations_metadata( + response = await list_conversations( mock_context, conversation_types=[ - ConversationTypeSlackName.PUBLIC_CHANNEL, - ConversationTypeSlackName.PRIVATE_CHANNEL, + ConversationType.PUBLIC_CHANNEL, + ConversationType.PRIVATE_CHANNEL, ], ) @@ -195,7 +283,7 @@ async def test_list_conversations_metadata_with_custom_pagination_args( "response_metadata": {"next_cursor": "456"}, } - response = await list_conversations_metadata(mock_context, limit=3, next_cursor="123") + response = await list_conversations(mock_context, limit=3, next_cursor="123") assert response["conversations"] == [ extract_conversation_metadata(mock_channel_info) for _ in range(3) @@ -203,7 +291,7 @@ async def test_list_conversations_metadata_with_custom_pagination_args( assert response["next_cursor"] == "456" mock_chat_slack_client.conversations_list.assert_called_once_with( - types=",".join([conv_type.value for conv_type in ConversationTypeSlackName]), + types=None, exclude_archived=True, limit=3, cursor="123", @@ -214,64 +302,37 @@ async def test_list_conversations_metadata_with_custom_pagination_args( @pytest.mark.parametrize( "faulty_slack_function_name, tool_function, tool_args", [ - ("users_list", send_dm_to_user, ("testuser", "Hello!")), - ("conversations_list", send_message_to_channel, ("general", "Hello!")), + ("users_list", send_message, {"usernames": ["testuser"], "message": "Hello!"}), + ("conversations_list", send_message, {"channel_name": "general", "message": "Hello!"}), ], ) async def test_tools_with_slack_error( mock_context, mock_chat_slack_client, faulty_slack_function_name, tool_function, tool_args ): + mock_chat_slack_client.auth_test.return_value = {"ok": True, "user_id": "current_user_id"} getattr(mock_chat_slack_client, faulty_slack_function_name).side_effect = SlackApiError( message="test_slack_error", response={"ok": False, "error": "test_slack_error"}, ) with pytest.raises(ToolExecutionError) as e: - await tool_function(mock_context, *tool_args) + await tool_function(mock_context, **tool_args) assert "test_slack_error" in str(e.value) -@pytest.mark.asyncio -@pytest.mark.parametrize( - "tool_function, conversation_type", - [ - (list_public_channels_metadata, ConversationType.PUBLIC_CHANNEL), - (list_private_channels_metadata, ConversationType.PRIVATE_CHANNEL), - ( - list_group_direct_message_conversations_metadata, - ConversationType.MULTI_PERSON_DIRECT_MESSAGE, - ), - (list_direct_message_conversations_metadata, ConversationType.DIRECT_MESSAGE), - ], -) -async def test_list_channels_metadata( - mock_context, - mock_list_conversations_metadata, - tool_function, - conversation_type, -): - response = await tool_function(mock_context, limit=3) - - mock_list_conversations_metadata.assert_called_once_with( - mock_context, conversation_types=[conversation_type], limit=3 - ) - - assert response == mock_list_conversations_metadata.return_value - - @pytest.mark.asyncio async def test_get_conversation_metadata_by_id( - mock_context, mock_chat_slack_client, mock_channel_info + mock_context, mock_conversation_retrieval_slack_client, mock_channel_info ): - mock_chat_slack_client.conversations_info.return_value = { + mock_conversation_retrieval_slack_client.conversations_info.return_value = { "ok": True, "channel": mock_channel_info, } - response = await get_conversation_metadata_by_id(mock_context, "C12345") + response = await get_conversation_metadata(mock_context, conversation_id="C12345") assert response == extract_conversation_metadata(mock_channel_info) - mock_chat_slack_client.conversations_info.assert_called_once_with( + mock_conversation_retrieval_slack_client.conversations_info.assert_called_once_with( channel="C12345", include_locale=True, include_num_members=True, @@ -279,304 +340,618 @@ async def test_get_conversation_metadata_by_id( @pytest.mark.asyncio -@patch("arcade_slack.tools.chat.list_conversations_metadata") async def test_get_conversation_metadata_by_id_slack_api_error( - mock_list_conversations_metadata, mock_context, mock_chat_slack_client, mock_channel_info + mock_context, + mock_conversation_retrieval_slack_client, + mock_channel_info, ): - mock_channel_info["name"] = "whatever_conversation_should_be_present_in_additional_prompt" - mock_list_conversations_metadata.return_value = { - "conversations": [extract_conversation_metadata(mock_channel_info)], - "response_metadata": {"next_cursor": None}, - } - mock_chat_slack_client.conversations_info.side_effect = SlackApiError( + mock_conversation_retrieval_slack_client.conversations_info.side_effect = SlackApiError( message="channel_not_found", response={"ok": False, "error": "channel_not_found"}, ) - with pytest.raises(RetryableToolError) as e: - await get_conversation_metadata_by_id(mock_context, "C12345") + with pytest.raises(ToolExecutionError) as e: + await get_conversation_metadata(mock_context, conversation_id="C12345") - assert ( - "whatever_conversation_should_be_present_in_additional_prompt" - in e.additional_prompt_content - ) - - mock_chat_slack_client.conversations_info.assert_called_once_with( - channel="C12345", - include_locale=True, - include_num_members=True, - ) - mock_list_conversations_metadata.assert_called_once_with(mock_context) + assert "C12345" in e.value.message + assert "not found" in e.value.message @pytest.mark.asyncio -async def test_get_conversation_metadata_by_name( - mock_context, mock_list_conversations_metadata, mock_channel_info +async def test_get_conversation_metadata_by_channel_name( + mock_context, + mock_conversation_retrieval_slack_client, + dummy_channel_factory, + random_str_factory, ): - sample_conversation = extract_conversation_metadata(mock_channel_info) - mock_list_conversations_metadata.return_value = { - "conversations": [sample_conversation], - "next_cursor": None, + channel_name = random_str_factory() + channel1 = dummy_channel_factory(is_channel=True, name=f"{channel_name}_another_channel") + channel2 = dummy_channel_factory(is_channel=True, name=channel_name) + + mock_conversation_retrieval_slack_client.conversations_list.return_value = { + "ok": True, + "channels": [channel1, channel2], } - response = await get_channel_metadata_by_name(mock_context, sample_conversation["name"]) + response = await get_conversation_metadata(mock_context, channel_name=channel_name) - assert response == sample_conversation - mock_list_conversations_metadata.assert_called_once_with( - context=mock_context, - conversation_types=[ - ConversationType.PUBLIC_CHANNEL, - ConversationType.PRIVATE_CHANNEL, - ], - next_cursor=None, + assert response == extract_conversation_metadata(channel2) + mock_conversation_retrieval_slack_client.conversations_list.assert_called_once_with( + types=f"{ConversationTypeSlackName.PUBLIC_CHANNEL.value},{ConversationTypeSlackName.PRIVATE_CHANNEL.value}", + exclude_archived=True, + limit=MAX_PAGINATION_SIZE_LIMIT, + cursor=None, ) @pytest.mark.asyncio -async def test_get_channel_metadata_by_name_triggering_pagination( - mock_context, mock_list_conversations_metadata, mock_channel_info +async def test_get_conversation_metadata_by_channel_name_triggering_pagination( + mock_context, + mock_conversation_retrieval_slack_client, + dummy_channel_factory, + random_str_factory, ): - target_channel = extract_conversation_metadata(mock_channel_info) - another_channel = extract_conversation_metadata(mock_channel_info) - another_channel["name"] = "another_channel" + target_channel_name = random_str_factory() + target_channel = dummy_channel_factory(is_channel=True, name=target_channel_name) + another_channel = dummy_channel_factory( + is_channel=True, name=f"{target_channel_name}_another_channel" + ) - mock_list_conversations_metadata.side_effect = [ + mock_conversation_retrieval_slack_client.conversations_list.side_effect = [ { - "conversations": [another_channel], - "next_cursor": "123", + "ok": True, + "channels": [another_channel], + "response_metadata": {"next_cursor": "123"}, }, { - "conversations": [target_channel], - "next_cursor": None, + "ok": True, + "channels": [target_channel], + "response_metadata": {"next_cursor": None}, }, ] - response = await get_channel_metadata_by_name(mock_context, target_channel["name"]) + response = await get_conversation_metadata(mock_context, channel_name=target_channel_name) - assert response == target_channel - assert mock_list_conversations_metadata.call_count == 2 - mock_list_conversations_metadata.assert_has_calls([ + assert response == extract_conversation_metadata(target_channel) + assert mock_conversation_retrieval_slack_client.conversations_list.call_count == 2 + mock_conversation_retrieval_slack_client.conversations_list.assert_has_calls([ call( - context=mock_context, - conversation_types=[ConversationType.PUBLIC_CHANNEL, ConversationType.PRIVATE_CHANNEL], - next_cursor=None, + types=f"{ConversationTypeSlackName.PUBLIC_CHANNEL.value},{ConversationTypeSlackName.PRIVATE_CHANNEL.value}", + exclude_archived=True, + limit=MAX_PAGINATION_SIZE_LIMIT, + cursor=None, ), call( - context=mock_context, - conversation_types=[ConversationType.PUBLIC_CHANNEL, ConversationType.PRIVATE_CHANNEL], - next_cursor="123", + types=f"{ConversationTypeSlackName.PUBLIC_CHANNEL.value},{ConversationTypeSlackName.PRIVATE_CHANNEL.value}", + exclude_archived=True, + limit=MAX_PAGINATION_SIZE_LIMIT, + cursor="123", ), ]) @pytest.mark.asyncio -async def test_get_channel_metadata_by_name_not_found( - mock_context, mock_list_conversations_metadata, mock_channel_info +async def test_get_conversation_metadata_by_channel_name_not_found( + mock_context, + mock_conversation_retrieval_slack_client, + dummy_channel_factory, + random_str_factory, ): - first_channel = extract_conversation_metadata(mock_channel_info) - second_channel = extract_conversation_metadata(mock_channel_info) - second_channel["name"] = "second_channel" + not_found_name = random_str_factory() + channel1 = dummy_channel_factory(is_channel=True, name=f"{not_found_name}_first") + channel2 = dummy_channel_factory(is_channel=True, name=f"{not_found_name}_second") - mock_list_conversations_metadata.side_effect = [ + mock_conversation_retrieval_slack_client.conversations_list.side_effect = [ { - "conversations": [second_channel], - "next_cursor": "123", + "ok": True, + "channels": [channel1], + "response_metadata": {"next_cursor": "123"}, }, { - "conversations": [first_channel], - "next_cursor": None, + "ok": True, + "channels": [channel2], + "response_metadata": {"next_cursor": None}, }, ] - with pytest.raises(RetryableToolError): - await get_channel_metadata_by_name(mock_context, "inexistent_channel") + with pytest.raises(RetryableToolError) as error: + await get_conversation_metadata(mock_context, channel_name=not_found_name) - assert mock_list_conversations_metadata.call_count == 2 - mock_list_conversations_metadata.assert_has_calls([ + assert "not found" in error.value.message + assert not_found_name in error.value.message + assert ( + json.dumps([ + {"id": channel1["id"], "name": channel1["name"]}, + {"id": channel2["id"], "name": channel2["name"]}, + ]) + in error.value.additional_prompt_content + ) + + assert mock_conversation_retrieval_slack_client.conversations_list.call_count == 2 + mock_conversation_retrieval_slack_client.conversations_list.assert_has_calls([ call( - context=mock_context, - conversation_types=[ConversationType.PUBLIC_CHANNEL, ConversationType.PRIVATE_CHANNEL], - next_cursor=None, + types=f"{ConversationTypeSlackName.PUBLIC_CHANNEL.value},{ConversationTypeSlackName.PRIVATE_CHANNEL.value}", + exclude_archived=True, + limit=MAX_PAGINATION_SIZE_LIMIT, + cursor=None, ), call( - context=mock_context, - conversation_types=[ConversationType.PUBLIC_CHANNEL, ConversationType.PRIVATE_CHANNEL], - next_cursor="123", + types=f"{ConversationTypeSlackName.PUBLIC_CHANNEL.value},{ConversationTypeSlackName.PRIVATE_CHANNEL.value}", + exclude_archived=True, + limit=MAX_PAGINATION_SIZE_LIMIT, + cursor="123", ), ]) @pytest.mark.asyncio -@patch("arcade_slack.tools.chat.async_paginate") -@patch("arcade_slack.tools.chat.get_user_info_by_id") -async def test_get_members_from_conversation_id( - mock_get_user_info_by_id, mock_async_paginate, mock_context, mock_chat_slack_client +async def test_get_conversation_metadata_by_username( + mock_context, + mock_chat_slack_client, + mock_user_retrieval_slack_client, + dummy_user_factory, + dummy_channel_factory, ): - member1 = {"id": "U123", "name": "testuser123"} - member1_info = extract_basic_user_info(member1) - member2 = {"id": "U456", "name": "testuser456"} - member2_info = extract_basic_user_info(member2) + current_user = dummy_user_factory(id_="U1", name="current_user") + other_user = dummy_user_factory(id_="U2", name="other_user") + conversation = dummy_channel_factory(is_im=True) - mock_async_paginate.return_value = [member1["id"], member2["id"]], "token123" - mock_get_user_info_by_id.side_effect = [member1_info, member2_info] + mock_chat_slack_client.auth_test.return_value = { + "ok": True, + "user_id": current_user["id"], + } - response = await get_members_in_conversation_by_id( - mock_context, conversation_id="C12345", limit=2 + mock_user_retrieval_slack_client.users_list.return_value = { + "ok": True, + "members": [current_user, other_user], + "response_metadata": {"next_cursor": None}, + } + + mock_chat_slack_client.conversations_open.return_value = { + "ok": True, + "channel": conversation, + } + + response = await get_conversation_metadata(mock_context, usernames=[other_user["name"]]) + + assert response == extract_conversation_metadata(conversation) + + +@pytest.mark.asyncio +async def test_get_dm_conversation_metadata_by_username_not_found( + mock_context, + mock_chat_slack_client, + mock_user_retrieval_slack_client, + dummy_user_factory, + dummy_channel_factory, + random_str_factory, +): + current_user = dummy_user_factory(id_="U1", name="current_user") + other_user = dummy_user_factory(id_="U2", name="other_user") + conversation = dummy_channel_factory(is_im=True) + not_found_user_name = random_str_factory() + + mock_chat_slack_client.auth_test.return_value = { + "ok": True, + "user_id": current_user["id"], + } + + mock_user_retrieval_slack_client.users_list.return_value = { + "ok": True, + "members": [current_user, other_user], + "response_metadata": {"next_cursor": None}, + } + + mock_chat_slack_client.conversations_open.return_value = { + "ok": True, + "channel": conversation, + } + + with pytest.raises(RetryableToolError) as error: + await get_conversation_metadata(mock_context, usernames=[not_found_user_name]) + + assert "not found" in error.value.message + assert not_found_user_name in error.value.message + assert other_user["id"] in error.value.additional_prompt_content + assert other_user["name"] in error.value.additional_prompt_content + + mock_chat_slack_client.conversations_open.assert_not_called() + + +@pytest.mark.asyncio +async def test_get_mpim_conversation_metadata_by_usernames( + mock_context, + mock_chat_slack_client, + mock_user_retrieval_slack_client, + dummy_user_factory, + dummy_channel_factory, +): + current_user = dummy_user_factory(id_="U1", name="current_user") + other_user1 = dummy_user_factory(id_="U2", name="other_user1") + other_user2 = dummy_user_factory(id_="U3", name="other_user2") + other_user3 = dummy_user_factory(id_="U4", name="other_user3") + other_user4 = dummy_user_factory(id_="U5", name="other_user4") + + conversation = dummy_channel_factory(is_mpim=True) + + mock_chat_slack_client.auth_test.return_value = { + "ok": True, + "user_id": current_user["id"], + } + + mock_user_retrieval_slack_client.users_list.side_effect = [ + { + "ok": True, + "members": [current_user, other_user1, other_user3], + "response_metadata": {"next_cursor": "users_list_cursor1"}, + }, + { + "ok": True, + "members": [current_user, other_user2, other_user4], + "response_metadata": {"next_cursor": None}, + }, + ] + + mock_chat_slack_client.conversations_open.return_value = { + "ok": True, + "channel": conversation, + } + + response = await get_conversation_metadata( + mock_context, + usernames=[other_user1["name"], other_user2["name"]], + ) + + assert response == extract_conversation_metadata(conversation) + + mock_chat_slack_client.conversations_open.assert_called_once_with( + users=[current_user["id"], other_user1["id"], other_user2["id"]], + return_im=True, + ) + + +@pytest.mark.asyncio +async def test_get_mpim_conversation_metadata_by_user_ids_and_usernames( + mock_context, + mock_chat_slack_client, + mock_user_retrieval_slack_client, + dummy_user_factory, + dummy_channel_factory, +): + current_user = dummy_user_factory(id_="U1", name="current_user") + other_user1 = dummy_user_factory(id_="U2", name="other_user1") + other_user2 = dummy_user_factory(id_="U3", name="other_user2") + other_user3 = dummy_user_factory(id_="U4", name="other_user3") + other_user4 = dummy_user_factory(id_="U5", name="other_user4") + + conversation = dummy_channel_factory(is_mpim=True) + + mock_chat_slack_client.auth_test.return_value = { + "ok": True, + "user_id": current_user["id"], + } + + mock_user_retrieval_slack_client.users_list.side_effect = [ + { + "ok": True, + "members": [current_user, other_user1, other_user3], + "response_metadata": {"next_cursor": "users_list_cursor1"}, + }, + { + "ok": True, + "members": [current_user, other_user2, other_user4], + "response_metadata": {"next_cursor": None}, + }, + ] + + mock_chat_slack_client.conversations_open.return_value = { + "ok": True, + "channel": conversation, + } + + response = await get_conversation_metadata( + mock_context, + user_ids=[other_user3["id"]], + usernames=[other_user1["name"], other_user2["name"]], + ) + + assert response == extract_conversation_metadata(conversation) + + mock_chat_slack_client.conversations_open.assert_called_once_with( + users=[other_user3["id"], current_user["id"], other_user1["id"], other_user2["id"]], + return_im=True, + ) + + +@pytest.mark.asyncio +async def test_get_mpim_conversation_metadata_by_user_ids_usernames_and_emails( + mock_context, + mock_chat_slack_client, + mock_user_retrieval_slack_client, + dummy_user_factory, + dummy_channel_factory, +): + current_user = dummy_user_factory(id_="U1", name="current_user") + other_user1 = dummy_user_factory(id_="U2", name="other_user1") + other_user2 = dummy_user_factory(id_="U3", name="other_user2") + other_user3 = dummy_user_factory(id_="U4", name="other_user3") + other_user4 = dummy_user_factory(id_="U5", name="other_user4") + other_user5 = dummy_user_factory(id_="U6", name="other_user5") + + conversation = dummy_channel_factory(is_mpim=True) + + mock_chat_slack_client.auth_test.return_value = { + "ok": True, + "user_id": current_user["id"], + } + + mock_user_retrieval_slack_client.users_list.side_effect = [ + { + "ok": True, + "members": [current_user, other_user1, other_user3], + "response_metadata": {"next_cursor": "users_list_cursor1"}, + }, + { + "ok": True, + "members": [current_user, other_user2, other_user4], + "response_metadata": {"next_cursor": None}, + }, + ] + + mock_user_retrieval_slack_client.users_lookupByEmail.side_effect = [ + { + "ok": True, + "user": other_user5, + }, + ] + + mock_chat_slack_client.conversations_open.return_value = { + "ok": True, + "channel": conversation, + } + + response = await get_conversation_metadata( + mock_context, + user_ids=[other_user3["id"]], + usernames=[other_user1["name"], other_user2["name"]], + emails=[other_user5["profile"]["email"]], + ) + + assert response == extract_conversation_metadata(conversation) + + mock_chat_slack_client.conversations_open.assert_called_once_with( + users=[ + other_user3["id"], + current_user["id"], + other_user1["id"], + other_user2["id"], + other_user5["id"], + ], + return_im=True, + ) + + +@pytest.mark.asyncio +async def test_get_users_in_conversation_by_id_with_conversation_and_user_paginations( + mock_context, + mock_chat_slack_client, + mock_user_retrieval_slack_client, + dummy_user_factory, + random_str_factory, +): + conversation_id = random_str_factory() + user1 = dummy_user_factory(id_="1") + user2 = dummy_user_factory(id_="2") + user3 = dummy_user_factory(id_="3") + + mock_chat_slack_client.conversations_members.side_effect = [ + { + "ok": True, + "members": [user1["id"], user2["id"]], + "response_metadata": {"next_cursor": "conversations_members_cursor1"}, + }, + { + "ok": True, + "members": [user3["id"]], + "response_metadata": {"next_cursor": "conversations_members_cursor2"}, + }, + ] + + mock_user_retrieval_slack_client.users_list.side_effect = [ + { + "ok": True, + "members": [user1, user2], + "response_metadata": {"next_cursor": "users_list_cursor"}, + }, + { + "ok": True, + "members": [user3], + "response_metadata": {"next_cursor": None}, + }, + ] + + response = await get_users_in_conversation( + context=mock_context, + conversation_id=conversation_id, + limit=3, ) assert response == { - "members": [member1_info, member2_info], - "next_cursor": "token123", + "users": [ + cast_user_dict(user1), + cast_user_dict(user2), + cast_user_dict(user3), + ], + "next_cursor": "conversations_members_cursor2", } - mock_async_paginate.assert_called_once_with( - mock_chat_slack_client.conversations_members, - "members", - limit=2, - next_cursor=None, - channel="C12345", - ) - mock_get_user_info_by_id.assert_has_calls([ - call(mock_context, member1["id"]), - call(mock_context, member2["id"]), + + mock_chat_slack_client.conversations_members.assert_has_calls([ + call( + channel=conversation_id, + limit=3, + cursor=None, + ), + call( + channel=conversation_id, + limit=1, + cursor="conversations_members_cursor1", + ), + ]) + + mock_user_retrieval_slack_client.users_list.assert_has_calls([ + call( + limit=MAX_PAGINATION_SIZE_LIMIT, + cursor=None, + ), + call( + limit=MAX_PAGINATION_SIZE_LIMIT, + cursor="users_list_cursor", + ), ]) @pytest.mark.asyncio -@patch("arcade_slack.tools.chat.async_paginate") -@patch("arcade_slack.tools.chat.get_user_info_by_id") -@patch("arcade_slack.tools.chat.list_conversations_metadata") -async def test_get_members_from_conversation_id_channel_not_found( - mock_list_conversations_metadata, - mock_get_user_info_by_id, - mock_async_paginate, +async def test_get_users_in_conversation_by_channel_name( mock_context, mock_chat_slack_client, - mock_channel_info, + mock_conversation_retrieval_slack_client, + mock_user_retrieval_slack_client, + dummy_channel_factory, + dummy_user_factory, + random_str_factory, ): - conversations = [extract_conversation_metadata(mock_channel_info)] * 2 - mock_list_conversations_metadata.return_value = { - "conversations": conversations, - "next_cursor": None, - } + channel_name = random_str_factory() + channel1 = dummy_channel_factory(is_channel=True, name=f"{channel_name}_another_channel") + channel2 = dummy_channel_factory(is_channel=True, name=channel_name) - member1 = {"id": "U123", "name": "testuser123"} - member1_info = extract_basic_user_info(member1) - member2 = {"id": "U456", "name": "testuser456"} - member2_info = extract_basic_user_info(member2) - - mock_async_paginate.side_effect = SlackApiError( - message="channel_not_found", - response={"ok": False, "error": "channel_not_found"}, - ) - mock_get_user_info_by_id.side_effect = [member1_info, member2_info] - - with pytest.raises(RetryableToolError): - await get_members_in_conversation_by_id(mock_context, conversation_id="C12345", limit=2) - - mock_async_paginate.assert_called_once_with( - mock_chat_slack_client.conversations_members, - "members", - limit=2, - next_cursor=None, - channel="C12345", - ) - mock_get_user_info_by_id.assert_not_called() - - -@pytest.mark.asyncio -@patch("arcade_slack.tools.chat.list_conversations_metadata") -@patch("arcade_slack.tools.chat.get_members_in_conversation_by_id") -async def test_get_members_in_channel_by_name( - mock_get_members_in_conversation_by_id, - mock_list_conversations_metadata, - mock_context, - mock_channel_info, -): - mock_list_conversations_metadata.return_value = { - "conversations": [extract_conversation_metadata(mock_channel_info)], - "next_cursor": None, - } - - response = await get_members_in_channel_by_name( - mock_context, mock_channel_info["name"], limit=2 - ) - - assert response == mock_get_members_in_conversation_by_id.return_value - mock_list_conversations_metadata.assert_called_once_with( - context=mock_context, - conversation_types=[ - ConversationType.PUBLIC_CHANNEL, - ConversationType.PRIVATE_CHANNEL, - ], - next_cursor=None, - ) - mock_get_members_in_conversation_by_id.assert_called_once_with( - context=mock_context, - conversation_id="C12345", - limit=2, - next_cursor=None, - ) - - -@pytest.mark.asyncio -@patch("arcade_slack.tools.chat.list_conversations_metadata") -@patch("arcade_slack.tools.chat.get_members_in_conversation_by_id") -async def test_get_members_in_channel_by_name_triggering_pagination( - mock_get_members_in_conversation_by_id, - mock_list_conversations_metadata, - mock_context, - mock_channel_info, -): - conversation1 = copy.deepcopy(mock_channel_info) - conversation1["name"] = "conversation1" - conversation2 = copy.deepcopy(mock_channel_info) - conversation2["name"] = "conversation2" - - mock_list_conversations_metadata.side_effect = [ + mock_conversation_retrieval_slack_client.conversations_list.side_effect = [ { - "conversations": [extract_conversation_metadata(conversation1)], - "next_cursor": "123", + "ok": True, + "channels": [channel1], + "response_metadata": {"next_cursor": "123"}, }, { - "conversations": [extract_conversation_metadata(conversation2)], - "next_cursor": None, + "ok": True, + "channels": [channel2], + "response_metadata": {"next_cursor": None}, }, ] - response = await get_members_in_channel_by_name(mock_context, conversation2["name"], limit=2) + user1 = dummy_user_factory(id_="1") + user2 = dummy_user_factory(id_="2") - assert response == mock_get_members_in_conversation_by_id.return_value - mock_list_conversations_metadata.assert_has_calls([ - call( - context=mock_context, - conversation_types=[ConversationType.PUBLIC_CHANNEL, ConversationType.PRIVATE_CHANNEL], - next_cursor=None, - ), - call( - context=mock_context, - conversation_types=[ConversationType.PUBLIC_CHANNEL, ConversationType.PRIVATE_CHANNEL], - next_cursor="123", - ), - ]) - mock_get_members_in_conversation_by_id.assert_called_once_with( - context=mock_context, - conversation_id="C12345", - limit=2, - next_cursor=None, - ) + mock_chat_slack_client.conversations_members.side_effect = [ + { + "ok": True, + "members": [user1["id"], user2["id"]], + "response_metadata": {"next_cursor": None}, + }, + ] + + mock_user_retrieval_slack_client.users_list.side_effect = [ + { + "ok": True, + "members": [user1, user2], + "response_metadata": {"next_cursor": None}, + }, + ] + + response = await get_users_in_conversation(mock_context, channel_name=channel_name) + + assert response == { + "users": [ + cast_user_dict(user1), + cast_user_dict(user2), + ], + "next_cursor": None, + } @pytest.mark.asyncio -async def test_get_conversation_history_by_id(mock_context, mock_chat_slack_client): - mock_chat_slack_client.conversations_history.return_value = { +async def test_get_users_in_conversation_by_channel_name_not_found( + mock_context, + mock_conversation_retrieval_slack_client, + dummy_channel_factory, + random_str_factory, +): + not_found_channel_name = random_str_factory() + channel1 = dummy_channel_factory(is_channel=True, name=f"{not_found_channel_name}_first") + channel2 = dummy_channel_factory(is_channel=True, name=f"{not_found_channel_name}_second") + + mock_conversation_retrieval_slack_client.conversations_list.side_effect = [ + { + "ok": True, + "channels": [channel1], + "response_metadata": {"next_cursor": "123"}, + }, + { + "ok": True, + "channels": [channel2], + "response_metadata": {"next_cursor": None}, + }, + ] + + with pytest.raises(RetryableToolError) as error: + await get_users_in_conversation(mock_context, channel_name=not_found_channel_name) + + assert "not found" in error.value.message + assert not_found_channel_name in error.value.message + assert ( + json.dumps([ + {"id": channel1["id"], "name": channel1["name"]}, + {"id": channel2["id"], "name": channel2["name"]}, + ]) + in error.value.additional_prompt_content + ) + + assert mock_conversation_retrieval_slack_client.conversations_list.call_count == 2 + mock_conversation_retrieval_slack_client.conversations_list.assert_has_calls([ + call( + types=f"{ConversationTypeSlackName.PUBLIC_CHANNEL.value},{ConversationTypeSlackName.PRIVATE_CHANNEL.value}", + exclude_archived=True, + limit=MAX_PAGINATION_SIZE_LIMIT, + cursor=None, + ), + call( + types=f"{ConversationTypeSlackName.PUBLIC_CHANNEL.value},{ConversationTypeSlackName.PRIVATE_CHANNEL.value}", + exclude_archived=True, + limit=MAX_PAGINATION_SIZE_LIMIT, + cursor="123", + ), + ]) + + +@pytest.mark.asyncio +async def test_get_messages_by_conversation_id( + mock_context, + mock_message_retrieval_slack_client, + mock_user_retrieval_slack_client, + dummy_user_factory, + dummy_message_factory, +): + user = dummy_user_factory() + message = dummy_message_factory(user_id=user["id"]) + + mock_message_retrieval_slack_client.conversations_history.return_value = { "ok": True, - "messages": [{"text": "Hello, world!"}], + "messages": [message], + "response_metadata": {"next_cursor": "cursor"}, } - response = await get_messages_in_conversation_by_id(mock_context, "C12345", limit=1) + mock_user_retrieval_slack_client.users_info.return_value = { + "ok": True, + "user": user, + } - assert response == {"messages": [{"text": "Hello, world!"}], "next_cursor": None} - mock_chat_slack_client.conversations_history.assert_called_once_with( + response = await get_messages(mock_context, "C12345", limit=1) + + assert response["next_cursor"] == "cursor" + assert len(response["messages"]) == 1 + returned_message = response["messages"][0] + assert returned_message["user"] == {"id": user["id"], "name": user["name"]} + assert returned_message["text"] == message["text"] + + mock_message_retrieval_slack_client.conversations_history.assert_called_once_with( channel="C12345", include_all_metadata=True, inclusive=True, @@ -585,28 +960,34 @@ async def test_get_conversation_history_by_id(mock_context, mock_chat_slack_clie ) -# TODO: pass a current unix timestamp to the tool, instead of mocking the datetime -# conversion. Have to wait until arcade.core.annotations.Inferrable is implemented. @pytest.mark.asyncio -@patch("arcade_slack.tools.chat.convert_relative_datetime_to_unix_timestamp") -@patch("arcade_slack.tools.chat.datetime") -async def test_get_conversation_history_by_id_with_relative_datetime_args( +@patch("arcade_slack.message_retrieval.convert_relative_datetime_to_unix_timestamp") +@patch("arcade_slack.message_retrieval.datetime") +async def test_get_messages_by_conversation_id_with_relative_datetime_args( mock_datetime, mock_convert_relative_datetime_to_unix_timestamp, mock_context, - mock_chat_slack_client, + mock_message_retrieval_slack_client, + mock_user_retrieval_slack_client, + dummy_user_factory, + dummy_message_factory, ): - mock_chat_slack_client.conversations_history.return_value = { + user = dummy_user_factory() + message = dummy_message_factory(user_id=user["id"]) + + mock_message_retrieval_slack_client.conversations_history.return_value = { "ok": True, - "messages": [{"text": "Hello, world!"}], + "messages": [message], + } + + mock_user_retrieval_slack_client.users_info.return_value = { + "ok": True, + "user": user, } expected_oldest_timestamp = 1716489600 expected_latest_timestamp = 1716403200 - # Ideally we'd pass the current unix timestamp to the function, instead of mocking, but - # currently there's no way to have a tool argument that is not exposed to the LLM. We - # should have that soon, though. mock_datetime.now.return_value = datetime(2025, 1, 1, 0, 0, 0, tzinfo=timezone.utc) expected_current_unix_timestamp = int(mock_datetime.now.return_value.timestamp()) mock_convert_relative_datetime_to_unix_timestamp.side_effect = [ @@ -614,16 +995,25 @@ async def test_get_conversation_history_by_id_with_relative_datetime_args( expected_oldest_timestamp, ] - response = await get_messages_in_conversation_by_id( - mock_context, "C12345", oldest_relative="02:00:00", latest_relative="01:00:00", limit=1 + response = await get_messages( + context=mock_context, + conversation_id="C12345", + oldest_relative="02:00:00", + latest_relative="01:00:00", + limit=1, ) - assert response == {"messages": [{"text": "Hello, world!"}], "next_cursor": None} + assert response["next_cursor"] is None + assert len(response["messages"]) == 1 + returned_message = response["messages"][0] + assert returned_message["user"] == {"id": user["id"], "name": user["name"]} + assert returned_message["text"] == message["text"] + mock_convert_relative_datetime_to_unix_timestamp.assert_has_calls([ call("01:00:00", expected_current_unix_timestamp), call("02:00:00", expected_current_unix_timestamp), ]) - mock_chat_slack_client.conversations_history.assert_called_once_with( + mock_message_retrieval_slack_client.conversations_history.assert_called_once_with( channel="C12345", include_all_metadata=True, inclusive=True, @@ -634,43 +1024,56 @@ async def test_get_conversation_history_by_id_with_relative_datetime_args( ) -# TODO: pass a current unix timestamp to the tool, instead of mocking the datetime -# conversion. Have to wait until arcade.core.annotations.Inferrable is implemented. @pytest.mark.asyncio -@patch("arcade_slack.tools.chat.convert_datetime_to_unix_timestamp") -async def test_get_conversation_history_by_id_with_absolute_datetime_args( - mock_convert_datetime_to_unix_timestamp, mock_context, mock_chat_slack_client +@patch("arcade_slack.message_retrieval.convert_datetime_to_unix_timestamp") +async def test_get_messages_by_conversation_id_with_absolute_datetime_args( + mock_convert_datetime_to_unix_timestamp, + mock_context, + mock_message_retrieval_slack_client, + mock_user_retrieval_slack_client, + dummy_user_factory, + dummy_message_factory, ): - mock_chat_slack_client.conversations_history.return_value = { + user = dummy_user_factory() + message = dummy_message_factory(user_id=user["id"]) + + mock_message_retrieval_slack_client.conversations_history.return_value = { "ok": True, - "messages": [{"text": "Hello, world!"}], + "messages": [message], + } + + mock_user_retrieval_slack_client.users_info.return_value = { + "ok": True, + "user": user, } expected_latest_timestamp = 1716403200 expected_oldest_timestamp = 1716489600 - # Ideally we'd pass the current unix timestamp to the function, instead of mocking, but - # currently there's no way to have a tool argument that is not exposed to the LLM. We - # should have that soon, though. mock_convert_datetime_to_unix_timestamp.side_effect = [ expected_latest_timestamp, expected_oldest_timestamp, ] - response = await get_messages_in_conversation_by_id( - mock_context, - "C12345", + response = await get_messages( + context=mock_context, + conversation_id="C12345", oldest_datetime="2025-01-01 00:00:00", latest_datetime="2025-01-02 00:00:00", limit=1, ) - assert response == {"messages": [{"text": "Hello, world!"}], "next_cursor": None} + assert response["next_cursor"] is None + assert len(response["messages"]) == 1 + returned_message = response["messages"][0] + assert returned_message["user"] == {"id": user["id"], "name": user["name"]} + assert returned_message["text"] == message["text"] + mock_convert_datetime_to_unix_timestamp.assert_has_calls([ call("2025-01-02 00:00:00"), call("2025-01-01 00:00:00"), ]) - mock_chat_slack_client.conversations_history.assert_called_once_with( + mock_message_retrieval_slack_client.conversations_history.assert_called_once_with( channel="C12345", include_all_metadata=True, inclusive=True, @@ -682,34 +1085,45 @@ async def test_get_conversation_history_by_id_with_absolute_datetime_args( @pytest.mark.asyncio -async def test_get_conversation_history_by_id_with_messed_oldest_args( - mock_context, mock_chat_slack_client +async def test_get_messages_by_conversation_id_with_messed_oldest_args( + mock_context, mock_message_retrieval_slack_client ): with pytest.raises(ToolExecutionError): - await get_messages_in_conversation_by_id( - mock_context, - "C12345", + await get_messages( + context=mock_context, + conversation_id="C12345", oldest_datetime="2025-01-01 00:00:00", oldest_relative="01:00:00", ) + mock_message_retrieval_slack_client.conversations_history.assert_not_called() + @pytest.mark.asyncio -async def test_get_conversation_history_by_id_with_messed_latest_args( - mock_context, mock_chat_slack_client +async def test_get_messages_by_conversation_id_with_messed_latest_args( + mock_context, mock_message_retrieval_slack_client ): with pytest.raises(ToolExecutionError): - await get_messages_in_conversation_by_id( - mock_context, - "C12345", + await get_messages( + context=mock_context, + conversation_id="C12345", latest_datetime="2025-01-01 00:00:00", latest_relative="01:00:00", ) + mock_message_retrieval_slack_client.conversations_history.assert_not_called() + @pytest.mark.asyncio -async def test_get_conversation_history_by_name(mock_context, mock_chat_slack_client): - mock_chat_slack_client.conversations_list.return_value = { +async def test_get_messages_by_channel_name( + mock_context, + mock_message_retrieval_slack_client, + mock_conversation_retrieval_slack_client, + mock_user_retrieval_slack_client, + dummy_message_factory, + dummy_user_factory, +): + mock_conversation_retrieval_slack_client.conversations_list.return_value = { "ok": True, "channels": [ { @@ -720,273 +1134,35 @@ async def test_get_conversation_history_by_name(mock_context, mock_chat_slack_cl } ], } - mock_chat_slack_client.conversations_history.return_value = { + + user = dummy_user_factory() + message = dummy_message_factory(user_id=user["id"]) + mock_message_retrieval_slack_client.conversations_history.return_value = { "ok": True, - "messages": [{"text": "Hello, world!"}], + "messages": [message], } - response = await get_messages_in_channel_by_name(mock_context, "general", limit=1) - - assert response == {"messages": [{"text": "Hello, world!"}], "next_cursor": None} - mock_chat_slack_client.conversations_history.assert_called_once_with( - channel="C12345", include_all_metadata=True, inclusive=True, limit=1, cursor=None - ) - - -@pytest.mark.asyncio -@patch("arcade_slack.tools.chat.retrieve_conversations_by_user_ids") -async def test_get_direct_message_conversation_metadata_by_username( - mock_retrieve_conversations_by_user_ids, - mock_context, - mock_chat_slack_client, - mock_users_slack_client, -): - mock_chat_slack_client.auth_test.return_value = { + mock_user_retrieval_slack_client.users_info.return_value = { "ok": True, - "user_id": "U1", - "team_id": "T1", - "user": "user1", + "user": user, } - mock_users_slack_client.users_list.return_value = { - "ok": True, - "members": [ - {"id": "U1", "name": "user1"}, - {"id": "U2", "name": "user2"}, - ], - "response_metadata": {"next_cursor": None}, - } - - conversation = { - "id": "C12345", - "type": ConversationTypeSlackName.IM.value, - "is_im": True, - "members": ["U1", "U2"], - } - - mock_retrieve_conversations_by_user_ids.return_value = [conversation] - - response = await get_direct_message_conversation_metadata_by_username( - context=mock_context, username="user2" - ) - - assert response == conversation - mock_retrieve_conversations_by_user_ids.assert_called_once_with( - list_conversations_func=list_conversations_metadata, - get_members_in_conversation_func=get_members_in_conversation_by_id, + response = await get_messages( context=mock_context, - conversation_types=[ConversationType.DIRECT_MESSAGE], - user_ids=["U1", "U2"], - exact_match=True, + channel_name="general", limit=1, - next_cursor=None, ) + assert response["next_cursor"] is None + assert len(response["messages"]) == 1 + returned_message = response["messages"][0] + assert returned_message["user"] == {"id": user["id"], "name": user["name"]} + assert returned_message["text"] == message["text"] -@pytest.mark.asyncio -@patch("arcade_slack.tools.chat.retrieve_conversations_by_user_ids") -async def test_get_direct_message_conversation_metadata_by_username_username_not_found( - mock_retrieve_conversations_by_user_ids, - mock_context, - mock_chat_slack_client, - mock_users_slack_client, -): - mock_chat_slack_client.users_identity.return_value = { - "ok": True, - "user": {"id": "U1", "name": "user1"}, - "team": {"id": "T1", "name": "team1"}, - } - - mock_users_slack_client.users_list.return_value = { - "ok": True, - "members": [ - {"id": "U1", "name": "user1"}, - {"id": "U2", "name": "user2"}, - ], - "response_metadata": {"next_cursor": None}, - } - - mock_retrieve_conversations_by_user_ids.side_effect = TimeoutError() - - with pytest.raises(RetryableToolError): - await get_direct_message_conversation_metadata_by_username( - context=mock_context, username="user999" - ) - - -@pytest.mark.asyncio -@patch("arcade_slack.tools.chat.get_messages_in_conversation_by_id") -@patch("arcade_slack.tools.chat.get_direct_message_conversation_metadata_by_username") -async def test_get_messages_in_direct_conversation_by_username( - mock_get_direct_message_conversation_metadata_by_username, - mock_get_messages_in_conversation_by_id, - mock_context, -): - mock_get_direct_message_conversation_metadata_by_username.return_value = { - "id": "C12345", - } - - response = await get_messages_in_direct_message_conversation_by_username( - context=mock_context, username="user2" - ) - - assert response == mock_get_messages_in_conversation_by_id.return_value - mock_get_direct_message_conversation_metadata_by_username.assert_called_once_with( - context=mock_context, username="user2" - ) - mock_get_messages_in_conversation_by_id.assert_called_once_with( - context=mock_context, - conversation_id="C12345", - oldest_relative=None, - latest_relative=None, - oldest_datetime=None, - latest_datetime=None, - limit=None, - next_cursor=None, - ) - - -@pytest.mark.asyncio -@patch("arcade_slack.tools.chat.get_direct_message_conversation_metadata_by_username") -async def test_get_messages_in_direct_conversation_by_username_not_found( - mock_get_direct_message_conversation_metadata_by_username, - mock_context, -): - mock_get_direct_message_conversation_metadata_by_username.return_value = None - - with pytest.raises(ToolExecutionError): - await get_messages_in_direct_message_conversation_by_username( - context=mock_context, username="user2" - ) - - -@pytest.mark.asyncio -@patch("arcade_slack.tools.chat.retrieve_conversations_by_user_ids") -async def test_get_multi_person_direct_message_conversation_metadata_by_username( - mock_retrieve_conversations_by_user_ids, - mock_context, - mock_chat_slack_client, - mock_users_slack_client, -): - mock_chat_slack_client.auth_test.return_value = { - "ok": True, - "user_id": "U1", - "team_id": "T1", - "user": "user1", - } - - mock_users_slack_client.users_list.return_value = { - "ok": True, - "members": [ - {"id": "U1", "name": "user1"}, - {"id": "U2", "name": "user2"}, - {"id": "U3", "name": "user3"}, - {"id": "U4", "name": "user4"}, - {"id": "U5", "name": "user5"}, - ], - "response_metadata": {"next_cursor": None}, - } - - conversation = { - "id": "C12345", - "type": ConversationTypeSlackName.MPIM.value, - "is_mpim": True, - "members": ["U1", "U4", "U5"], - } - - mock_retrieve_conversations_by_user_ids.return_value = [conversation] - - response = await get_multi_person_dm_conversation_metadata_by_usernames( - context=mock_context, usernames=["user1", "user4", "user5"] - ) - - assert response == conversation - mock_retrieve_conversations_by_user_ids.assert_called_once_with( - list_conversations_func=list_conversations_metadata, - get_members_in_conversation_func=get_members_in_conversation_by_id, - context=mock_context, - conversation_types=[ConversationType.MULTI_PERSON_DIRECT_MESSAGE], - user_ids=["U1", "U4", "U5"], - exact_match=True, + mock_message_retrieval_slack_client.conversations_history.assert_called_once_with( + channel="C12345", + include_all_metadata=True, + inclusive=True, limit=1, - next_cursor=None, + cursor=None, ) - - -@pytest.mark.asyncio -@patch("arcade_slack.tools.chat.retrieve_conversations_by_user_ids") -async def test_get_multi_person_direct_message_conversation_metadata_by_username_username_not_found( - mock_retrieve_conversations_by_user_ids, - mock_context, - mock_chat_slack_client, - mock_users_slack_client, -): - mock_chat_slack_client.users_identity.return_value = { - "ok": True, - "user": {"id": "U1", "name": "user1"}, - "team": {"id": "T1", "name": "team1"}, - } - - mock_users_slack_client.users_list.return_value = { - "ok": True, - "members": [ - {"id": "U1", "name": "user1"}, - {"id": "U2", "name": "user2"}, - ], - "response_metadata": {"next_cursor": None}, - } - - mock_retrieve_conversations_by_user_ids.side_effect = TimeoutError() - - with pytest.raises(RetryableToolError): - await get_multi_person_dm_conversation_metadata_by_usernames( - context=mock_context, usernames=["user999", "user1", "user2"] - ) - - -@pytest.mark.asyncio -@patch("arcade_slack.tools.chat.get_messages_in_conversation_by_id") -@patch("arcade_slack.tools.chat.get_multi_person_dm_conversation_metadata_by_usernames") -async def test_get_messages_in_multi_person_dm_conversation_by_usernames( - mock_get_multi_person_dm_conversation_metadata_by_usernames, - mock_get_messages_in_conversation_by_id, - mock_context, -): - mock_get_multi_person_dm_conversation_metadata_by_usernames.return_value = { - "id": "C12345", - } - - response = await get_messages_in_multi_person_dm_conversation_by_usernames( - context=mock_context, usernames=["user1", "user4", "user5"] - ) - - assert response == mock_get_messages_in_conversation_by_id.return_value - - mock_get_multi_person_dm_conversation_metadata_by_usernames.assert_called_once_with( - context=mock_context, usernames=["user1", "user4", "user5"] - ) - - mock_get_messages_in_conversation_by_id.assert_called_once_with( - context=mock_context, - conversation_id="C12345", - oldest_relative=None, - latest_relative=None, - oldest_datetime=None, - latest_datetime=None, - limit=None, - next_cursor=None, - ) - - -@pytest.mark.asyncio -@patch("arcade_slack.tools.chat.get_multi_person_dm_conversation_metadata_by_usernames") -async def test_get_messages_in_multi_person_dm_conversation_by_usernames_not_found( - mock_get_multi_person_dm_conversation_metadata_by_usernames, - mock_context, -): - mock_get_multi_person_dm_conversation_metadata_by_usernames.return_value = None - - with pytest.raises(ToolExecutionError): - await get_messages_in_direct_message_conversation_by_username( - context=mock_context, username="user2" - ) diff --git a/toolkits/slack/tests/test_models.py b/toolkits/slack/tests/test_models.py new file mode 100644 index 00000000..208473a0 --- /dev/null +++ b/toolkits/slack/tests/test_models.py @@ -0,0 +1,107 @@ +import asyncio +from unittest.mock import AsyncMock + +import pytest +from arcade_tdk.errors import ToolExecutionError +from slack_sdk.errors import SlackApiError + +from arcade_slack.models import ( + ConcurrencySafeCoroutineCaller, + FindChannelByNameSentinel, + FindMultipleUsersByUsernameSentinel, + FindUserByUsernameSentinel, + GetUserByEmailCaller, +) + + +def test_find_user_by_username_sentinel(): + sentinel = FindUserByUsernameSentinel(username="jenifer") + assert sentinel(last_result=[{"name": "jack"}]) is False + assert sentinel(last_result=[{"name": "john"}, {"name": "jack"}]) is False + assert sentinel(last_result=[{"name": "hello"}, {"name": "jenifer"}]) is True + assert sentinel(last_result=[{"name": "JENIFER"}]) is True + + +def test_find_multiple_users_by_username_sentinel(): + sentinel = FindMultipleUsersByUsernameSentinel(usernames=["jenifer", "jack"]) + assert sentinel(last_result=[{"name": "jack"}]) is False + assert sentinel(last_result=[{"name": "john"}, {"name": "jack"}]) is False + assert sentinel(last_result=[{"name": "hello"}, {"name": "JENIFER"}]) is True + assert sentinel(last_result=[{"name": "world"}]) is True + + +def test_find_channel_by_name_sentinel(): + sentinel = FindChannelByNameSentinel(channel_name="foobar") + assert sentinel(last_result=[{"name": "foo"}]) is False + assert sentinel(last_result=[{"name": "foo"}, {"name": "bar"}]) is False + assert sentinel(last_result=[{"name": "foo"}, {"name": "foobar"}]) is True + assert sentinel(last_result=[{"name": "FOObar"}]) is True + + +@pytest.mark.asyncio +async def test_concurrency_safe_coroutine_caller(): + mock_func = AsyncMock() + mock_semaphore = AsyncMock(spec=asyncio.Semaphore) + + caller = ConcurrencySafeCoroutineCaller(mock_func, "arg1", "arg2", kwarg1="kwarg1") + response = await caller(semaphore=mock_semaphore) + + assert response == mock_func.return_value + mock_func.assert_awaited_once_with("arg1", "arg2", kwarg1="kwarg1") + mock_semaphore.__aenter__.assert_awaited_once() + mock_semaphore.__aexit__.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_user_by_email_caller_success(): + mock_func = AsyncMock() + mock_func.return_value = {"user": {"id": "U1234567890", "name": "John Doe"}} + mock_semaphore = AsyncMock(spec=asyncio.Semaphore) + + caller = GetUserByEmailCaller(mock_func, "test@example.com") + response = await caller(semaphore=mock_semaphore) + + assert response == { + "user": {"id": "U1234567890", "name": "John Doe"}, + "email": "test@example.com", + } + mock_func.assert_awaited_once_with(email="test@example.com") + mock_semaphore.__aenter__.assert_awaited_once() + mock_semaphore.__aexit__.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_user_by_email_caller_not_found_error(): + mock_func = AsyncMock( + side_effect=SlackApiError(message="User not found", response={"error": "user_not_found"}) + ) + mock_semaphore = AsyncMock(spec=asyncio.Semaphore) + + caller = GetUserByEmailCaller(mock_func, "test@example.com") + response = await caller(semaphore=mock_semaphore) + + assert response == { + "user": None, + "email": "test@example.com", + } + mock_func.assert_awaited_once_with(email="test@example.com") + mock_semaphore.__aenter__.assert_awaited_once() + mock_semaphore.__aexit__.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_user_by_email_caller_unknown_slack_api_error(): + mock_func = AsyncMock( + side_effect=SlackApiError(message="Unknown error", response={"error": "unknown_error"}) + ) + mock_semaphore = AsyncMock(spec=asyncio.Semaphore) + + caller = GetUserByEmailCaller(mock_func, "test@example.com") + with pytest.raises(ToolExecutionError) as error: + await caller(semaphore=mock_semaphore) + + assert error.value.message == "Error getting user by email" + assert error.value.developer_message == "Error getting user by email: unknown_error" + mock_func.assert_awaited_once_with(email="test@example.com") + mock_semaphore.__aenter__.assert_awaited_once() + mock_semaphore.__aexit__.assert_awaited_once() diff --git a/toolkits/slack/tests/test_user_retrieval.py b/toolkits/slack/tests/test_user_retrieval.py new file mode 100644 index 00000000..c584dd4b --- /dev/null +++ b/toolkits/slack/tests/test_user_retrieval.py @@ -0,0 +1,313 @@ +import json + +import pytest +from arcade_tdk.errors import RetryableToolError, ToolExecutionError +from slack_sdk.errors import SlackApiError + +from arcade_slack.user_retrieval import ( + get_single_user_by_id, + get_users_by_id, + get_users_by_id_username_or_email, +) +from arcade_slack.utils import ( + cast_user_dict, + extract_basic_user_info, + short_user_info, +) + + +@pytest.mark.asyncio +async def test_get_multiple_users_by_emails_success( + mock_context, mock_user_retrieval_slack_client, dummy_user_factory +): + user1 = dummy_user_factory() + user2 = dummy_user_factory() + + emails = [user1["profile"]["email"], user2["profile"]["email"]] + + mock_user_retrieval_slack_client.users_lookupByEmail.side_effect = [ + {"ok": True, "user": user1}, + {"ok": True, "user": user2}, + ] + + response = await get_users_by_id_username_or_email(context=mock_context, emails=emails) + + assert response == [extract_basic_user_info(user1), extract_basic_user_info(user2)] + + +@pytest.mark.asyncio +async def test_get_multiple_users_by_usernames_or_emails_with_emails_not_found( + mock_context, + mock_user_retrieval_slack_client, + mock_users_slack_client, + dummy_user_factory, +): + user1 = dummy_user_factory(email="user1@example.com") + + emails = [user1["profile"]["email"], "not_found@example.com"] + + async def lookup_by_email_side_effect(*, email): + if email == "user1@example.com": + return {"ok": True, "user": user1} + raise SlackApiError( + message="User not found", + response={"error": "user_not_found"}, + ) + + mock_user_retrieval_slack_client.users_lookupByEmail.side_effect = lookup_by_email_side_effect + + mock_users_slack_client.users_list.return_value = { + "ok": True, + "members": [user1], + } + + with pytest.raises(RetryableToolError) as error: + await get_users_by_id_username_or_email(context=mock_context, emails=emails) + + assert "not_found@example.com" in error.value.message + assert json.dumps(short_user_info(user1)) in error.value.additional_prompt_content + + +@pytest.mark.asyncio +async def test_get_multiple_users_by_usernames_or_emails_with_usernames_success( + mock_context, mock_user_retrieval_slack_client, dummy_user_factory +): + user1 = dummy_user_factory() + user2 = dummy_user_factory() + + usernames = [user1["name"], user2["name"]] + + mock_user_retrieval_slack_client.users_list.return_value = { + "ok": True, + "members": [user1, user2], + } + + response = await get_users_by_id_username_or_email(context=mock_context, usernames=usernames) + + assert response == [extract_basic_user_info(user1), extract_basic_user_info(user2)] + + +@pytest.mark.asyncio +async def test_get_multiple_users_by_usernames_or_emails_with_usernames_not_found( + mock_context, mock_user_retrieval_slack_client, dummy_user_factory +): + user1 = dummy_user_factory() + user2 = dummy_user_factory() + user3 = dummy_user_factory() + + usernames = [user1["name"], "username_not_found"] + + mock_user_retrieval_slack_client.users_list.return_value = { + "ok": True, + "members": [user1, user2, user3], + } + + with pytest.raises(RetryableToolError) as error: + await get_users_by_id_username_or_email(context=mock_context, usernames=usernames) + + assert "username_not_found" in error.value.message + assert json.dumps(short_user_info(user1)) in error.value.additional_prompt_content + assert json.dumps(short_user_info(user2)) in error.value.additional_prompt_content + assert json.dumps(short_user_info(user3)) in error.value.additional_prompt_content + + +@pytest.mark.asyncio +async def test_get_multiple_users_by_mixed_usernames_and_emails_success( + mock_context, mock_user_retrieval_slack_client, dummy_user_factory +): + user1 = dummy_user_factory() + user2 = dummy_user_factory() + user3 = dummy_user_factory() + user4 = dummy_user_factory() + + mock_user_retrieval_slack_client.users_list.return_value = { + "ok": True, + "members": [user1, user2], + } + mock_user_retrieval_slack_client.users_lookupByEmail.side_effect = [ + {"ok": True, "user": user3}, + {"ok": True, "user": user4}, + ] + + response = await get_users_by_id_username_or_email( + context=mock_context, + usernames=[user1["name"], user2["name"]], + emails=[user3["profile"]["email"], user4["profile"]["email"]], + ) + + assert response == [ + extract_basic_user_info(user1), + extract_basic_user_info(user2), + extract_basic_user_info(user3), + extract_basic_user_info(user4), + ] + + +@pytest.mark.asyncio +async def test_get_single_user_by_id_success( + mock_context, mock_user_retrieval_slack_client, dummy_user_factory +): + user = dummy_user_factory() + + mock_user_retrieval_slack_client.users_info.return_value = {"ok": True, "user": user} + + response = await get_single_user_by_id( + auth_token=mock_context.get_auth_token_or_empty(), + user_id=user["id"], + ) + + assert response == cast_user_dict(user) + + +@pytest.mark.asyncio +async def test_get_single_user_by_id_not_found( + mock_context, mock_user_retrieval_slack_client, dummy_user_factory +): + user = dummy_user_factory() + + mock_user_retrieval_slack_client.users_info.side_effect = SlackApiError( + message="User not found", + response={"error": "user_not_found"}, + ) + + response = await get_single_user_by_id( + auth_token=mock_context.get_auth_token_or_empty(), + user_id=user["id"], + ) + + assert response is None + + +@pytest.mark.asyncio +async def test_get_single_user_by_id_not_ok( + mock_context, mock_user_retrieval_slack_client, dummy_user_factory +): + user = dummy_user_factory() + + mock_user_retrieval_slack_client.users_info.return_value = {"ok": False, "error": "not_ok"} + + response = await get_single_user_by_id( + auth_token=mock_context.get_auth_token_or_empty(), + user_id=user["id"], + ) + + assert response is None + + +@pytest.mark.asyncio +async def test_get_single_user_by_id_unknown_error( + mock_context, mock_user_retrieval_slack_client, dummy_user_factory +): + user = dummy_user_factory() + + mock_user_retrieval_slack_client.users_info.side_effect = SlackApiError( + message="Unknown error", + response={"error": "unknown_error_string"}, + ) + + with pytest.raises(ToolExecutionError) as error: + await get_single_user_by_id( + auth_token=mock_context.get_auth_token_or_empty(), + user_id=user["id"], + ) + + assert user["id"] in error.value.message + assert "unknown_error_string" in error.value.developer_message + + +@pytest.mark.asyncio +async def test_get_users_by_id_with_one_user_id_success( + mock_context, mock_user_retrieval_slack_client, dummy_user_factory +): + user1 = dummy_user_factory() + + mock_user_retrieval_slack_client.users_info.return_value = {"ok": True, "user": user1} + + response = await get_users_by_id( + auth_token=mock_context.get_auth_token_or_empty(), + user_ids=[user1["id"]], + ) + + assert response == {"users": [cast_user_dict(user1)], "not_found": []} + + mock_user_retrieval_slack_client.users_list.assert_not_called() + + +@pytest.mark.asyncio +async def test_get_users_by_id_with_one_user_id_not_found( + mock_context, mock_user_retrieval_slack_client, dummy_user_factory +): + user1 = dummy_user_factory() + + mock_user_retrieval_slack_client.users_info.side_effect = SlackApiError( + message="User not found", + response={"error": "user_not_found"}, + ) + + response = await get_users_by_id( + auth_token=mock_context.get_auth_token_or_empty(), + user_ids=[user1["id"]], + ) + + assert response == {"users": [], "not_found": [user1["id"]]} + + mock_user_retrieval_slack_client.users_list.assert_not_called() + + +@pytest.mark.asyncio +async def test_get_users_by_id_with_multiple_user_ids_success( + mock_context, mock_user_retrieval_slack_client, dummy_user_factory +): + user1 = dummy_user_factory() + user2 = dummy_user_factory() + user3 = dummy_user_factory() + user4 = dummy_user_factory() + + mock_user_retrieval_slack_client.users_list.side_effect = [ + { + "ok": True, + "members": [user1, user2], + "response_metadata": {"next_cursor": "next_cursor"}, + }, + {"ok": True, "members": [user3, user4]}, + ] + + response = await get_users_by_id( + auth_token=mock_context.get_auth_token_or_empty(), + user_ids=[user1["id"], user4["id"]], + ) + + assert response == {"users": [cast_user_dict(user1), cast_user_dict(user4)], "not_found": []} + + +@pytest.mark.asyncio +async def test_get_users_by_id_with_multiple_user_ids_some_not_found( + mock_context, mock_user_retrieval_slack_client, dummy_user_factory +): + user1 = dummy_user_factory(id_="U1") + user2 = dummy_user_factory(id_="U2") + user3 = dummy_user_factory(id_="U3") + user4 = dummy_user_factory(id_="U4") + + mock_user_retrieval_slack_client.users_list.side_effect = [ + { + "ok": True, + "members": [user1, user2], + "response_metadata": {"next_cursor": "next_cursor"}, + }, + { + "ok": True, + "members": [user3, user4], + "response_metadata": {"next_cursor": None}, + }, + ] + + response = await get_users_by_id( + auth_token=mock_context.get_auth_token_or_empty(), + user_ids=[user1["id"], user4["id"], "user_not_exists"], + ) + + assert response == { + "users": [cast_user_dict(user1), cast_user_dict(user4)], + "not_found": ["user_not_exists"], + } diff --git a/toolkits/slack/tests/test_users.py b/toolkits/slack/tests/test_users.py index a9641e72..d2849bb7 100644 --- a/toolkits/slack/tests/test_users.py +++ b/toolkits/slack/tests/test_users.py @@ -1,21 +1,19 @@ +import json from unittest.mock import patch import pytest -from arcade_tdk.errors import RetryableToolError +from arcade_tdk.errors import RetryableToolError, ToolExecutionError from slack_sdk.errors import SlackApiError -from arcade_slack.tools.users import get_user_info_by_id, list_users -from arcade_slack.utils import extract_basic_user_info - - -@pytest.fixture -def mock_slack_client(mocker): - mock_client = mocker.patch("arcade_slack.tools.users.AsyncWebClient", autospec=True) - return mock_client.return_value +from arcade_slack.tools.users import ( + get_users_info, + list_users, +) +from arcade_slack.utils import extract_basic_user_info, short_user_info @pytest.mark.asyncio -async def test_get_user_info_by_id_success(mock_context, mock_slack_client): +async def test_get_user_info_by_id_success(mock_context, mock_user_retrieval_slack_client): # Mock the response from slackClient.users_info mock_user = { "id": "U12345", @@ -23,44 +21,46 @@ async def test_get_user_info_by_id_success(mock_context, mock_slack_client): "real_name": "Test User", "profile": {"email": "testuser@example.com"}, } - mock_slack_client.users_info.return_value = {"ok": True, "user": mock_user} + mock_user_retrieval_slack_client.users_info.return_value = {"ok": True, "user": mock_user} - # Call the function - response = await get_user_info_by_id(mock_context, user_id="U12345") + response = await get_users_info(mock_context, user_ids=["U12345"]) - # Verify that the correct Slack API method was called - mock_slack_client.users_info.assert_called_once_with(user="U12345") + mock_user_retrieval_slack_client.users_info.assert_called_once_with(user="U12345") - # Verify the response expected_response = extract_basic_user_info(mock_user) - assert response == expected_response + assert response == {"users": [expected_response]} @pytest.mark.asyncio @patch("arcade_slack.tools.users.list_users") -async def test_get_user_info_by_id_user_not_found(mock_list_users, mock_context, mock_slack_client): +async def test_get_user_info_by_id_user_not_found( + mock_list_users, mock_context, mock_user_retrieval_slack_client +): error_response = {"ok": False, "error": "user_not_found"} - mock_slack_client.users_info.side_effect = SlackApiError( + mock_user_retrieval_slack_client.users_info.side_effect = SlackApiError( message="User not found", response=error_response, ) existing_user = {"id": "U12345", "name": "testuser"} mock_list_users.return_value = {"users": [existing_user]} + mock_list_users.__tool_name__ = list_users.__tool_name__ with pytest.raises(RetryableToolError) as e: - await get_user_info_by_id(mock_context, user_id="U99999") + await get_users_info(mock_context, user_ids=["U99999"]) assert existing_user["id"] in e.value.additional_prompt_content - assert existing_user["name"] in e.value.additional_prompt_content - mock_slack_client.users_info.assert_called_once_with(user="U99999") - mock_list_users.assert_called_once_with(mock_context) + mock_user_retrieval_slack_client.users_info.assert_called_once_with(user="U99999") + mock_list_users.assert_called_once_with(mock_context, limit=100, exclude_bots=True) @pytest.mark.asyncio -async def test_list_users_success(mock_context, mock_slack_client): - mock_slack_client.users_list.return_value = {"ok": True, "members": [{"id": "U12345"}]} +async def test_list_users_success(mock_context, mock_users_slack_client): + mock_users_slack_client.users_list.return_value = { + "ok": True, + "members": [{"id": "U12345"}], + } response = await list_users(mock_context) assert response == { "users": [extract_basic_user_info({"id": "U12345"})], @@ -69,24 +69,306 @@ async def test_list_users_success(mock_context, mock_slack_client): @pytest.mark.asyncio -async def test_list_users_with_pagination_success(mock_context, mock_slack_client): - mock_slack_client.users_list.side_effect = [ +async def test_list_users_with_pagination_success( + dummy_user_factory, mock_context, mock_users_slack_client +): + user1 = dummy_user_factory() + user2 = dummy_user_factory() + user3 = dummy_user_factory() + user4 = dummy_user_factory() + + mock_users_slack_client.users_list.side_effect = [ { "ok": True, - "members": [{"id": "U12345"}], + "members": [user1, user2], "response_metadata": {"next_cursor": "cursor_xyz"}, }, { "ok": True, - "members": [{"id": "U123456"}], + "members": [user3, user4], "response_metadata": {"next_cursor": None}, }, ] response = await list_users(mock_context, limit=3) assert response == { "users": [ - extract_basic_user_info({"id": "U12345"}), - extract_basic_user_info({"id": "U123456"}), + extract_basic_user_info(user1), + extract_basic_user_info(user2), + extract_basic_user_info(user3), + extract_basic_user_info(user4), ], "next_cursor": None, } + + assert mock_users_slack_client.users_list.call_count == 2 + + +@pytest.mark.asyncio +async def test_get_user_by_username_success( + mock_context, + mock_user_retrieval_slack_client, + dummy_user_factory, +): + user1 = dummy_user_factory() + user2 = dummy_user_factory() + + mock_user_retrieval_slack_client.users_list.return_value = { + "ok": True, + "members": [user1, user2], + } + + response = await get_users_info(mock_context, usernames=[user1["name"]]) + + assert response == {"users": [extract_basic_user_info(user1)]} + + +@pytest.mark.asyncio +async def test_get_user_by_username_with_pagination_success( + mock_context, mock_user_retrieval_slack_client, dummy_user_factory +): + user1 = dummy_user_factory() + user2 = dummy_user_factory() + user3 = dummy_user_factory() + user4 = dummy_user_factory() + user5 = dummy_user_factory() + + mock_user_retrieval_slack_client.users_list.side_effect = [ + { + "ok": True, + "members": [user1, user2], + "response_metadata": {"next_cursor": "cursor1"}, + }, + { + "ok": True, + "members": [user3, user4], + "response_metadata": {"next_cursor": "cursor2"}, + }, + { + "ok": True, + "members": [user5], + "response_metadata": {"next_cursor": None}, + }, + ] + + response = await get_users_info(mock_context, usernames=[user3["name"]]) + + assert response == {"users": [extract_basic_user_info(user3)]} + + assert mock_user_retrieval_slack_client.users_list.call_count == 2 + + +@pytest.mark.asyncio +async def test_get_user_by_username_not_found( + mock_context, mock_user_retrieval_slack_client, dummy_user_factory +): + user1 = dummy_user_factory() + user2 = dummy_user_factory() + user3 = dummy_user_factory(is_bot=True) + + mock_user_retrieval_slack_client.users_list.return_value = { + "ok": True, + "members": [user1, user2, user3], + } + + with pytest.raises(RetryableToolError) as e: + await get_users_info(mock_context, usernames=[user1["name"] + "not_found"]) + + # Check that the error message contains the available users + assert user1["id"] in e.value.additional_prompt_content + assert user2["id"] in e.value.additional_prompt_content + assert user3["id"] not in e.value.additional_prompt_content + + +@pytest.mark.asyncio +async def test_get_multiple_users_by_username_success( + mock_context, mock_user_retrieval_slack_client, dummy_user_factory +): + user1 = dummy_user_factory() + user2 = dummy_user_factory() + user3 = dummy_user_factory() + + mock_user_retrieval_slack_client.users_list.return_value = { + "ok": True, + "members": [user1, user2, user3], + } + + response = await get_users_info( + mock_context, + usernames=[user1["name"], user2["name"]], + ) + + assert response == { + "users": [ + extract_basic_user_info(user1), + extract_basic_user_info(user2), + ] + } + + +@pytest.mark.asyncio +async def test_get_multiple_users_by_username_with_pagination_success( + mock_context, mock_user_retrieval_slack_client, dummy_user_factory +): + user1 = dummy_user_factory() + user2 = dummy_user_factory() + user3 = dummy_user_factory() + user4 = dummy_user_factory() + user5 = dummy_user_factory() + + mock_user_retrieval_slack_client.users_list.side_effect = [ + {"ok": True, "members": [user1, user2], "response_metadata": {"next_cursor": "cursor1"}}, + {"ok": True, "members": [user3, user4], "response_metadata": {"next_cursor": "cursor2"}}, + {"ok": True, "members": [user5], "response_metadata": {"next_cursor": None}}, + ] + + response = await get_users_info(mock_context, usernames=[user1["name"], user3["name"]]) + + assert response == {"users": [extract_basic_user_info(user1), extract_basic_user_info(user3)]} + assert mock_user_retrieval_slack_client.users_list.call_count == 2 + + +@pytest.mark.asyncio +async def test_get_multiple_users_by_username_not_found( + mock_context, mock_user_retrieval_slack_client, dummy_user_factory +): + user1 = dummy_user_factory() + user2 = dummy_user_factory() + user3 = dummy_user_factory(is_bot=True) + + mock_user_retrieval_slack_client.users_list.return_value = { + "ok": True, + "members": [user1, user2, user3], + } + + not_found_username = f"{user2['name']} not_found" + + with pytest.raises(RetryableToolError) as e: + await get_users_info(mock_context, usernames=[user1["name"], not_found_username]) + + assert user1["id"] in e.value.additional_prompt_content + assert user2["id"] in e.value.additional_prompt_content + assert user3["id"] not in e.value.additional_prompt_content + + +@pytest.mark.asyncio +async def test_get_user_by_email_success( + mock_context, mock_user_retrieval_slack_client, dummy_user_factory +): + user = dummy_user_factory() + mock_user_retrieval_slack_client.users_lookupByEmail.return_value = { + "ok": True, + "user": user, + } + response = await get_users_info(mock_context, emails=[user["profile"]["email"]]) + assert response == {"users": [extract_basic_user_info(user)]} + + +@pytest.mark.asyncio +async def test_get_user_by_email_not_found( + mock_context, mock_users_slack_client, mock_user_retrieval_slack_client, dummy_user_factory +): + additional_user = dummy_user_factory(email="additional_user@example.com") + + async def lookup_by_email_side_effect(*, email): + if email == "additional_user@example.com": + return {"ok": True, "user": additional_user} + raise SlackApiError( + message="User not found", + response={"ok": False, "error": "user_not_found"}, + ) + + mock_user_retrieval_slack_client.users_lookupByEmail.side_effect = lookup_by_email_side_effect + mock_users_slack_client.users_list.return_value = { + "ok": True, + "members": [additional_user], + } + + with pytest.raises(RetryableToolError) as e: + await get_users_info(mock_context, emails=["not_found@example.com"]) + + assert "not_found@example.com" in e.value.message + assert json.dumps(short_user_info(additional_user)) in e.value.additional_prompt_content + + +@pytest.mark.asyncio +async def test_get_multiple_users_by_email_success( + mock_context, mock_user_retrieval_slack_client, dummy_user_factory +): + user1 = dummy_user_factory() + user2 = dummy_user_factory() + user3 = dummy_user_factory() + + mock_user_retrieval_slack_client.users_lookupByEmail.side_effect = [ + {"ok": True, "user": user1}, + {"ok": True, "user": user2}, + {"ok": True, "user": user3}, + ] + + response = await get_users_info( + mock_context, + emails=[ + user1["profile"]["email"], + user2["profile"]["email"], + user3["profile"]["email"], + ], + ) + + assert response == { + "users": [ + extract_basic_user_info(user1), + extract_basic_user_info(user2), + extract_basic_user_info(user3), + ] + } + + +@pytest.mark.asyncio +async def test_get_multiple_users_by_email_with_invalid_address( + mock_context, mock_user_retrieval_slack_client +): + with pytest.raises(ToolExecutionError) as e: + await get_users_info( + mock_context, + emails=["amyra@jades.com", "invalid_address"], + ) + + mock_user_retrieval_slack_client.assert_not_called() + assert e.value.message == "Invalid email address: invalid_address" + + +@pytest.mark.asyncio +async def test_get_multiple_users_by_email_not_found( + mock_context, mock_user_retrieval_slack_client, mock_users_slack_client, dummy_user_factory +): + user1 = dummy_user_factory() + user2 = dummy_user_factory() + + async def lookup_by_email_side_effect(*, email): + if email == user1["profile"]["email"]: + return {"ok": True, "user": user1} + if email == user2["profile"]["email"]: + return {"ok": True, "user": user2} + raise SlackApiError( + message="User not found", + response={"ok": False, "error": "user_not_found"}, + ) + + mock_user_retrieval_slack_client.users_lookupByEmail.side_effect = lookup_by_email_side_effect + mock_users_slack_client.users_list.return_value = { + "ok": True, + "members": [user1, user2], + } + + with pytest.raises(RetryableToolError) as e: + await get_users_info( + mock_context, + emails=[ + user1["profile"]["email"], + user2["profile"]["email"], + "not_found@example.com", + ], + ) + + assert "not_found@example.com" in e.value.message + assert json.dumps(short_user_info(user1)) in e.value.additional_prompt_content + assert json.dumps(short_user_info(user2)) in e.value.additional_prompt_content diff --git a/toolkits/slack/tests/test_utils.py b/toolkits/slack/tests/test_utils.py index c095f7b8..dc4aec85 100644 --- a/toolkits/slack/tests/test_utils.py +++ b/toolkits/slack/tests/test_utils.py @@ -1,25 +1,26 @@ import asyncio -from unittest.mock import AsyncMock, MagicMock, call, patch +import copy +import json +from unittest.mock import AsyncMock, call, patch import pytest -from arcade_tdk import ToolContext +from arcade_tdk.errors import RetryableToolError from slack_sdk.errors import SlackApiError from slack_sdk.web.async_client import AsyncWebClient from arcade_slack.exceptions import PaginationTimeoutError from arcade_slack.models import ( - ConversationType, - ConversationTypeSlackName, -) -from arcade_slack.tools.chat import ( - get_members_in_conversation_by_id, - list_conversations_metadata, + ConcurrencySafeCoroutineCaller, + FindMultipleUsersByUsernameSentinel, + FindUserByUsernameSentinel, ) from arcade_slack.utils import ( async_paginate, - convert_conversation_type_to_slack_name, + build_multiple_users_retrieval_response, filter_conversations_by_user_ids, - retrieve_conversations_by_user_ids, + gather_with_concurrency_limit, + is_valid_email, + populate_users_in_messages, ) @@ -41,6 +42,170 @@ async def test_async_paginate(): assert next_cursor is None +@pytest.mark.asyncio +async def test_async_paginate_with_find_user_sentinel(): + mock_slack_client = AsyncMock() + mock_slack_client.users_list.side_effect = [ + { + "ok": True, + "members": [ + {"id": "123", "name": "Jack"}, + {"id": "456", "name": "John"}, + ], + "response_metadata": {"next_cursor": "cursor1"}, + }, + { + "ok": True, + "members": [{"id": "789", "name": "Jenifer"}], + "response_metadata": {"next_cursor": "cursor2"}, + }, + { + "ok": True, + "members": [{"id": "007", "name": "James"}], + "response_metadata": {"next_cursor": None}, + }, + ] + + results, next_cursor = await async_paginate( + func=mock_slack_client.users_list, + response_key="members", + sentinel=FindUserByUsernameSentinel(username="jenifer"), + ) + + assert results == [ + {"id": "123", "name": "Jack"}, + {"id": "456", "name": "John"}, + {"id": "789", "name": "Jenifer"}, + ] + assert next_cursor == "cursor2" + + +@pytest.mark.asyncio +async def test_async_paginate_with_find_user_sentinel_not_found(): + mock_slack_client = AsyncMock() + mock_slack_client.users_list.side_effect = [ + { + "ok": True, + "members": [ + {"id": "123", "name": "Jack"}, + {"id": "456", "name": "John"}, + ], + "response_metadata": {"next_cursor": "cursor1"}, + }, + { + "ok": True, + "members": [{"id": "789", "name": "Jenifer"}], + "response_metadata": {"next_cursor": "cursor2"}, + }, + { + "ok": True, + "members": [{"id": "007", "name": "James"}], + "response_metadata": {"next_cursor": None}, + }, + ] + + results, next_cursor = await async_paginate( + func=mock_slack_client.users_list, + response_key="members", + sentinel=FindUserByUsernameSentinel(username="Do not find me"), + ) + + assert results == [ + {"id": "123", "name": "Jack"}, + {"id": "456", "name": "John"}, + {"id": "789", "name": "Jenifer"}, + {"id": "007", "name": "James"}, + ] + assert next_cursor is None + + +@pytest.mark.asyncio +async def test_async_paginate_with_find_multiple_users_sentinel(): + mock_slack_client = AsyncMock() + mock_slack_client.users_list.side_effect = [ + { + "ok": True, + "members": [ + {"id": "123", "name": "Jack"}, + {"id": "456", "name": "John"}, + ], + "response_metadata": {"next_cursor": "cursor1"}, + }, + { + "ok": True, + "members": [ + {"id": "789", "name": "Jenifer"}, + {"id": "101", "name": "Janis"}, + ], + "response_metadata": {"next_cursor": "cursor2"}, + }, + { + "ok": True, + "members": [{"id": "007", "name": "James"}], + "response_metadata": {"next_cursor": None}, + }, + ] + + results, next_cursor = await async_paginate( + func=mock_slack_client.users_list, + response_key="members", + sentinel=FindMultipleUsersByUsernameSentinel(usernames=["jenifer", "jack"]), + ) + + assert results == [ + {"id": "123", "name": "Jack"}, + {"id": "456", "name": "John"}, + {"id": "789", "name": "Jenifer"}, + {"id": "101", "name": "Janis"}, + ] + assert next_cursor == "cursor2" + + +@pytest.mark.asyncio +async def test_async_paginate_with_find_multiple_users_sentinel_not_found(): + mock_slack_client = AsyncMock() + mock_slack_client.users_list.side_effect = [ + { + "ok": True, + "members": [ + {"id": "123", "name": "Jack"}, + {"id": "456", "name": "John"}, + ], + "response_metadata": {"next_cursor": "cursor1"}, + }, + { + "ok": True, + "members": [ + {"id": "789", "name": "Jenifer"}, + {"id": "101", "name": "Janis"}, + ], + "response_metadata": {"next_cursor": "cursor2"}, + }, + { + "ok": True, + "members": [{"id": "007", "name": "James"}], + "response_metadata": {"next_cursor": None}, + }, + ] + + results, next_cursor = await async_paginate( + func=mock_slack_client.users_list, + response_key="members", + sentinel=FindMultipleUsersByUsernameSentinel( + usernames=["jenifer", "jack", "do not find me"] + ), + ) + + assert results == [ + {"id": "123", "name": "Jack"}, + {"id": "456", "name": "John"}, + {"id": "789", "name": "Jenifer"}, + {"id": "101", "name": "Janis"}, + {"id": "007", "name": "James"}, + ] + assert next_cursor is None + + @pytest.mark.asyncio async def test_async_paginate_with_response_error(): mock_slack_client = AsyncMock() @@ -278,407 +443,247 @@ def test_filter_conversations_by_user_ids_exact_match_empty_response(): assert response == [] -@pytest.mark.asyncio @pytest.mark.parametrize( - "search_user_ids, conversation_types, exact_match, limit, expected_conversation_ids", + "users_by_email, users_by_username, expected_response", [ - (["U1", "U2"], [ConversationType.DIRECT_MESSAGE], False, 1, ["C1"]), - (["U1", "U2"], [ConversationType.DIRECT_MESSAGE], True, 1, ["C1"]), - (["U1", "U2", "U3"], [ConversationType.DIRECT_MESSAGE], False, 1, []), ( - ["U1", "U2"], - [ConversationType.DIRECT_MESSAGE, ConversationType.PUBLIC_CHANNEL], - False, - 10, - ["C1", "C3", "C4"], + {"users": [{"id": "U1", "name": "user1"}]}, + {"users": [{"id": "U2", "name": "user2"}]}, + [{"id": "U1", "name": "user1"}, {"id": "U2", "name": "user2"}], ), ( - ["U1", "U2"], - [ConversationType.DIRECT_MESSAGE, ConversationType.PUBLIC_CHANNEL], - True, - 10, - ["C1", "C3"], + {"users": [{"id": "U1", "name": "user1"}]}, + {"users": []}, + [{"id": "U1", "name": "user1"}], + ), + ( + {"users": []}, + {"users": [{"id": "U2", "name": "user2"}]}, + [{"id": "U2", "name": "user2"}], + ), + ( + {"users": []}, + {"users": []}, + [], ), ], ) -async def test_retrieve_conversations_by_user_ids( - mock_chat_slack_client, - mock_users_slack_client, - search_user_ids, - conversation_types, - exact_match, - limit, - expected_conversation_ids, +@pytest.mark.asyncio +async def test_build_multiple_users_retrieval_response_success( + users_by_email, + users_by_username, + expected_response, + mock_context, ): - context = MagicMock(spec=ToolContext) - context.authorization = MagicMock() - context.authorization.token = MagicMock() + response = await build_multiple_users_retrieval_response( + context=mock_context, + users_responses=[users_by_email, users_by_username], + ) + assert response == expected_response - conversation_types_slack_name_str = [ - convert_conversation_type_to_slack_name(conv_type).value - for conv_type in conversation_types or ConversationType - ] - conversations = [ - { - "conversation": { - "id": "C1", - "type": ConversationTypeSlackName.IM.value, - "name": "im-1", - "is_channel": False, - "is_im": True, - "is_member": True, +@pytest.mark.parametrize( + "users_by_email, users_by_username", + [ + # Both emails and usernames not found + ( + { + "users": [{"id": "U1", "name": "user1"}], + "not_found": ["email_not_found"], }, - "members": { - "ok": True, - "members": ["U1", "U2"], - "response_metadata": {"next_cursor": None}, + { + "users": [{"id": "U2", "name": "user2"}], + "not_found": ["username_not_found"], + "available_users": [{"id": "U3", "name": "user3"}], }, - "users": [ - {"ok": True, "user": {"id": "U1", "team_id": "T123", "name": "user1"}}, - {"ok": True, "user": {"id": "U2", "team_id": "T123", "name": "user2"}}, - ], - }, - { - "conversation": { - "id": "C2", - "type": ConversationTypeSlackName.IM.value, - "name": "im-2", - "is_channel": False, - "is_im": True, - "is_member": True, + ), + # Email not found, usernames found + ( + { + "users": [{"id": "U1", "name": "user1"}], + "not_found": ["email_not_found"], }, - "members": { - "ok": True, - "members": ["U2", "U3"], - "response_metadata": {"next_cursor": None}, + { + "users": [{"id": "U2", "name": "user2"}], + "not_found": [], }, - "users": [ - {"ok": True, "user": {"id": "U2", "team_id": "T123", "name": "user2"}}, - {"ok": True, "user": {"id": "U3", "team_id": "T123", "name": "user3"}}, - ], - }, - { - "conversation": { - "id": "C3", - "type": ConversationTypeSlackName.PUBLIC_CHANNEL.value, - "name": "general", - "is_channel": True, - "is_im": False, - "is_member": True, + ), + # Email found, username not found + ( + { + "users": [{"id": "U1", "name": "user1"}], + "not_found": [], }, - "members": { - "ok": True, - "members": ["U1", "U2"], - "response_metadata": {"next_cursor": None}, + { + "users": [{"id": "U2", "name": "user2"}], + "not_found": ["username_not_found"], + "available_users": [{"id": "U3", "name": "user3"}], }, - "users": [ - {"ok": True, "user": {"id": "U1", "team_id": "T123", "name": "user1"}}, - {"ok": True, "user": {"id": "U2", "team_id": "T123", "name": "user2"}}, - ], - }, - { - "conversation": { - "id": "C4", - "type": ConversationTypeSlackName.PUBLIC_CHANNEL.value, - "name": "random", - "is_channel": True, - "is_im": False, - "is_member": True, - }, - "members": { - "ok": True, - "members": ["U1", "U2", "U3", "U4"], - "response_metadata": {"next_cursor": None}, - }, - "users": [ - {"ok": True, "user": {"id": "U1", "team_id": "T123", "name": "user1"}}, - {"ok": True, "user": {"id": "U2", "team_id": "T123", "name": "user2"}}, - {"ok": True, "user": {"id": "U3", "team_id": "T123", "name": "user3"}}, - {"ok": True, "user": {"id": "U4", "team_id": "T123", "name": "user4"}}, - ], - }, - ] + ), + ], +) +@pytest.mark.asyncio +async def test_build_multiple_users_retrieval_response_not_found( + users_by_email, + users_by_username, + mock_context, +): + with pytest.raises(RetryableToolError) as error: + await build_multiple_users_retrieval_response( + context=mock_context, + users_responses=[users_by_email, users_by_username], + ) - conversations_listed = [ - conversation - for conversation in conversations - if conversation["conversation"]["type"] in conversation_types_slack_name_str - ] + emails_not_found = users_by_email.get("not_found", []) + usernames_not_found = users_by_username.get("not_found", []) + available_users = users_by_username.get("available_users", []) - mock_chat_slack_client.conversations_list.return_value = { - "ok": True, - "channels": [conversation["conversation"] for conversation in conversations_listed], - "response_metadata": {"next_cursor": None}, - } + for email in emails_not_found: + assert email in error.value.message + for username in usernames_not_found: + assert username in error.value.message + for user in available_users: + assert json.dumps(user) in error.value.additional_prompt_content - mock_chat_slack_client.conversations_members.side_effect = [ - conversation["members"] for conversation in conversations_listed - ] - mock_users_slack_client.users_info.side_effect = [ - user for conversation in conversations_listed for user in conversation["users"] - ] +def test_is_valid_email(): + assert is_valid_email("test@example.com") + assert is_valid_email("test+123@example.com") + assert is_valid_email("test-123@example.com") + assert is_valid_email("test_123@example.com") + assert is_valid_email("test.123@example.com") + assert is_valid_email("test123@example.com") + assert is_valid_email("test@example.co") + assert is_valid_email("test@example.com.co") + assert not is_valid_email("test123@example") + assert not is_valid_email("test@example") + assert not is_valid_email("test@example.") + assert not is_valid_email("test@.com") + assert not is_valid_email("test@example.c") + assert not is_valid_email("test@example.com.") + assert not is_valid_email("test@example.com.c") - conversations_found = await retrieve_conversations_by_user_ids( - list_conversations_func=list_conversations_metadata, - get_members_in_conversation_func=get_members_in_conversation_by_id, - context=context, - conversation_types=conversation_types, - user_ids=search_user_ids, - exact_match=exact_match, - limit=limit, - next_cursor=None, + +@pytest.mark.asyncio +async def test_gather_with_concurrency_limit(): + mock_func1 = AsyncMock() + mock_func2 = AsyncMock() + + caller1 = ConcurrencySafeCoroutineCaller(mock_func1, "arg1", "arg2", kwarg1="kwarg1") + caller2 = ConcurrencySafeCoroutineCaller(mock_func2, "arg1", "arg2", kwarg1="kwarg1") + + mock_semaphore = AsyncMock(spec=asyncio.Semaphore) + + response = await gather_with_concurrency_limit( + coroutine_callers=[caller1, caller2], + semaphore=mock_semaphore, ) - assert [conversation["id"] for conversation in conversations_found] == expected_conversation_ids + response = tuple(response) + + assert len(response) == 2 + assert response[0] == mock_func1.return_value + assert response[1] == mock_func2.return_value + + mock_func1.assert_awaited_once_with("arg1", "arg2", kwarg1="kwarg1") + mock_func2.assert_awaited_once_with("arg1", "arg2", kwarg1="kwarg1") + + assert mock_semaphore.__aenter__.await_count == 2 + assert mock_semaphore.__aexit__.await_count == 2 @pytest.mark.asyncio -@pytest.mark.parametrize( - ( - "search_user_ids, conversation_types, exact_match, limit, " - "expected_conversation_ids, expected_conversations_list_calls" - ), - [ - ( - ["U1", "U2", "U3"], - [ConversationType.MULTI_PERSON_DIRECT_MESSAGE], - False, - None, - ["C1", "C3"], - 2, - ), - ( - ["U1", "U2", "U3"], - [ConversationType.MULTI_PERSON_DIRECT_MESSAGE], - True, - None, - ["C1"], - 2, - ), - (["U1", "U2", "U99"], [ConversationType.MULTI_PERSON_DIRECT_MESSAGE], False, None, [], 2), - ( - ["U1", "U2"], - [ConversationType.MULTI_PERSON_DIRECT_MESSAGE, ConversationType.PUBLIC_CHANNEL], - False, - None, - ["C1", "C3", "C4", "C6"], - 2, - ), - ( - ["U1", "U2"], - [ConversationType.MULTI_PERSON_DIRECT_MESSAGE, ConversationType.PUBLIC_CHANNEL], - False, - 1, - ["C1"], - 2, - ), - ( - ["U1", "U2"], - [ConversationType.MULTI_PERSON_DIRECT_MESSAGE, ConversationType.PUBLIC_CHANNEL], - False, - 3, - ["C1", "C3", "C4"], - 2, - ), - ( - ["U1", "U2"], - [ConversationType.MULTI_PERSON_DIRECT_MESSAGE, ConversationType.PUBLIC_CHANNEL], - True, - None, - ["C4"], - 2, - ), - ], -) -async def test_retrieve_conversations_by_user_ids_with_pagination( - mock_chat_slack_client, - mock_users_slack_client, - search_user_ids, - conversation_types, - exact_match, - limit, - expected_conversation_ids, - expected_conversations_list_calls, +async def test_populate_users_in_messages( + mock_context, + mock_user_retrieval_slack_client, + dummy_message_factory, + dummy_reaction_factory, + dummy_user_factory, ): - context = MagicMock(spec=ToolContext) - context.authorization = MagicMock() - context.authorization.token = MagicMock() + user1 = dummy_user_factory(id_="U1", name="user1") + user2 = dummy_user_factory(id_="U2", name="user2") + user3 = dummy_user_factory(id_="U3", name="user3") + user4 = dummy_user_factory(id_="U4", name="user4") + user5 = dummy_user_factory(id_="U5", name="user5") - conversation_types_slack_name_str = [ - convert_conversation_type_to_slack_name(conv_type).value - for conv_type in conversation_types or ConversationType + user1_short = {"id": user1["id"], "name": user1["name"]} + user2_short = {"id": user2["id"], "name": user2["name"]} + user3_short = {"id": user3["id"], "name": user3["name"]} + user4_short = {"id": user4["id"], "name": user4["name"]} + + user2_mention = f"<@{user2['name']} (id:{user2['id']})>" + user5_mention = f"<@{user5['name']} (id:{user5['id']})>" + + reactions = [ + dummy_reaction_factory(name="thumbsup", user_ids=[user1["id"], user2["id"]]), + dummy_reaction_factory(name="partyparrot", user_ids=[user3["id"], user4["id"]]), ] - conversations = [ - { - "conversation": { - "id": "C1", - "type": ConversationTypeSlackName.MPIM.value, - "name": "mpim-1", - "is_channel": False, - "is_im": False, - "is_mpim": True, - "is_member": True, - }, - "members": { - "ok": True, - "members": ["U1", "U2", "U3"], - "response_metadata": {"next_cursor": None}, - }, - "users": [ - {"ok": True, "user": {"id": "U1", "team_id": "T123", "name": "user1"}}, - {"ok": True, "user": {"id": "U2", "team_id": "T123", "name": "user2"}}, - {"ok": True, "user": {"id": "U3", "team_id": "T123", "name": "user3"}}, - ], - }, - { - "conversation": { - "id": "C2", - "type": ConversationTypeSlackName.MPIM.value, - "name": "mpim-2", - "is_channel": False, - "is_im": False, - "is_mpim": True, - "is_member": True, - }, - "members": { - "ok": True, - "members": ["U2", "U3"], - "response_metadata": {"next_cursor": None}, - }, - "users": [ - {"ok": True, "user": {"id": "U2", "team_id": "T123", "name": "user2"}}, - {"ok": True, "user": {"id": "U3", "team_id": "T123", "name": "user3"}}, - ], - }, - { - "conversation": { - "id": "C3", - "type": ConversationTypeSlackName.MPIM.value, - "name": "mpim-3", - "is_channel": False, - "is_im": False, - "is_mpim": True, - "is_member": True, - }, - "members": { - "ok": True, - "members": ["U1", "U2", "U3", "U4"], - "response_metadata": {"next_cursor": None}, - }, - "users": [ - {"ok": True, "user": {"id": "U1", "team_id": "T123", "name": "user1"}}, - {"ok": True, "user": {"id": "U2", "team_id": "T123", "name": "user2"}}, - {"ok": True, "user": {"id": "U3", "team_id": "T123", "name": "user3"}}, - {"ok": True, "user": {"id": "U4", "team_id": "T123", "name": "user4"}}, - ], - }, - { - "conversation": { - "id": "C4", - "type": ConversationTypeSlackName.PUBLIC_CHANNEL.value, - "name": "channel-4", - "is_channel": True, - "is_im": False, - "is_member": True, - }, - "members": { - "ok": True, - "members": ["U1", "U2"], - "response_metadata": {"next_cursor": None}, - }, - "users": [ - {"ok": True, "user": {"id": "U1", "team_id": "T123", "name": "user1"}}, - {"ok": True, "user": {"id": "U2", "team_id": "T123", "name": "user2"}}, - ], - }, - { - "conversation": { - "id": "C5", - "type": ConversationTypeSlackName.PUBLIC_CHANNEL.value, - "name": "channel-5", - "is_channel": True, - "is_im": False, - "is_member": True, - }, - "members": { - "ok": True, - "members": ["U2", "U3", "U4"], - "response_metadata": {"next_cursor": None}, - }, - "users": [ - {"ok": True, "user": {"id": "U2", "team_id": "T123", "name": "user2"}}, - {"ok": True, "user": {"id": "U3", "team_id": "T123", "name": "user3"}}, - {"ok": True, "user": {"id": "U4", "team_id": "T123", "name": "user4"}}, - ], - }, - { - "conversation": { - "id": "C6", - "type": ConversationTypeSlackName.PUBLIC_CHANNEL.value, - "name": "channel-6", - "is_channel": True, - "is_im": False, - "is_member": True, - }, - "members": { - "ok": True, - "members": ["U1", "U2", "U3", "U4"], - "response_metadata": {"next_cursor": None}, - }, - "users": [ - {"ok": True, "user": {"id": "U1", "team_id": "T123", "name": "user1"}}, - {"ok": True, "user": {"id": "U2", "team_id": "T123", "name": "user2"}}, - {"ok": True, "user": {"id": "U3", "team_id": "T123", "name": "user3"}}, - {"ok": True, "user": {"id": "U4", "team_id": "T123", "name": "user4"}}, - ], - }, + messages = [ + dummy_message_factory( + user_id=user1["id"], + text=f"Hello <@{user2['id']}>", + ), + dummy_message_factory( + user_id=user2["id"], + text="foobar", + reactions=copy.deepcopy(reactions[:1]), + ), + dummy_message_factory( + user_id=user3["id"], + text=f"Is this @{user5['id']} a user mention?", + ), + dummy_message_factory( + user_id=user4["id"], + text="hello", + reactions=copy.deepcopy(reactions), + ), ] - conversations_listed = [ - conversation - for conversation in conversations - if conversation["conversation"]["type"] in conversation_types_slack_name_str - ] - - split_size = len(conversations_listed) // 2 - - conversations_listed_1 = conversations_listed[:split_size] - conversations_listed_2 = conversations_listed[split_size:] - - mock_chat_slack_client.conversations_list.side_effect = [ + mock_user_retrieval_slack_client.users_list.side_effect = [ { "ok": True, - "channels": [conversation["conversation"] for conversation in conversations_listed_1], - "response_metadata": {"next_cursor": "cursor_1"}, + "members": [user1, user2, user3], + "response_metadata": {"next_cursor": "cursor1"}, }, { "ok": True, - "channels": [conversation["conversation"] for conversation in conversations_listed_2], + "members": [user4, user5], "response_metadata": {"next_cursor": None}, }, ] - mock_chat_slack_client.conversations_members.side_effect = [ - conversation["members"] for conversation in conversations_listed - ] - - mock_users_slack_client.users_info.side_effect = [ - user for conversation in conversations_listed for user in conversation["users"] - ] - - conversations_found = await retrieve_conversations_by_user_ids( - list_conversations_func=list_conversations_metadata, - get_members_in_conversation_func=get_members_in_conversation_by_id, - context=context, - conversation_types=conversation_types, - user_ids=search_user_ids, - exact_match=exact_match, - limit=limit, - next_cursor=None, + response = await populate_users_in_messages( + auth_token=mock_context.get_auth_token_or_empty(), + messages=messages, ) - assert [conversation["id"] for conversation in conversations_found] == expected_conversation_ids - assert mock_chat_slack_client.conversations_list.call_count == expected_conversations_list_calls + msg1 = response[0] + msg2 = response[1] + msg3 = response[2] + msg4 = response[3] + + assert msg1["user"] == user1_short + assert msg1["text"] == f"Hello {user2_mention}" + assert "reactions" not in msg1 + + assert msg2["user"] == user2_short + assert msg2["text"] == "foobar" + assert "reactions" in msg2 + assert len(msg2["reactions"]) == 1 + assert msg2["reactions"][0]["name"] == "thumbsup" + assert msg2["reactions"][0]["users"] == [user1_short, user2_short] + + assert msg3["user"] == user3_short + assert msg3["text"] == f"Is this @{user5['id']} a user mention?" + assert "reactions" not in msg3 + assert user5_mention not in msg3["text"] + + assert msg4["user"] == user4_short + assert msg4["text"] == "hello" + assert "reactions" in msg4 + assert len(msg4["reactions"]) == 2 + assert msg4["reactions"][0]["name"] == "thumbsup" + assert msg4["reactions"][0]["users"] == [user1_short, user2_short] + assert msg4["reactions"][1]["name"] == "partyparrot" + assert msg4["reactions"][1]["users"] == [user3_short, user4_short]