diff --git a/toolkits/google/arcade_google/tools/file_picker.py b/toolkits/google/arcade_google/tools/file_picker.py index b509e6d4..25a578b0 100644 --- a/toolkits/google/arcade_google/tools/file_picker.py +++ b/toolkits/google/arcade_google/tools/file_picker.py @@ -49,6 +49,6 @@ def generate_google_file_picker_url( "url": url, "llm_instructions": ( "Instruct the user to click the following link to open the Google Drive File Picker. " - "This will allow them to select files and grant access permissions: {url}" + f"This will allow them to select files and grant access permissions: {url}" ), } diff --git a/toolkits/slack/arcade_slack/constants.py b/toolkits/slack/arcade_slack/constants.py index 10c6176d..bed2bb03 100644 --- a/toolkits/slack/arcade_slack/constants.py +++ b/toolkits/slack/arcade_slack/constants.py @@ -1,13 +1,18 @@ import os -from arcade_slack.custom_types import PositiveInt +from arcade_slack.custom_types import PositiveNonZeroInt MAX_PAGINATION_SIZE_LIMIT = 200 -MAX_PAGINATION_TIMEOUT_SECONDS = PositiveInt( +MAX_PAGINATION_TIMEOUT_SECONDS = PositiveNonZeroInt( os.environ.get( "MAX_PAGINATION_TIMEOUT_SECONDS", os.environ.get("MAX_SLACK_PAGINATION_TIMEOUT_SECONDS", 30), ), name="MAX_PAGINATION_TIMEOUT_SECONDS or MAX_SLACK_PAGINATION_TIMEOUT_SECONDS", ) + +MAX_CONCURRENT_REQUESTS = PositiveNonZeroInt( + os.environ.get("SLACK_MAX_CONCURRENT_REQUESTS", 3), + name="SLACK_MAX_CONCURRENT_REQUESTS", +) diff --git a/toolkits/slack/arcade_slack/conversation_retrieval.py b/toolkits/slack/arcade_slack/conversation_retrieval.py new file mode 100644 index 00000000..1a32fa02 --- /dev/null +++ b/toolkits/slack/arcade_slack/conversation_retrieval.py @@ -0,0 +1,74 @@ +import json +from typing import cast + +from arcade_tdk.errors import RetryableToolError, ToolExecutionError +from slack_sdk.errors import SlackApiError +from slack_sdk.web.async_client import AsyncWebClient + +from arcade_slack.models import ( + ConversationType, + FindChannelByNameSentinel, +) +from arcade_slack.utils import ( + async_paginate, + extract_conversation_metadata, +) + + +async def get_conversation_by_id( + auth_token: str, + conversation_id: str, +) -> dict: + """Get metadata of a conversation in Slack by the conversation_id.""" + try: + slack_client = AsyncWebClient(token=auth_token) + response = await slack_client.conversations_info( + channel=conversation_id, + include_locale=True, + include_num_members=True, + ) + return dict(**extract_conversation_metadata(response["channel"])) + + except SlackApiError as e: + slack_error = cast(str, e.response.get("error", "")) + if "not_found" in slack_error.lower(): + message = f"Conversation with ID '{conversation_id}' not found." + raise ToolExecutionError(message=message, developer_message=message) + raise + + +async def get_channel_by_name( + auth_token: str, + channel_name: str, +) -> dict: + channel_name_casefolded = channel_name.lstrip("#").casefold() + + slack_client = AsyncWebClient(token=auth_token) + + results, _ = await async_paginate( + func=slack_client.conversations_list, + response_key="channels", + types=",".join([ + ConversationType.PUBLIC_CHANNEL.value, + ConversationType.PRIVATE_CHANNEL.value, + ]), + exclude_archived=True, + sentinel=FindChannelByNameSentinel(channel_name_casefolded), + ) + + available_channels = [] + + for channel in results: + if channel["name"].casefold() == channel_name_casefolded: + return dict(**extract_conversation_metadata(channel)) + else: + available_channels.append({"id": channel["id"], "name": channel["name"]}) + + error_message = f"Channel with name '{channel_name}' not found." + + raise RetryableToolError( + message=error_message, + developer_message=error_message, + additional_prompt_content=f"Available channels: {json.dumps(available_channels)}", + retry_after_ms=500, + ) diff --git a/toolkits/slack/arcade_slack/custom_types.py b/toolkits/slack/arcade_slack/custom_types.py index 9037bb18..1c7005a5 100644 --- a/toolkits/slack/arcade_slack/custom_types.py +++ b/toolkits/slack/arcade_slack/custom_types.py @@ -1,11 +1,11 @@ from typing import NewType -class PositiveInt(int): - def __new__(cls, value: str | int, name: str = "value") -> "PositiveInt": +class PositiveNonZeroInt(int): + def __new__(cls, value: str | int, name: str = "value") -> "PositiveNonZeroInt": def validate(val: int) -> int: - if val <= 0: - raise ValueError(f"{name} must be positive, got {val}") + if val < 1: + raise ValueError(f"{name} must be a positive non-zero integer, got {val}") return val try: diff --git a/toolkits/slack/arcade_slack/exceptions.py b/toolkits/slack/arcade_slack/exceptions.py index 8f4ff55a..6f7a6b86 100644 --- a/toolkits/slack/arcade_slack/exceptions.py +++ b/toolkits/slack/arcade_slack/exceptions.py @@ -8,23 +8,3 @@ class PaginationTimeoutError(SlackToolkitError): def __init__(self, timeout_seconds: int): self.timeout_seconds = timeout_seconds super().__init__(f"The pagination process timed out after {timeout_seconds} seconds.") - - -class ItemNotFoundError(SlackToolkitError): - """Raised when an item is not found.""" - - -class UsernameNotFoundError(SlackToolkitError): - """Raised when a user is not found by the username searched""" - - def __init__(self, usernames_found: list[str], username_not_found: str) -> None: - self.usernames_found = usernames_found - self.username_not_found = username_not_found - - -class ConversationNotFoundError(SlackToolkitError): - """Raised when a conversation is not found""" - - -class DirectMessageConversationNotFoundError(ConversationNotFoundError): - """Raised when a direct message conversation searched is not found""" diff --git a/toolkits/slack/arcade_slack/message_retrieval.py b/toolkits/slack/arcade_slack/message_retrieval.py new file mode 100644 index 00000000..c0b87764 --- /dev/null +++ b/toolkits/slack/arcade_slack/message_retrieval.py @@ -0,0 +1,76 @@ +from datetime import datetime, timezone +from typing import Any + +from arcade_tdk.errors import ToolExecutionError +from slack_sdk.web.async_client import AsyncWebClient + +from arcade_slack.utils import ( + async_paginate, + convert_datetime_to_unix_timestamp, + convert_relative_datetime_to_unix_timestamp, + enrich_message_datetime, +) + + +async def retrieve_messages_in_conversation( + conversation_id: str, + auth_token: str | None = None, + oldest_relative: str | None = None, + latest_relative: str | None = None, + oldest_datetime: str | None = None, + latest_datetime: str | None = None, + limit: int | None = None, + next_cursor: str | None = None, +) -> dict: + error_message = None + if oldest_datetime and oldest_relative: + error_message = "Cannot specify both 'oldest_datetime' and 'oldest_relative'." + + if latest_datetime and latest_relative: + error_message = "Cannot specify both 'latest_datetime' and 'latest_relative'." + + if error_message: + raise ToolExecutionError(error_message, developer_message=error_message) + + current_unix_timestamp = int(datetime.now(timezone.utc).timestamp()) + + if latest_relative: + latest_timestamp = convert_relative_datetime_to_unix_timestamp( + latest_relative, current_unix_timestamp + ) + elif latest_datetime: + latest_timestamp = convert_datetime_to_unix_timestamp(latest_datetime) + else: + latest_timestamp = None + + if oldest_relative: + oldest_timestamp = convert_relative_datetime_to_unix_timestamp( + oldest_relative, current_unix_timestamp + ) + elif oldest_datetime: + oldest_timestamp = convert_datetime_to_unix_timestamp(oldest_datetime) + else: + oldest_timestamp = None + + datetime_args: dict[str, Any] = {} + if oldest_timestamp: + datetime_args["oldest"] = oldest_timestamp + if latest_timestamp: + datetime_args["latest"] = latest_timestamp + + slackClient = AsyncWebClient(token=auth_token) + + response, next_cursor = await async_paginate( + slackClient.conversations_history, + "messages", + limit=limit, + next_cursor=next_cursor, + channel=conversation_id, + include_all_metadata=True, + inclusive=True, # Include messages at the start and end of the time range + **datetime_args, + ) + + messages = [enrich_message_datetime(message) for message in response] + + return {"messages": messages, "next_cursor": next_cursor} diff --git a/toolkits/slack/arcade_slack/models.py b/toolkits/slack/arcade_slack/models.py index bfcb09cb..180be431 100644 --- a/toolkits/slack/arcade_slack/models.py +++ b/toolkits/slack/arcade_slack/models.py @@ -1,5 +1,12 @@ +import asyncio +from abc import ABC, abstractmethod +from collections.abc import Awaitable, Callable +from contextlib import suppress from enum import Enum -from typing import Literal, TypedDict +from typing import Any, Literal, TypedDict + +from arcade_tdk.errors import ToolExecutionError +from slack_sdk.errors import SlackApiError from arcade_slack.custom_types import ( SlackOffsetSecondsFromUTC, @@ -24,6 +31,16 @@ class ConversationType(str, Enum): MULTI_PERSON_DIRECT_MESSAGE = "multi_person_direct_message" DIRECT_MESSAGE = "direct_message" + def to_slack_name_str(self) -> str: + mapping = { + ConversationType.PUBLIC_CHANNEL: ConversationTypeSlackName.PUBLIC_CHANNEL.value, + ConversationType.PRIVATE_CHANNEL: ConversationTypeSlackName.PRIVATE_CHANNEL.value, + ConversationType.MULTI_PERSON_DIRECT_MESSAGE: ConversationTypeSlackName.MPIM.value, + ConversationType.DIRECT_MESSAGE: ConversationTypeSlackName.IM.value, + } + + return mapping[self] + """ About Slack dictionaries: Slack does not guarantee the presence of all fields for a given @@ -204,3 +221,150 @@ class SlackConversationsToolResponse(TypedDict, total=True): conversations: list[ConversationMetadata] next_cursor: SlackPaginationNextCursor | None + + +class PaginationSentinel(ABC): + """Base class for pagination sentinel classes.""" + + def __init__(self, **kwargs: Any) -> None: + self.kwargs = kwargs + + @abstractmethod + def __call__(self, last_result: Any) -> bool: + """Determine if the pagination should stop.""" + raise NotImplementedError + + +class FindUserByUsernameSentinel(PaginationSentinel): + """Sentinel class for finding a user by username.""" + + def __call__(self, last_result: Any) -> bool: + for user in last_result: + if not isinstance(user.get("name"), str): + continue + if user.get("name").casefold() == self.kwargs["username"].casefold(): + return True + return False + + +class FindMultipleUsersByUsernameSentinel(PaginationSentinel): + """Sentinel class for finding multiple users by username.""" + + def __init__(self, usernames: list[str]) -> None: + if not usernames: + raise ValueError("usernames must be a non-empty list of strings") + super().__init__(usernames=usernames) + self.usernames_pending = {username.casefold() for username in usernames} + + def _flag_username_found(self, username: str) -> None: + with suppress(KeyError): + self.usernames_pending.remove(username.casefold()) + + def _all_usernames_found(self) -> bool: + return not self.usernames_pending + + def __call__(self, last_result: Any) -> bool: + if not self.usernames_pending: + return True + for user in last_result: + username = user.get("name") + if not isinstance(username, str): + continue + if username.casefold() in self.usernames_pending: + self._flag_username_found(username) + if self._all_usernames_found(): + return True + return False + + +class FindMultipleUsersByIdSentinel(PaginationSentinel): + """Sentinel class for finding multiple users by ID.""" + + def __init__(self, user_ids: list[str]) -> None: + if not user_ids: + raise ValueError("user_ids must be a non-empty list of strings") + super().__init__(user_ids=user_ids) + self.user_ids_pending = set(user_ids) + + def _flag_user_id_found(self, user_id: str) -> None: + with suppress(KeyError): + self.user_ids_pending.remove(user_id.casefold()) + + def _all_user_ids_found(self) -> bool: + return not self.user_ids_pending + + def __call__(self, last_result: Any) -> bool: + if not self.user_ids_pending: + return True + for user in last_result: + user_id = user.get("id") + if user_id in self.user_ids_pending: + self._flag_user_id_found(user_id) + if self._all_user_ids_found(): + return True + return False + + +class FindChannelByNameSentinel(PaginationSentinel): + """Sentinel class for finding a channel by name.""" + + def __init__(self, channel_name: str) -> None: + super().__init__(channel_name=channel_name) + self.channel_name_casefold = channel_name.casefold() + + def __call__(self, last_result: Any) -> bool: + for channel in last_result: + channel_name = channel.get("name") + if not isinstance(channel_name, str): + continue + if channel_name.casefold() == self.channel_name_casefold: + return True + return False + + +class AbstractConcurrencySafeCoroutineCaller(ABC): + """Abstract base class for concurrency-safe coroutine callers.""" + + def __init__(self, func: Callable[..., Awaitable[Any]], *args: Any, **kwargs: Any) -> None: + self.func = func + self.args = args + self.kwargs = kwargs + + @abstractmethod + async def __call__(self, semaphore: asyncio.Semaphore) -> Any: + """Call a coroutine with a semaphore.""" + raise NotImplementedError + + +class ConcurrencySafeCoroutineCaller(AbstractConcurrencySafeCoroutineCaller): + """Calls a coroutine with an asyncio semaphore.""" + + async def __call__(self, semaphore: asyncio.Semaphore) -> Any: + async with semaphore: + return await self.func(*self.args, **self.kwargs) + + +class GetUserByEmailCaller(AbstractConcurrencySafeCoroutineCaller): + """Call Slack's lookupByEmail method with an asyncio semaphore while handling API errors.""" + + def __init__( + self, + func: Callable[..., Awaitable[Any]], + email: str, + ) -> None: + super().__init__(func) + self.email = email + + async def __call__(self, semaphore: asyncio.Semaphore) -> dict[str, Any]: + async with semaphore: + try: + user = await self.func(email=self.email) + return {"user": user["user"], "email": self.email} + except SlackApiError as e: + if e.response.get("error") in ["user_not_found", "users_not_found"]: + return {"user": None, "email": self.email} + else: + raise ToolExecutionError( + message="Error getting user by email", + developer_message=f"Error getting user by email: {e.response.get('error')}", + ) diff --git a/toolkits/slack/arcade_slack/tools/chat.py b/toolkits/slack/arcade_slack/tools/chat.py index 61462497..86f573ca 100644 --- a/toolkits/slack/arcade_slack/tools/chat.py +++ b/toolkits/slack/arcade_slack/tools/chat.py @@ -1,137 +1,93 @@ -import asyncio -from datetime import datetime, timezone from typing import Annotated, cast from arcade_tdk import ToolContext, tool from arcade_tdk.auth import Slack -from arcade_tdk.errors import RetryableToolError, ToolExecutionError +from arcade_tdk.errors import ToolExecutionError from slack_sdk.errors import SlackApiError from slack_sdk.web.async_client import AsyncWebClient -from arcade_slack.constants import MAX_PAGINATION_TIMEOUT_SECONDS -from arcade_slack.exceptions import ( - ItemNotFoundError, - UsernameNotFoundError, +from arcade_slack.constants import MAX_PAGINATION_SIZE_LIMIT +from arcade_slack.conversation_retrieval import ( + get_channel_by_name, + get_conversation_by_id, ) +from arcade_slack.message_retrieval import retrieve_messages_in_conversation from arcade_slack.models import ( ConversationType, - SlackUserList, ) -from arcade_slack.tools.users import get_user_info_by_id, list_users +from arcade_slack.user_retrieval import ( + get_users_by_id, + get_users_by_id_username_or_email, +) from arcade_slack.utils import ( async_paginate, - convert_conversation_type_to_slack_name, - convert_datetime_to_unix_timestamp, - convert_relative_datetime_to_unix_timestamp, - enrich_message_datetime, extract_conversation_metadata, - format_users, - get_user_by_username, - retrieve_conversations_by_user_ids, + populate_users_in_messages, + raise_for_users_not_found, ) @tool( requires_auth=Slack( scopes=[ - "chat:write", - "im:write", - "users.profile:read", - "users:read", - ], - ) -) -async def send_dm_to_user( - context: ToolContext, - user_name: Annotated[ - str, - ( - "The Slack username of the person you want to message. " - "Slack usernames are ALWAYS lowercase." - ), - ], - message: Annotated[str, "The message you want to send"], -) -> Annotated[dict, "The response from the Slack API"]: - """Send a direct message to a user in Slack.""" - - token = ( - context.authorization.token if context.authorization and context.authorization.token else "" - ) - slackClient = AsyncWebClient(token=token) - - try: - # Step 1: Retrieve the user's Slack ID based on their username - user_list_response = await slackClient.users_list() - user_id = None - for user in user_list_response["members"]: - response_user_name = ( - "" if not isinstance(user.get("name"), str) else user["name"].lower() - ) - if response_user_name == user_name.lower(): - user_id = user["id"] - break - - if not user_id: - raise RetryableToolError( - "User not found", - developer_message=f"User with username '{user_name}' not found.", - additional_prompt_content=format_users(cast(SlackUserList, user_list_response)), - retry_after_ms=500, # Play nice with Slack API rate limits - ) - - # Step 2: Retrieve the DM channel ID with the user - im_response = await slackClient.conversations_open(users=[user_id]) - dm_channel_id = im_response["channel"]["id"] - - # Step 3: Send the message as if it's from you (because we're using a user token) - response = await slackClient.chat_postMessage(channel=dm_channel_id, text=message) - - except SlackApiError as e: - error_message = e.response["error"] if "error" in e.response else str(e) - raise ToolExecutionError( - "Error sending message", - developer_message=f"Slack API Error: {error_message}", - ) - else: - return {"response": response.data} - - -@tool( - requires_auth=Slack( - scopes=[ - "chat:write", "channels:read", "groups:read", + "mpim:read", + "im:read", + "users:read", + "users:read.email", + "chat:write", + "mpim:write", + "im:write", ], ) ) -async def send_message_to_channel( +async def send_message( context: ToolContext, - channel_name: Annotated[str, "The Slack channel name where you want to send the message. "], - message: Annotated[str, "The message you want to send"], + message: Annotated[str, "The content of the message to send."], + channel_name: Annotated[ + str | None, + "The channel name to send the message to. Prefer providing a conversation_id, " + "when available, since the performance is better.", + ] = None, + conversation_id: Annotated[str | None, "The conversation ID to send the message to."] = None, + user_ids: Annotated[list[str] | None, "The Slack user IDs of the people to message."] = None, + emails: Annotated[list[str] | None, "The emails of the people to message."] = None, + usernames: Annotated[ + list[str] | None, + "The Slack usernames of the people to message. Prefer providing user_ids and/or emails, " + "when available, since the performance is better.", + ] = None, ) -> Annotated[dict, "The response from the Slack API"]: - """Send a message to a channel in Slack.""" + """Send a message to a Channel, Direct Message (IM/DM), or Multi-Person (MPIM) conversation - try: - slackClient = AsyncWebClient( - token=context.authorization.token - if context.authorization and context.authorization.token - else "" - ) + Provide exactly one of: + - channel_name; or + - conversation_id; or + - any combination of user_ids, usernames, and/or emails. - channel = await get_channel_metadata_by_name(context=context, channel_name=channel_name) - channel_id = channel["id"] - - response = await slackClient.chat_postMessage(channel=channel_id, text=message) - - except SlackApiError as e: - error_message = e.response["error"] if "error" in e.response else str(e) + In case multiple user_ids, usernames, and/or emails are provided, the tool will open a + multi-person conversation with the specified people and send the message to it. + """ + if conversation_id and any([channel_name, user_ids, usernames, emails]): raise ToolExecutionError( - "Error sending message", - developer_message=f"Slack API Error: {error_message}", + "Provide exactly one of: channel_name, OR conversation_id, OR any combination of " + "user_ids, usernames, and/or emails." ) - else: - return {"response": response.data} + + if not conversation_id: + conversation = await get_conversation_metadata( + context=context, + channel_name=channel_name, + user_ids=user_ids, + usernames=usernames, + emails=emails, + ) + conversation_id = conversation["id"] + + slack_client = AsyncWebClient(token=context.get_auth_token_or_empty()) + response = await slack_client.chat_postMessage(channel=cast(str, conversation_id), text=message) + return {"success": True, "data": response.data} @tool( @@ -146,50 +102,51 @@ async def send_message_to_channel( ], ) ) -async def get_members_in_conversation_by_id( +async def get_users_in_conversation( context: ToolContext, - conversation_id: Annotated[str, "The ID of the conversation to get members for"], - limit: Annotated[int | None, "The maximum number of members to return."] = None, + conversation_id: Annotated[str | None, "The ID of the conversation to get users in."] = None, + channel_name: Annotated[ + str | None, + "The name of the channel to get users in. Prefer providing a conversation_id, " + "when available, since the performance is better.", + ] = None, + # The user object is relatively small, so we allow a higher limit. + limit: Annotated[ + int, "The maximum number of users to return. Defaults to 200. Maximum is 500." + ] = 200, next_cursor: Annotated[str | None, "The cursor to use for pagination."] = None, -) -> Annotated[dict, "Information about each member in the conversation"]: - """Get the members of a conversation in Slack by the conversation's ID.""" - token = ( - context.authorization.token if context.authorization and context.authorization.token else "" +) -> Annotated[dict, "Information about each user in the conversation"]: + """Get the users in a Slack conversation (Channel, DM/IM, or MPIM) by its ID or by channel name. + + Provide exactly one of conversation_id or channel_name. Prefer providing a conversation_id, + when available, since the performance is better. + """ + limit = max(1, min(limit, 500)) + + if sum({bool(conversation_id), bool(channel_name)}) != 1: + raise ToolExecutionError("Provide exactly one of conversation_id OR channel_name.") + + auth_token = context.get_auth_token_or_empty() + + if not conversation_id: + channel = await get_channel_by_name(auth_token, cast(str, channel_name)) + conversation_id = channel["id"] + + slack_client = AsyncWebClient(token=auth_token) + user_ids, next_cursor = await async_paginate( + func=slack_client.conversations_members, + response_key="members", + limit=limit, + next_cursor=next_cursor, + channel=conversation_id, ) - slackClient = AsyncWebClient(token=token) - try: - member_ids, next_cursor = await async_paginate( - slackClient.conversations_members, - "members", - limit=limit, - next_cursor=next_cursor, - channel=conversation_id, - ) - except SlackApiError as e: - if e.response["error"] == "channel_not_found": - conversations = await list_conversations_metadata(context) - available_conversations = ", ".join( - f"{conversation['id']} ({conversation['name']})" - for conversation in conversations["conversations"] - ) + response = await get_users_by_id(auth_token, user_ids) - raise RetryableToolError( - "Conversation not found", - developer_message=f"Conversation with ID '{conversation_id}' not found.", - additional_prompt_content=f"Available conversations: {available_conversations}", - retry_after_ms=500, - ) - - # Get the members' info - # TODO: This will probably hit rate limits. We should probably call list_users() and - # then filter the results instead. - members = await asyncio.gather(*[ - get_user_info_by_id(context, member_id) for member_id in member_ids - ]) + await raise_for_users_not_found(context, [response]) return { - "members": [member for member in members if not member.get("is_bot")], + "users": [user for user in response["users"] if not user.get("is_bot")], "next_cursor": next_cursor, } @@ -199,41 +156,41 @@ async def get_members_in_conversation_by_id( scopes=[ "channels:read", "groups:read", - "im:read", "mpim:read", + "im:read", "users:read", "users:read.email", - ], + "channels:history", + "groups:history", + "mpim:history", + "im:history", + ] ) ) -async def get_members_in_channel_by_name( +async def get_messages( context: ToolContext, - channel_name: Annotated[str, "The name of the channel to get members for"], - limit: Annotated[int | None, "The maximum number of members to return."] = None, - next_cursor: Annotated[str | None, "The cursor to use for pagination."] = None, -) -> Annotated[dict, "The channel members' IDs and Names"]: - """Get the members of a conversation in Slack by the conversation's name.""" - channel = await get_channel_metadata_by_name(context=context, channel_name=channel_name) - - return await get_members_in_conversation_by_id( # type: ignore[no-any-return] - context=context, - conversation_id=channel["id"], - limit=limit, - next_cursor=next_cursor, - ) - - -# TODO: make the function accept a current unix timestamp argument to allow testing without -# mocking. Have to wait until arcade.core.annotations.Inferrable is implemented, so that we -# can avoid exposing this arg to the LLM. -@tool( - requires_auth=Slack( - scopes=["channels:history", "groups:history", "im:history", "mpim:history"], - ) -) -async def get_messages_in_conversation_by_id( - context: ToolContext, - conversation_id: Annotated[str, "The ID of the conversation to get history for"], + conversation_id: Annotated[ + str | None, + "The ID of the conversation to get messages from. Provide exactly one of conversation_id " + "OR any combination of user_ids, usernames, and/or emails.", + ] = None, + channel_name: Annotated[ + str | None, + "The name of the channel to get messages from. Prefer providing a conversation_id, " + "when available, since the performance is better.", + ] = None, + user_ids: Annotated[ + list[str] | None, "The IDs of the users in the conversation to get messages from." + ] = None, + usernames: Annotated[ + list[str] | None, + "The usernames of the users in the conversation to get messages from. Prefer providing" + "user_ids and/or emails, when available, since the performance is better.", + ] = None, + emails: Annotated[ + list[str] | None, + "The emails of the users in the conversation to get messages from.", + ] = None, oldest_relative: Annotated[ str | None, ( @@ -262,93 +219,365 @@ async def get_messages_in_conversation_by_id( "format 'YYYY-MM-DD HH:MM:SS'" ), ] = None, - limit: Annotated[int | None, "The maximum number of messages to return."] = None, + limit: Annotated[ + # The message object can be relatively large, so we limit maximum to 100 + # to preserve LLM's context window and reduce the likelihood of hallucinations. + int, "The maximum number of messages to return. Defaults to 20. Maximum is 100." + ] = 20, next_cursor: Annotated[str | None, "The cursor to use for pagination."] = None, ) -> Annotated[ dict, - ( - "The messages in a conversation and next cursor for paginating results (when " - "there are additional messages to retrieve)." - ), + "The messages in a Slack Channel, DM (direct message) or MPIM (multi-person) conversation.", ]: - """Get the messages in a conversation by the conversation's ID. + """Get messages in a Slack Channel, DM (direct message) or MPIM (multi-person) conversation. - A conversation can be a channel, a DM, or a group DM. + Provide exactly one of: + - conversation_id; or + - channel_name; or + - any combination of user_ids, usernames, and/or emails. - To filter by an absolute datetime, use 'oldest_datetime' and/or 'latest_datetime'. If - only 'oldest_datetime' is provided, it returns messages from the oldest_datetime to the - current time. If only 'latest_datetime' is provided, it returns messages since the + To filter messages by an absolute datetime, use 'oldest_datetime' and/or 'latest_datetime'. If + only 'oldest_datetime' is provided, it will return messages from the oldest_datetime to the + current time. If only 'latest_datetime' is provided, it will return messages since the beginning of the conversation to the latest_datetime. - To filter by a relative datetime (e.g. 3 days ago, 1 hour ago, etc.), use - 'oldest_relative' and/or 'latest_relative'. If only 'oldest_relative' is provided, it returns - messages from the oldest_relative to the current time. If only 'latest_relative' is provided, - it returns messages from the current time to the latest_relative. + To filter messages by a relative datetime (e.g. 3 days ago, 1 hour ago, etc.), use + 'oldest_relative' and/or 'latest_relative'. If only 'oldest_relative' is provided, it will + return messages from the oldest_relative to the current time. If only 'latest_relative' is + provided, it will return messages from the current time to the latest_relative. Do not provide both 'oldest_datetime' and 'oldest_relative' or both 'latest_datetime' and 'latest_relative'. Leave all arguments with the default None to get messages without date/time filtering""" - error_message = None - if oldest_datetime and oldest_relative: - error_message = "Cannot specify both 'oldest_datetime' and 'oldest_relative'." + limit = max(1, min(limit, 100)) - if latest_datetime and latest_relative: - error_message = "Cannot specify both 'latest_datetime' and 'latest_relative'." - - if error_message: - raise ToolExecutionError(error_message, developer_message=error_message) - - current_unix_timestamp = int(datetime.now(timezone.utc).timestamp()) - - if latest_relative: - latest_timestamp = convert_relative_datetime_to_unix_timestamp( - latest_relative, current_unix_timestamp + if not conversation_id: + conversation = await get_conversation_metadata( + context=context, + channel_name=channel_name, + user_ids=user_ids, + usernames=usernames, + emails=emails, ) - elif latest_datetime: - latest_timestamp = convert_datetime_to_unix_timestamp(latest_datetime) - else: - latest_timestamp = None + conversation_id = conversation["id"] - if oldest_relative: - oldest_timestamp = convert_relative_datetime_to_unix_timestamp( - oldest_relative, current_unix_timestamp - ) - elif oldest_datetime: - oldest_timestamp = convert_datetime_to_unix_timestamp(oldest_datetime) - else: - oldest_timestamp = None - - token = ( - context.authorization.token if context.authorization and context.authorization.token else "" - ) - slackClient = AsyncWebClient(token=token) - - datetime_args = {} - if oldest_timestamp: - datetime_args["oldest"] = oldest_timestamp - if latest_timestamp: - datetime_args["latest"] = latest_timestamp - - response, next_cursor = await async_paginate( - slackClient.conversations_history, - "messages", + response = await retrieve_messages_in_conversation( + auth_token=context.get_auth_token_or_empty(), + conversation_id=cast(str, conversation_id), + oldest_relative=oldest_relative, + latest_relative=latest_relative, + oldest_datetime=oldest_datetime, + latest_datetime=latest_datetime, limit=limit, next_cursor=next_cursor, - channel=conversation_id, - include_all_metadata=True, - inclusive=True, # Include messages at the start and end of the time range - **datetime_args, ) - messages = [enrich_message_datetime(message) for message in response] + response["messages"] = await populate_users_in_messages( + auth_token=context.get_auth_token_or_empty(), + messages=response["messages"], + ) - return {"messages": messages, "next_cursor": next_cursor} + return cast(dict, response) + + +@tool( + requires_auth=Slack( + scopes=[ + "channels:read", + "groups:read", + "mpim:read", + "im:read", + "users:read", + "users:read.email", + ], + ) +) +async def get_conversation_metadata( + context: ToolContext, + conversation_id: Annotated[str | None, "The ID of the conversation to get metadata for"] = None, + channel_name: Annotated[ + str | None, + "The name of the channel to get metadata for. Prefer providing a conversation_id, " + "when available, since the performance is better.", + ] = None, + usernames: Annotated[ + list[str] | None, + "The usernames of the users to get the conversation metadata. " + "Prefer providing user_ids and/or emails, when available, since the performance is better.", + ] = None, + emails: Annotated[ + list[str] | None, + "The emails of the users to get the conversation metadata.", + ] = None, + user_ids: Annotated[ + list[str] | None, + "The IDs of the users to get the conversation metadata.", + ] = None, +) -> Annotated[ + dict | None, + "The conversation metadata.", +]: + """Get metadata of a Channel, a Direct Message (IM / DM) or a Multi-Person (MPIM) conversation. + + Use this tool to retrieve metadata about a conversation with a conversation_id, a channel name, + or by the user_id(s), username(s), and/or email(s) of the user(s) in the conversation. + + This tool does not return the messages in a conversation. To get the messages, use the + 'Slack.GetMessages' tool instead. + + Provide exactly one of: + - conversation_id; or + - channel_name; or + - any combination of user_ids, usernames, and/or emails. + """ + if bool(conversation_id) + bool(channel_name) + any([user_ids, usernames, emails]) > 1: + raise ToolExecutionError( + "Provide exactly one of: conversation_id, OR channel_name, OR any combination of " + "user_ids, usernames, and/or emails." + ) + + auth_token = context.get_auth_token_or_empty() + + if conversation_id: + return await get_conversation_by_id(auth_token, conversation_id) + + elif channel_name: + return await get_channel_by_name(auth_token, channel_name) + + user_ids_list = user_ids if isinstance(user_ids, list) else [] + + slack_client = AsyncWebClient(token=auth_token) + + try: + current_user = await slack_client.auth_test() + except SlackApiError as e: + message = "Failed to get currently authenticated user's info." + developer_message = f"{message} Slack error: '{e.response.get('error', 'unknown_error')}'" + raise ToolExecutionError(message, developer_message) + + if current_user["user_id"] not in user_ids_list: + user_ids_list.append(current_user["user_id"]) + + if usernames or emails: + other_users = await get_users_by_id_username_or_email( + context=context, + usernames=usernames, + emails=emails, + ) + user_ids_list.extend([user["id"] for user in other_users]) + + try: + response = await slack_client.conversations_open(users=user_ids_list, return_im=True) + return dict(**extract_conversation_metadata(response["channel"])) + except SlackApiError as e: + message = "Failed to retrieve conversation metadata." + slack_error = e.response.get("error", "unknown_error") + raise ToolExecutionError( + message=message, + developer_message=f"{message} Slack error: '{slack_error}'", + ) + + +@tool( + requires_auth=Slack( + scopes=["channels:read", "groups:read", "im:read", "mpim:read"], + ) +) +async def list_conversations( + context: ToolContext, + conversation_types: Annotated[ + list[ConversationType] | None, + "Optionally filter by the type(s) of conversations. Defaults to None (all types).", + ] = None, + # The conversation object is relatively small, so we allow a higher limit. + limit: Annotated[ + int, + f"The maximum number of conversations to list. Defaults to {MAX_PAGINATION_SIZE_LIMIT}. " + "Maximum is 500.", + ] = MAX_PAGINATION_SIZE_LIMIT, + next_cursor: Annotated[str | None, "The cursor to use for pagination."] = None, +) -> Annotated[dict, "The list of conversations found with metadata"]: + """List metadata for Slack conversations (channels, DMs, MPIMs) the user is a member of. + + This tool does not return the messages in a conversation. To get the messages, use the + 'Slack.GetMessages' tool instead. Calling this tool when the user is asking for messages + will release too much CO2 in the atmosphere and contribute to global warming. + """ + limit = max(1, min(limit, 500)) + + if conversation_types: + conversation_types_filter = ",".join( + conversation_type.to_slack_name_str() for conversation_type in conversation_types + ) + else: + conversation_types_filter = None + + slack_client = AsyncWebClient(token=context.get_auth_token_or_empty()) + + results, next_cursor = await async_paginate( + slack_client.conversations_list, + "channels", + limit=limit, + next_cursor=next_cursor, + types=conversation_types_filter, + exclude_archived=True, + ) + + return { + "conversations": [ + dict(**extract_conversation_metadata(conversation)) + for conversation in results + if conversation.get("is_im") or conversation.get("is_member") + ], + "next_cursor": next_cursor, + } + + +################################################################################## +# NOTE: The tools below are kept here for backwards compatibility. Prefer using: # +# - send_message +# - get_messages +# - get_conversation_metadata +# - get_users_in_conversation +# - list_conversations +################################################################################## + + +@tool( + requires_auth=Slack( + scopes=[ + "channels:read", + "groups:read", + "mpim:read", + "im:read", + "users:read", + "users:read.email", + "chat:write", + "mpim:write", + "im:write", + ], + ) +) +async def send_dm_to_user( + context: ToolContext, + user_name: Annotated[ + str, + ( + "The Slack username of the person you want to message. " + "Slack usernames are ALWAYS lowercase." + ), + ], + message: Annotated[str, "The message you want to send"], +) -> Annotated[dict, "The response from the Slack API"]: + """Send a direct message to a user in Slack. + + This tool is deprecated. Use `Slack.SendMessage` instead. + """ + return await send_message( # type: ignore[no-any-return] + context=context, + usernames=[user_name], + message=message, + ) + + +@tool( + requires_auth=Slack( + scopes=[ + "channels:read", + "groups:read", + "mpim:read", + "im:read", + "users:read", + "users:read.email", + "chat:write", + "mpim:write", + "im:write", + ], + ) +) +async def send_message_to_channel( + context: ToolContext, + channel_name: Annotated[str, "The Slack channel name where you want to send the message. "], + message: Annotated[str, "The message you want to send"], +) -> Annotated[dict, "The response from the Slack API"]: + """Send a message to a channel in Slack. + + This tool is deprecated. Use `Slack.SendMessage` instead. + """ + return await send_message( # type: ignore[no-any-return] + context=context, + channel_name=channel_name, + message=message, + ) + + +@tool( + requires_auth=Slack( + scopes=[ + "channels:read", + "groups:read", + "im:read", + "mpim:read", + "users:read", + "users:read.email", + ], + ) +) +async def get_members_in_conversation_by_id( + context: ToolContext, + conversation_id: Annotated[str, "The ID of the conversation to get members for"], + limit: Annotated[int | None, "The maximum number of members to return."] = None, + next_cursor: Annotated[str | None, "The cursor to use for pagination."] = None, +) -> Annotated[dict, "Information about each member in the conversation"]: + """Get the members of a conversation in Slack by the conversation's ID. + + This tool is deprecated. Use the `Slack.GetUsersInConversation` tool instead. + """ + response = await get_users_in_conversation( + context=context, + conversation_id=conversation_id, + limit=limit, + next_cursor=next_cursor, + ) + response["members"] = response["users"] + del response["users"] + return cast(dict, response) + + +@tool( + requires_auth=Slack( + scopes=[ + "channels:read", + "groups:read", + "im:read", + "mpim:read", + "users:read", + "users:read.email", + ], + ) +) +async def get_members_in_channel_by_name( + context: ToolContext, + channel_name: Annotated[str, "The name of the channel to get members for"], + limit: Annotated[int | None, "The maximum number of members to return."] = None, + next_cursor: Annotated[str | None, "The cursor to use for pagination."] = None, +) -> Annotated[dict, "The channel members' IDs and Names"]: + """Get the members of a conversation in Slack by the conversation's name. + + This tool is deprecated. Use the `Slack.GetUsersInConversation` tool instead. + """ + response = await get_users_in_conversation( + context=context, + channel_name=channel_name, + limit=limit, + next_cursor=next_cursor, + ) + response["members"] = response["users"] + del response["users"] + return cast(dict, response) -# TODO: make the function accept a current unix timestamp argument to allow testing without -# mocking. Have to wait until arcade.core.annotations.Inferrable is implemented, so that we -# can avoid exposing this arg to the LLM. @tool( requires_auth=Slack( scopes=[ @@ -405,6 +634,8 @@ async def get_messages_in_channel_by_name( ]: """Get the messages in a channel by the channel's name. + This tool is deprecated. Use the `Slack.GetMessages` tool instead. + To filter messages by an absolute datetime, use 'oldest_datetime' and/or 'latest_datetime'. If only 'oldest_datetime' is provided, it will return messages from the oldest_datetime to the current time. If only 'latest_datetime' is provided, it will return messages since the @@ -419,11 +650,9 @@ async def get_messages_in_channel_by_name( 'latest_relative'. Leave all arguments with the default None to get messages without date/time filtering""" - channel = await get_channel_metadata_by_name(context=context, channel_name=channel_name) - - return await get_messages_in_conversation_by_id( # type: ignore[no-any-return] + return await get_messages( # type: ignore[no-any-return] context=context, - conversation_id=channel["id"], + channel_name=channel_name, oldest_relative=oldest_relative, latest_relative=latest_relative, oldest_datetime=oldest_datetime, @@ -433,7 +662,84 @@ async def get_messages_in_channel_by_name( ) -@tool(requires_auth=Slack(scopes=["im:history", "im:read"])) +@tool( + requires_auth=Slack( + scopes=["channels:history", "groups:history", "im:history", "mpim:history"], + ) +) +async def get_messages_in_conversation_by_id( + context: ToolContext, + conversation_id: Annotated[str, "The ID of the conversation to get history for"], + oldest_relative: Annotated[ + str | None, + ( + "The oldest message to include in the results, specified as a time offset from the " + "current time in the format 'DD:HH:MM'" + ), + ] = None, + latest_relative: Annotated[ + str | None, + ( + "The latest message to include in the results, specified as a time offset from the " + "current time in the format 'DD:HH:MM'" + ), + ] = None, + oldest_datetime: Annotated[ + str | None, + ( + "The oldest message to include in the results, specified as a datetime object in the " + "format 'YYYY-MM-DD HH:MM:SS'" + ), + ] = None, + latest_datetime: Annotated[ + str | None, + ( + "The latest message to include in the results, specified as a datetime object in the " + "format 'YYYY-MM-DD HH:MM:SS'" + ), + ] = None, + limit: Annotated[int | None, "The maximum number of messages to return."] = None, + next_cursor: Annotated[str | None, "The cursor to use for pagination."] = None, +) -> Annotated[ + dict, + ( + "The messages in a conversation and next cursor for paginating results (when " + "there are additional messages to retrieve)." + ), +]: + """Get the messages in a conversation by the conversation's ID. + + This tool is deprecated. Use the 'Slack.GetMessages' tool instead. + + A conversation can be a channel, a DM, or a group DM. + + To filter by an absolute datetime, use 'oldest_datetime' and/or 'latest_datetime'. If + only 'oldest_datetime' is provided, it returns messages from the oldest_datetime to the + current time. If only 'latest_datetime' is provided, it returns messages since the + beginning of the conversation to the latest_datetime. + + To filter by a relative datetime (e.g. 3 days ago, 1 hour ago, etc.), use + 'oldest_relative' and/or 'latest_relative'. If only 'oldest_relative' is provided, it returns + messages from the oldest_relative to the current time. If only 'latest_relative' is provided, + it returns messages from the current time to the latest_relative. + + Do not provide both 'oldest_datetime' and 'oldest_relative' or both 'latest_datetime' and + 'latest_relative'. + + Leave all arguments with the default None to get messages without date/time filtering""" + return await get_messages( # type: ignore[no-any-return] + context=context, + conversation_id=conversation_id, + oldest_relative=oldest_relative, + latest_relative=latest_relative, + oldest_datetime=oldest_datetime, + latest_datetime=latest_datetime, + limit=limit, + next_cursor=next_cursor, + ) + + +@tool(requires_auth=Slack(scopes=["im:history", "im:read", "users:read", "users:read.email"])) async def get_messages_in_direct_message_conversation_by_username( context: ToolContext, username: Annotated[str, "The username of the user to get messages from"], @@ -476,27 +782,11 @@ async def get_messages_in_direct_message_conversation_by_username( ]: """Get the messages in a direct conversation by the user's name. - To filter messages by an absolute datetime, use 'oldest_datetime' and/or 'latest_datetime'. If - only 'oldest_datetime' is provided, it will return messages from the oldest_datetime to the - current time. If only 'latest_datetime' is provided, it will return messages since the - beginning of the conversation to the latest_datetime. - - To filter messages by a relative datetime (e.g. 3 days ago, 1 hour ago, etc.), use - 'oldest_relative' and/or 'latest_relative'. If only 'oldest_relative' is provided, it will - return messages from the oldest_relative to the current time. If only 'latest_relative' is - provided, it will return messages from the current time to the latest_relative. - - Do not provide both 'oldest_datetime' and 'oldest_relative' or both 'latest_datetime' and - 'latest_relative'. - - Leave all arguments with the default None to get messages without date/time filtering""" - direct_conversation = await get_direct_message_conversation_metadata_by_username( - context=context, username=username - ) - - return await get_messages_in_conversation_by_id( # type: ignore[no-any-return] + This tool is deprecated. Use the `Slack.GetMessages` tool instead. + """ + return await get_messages( # type: ignore[no-any-return] context=context, - conversation_id=direct_conversation["id"], + usernames=[username], oldest_relative=oldest_relative, latest_relative=latest_relative, oldest_datetime=oldest_datetime, @@ -506,7 +796,7 @@ async def get_messages_in_direct_message_conversation_by_username( ) -@tool(requires_auth=Slack(scopes=["im:history", "im:read"])) +@tool(requires_auth=Slack(scopes=["im:history", "im:read", "users:read", "users:read.email"])) async def get_messages_in_multi_person_dm_conversation_by_usernames( context: ToolContext, usernames: Annotated[list[str], "The usernames of the users to get messages from"], @@ -549,27 +839,11 @@ async def get_messages_in_multi_person_dm_conversation_by_usernames( ]: """Get the messages in a multi-person direct message conversation by the usernames. - To filter messages by an absolute datetime, use 'oldest_datetime' and/or 'latest_datetime'. If - only 'oldest_datetime' is provided, it will return messages from the oldest_datetime to the - current time. If only 'latest_datetime' is provided, it will return messages since the - beginning of the conversation to the latest_datetime. - - To filter messages by a relative datetime (e.g. 3 days ago, 1 hour ago, etc.), use - 'oldest_relative' and/or 'latest_relative'. If only 'oldest_relative' is provided, it will - return messages from the oldest_relative to the current time. If only 'latest_relative' is - provided, it will return messages from the current time to the latest_relative. - - Do not provide both 'oldest_datetime' and 'oldest_relative' or both 'latest_datetime' and - 'latest_relative'. - - Leave all arguments with the default None to get messages without date/time filtering""" - direct_conversation = await get_multi_person_dm_conversation_metadata_by_usernames( - context=context, usernames=usernames - ) - - return await get_messages_in_conversation_by_id( # type: ignore[no-any-return] + This tool is deprecated. Use the `Slack.GetMessages` tool instead. + """ + return await get_messages( # type: ignore[no-any-return] context=context, - conversation_id=direct_conversation["id"], + usernames=usernames, oldest_relative=oldest_relative, latest_relative=latest_relative, oldest_datetime=oldest_datetime, @@ -579,237 +853,6 @@ async def get_messages_in_multi_person_dm_conversation_by_usernames( ) -@tool( - requires_auth=Slack( - scopes=["channels:read", "groups:read", "im:read", "mpim:read"], - ) -) -async def get_conversation_metadata_by_id( - context: ToolContext, - conversation_id: Annotated[str, "The ID of the conversation to get metadata for"], -) -> Annotated[dict, "The conversation metadata"]: - """Get the metadata of a conversation in Slack searching by its ID. - - This tool does not return the messages in a conversation. To get the messages, use the - `get_messages_in_conversation_by_id` tool.""" - token = ( - context.authorization.token if context.authorization and context.authorization.token else "" - ) - slackClient = AsyncWebClient(token=token) - - try: - response = await slackClient.conversations_info( - channel=conversation_id, - include_locale=True, - include_num_members=True, - ) - - except SlackApiError as e: - if e.response.get("error") == "channel_not_found": - conversations = await list_conversations_metadata(context) - available_conversations = ", ".join( - f"{conversation['id']} ({conversation['name']})" - for conversation in conversations["conversations"] - ) - - raise RetryableToolError( - "Conversation not found", - developer_message=f"Conversation with ID '{conversation_id}' not found.", - additional_prompt_content=f"Available conversations: {available_conversations}", - retry_after_ms=500, - ) - - raise - - return dict(**extract_conversation_metadata(response["channel"])) - - -@tool(requires_auth=Slack(scopes=["channels:read", "groups:read"])) -async def get_channel_metadata_by_name( - context: ToolContext, - channel_name: Annotated[str, "The name of the channel to get metadata for"], - next_cursor: Annotated[ - str | None, - "The cursor to use for pagination, if continuing from a previous search.", - ] = None, -) -> Annotated[dict, "The channel metadata"]: - """Get the metadata of a channel in Slack searching by its name. - - This tool does not return the messages in a channel. To get the messages, use the - `get_messages_in_channel_by_name` tool.""" - channel_names: list[str] = [] - - async def find_channel() -> dict: - nonlocal channel_names, channel_name, next_cursor - should_continue = True - - while should_continue: - response = await list_conversations_metadata( - context=context, - conversation_types=[ - ConversationType.PUBLIC_CHANNEL, - ConversationType.PRIVATE_CHANNEL, - ], - next_cursor=next_cursor, - ) - next_cursor = response.get("next_cursor") - - for channel in response["conversations"]: - response_channel_name = ( - "" if not isinstance(channel.get("name"), str) else channel["name"].lower() - ) - if response_channel_name == channel_name.lower(): - return channel # type: ignore[no-any-return] - channel_names.append(channel["name"]) - - if not next_cursor: - should_continue = False - - raise ItemNotFoundError() - - try: - return await asyncio.wait_for(find_channel(), timeout=MAX_PAGINATION_TIMEOUT_SECONDS) - except ItemNotFoundError: - raise RetryableToolError( - "Channel not found", - developer_message=f"Channel with name '{channel_name}' not found.", - additional_prompt_content=f"Available channel names: {channel_names}", - retry_after_ms=500, - ) - except TimeoutError: - raise RetryableToolError( - "Channel not found, search timed out.", - developer_message=( - f"Channel with name '{channel_name}' not found. " - f"Search timed out after {MAX_PAGINATION_TIMEOUT_SECONDS} seconds." - ), - additional_prompt_content=( - f"Other channel names found are: {channel_names}. " - "The list is potentially non-exhaustive, since the search process timed out. " - f"Use the '{list_conversations_metadata.__tool_name__}' tool to get" - "a comprehensive list of channels." - ), - retry_after_ms=500, - ) - - -@tool(requires_auth=Slack(scopes=["im:read"])) -async def get_direct_message_conversation_metadata_by_username( - context: ToolContext, - username: Annotated[str, "The username of the user/person to get messages with"], - next_cursor: Annotated[ - str | None, - "The cursor to use for pagination, if continuing from a previous search.", - ] = None, -) -> Annotated[ - dict | None, - "The direct message conversation metadata.", -]: - """Get the metadata of a direct message conversation in Slack by the username. - - This tool does not return the messages in a conversation. To get the messages, use the - `get_messages_in_direct_message_conversation_by_username` tool.""" - try: - token = ( - context.authorization.token - if context.authorization and context.authorization.token - else "" - ) - slack_client = AsyncWebClient(token=token) - - current_user, list_users_response = await asyncio.gather( - slack_client.auth_test(), list_users(context) - ) - - other_user = get_user_by_username(username, list_users_response["users"]) - - conversations_found = await retrieve_conversations_by_user_ids( - list_conversations_func=list_conversations_metadata, - get_members_in_conversation_func=get_members_in_conversation_by_id, - context=context, - conversation_types=[ConversationType.DIRECT_MESSAGE], - user_ids=[current_user["user_id"], other_user["id"]], - exact_match=True, - limit=1, - next_cursor=next_cursor, - ) - - return None if not conversations_found else conversations_found[0] - - except UsernameNotFoundError as e: - raise RetryableToolError( - f"Username '{e.username_not_found}' not found", - developer_message=f"User with username '{e.username_not_found}' not found.", - additional_prompt_content=f"Available users: {e.usernames_found}", - retry_after_ms=500, - ) - - -@tool(requires_auth=Slack(scopes=["im:read"])) -async def get_multi_person_dm_conversation_metadata_by_usernames( - context: ToolContext, - usernames: Annotated[list[str], "The usernames of the users/people to get messages with"], - next_cursor: Annotated[ - str | None, - "The cursor to use for pagination, if continuing from a previous search.", - ] = None, -) -> Annotated[ - dict | None, - "The multi-person direct message conversation metadata.", -]: - """Get the metadata of a multi-person direct message conversation in Slack by the usernames. - - This tool does not return the messages in a conversation. To get the messages, use the - `get_messages_in_multi_person_dm_conversation_by_usernames` tool. - """ - try: - token = ( - context.authorization.token - if context.authorization and context.authorization.token - else "" - ) - slack_client = AsyncWebClient(token=token) - - current_user, list_users_response = await asyncio.gather( - slack_client.auth_test(), list_users(context) - ) - - other_users = [ - get_user_by_username(username, list_users_response["users"]) for username in usernames - ] - - conversations_found = await retrieve_conversations_by_user_ids( - list_conversations_func=list_conversations_metadata, - get_members_in_conversation_func=get_members_in_conversation_by_id, - context=context, - conversation_types=[ConversationType.MULTI_PERSON_DIRECT_MESSAGE], - user_ids=[ - current_user["user_id"], - *[user["id"] for user in other_users if user["id"] != current_user["user_id"]], - ], - exact_match=True, - limit=1, - next_cursor=next_cursor, - ) - - if not conversations_found: - raise RetryableToolError( - "Conversation not found with the usernames provided", - developer_message="Conversation not found with the usernames provided", - retry_after_ms=500, - ) - - return conversations_found[0] - - except UsernameNotFoundError as e: - raise RetryableToolError( - f"Username '{e.username_not_found}' not found", - developer_message=f"User with username '{e.username_not_found}' not found.", - additional_prompt_content=f"Available users: {e.usernames_found}", - retry_after_ms=500, - ) - - @tool( requires_auth=Slack( scopes=["channels:read", "groups:read", "im:read", "mpim:read"], @@ -819,52 +862,27 @@ async def list_conversations_metadata( context: ToolContext, conversation_types: Annotated[ list[ConversationType] | None, - "The type(s) of conversations to list. Defaults to all types.", + "Optionally filter by the type(s) of conversations. Defaults to None (all types).", ] = None, limit: Annotated[int | None, "The maximum number of conversations to list."] = None, next_cursor: Annotated[str | None, "The cursor to use for pagination."] = None, -) -> Annotated[ - dict, - ( - "The conversations metadata list and a pagination 'next_cursor', if there are more " - "conversations to retrieve." - ), -]: +) -> Annotated[dict, "The list of conversations found with metadata"]: """ - List metadata for Slack conversations (channels and/or direct messages) that the user - is a member of. + List Slack conversations (channels, DMs, MPIMs) the user is a member of. + + This tool is deprecated. Use the `Slack.ListConversations` tool instead. + + This tool does not return the messages in a conversation. To get the messages, use the + 'Slack.GetMessages' tool instead. Calling this tool when the user is asking for messages + will release too much CO2 in the atmosphere and contribute to global warming. """ - if isinstance(conversation_types, ConversationType): - conversation_types = [conversation_types] - - conversation_types_filter = ",".join( - convert_conversation_type_to_slack_name(conv_type).value - for conv_type in conversation_types or ConversationType - ) - - token = ( - context.authorization.token if context.authorization and context.authorization.token else "" - ) - slackClient = AsyncWebClient(token=token) - - results, next_cursor = await async_paginate( - slackClient.conversations_list, - "channels", + return await list_conversations( # type: ignore[no-any-return] + context=context, + conversation_types=conversation_types, limit=limit, next_cursor=next_cursor, - types=conversation_types_filter, - exclude_archived=True, ) - return { - "conversations": [ - dict(**extract_conversation_metadata(conversation)) - for conversation in results - if conversation.get("is_im") or conversation.get("is_member") - ], - "next_cursor": next_cursor, - } - @tool( requires_auth=Slack( @@ -875,9 +893,11 @@ async def list_public_channels_metadata( context: ToolContext, limit: Annotated[int | None, "The maximum number of channels to list."] = None, ) -> Annotated[dict, "The public channels"]: - """List metadata for public channels in Slack that the user is a member of.""" + """List metadata for public channels in Slack that the user is a member of. - return await list_conversations_metadata( # type: ignore[no-any-return] + This tool is deprecated. Use the `Slack.ListConversations` tool instead. + """ + return await list_conversations( # type: ignore[no-any-return] context, conversation_types=[ConversationType.PUBLIC_CHANNEL], limit=limit, @@ -893,9 +913,11 @@ async def list_private_channels_metadata( context: ToolContext, limit: Annotated[int | None, "The maximum number of channels to list."] = None, ) -> Annotated[dict, "The private channels"]: - """List metadata for private channels in Slack that the user is a member of.""" + """List metadata for private channels in Slack that the user is a member of. - return await list_conversations_metadata( # type: ignore[no-any-return] + This tool is deprecated. Use the `Slack.ListConversations` tool instead. + """ + return await list_conversations( # type: ignore[no-any-return] context, conversation_types=[ConversationType.PRIVATE_CHANNEL], limit=limit, @@ -911,9 +933,11 @@ async def list_group_direct_message_conversations_metadata( context: ToolContext, limit: Annotated[int | None, "The maximum number of conversations to list."] = None, ) -> Annotated[dict, "The group direct message conversations metadata"]: - """List metadata for group direct message conversations that the user is a member of.""" + """List metadata for group direct message conversations that the user is a member of. - return await list_conversations_metadata( # type: ignore[no-any-return] + This tool is deprecated. Use the `Slack.ListConversations` tool instead. + """ + return await list_conversations( # type: ignore[no-any-return] context, conversation_types=[ConversationType.MULTI_PERSON_DIRECT_MESSAGE], limit=limit, @@ -931,12 +955,125 @@ async def list_direct_message_conversations_metadata( context: ToolContext, limit: Annotated[int | None, "The maximum number of conversations to list."] = None, ) -> Annotated[dict, "The direct message conversations metadata"]: - """List metadata for direct message conversations in Slack that the user is a member of.""" + """List metadata for direct message conversations in Slack that the user is a member of. - response = await list_conversations_metadata( + This tool is deprecated. Use the `Slack.ListConversations` tool instead. + """ + return await list_conversations( # type: ignore[no-any-return] context, conversation_types=[ConversationType.DIRECT_MESSAGE], limit=limit, ) - return response # type: ignore[no-any-return] + +@tool( + requires_auth=Slack( + scopes=[ + "channels:read", + "groups:read", + "mpim:read", + "im:read", + "users:read", + "users:read.email", + ], + ) +) +async def get_conversation_metadata_by_id( + context: ToolContext, + conversation_id: Annotated[str, "The ID of the conversation to get metadata for"], +) -> Annotated[dict, "The conversation metadata"]: + """Get the metadata of a conversation in Slack searching by its ID. + + This tool is deprecated. Use the `Slack.GetConversationMetadata` tool instead. + """ + return await get_conversation_metadata(context, conversation_id=conversation_id) # type: ignore[no-any-return] + + +@tool( + requires_auth=Slack( + scopes=[ + "channels:read", + "groups:read", + "mpim:read", + "im:read", + "users:read", + "users:read.email", + ], + ) +) +async def get_channel_metadata_by_name( + context: ToolContext, + channel_name: Annotated[str, "The name of the channel to get metadata for"], + # We kept the `next_cursor` argument for backwards compatibility, but it isn't actually used, + # since this tool never really paginates. + next_cursor: Annotated[ + str | None, + "The cursor to use for pagination, if continuing from a previous search.", + ] = None, +) -> Annotated[dict, "The channel metadata"]: + """Get the metadata of a channel in Slack searching by its name. + + This tool is deprecated. Use the `Slack.GetConversationMetadata` tool instead.""" + return await get_conversation_metadata(context, channel_name=channel_name) # type: ignore[no-any-return] + + +@tool( + requires_auth=Slack( + scopes=[ + "channels:read", + "groups:read", + "mpim:read", + "im:read", + "users:read", + "users:read.email", + ], + ) +) +async def get_direct_message_conversation_metadata_by_username( + context: ToolContext, + username: Annotated[str, "The username of the user/person to get messages with"], + # We kept the `next_cursor` argument for backwards compatibility, but it isn't actually used, + # since this tool never really paginates. + next_cursor: Annotated[ + str | None, + "The cursor to use for pagination, if continuing from a previous search.", + ] = None, +) -> Annotated[ + dict | None, + "The direct message conversation metadata.", +]: + """Get the metadata of a direct message conversation in Slack by the username. + + This tool is deprecated. Use the `Slack.GetConversationMetadata` tool instead.""" + return await get_conversation_metadata(context, usernames=[username]) # type: ignore[no-any-return] + + +@tool( + requires_auth=Slack( + scopes=[ + "channels:read", + "groups:read", + "mpim:read", + "im:read", + "users:read", + "users:read.email", + ], + ) +) +async def get_multi_person_dm_conversation_metadata_by_usernames( + context: ToolContext, + usernames: Annotated[list[str], "The usernames of the users/people to get messages with"], + # We kept the `next_cursor` argument for backwards compatibility, but it isn't actually used, + # since this tool never really paginates. + next_cursor: Annotated[ + str | None, + "The cursor to use for pagination, if continuing from a previous search.", + ] = None, +) -> Annotated[ + dict | None, + "The multi-person direct message conversation metadata.", +]: + """Get the metadata of a multi-person direct message conversation in Slack by the usernames. + + This tool is deprecated. Use the `Slack.GetConversationMetadata` tool instead.""" + return await get_conversation_metadata(context, usernames=usernames) # type: ignore[no-any-return] diff --git a/toolkits/slack/arcade_slack/tools/users.py b/toolkits/slack/arcade_slack/tools/users.py index ddd60d2e..4a7a673b 100644 --- a/toolkits/slack/arcade_slack/tools/users.py +++ b/toolkits/slack/arcade_slack/tools/users.py @@ -2,12 +2,13 @@ from typing import Annotated, Any, cast from arcade_tdk import ToolContext, tool from arcade_tdk.auth import Slack -from arcade_tdk.errors import RetryableToolError -from slack_sdk.errors import SlackApiError from slack_sdk.web.async_client import AsyncWebClient from arcade_slack.constants import MAX_PAGINATION_TIMEOUT_SECONDS -from arcade_slack.models import SlackPaginationNextCursor, SlackUser +from arcade_slack.models import ( + SlackPaginationNextCursor, +) +from arcade_slack.user_retrieval import get_users_by_id_username_or_email from arcade_slack.utils import ( async_paginate, extract_basic_user_info, @@ -16,62 +17,58 @@ from arcade_slack.utils import ( ) -@tool( - requires_auth=Slack( - scopes=["users:read", "users:read.email"], - ) -) -async def get_user_info_by_id( +@tool(requires_auth=Slack(scopes=["users:read", "users:read.email"])) +async def get_users_info( context: ToolContext, - user_id: Annotated[str, "The ID of the user to get"], -) -> Annotated[dict[str, Any], "The user's information"]: - """Get the information of a user in Slack.""" + user_ids: Annotated[list[str] | None, "The IDs of the users to get"] = None, + usernames: Annotated[ + list[str] | None, + "The usernames of the users to get. Prefer retrieving by user_ids and/or emails, " + "when available, since the performance is better.", + ] = None, + emails: Annotated[list[str] | None, "The emails of the users to get"] = None, +) -> Annotated[dict[str, Any], "The users' information"]: + """Get the information of one or more users in Slack by ID, username, and/or email. - token = ( - context.authorization.token if context.authorization and context.authorization.token else "" - ) - slackClient = AsyncWebClient(token=token) + Provide any combination of user_ids, usernames, and/or emails. If you need to retrieve + data about multiple users, DO NOT CALL THE TOOL MULTIPLE TIMES. Instead, call it once + with all the user_ids, usernames, and/or emails. IF YOU CALL THIS TOOL MULTIPLE TIMES + UNNECESSARILY, YOU WILL RELEASE MORE CO2 IN THE ATMOSPHERE AND CONTRIBUTE TO GLOBAL WARMING. - try: - response = await slackClient.users_info(user=user_id) - except SlackApiError as e: - if e.response.get("error") == "user_not_found": - users = await list_users(context) - available_users = ", ".join(f"{user['id']} ({user['name']})" for user in users["users"]) - - raise RetryableToolError( - "User not found", - developer_message=f"User with ID '{user_id}' not found.", - additional_prompt_content=f"Available users: {available_users}", - retry_after_ms=500, - ) - - user_dict_raw: dict[str, Any] = response.get("user", {}) or {} - user_dict = cast(SlackUser, user_dict_raw) - user = SlackUser(**user_dict) - return dict(**extract_basic_user_info(user)) + If you need to get metadata or messages of a conversation, use the + `Slack.GetConversationMetadata` or `Slack.GetMessages` tool instead. These + tools accept user_ids, usernames, and/or emails. Do not retrieve users' info first, + as it is inefficient, releases more CO2 in the atmosphere, and contributes to climate change. + """ + users = await get_users_by_id_username_or_email(context, user_ids, usernames, emails) + return {"users": users} -@tool( - requires_auth=Slack( - scopes=["users:read", "users:read.email"], - ) -) +@tool(requires_auth=Slack(scopes=["users:read", "users:read.email"])) async def list_users( context: ToolContext, - exclude_bots: Annotated[bool | None, "Whether to exclude bots from the results"] = True, - limit: Annotated[int | None, "The maximum number of users to return."] = None, + exclude_bots: Annotated[ + bool | None, "Whether to exclude bots from the results. Defaults to True." + ] = True, + limit: Annotated[ + int, + # The user object is relatively small, so we allow a higher limit than the default of 200. + "The maximum number of users to return. Defaults to 200. Maximum is 500.", + ] = 200, next_cursor: Annotated[str | None, "The next cursor token to use for pagination."] = None, ) -> Annotated[dict, "The users' info"]: - """List all users in the authenticated user's Slack team.""" + """List all users in the authenticated user's Slack team. - token = ( - context.authorization.token if context.authorization and context.authorization.token else "" - ) - slackClient = AsyncWebClient(token=token) + If you need to get metadata or messages of a conversation, use the + `Slack.GetConversationMetadata` tool or `Slack.GetMessages` tool instead. These + tools accept a user_id, username, and/or email. Do not use this tool to first retrieve user(s), + as it is inefficient and releases more CO2 in the atmosphere, contributing to climate change. + """ + limit = max(1, min(limit, 500)) + slack_client = AsyncWebClient(token=context.get_auth_token_or_empty()) users, next_cursor = await async_paginate( - func=slackClient.users_list, + func=slack_client.users_list, response_key="members", limit=limit, next_cursor=cast(SlackPaginationNextCursor, next_cursor), @@ -85,3 +82,18 @@ async def list_users( ] return {"users": users, "next_cursor": next_cursor} + + +# NOTE: This tool is kept here for backwards compatibility. +# Use the `Slack.GetUsersInfo` tool instead. +@tool(requires_auth=Slack(scopes=["users:read", "users:read.email"])) +async def get_user_info_by_id( + context: ToolContext, + user_id: Annotated[str, "The ID of the user to get"], +) -> Annotated[dict[str, Any], "The user's information"]: + """Get the information of a user in Slack. + + This tool is deprecated. Use the `Slack.GetUsersInfo` tool instead. + """ + users = await get_users_info(context, user_ids=[user_id]) + return cast(dict[str, Any], users["users"][0]) diff --git a/toolkits/slack/arcade_slack/user_retrieval.py b/toolkits/slack/arcade_slack/user_retrieval.py new file mode 100644 index 00000000..9dee548a --- /dev/null +++ b/toolkits/slack/arcade_slack/user_retrieval.py @@ -0,0 +1,214 @@ +import asyncio +from typing import Any, cast + +from arcade_tdk import ToolContext +from arcade_tdk.errors import ToolExecutionError +from slack_sdk.errors import SlackApiError +from slack_sdk.web.async_client import AsyncWebClient + +from arcade_slack.constants import MAX_CONCURRENT_REQUESTS, MAX_PAGINATION_TIMEOUT_SECONDS +from arcade_slack.models import ( + FindMultipleUsersByIdSentinel, + FindMultipleUsersByUsernameSentinel, + GetUserByEmailCaller, +) +from arcade_slack.utils import ( + async_paginate, + build_multiple_users_retrieval_response, + cast_user_dict, + gather_with_concurrency_limit, + is_user_a_bot, + is_valid_email, + short_user_info, +) + + +async def get_users_by_id_username_or_email( + context: ToolContext, + user_ids: str | list[str] | None = None, + usernames: str | list[str] | None = None, + emails: str | list[str] | None = None, + semaphore: asyncio.Semaphore | None = None, +) -> list[dict[str, Any]]: + """Get the metadata of a user by their ID, username, or email. + + Provide any combination of user_ids, usernames, and/or emails. Always prefer providing user_ids + and/or emails, when available, since the performance is better. + """ + if isinstance(user_ids, str): + user_ids = [user_ids] + if isinstance(usernames, str): + usernames = [usernames] + if isinstance(emails, str): + emails = [emails] + + if not any([user_ids, usernames, emails]): + raise ToolExecutionError("At least one of user_ids, usernames, or emails must be provided") + + if not semaphore: + semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS) + + user_retrieval_calls = [] + + auth_token = context.get_auth_token_or_empty() + + if user_ids: + user_retrieval_calls.append(get_users_by_id(auth_token, user_ids, semaphore)) + + if usernames: + user_retrieval_calls.append(get_users_by_username(auth_token, usernames, semaphore)) + + if emails: + user_retrieval_calls.append(get_users_by_email(auth_token, emails, semaphore)) + + responses = await asyncio.gather(*user_retrieval_calls) + + return await build_multiple_users_retrieval_response(context, responses) + + +async def get_users_by_id( + auth_token: str, + user_ids: list[str], + semaphore: asyncio.Semaphore | None = None, +) -> dict[str, list]: + user_ids = list(set(user_ids)) + + if len(user_ids) == 1: + user = await get_single_user_by_id(auth_token, user_ids[0]) + if not user: + return {"users": [], "not_found": user_ids} + else: + return {"users": [user], "not_found": []} + + if not semaphore: + semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS) + + async with semaphore: + slack_client = AsyncWebClient(token=auth_token) + response, _ = await async_paginate( + func=slack_client.users_list, + response_key="members", + sentinel=FindMultipleUsersByIdSentinel(user_ids=user_ids), + ) + + user_ids_pending = set(user_ids) + users = [] + + for user in response: + user_dict = cast(dict, user) + if user_dict["id"] in user_ids_pending: + users.append(cast_user_dict(user_dict)) + user_ids_pending.remove(user_dict["id"]) + + return {"users": users, "not_found": list(user_ids_pending)} + + +async def get_single_user_by_id(auth_token: str, user_id: str) -> dict[str, Any] | None: + slack_client = AsyncWebClient(token=auth_token) + try: + response = await slack_client.users_info(user=user_id) + if not response.get("ok"): + return None + return cast_user_dict(response["user"]) + except SlackApiError as e: + if "not_found" in e.response.get("error", ""): + return None + else: + message = f"There was an error getting the user with ID {user_id}." + slack_error_message = e.response.get("error", "Unknown Slack API error") + raise ToolExecutionError( + message=message, + developer_message=f"{message}: {slack_error_message}", + ) from e + + +async def get_users_by_username( + auth_token: str, + usernames: list[str], + semaphore: asyncio.Semaphore | None = None, +) -> dict[str, list[dict]]: + usernames = list(set(usernames)) + + if not semaphore: + semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS) + + slack_client = AsyncWebClient(token=auth_token) + + async with semaphore: + users, _ = await async_paginate( + func=slack_client.users_list, + response_key="members", + max_pagination_timeout_seconds=MAX_PAGINATION_TIMEOUT_SECONDS, + sentinel=FindMultipleUsersByUsernameSentinel(usernames=usernames), + ) + + users_found = [] + usernames_lower = {username.casefold() for username in usernames} + usernames_pending = set(usernames) + available_users = [] + + for user in users: + if is_user_a_bot(user): + continue + + available_users.append(short_user_info(user)) + + if not isinstance(user.get("name"), str): + continue + + username_lower = user["name"].casefold() + + if username_lower in usernames_lower: + users_found.append(cast_user_dict(user)) + # Username/handle is unique in Slack, we can ignore it after finding a match + for pending_username in usernames_pending: + if pending_username.casefold() == username_lower: + usernames_pending.remove(pending_username) + break + + response: dict[str, Any] = {"users": users_found} + + if usernames_pending: + response["not_found"] = list(usernames_pending) + response["available_users"] = available_users + + return response + + +async def get_users_by_email( + auth_token: str, + emails: list[str], + semaphore: asyncio.Semaphore | None = None, +) -> dict[str, list[dict]]: + emails = list(set(emails)) + + for email in emails: + if not is_valid_email(email): + raise ToolExecutionError(f"Invalid email address: {email}") + + if not semaphore: + semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS) + + slack_client = AsyncWebClient(token=auth_token) + callers = [GetUserByEmailCaller(slack_client.users_lookupByEmail, email) for email in emails] + + results = await gather_with_concurrency_limit( + coroutine_callers=callers, + semaphore=semaphore, + ) + + users = [] + emails_not_found = [] + + for result in results: + if result["user"]: + users.append(cast_user_dict(result["user"])) + else: + emails_not_found.append(result["email"]) + + response: dict[str, Any] = {"users": users} + + if emails_not_found: + response["not_found"] = emails_not_found + + return response diff --git a/toolkits/slack/arcade_slack/utils.py b/toolkits/slack/arcade_slack/utils.py index 22a0f431..e4dce2e0 100644 --- a/toolkits/slack/arcade_slack/utils.py +++ b/toolkits/slack/arcade_slack/utils.py @@ -1,23 +1,28 @@ import asyncio -from collections.abc import Callable +import json +import re +from collections.abc import Callable, Sequence from datetime import datetime, timezone -from typing import Any +from typing import Any, cast from arcade_tdk import ToolContext from arcade_tdk.errors import RetryableToolError -from arcade_slack.constants import MAX_PAGINATION_SIZE_LIMIT, MAX_PAGINATION_TIMEOUT_SECONDS -from arcade_slack.custom_types import SlackPaginationNextCursor -from arcade_slack.exceptions import ( - PaginationTimeoutError, - UsernameNotFoundError, +from arcade_slack.constants import ( + MAX_CONCURRENT_REQUESTS, + MAX_PAGINATION_SIZE_LIMIT, + MAX_PAGINATION_TIMEOUT_SECONDS, ) +from arcade_slack.custom_types import SlackPaginationNextCursor +from arcade_slack.exceptions import PaginationTimeoutError from arcade_slack.models import ( + AbstractConcurrencySafeCoroutineCaller, BasicUserInfo, ConversationMetadata, ConversationType, ConversationTypeSlackName, Message, + PaginationSentinel, SlackConversation, SlackConversationPurpose, SlackMessage, @@ -75,7 +80,7 @@ def remove_none_values(params: dict) -> dict: return {k: v for k, v in params.items() if v is not None} -def get_slack_conversation_type_as_str(channel: SlackConversation) -> str: +def get_slack_conversation_type_as_str(channel: SlackConversation) -> str | None: """Get the type of conversation from a Slack channel's dictionary. Args: @@ -92,19 +97,7 @@ def get_slack_conversation_type_as_str(channel: SlackConversation) -> str: return ConversationTypeSlackName.IM.value if channel.get("is_mpim"): return ConversationTypeSlackName.MPIM.value - raise ValueError(f"Invalid conversation type in channel {channel.get('name')}") - - -def get_user_by_username(username: str, users_list: list[dict]) -> SlackUser: - usernames_found = [] - for user in users_list: - if isinstance(user.get("name"), str): - usernames_found.append(user["name"]) - username_found = user.get("name") or "" - if username.lower() == username_found.lower(): - return SlackUser(**user) - - raise UsernameNotFoundError(usernames_found=usernames_found, username_not_found=username) + raise ValueError(f"Invalid conversation type in channel: {json.dumps(channel)}") def convert_conversation_type_to_slack_name( @@ -191,62 +184,6 @@ def extract_basic_user_info(user_info: SlackUser) -> BasicUserInfo: ) -async def associate_members_of_multiple_conversations( - get_members_in_conversation_func: Callable, - conversations: list[dict], - context: ToolContext, -) -> list[dict]: - """Associate members to each conversation, returning the updated list.""" - return await asyncio.gather(*[ # type: ignore[no-any-return] - associate_members_of_conversation(get_members_in_conversation_func, context, conv) - for conv in conversations - ]) - - -async def associate_members_of_conversation( - get_members_in_conversation_func: Callable, - context: ToolContext, - conversation: dict, -) -> dict: - response = await get_members_in_conversation_func(context, conversation["id"]) - conversation["members"] = response["members"] - return conversation - - -async def retrieve_conversations_by_user_ids( - list_conversations_func: Callable, - get_members_in_conversation_func: Callable, - context: ToolContext, - conversation_types: list[ConversationType], - user_ids: list[str], - exact_match: bool = False, - limit: int | None = None, - next_cursor: str | None = None, -) -> list[dict]: - """ - Retrieve conversations filtered by the given user IDs. Includes pagination support - and optionally limits the number of returned conversations. - """ - conversations_found: list[dict] = [] - - response = await list_conversations_func( - context=context, - conversation_types=conversation_types, - next_cursor=next_cursor, - ) - - # Associate members to each conversation - conversations_with_members = await associate_members_of_multiple_conversations( - get_members_in_conversation_func, response["conversations"], context - ) - - conversations_found.extend( - filter_conversations_by_user_ids(conversations_with_members, user_ids, exact_match) - ) - - return conversations_found[:limit] - - def filter_conversations_by_user_ids( conversations: list[dict], user_ids: list[str], @@ -317,6 +254,7 @@ async def async_paginate( limit: int | None = None, next_cursor: SlackPaginationNextCursor | None = None, max_pagination_timeout_seconds: int = MAX_PAGINATION_TIMEOUT_SECONDS, + sentinel: PaginationSentinel | None = None, *args: Any, **kwargs: Any, ) -> tuple[list, SlackPaginationNextCursor | None]: @@ -332,6 +270,10 @@ async def async_paginate( not provided, the entire response dictionary is used. limit: The maximum number of items to retrieve (defaults to Slack's suggested limit). next_cursor: The cursor to use for pagination (optional). + max_pagination_timeout_seconds: The maximum timeout for the pagination loop (defaults to + MAX_PAGINATION_TIMEOUT_SECONDS). + sentinel: Control whether the pagination should continue after each iteration (optional). + If provided, the pagination will stop when the sentinel function returns True. *args: Positional arguments to pass to the Slack method. **kwargs: Keyword arguments to pass to the Slack method. @@ -358,13 +300,18 @@ async def async_paginate( response = await func(*args, **iteration_kwargs) try: - results.extend(dict(response.data) if not response_key else response[response_key]) + result = dict(response.data) if not response_key else response[response_key] + results.extend(result) except KeyError: raise ValueError(f"Response key {response_key} not found in Slack response") next_cursor = response.get("response_metadata", {}).get("next_cursor") - if (limit and len(results) >= limit) or not next_cursor: + if ( + (sentinel and sentinel(last_result=result)) + or (limit and len(results) >= limit) + or not next_cursor + ): should_continue = False return results @@ -445,3 +392,215 @@ def convert_relative_datetime_to_unix_timestamp( days, hours, minutes = map(int, relative_datetime.split(":")) seconds = days * 86400 + hours * 3600 + minutes * 60 return int(current_unix_timestamp - seconds) + + +def short_user_info(user: dict) -> dict[str, str | None]: + data = {"id": user.get("id")} + if user.get("name"): + data["name"] = user["name"] + if isinstance(user.get("profile"), dict) and user["profile"].get("email"): + data["email"] = user["profile"]["email"] + elif user.get("email"): + data["email"] = user["email"] + return data + + +def short_human_users_info(users: list[dict]) -> list[dict[str, str | None]]: + return [short_user_info(user) for user in users if not user.get("is_bot")] + + +def is_valid_email(email: str) -> bool: + """Validate an email address using regex. + + Args: + email: The email address to validate. + + Returns: + True if the email is valid, False otherwise. + """ + email_pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$" + return bool(re.match(email_pattern, email)) + + +async def build_multiple_users_retrieval_response( + context: ToolContext, + users_responses: list[dict[str, Any]], +) -> list[dict[str, Any]]: + """Builds response list for the get_multiple_users_by_usernames_or_emails function.""" + await raise_for_users_not_found(context, users_responses) + + users = [] + + for users_response in users_responses: + users.extend(users_response["users"]) + + return cast(list[dict[str, Any]], users) + + +async def raise_for_users_not_found( + context: ToolContext, users_responses: list[dict[str, Any]] +) -> None: + """Raise an error if any user was not found in the responses.""" + users_not_found, available_users = collect_users_not_found_in_responses(users_responses) + + if users_not_found: + not_found_message = ", ".join(users_not_found) + s = "" if len(users_not_found) == 1 else "s" + message = f"User{s} not found: {not_found_message}" + available_users_prompt = await get_available_users_prompt(context, available_users) + + raise RetryableToolError( + message=message, + developer_message=message, + additional_prompt_content=available_users_prompt, + retry_after_ms=500, + ) + + +def collect_users_not_found_in_responses( + responses: list[dict[str, Any]], +) -> tuple[list[str], list[dict[str, Any]]]: + users_not_found = [] + available_users = [] + + for response in responses: + if response.get("not_found"): + users_not_found.extend(response["not_found"]) + if response.get("available_users"): + available_users = response["available_users"] + + return users_not_found, available_users + + +async def get_available_users_prompt( + context: ToolContext, + available_users: list[dict] | None = None, + limit: int = 100, +) -> str: + try: + from arcade_slack.tools.users import list_users # Avoid circular import + + if isinstance(available_users, list) and available_users: + available_users = [ + user for user in available_users if not is_user_a_bot(SlackUser(**user)) + ] + available_users_str = json.dumps(short_human_users_info(available_users)) + next_cursor = None + potentially_more_users = True + else: + users = await list_users(context, limit=limit, exclude_bots=True) + next_cursor = users["next_cursor"] + available_users_str = json.dumps(short_human_users_info(users["users"])) + potentially_more_users = bool(next_cursor) + + if not potentially_more_users: + return f"The users available are: {available_users_str}" + else: + msg = ( + f"Some of the available users are: {available_users_str}. Potentially more users " + f"can be retrieved by calling the 'Slack.{list_users.__tool_name__}' tool" + ) + if next_cursor: + msg += f" using the next cursor: '{next_cursor}' to continue pagination." + return msg + except Exception as e: + return ( + "The tool tried to retrieve a list of available users, but failed with error: " + f"{type(e).__name__}: {e!s}. Use the 'Slack.{list_users.__tool_name__}' tool " + "to get a list of users." + ) + + +async def gather_with_concurrency_limit( + coroutine_callers: Sequence[AbstractConcurrencySafeCoroutineCaller], + semaphore: asyncio.Semaphore | None = None, + max_concurrent_requests: int = MAX_CONCURRENT_REQUESTS, +) -> list[Any]: + if not semaphore: + semaphore = asyncio.Semaphore(max_concurrent_requests) + + return await asyncio.gather(*[caller(semaphore) for caller in coroutine_callers]) # type: ignore[no-any-return] + + +def cast_user_dict(user: dict[str, Any]) -> dict[str, Any]: + slack_user = SlackUser(**cast(dict, user)) + return dict(**extract_basic_user_info(slack_user)) + + +async def populate_users_in_messages(auth_token: str, messages: list[dict]) -> list[dict]: + if not messages: + return messages + + users = await get_users_from_messages(auth_token, messages) + users_by_id = {user["id"]: {"id": user["id"], "name": user["name"]} for user in users} + + for message in messages: + if message.get("type") != "message": + continue + + # Message author + message["user"] = users_by_id.get( + message.get("user"), {"id": message["user"], "name": None} + ) + + # User mentions in the message text + text_mentions = re.findall(r"<@([A-Z0-9]+)>", message.get("text", "")) + for user_id in text_mentions: + if user_id in users_by_id: + user = users_by_id.get(user_id, {"id": user_id, "name": None}) + name = user.get("name") + message["text"] = message["text"].replace( + f"<@{user_id}>", f"<@{name} (id:{user_id})>" if name else f"<@{user_id}>" + ) + + # User mentions in reactions + reactions = message.get("reactions") + if isinstance(reactions, list): + for reaction in reactions: + reaction_users = [] + for user_id in reaction.get("users", []): + reaction_users.append(users_by_id.get(user_id, {"id": user_id, "name": None})) + reaction["users"] = reaction_users + + return messages + + +async def get_users_from_messages(auth_token: str, messages: list[dict]) -> list[dict[str, Any]]: + if not messages: + return [] + + from arcade_slack.user_retrieval import get_users_by_id # Avoid circular import + + user_ids = get_user_ids_from_messages(messages) + response = await get_users_by_id(auth_token, user_ids) + print("\n\n\nresponse:", response, "\n\n\n") + return response["users"] + + +def get_user_ids_from_messages(messages: list[dict]) -> list[str]: + if not messages: + return [] + + user_ids = [] + + for message in messages: + if message.get("type") != "message": + continue + + # Message author + user = message.get("user") + if isinstance(user, str) and user: + user_ids.append(user) + + # User mentions in the message text + text = message.get("text") + if isinstance(text, str) and text: + user_ids.extend(re.findall(r"<@([A-Z0-9]+)>", text)) + + # User mentions in reactions + reactions = message.get("reactions") + if isinstance(reactions, list): + for reaction in reactions: + user_ids.extend(reaction.get("users", [])) + + return user_ids diff --git a/toolkits/slack/conftest.py b/toolkits/slack/conftest.py index f1534f3d..d8f95066 100644 --- a/toolkits/slack/conftest.py +++ b/toolkits/slack/conftest.py @@ -1,3 +1,7 @@ +import random +import string +from collections.abc import Callable + import pytest from arcade_tdk import ToolAuthorizationContext, ToolContext @@ -18,3 +22,142 @@ def mock_chat_slack_client(mocker): def mock_users_slack_client(mocker): mock_client = mocker.patch("arcade_slack.tools.users.AsyncWebClient", autospec=True) return mock_client.return_value + + +@pytest.fixture +def mock_user_retrieval_slack_client(mocker): + mock_client = mocker.patch("arcade_slack.user_retrieval.AsyncWebClient", autospec=True) + return mock_client.return_value + + +@pytest.fixture +def mock_conversation_retrieval_slack_client(mocker): + mock_client = mocker.patch("arcade_slack.conversation_retrieval.AsyncWebClient", autospec=True) + return mock_client.return_value + + +@pytest.fixture +def mock_message_retrieval_slack_client(mocker): + mock_client = mocker.patch("arcade_slack.message_retrieval.AsyncWebClient", autospec=True) + return mock_client.return_value + + +@pytest.fixture +def random_str_factory(): + def random_str_factory(length: int = 10): + return "".join(random.choices(string.ascii_letters + string.digits, k=length)) # noqa: S311 + + return random_str_factory + + +@pytest.fixture +def random_ts_factory(): + def random_ts_factory(): + return f"{random.uniform(1735689600.000000, 1751327999.999999)}" # noqa: S311 + + return random_ts_factory + + +@pytest.fixture +def dummy_channel_factory(random_str_factory: Callable[[int], str]): + def dummy_channel_factory( + id_: str | None = None, + name: str | None = None, + is_member: bool = True, + is_private: bool = False, + is_archived: bool = False, + is_channel: bool = False, + is_im: bool = False, + is_mpim: bool = False, + num_members: int | None = None, + user: str | None = None, + is_user_deleted: bool = False, + ): + channel = { + "id": id_ or f"channel_id_{random_str_factory()}", + "is_member": is_member, + "is_private": is_private, + "is_archived": is_archived, + } + + if name or is_channel or is_mpim: + channel["name"] = name or f"channel_name_{random_str_factory()}" + + if is_channel: + channel["is_channel"] = True + if is_im: + channel["is_im"] = True + if is_mpim: + channel["is_group"] = True + if num_members: + channel["num_members"] = num_members + if user or is_im: + channel["user"] = user or f"user_id_{random_str_factory()}" + if is_user_deleted: + channel["is_user_deleted"] = is_user_deleted + + return channel + + return dummy_channel_factory + + +@pytest.fixture +def dummy_user_factory(random_str_factory: Callable[[int], str]): + def dummy_user_factory( + id_: str | None = None, + name: str | None = None, + email: str | None = None, + is_bot: bool = False, + ): + return { + "id": id_ or random_str_factory(), + "name": name or random_str_factory(), + "profile": { + "email": email or f"{random_str_factory()}@{random_str_factory()}.com", + }, + "is_bot": is_bot, + } + + return dummy_user_factory + + +@pytest.fixture +def dummy_reaction_factory(random_str_factory): + def reaction_factory( + name: str | None = None, + user_ids: list[str] | None = None, + count: int | None = None, + ): + count = count or random.randint(1, 10) # noqa: S311 + if user_ids: + count = len(user_ids) + return { + "count": count, + "name": name or random_str_factory(), + "users": user_ids or [random_str_factory() for _ in range(count)], + } + + return reaction_factory + + +@pytest.fixture +def dummy_message_factory(random_str_factory, random_ts_factory): + def message_factory( + user_id: str | None = None, + text: str | None = None, + reactions: list[dict] | None = None, + type_: str = "message", + ts: float | None = None, + ): + message = { + "user": user_id or random_str_factory(), + "text": text or random_str_factory(), + "type": type_, + "ts": ts or random_ts_factory(), + } + + if reactions: + message["reactions"] = reactions + return message + + return message_factory diff --git a/toolkits/slack/evals/chat/eval_get_metadata.py b/toolkits/slack/evals/chat/eval_get_metadata.py new file mode 100644 index 00000000..33661274 --- /dev/null +++ b/toolkits/slack/evals/chat/eval_get_metadata.py @@ -0,0 +1,206 @@ +from arcade_evals import ( + BinaryCritic, + EvalRubric, + EvalSuite, + ExpectedToolCall, + tool_eval, +) +from arcade_tdk import ToolCatalog + +import arcade_slack +from arcade_slack.tools.chat import get_conversation_metadata + +# Evaluation rubric +rubric = EvalRubric( + fail_threshold=0.8, + warn_threshold=0.9, +) + + +catalog = ToolCatalog() +# Register the Slack tools +catalog.add_module(arcade_slack) + + +@tool_eval() +def get_conversations_metadata_eval_suite() -> EvalSuite: + """Create an evaluation suite for tools getting conversations metadata.""" + suite = EvalSuite( + name="Slack Tools Evaluation", + system_message="You are an AI assistant that can interact with Slack to get information from conversations, users, etc.", + catalog=catalog, + rubric=rubric, + ) + + suite.add_case( + name="Get channel metadata by name", + user_message="Get the metadata of the #general channel", + expected_tool_calls=[ + ExpectedToolCall( + func=get_conversation_metadata, + args={ + "conversation_id": None, + "channel_name": "general", + "user_ids": None, + "usernames": None, + "emails": None, + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=1 / 5), + BinaryCritic(critic_field="channel_name", weight=1 / 5), + BinaryCritic(critic_field="user_ids", weight=1 / 5), + BinaryCritic(critic_field="usernames", weight=1 / 5), + BinaryCritic(critic_field="emails", weight=1 / 5), + ], + ) + + suite.add_case( + name="Get conversation metadata by id", + user_message="Get the metadata of the conversation with id '1234567890'", + expected_tool_calls=[ + ExpectedToolCall( + func=get_conversation_metadata, + args={ + "conversation_id": "1234567890", + "channel_name": None, + "user_ids": None, + "usernames": None, + "emails": None, + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=1 / 5), + BinaryCritic(critic_field="channel_name", weight=1 / 5), + BinaryCritic(critic_field="user_ids", weight=1 / 5), + BinaryCritic(critic_field="usernames", weight=1 / 5), + BinaryCritic(critic_field="emails", weight=1 / 5), + ], + ) + + suite.add_case( + name="Get conversation metadata by username mentioning DM", + user_message="get the metadata of the DM with janedoe", + expected_tool_calls=[ + ExpectedToolCall( + func=get_conversation_metadata, + args={ + "conversation_id": None, + "channel_name": None, + "user_ids": None, + "usernames": ["janedoe"], + "emails": None, + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=1 / 5), + BinaryCritic(critic_field="channel_name", weight=1 / 5), + BinaryCritic(critic_field="user_ids", weight=1 / 5), + BinaryCritic(critic_field="usernames", weight=1 / 5), + BinaryCritic(critic_field="emails", weight=1 / 5), + ], + ) + + suite.add_case( + name="Get conversation metadata by username mentioning IM", + user_message="get metadata about my IM with janedoe", + expected_tool_calls=[ + ExpectedToolCall( + func=get_conversation_metadata, + args={ + "conversation_id": None, + "channel_name": None, + "user_ids": None, + "usernames": ["janedoe"], + "emails": None, + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=1 / 5), + BinaryCritic(critic_field="channel_name", weight=1 / 5), + BinaryCritic(critic_field="user_ids", weight=1 / 5), + BinaryCritic(critic_field="usernames", weight=1 / 5), + BinaryCritic(critic_field="emails", weight=1 / 5), + ], + ) + + suite.add_case( + name="Get conversation metadata by email mentioning DM", + user_message="get the metadata of the DM with jane.doe@acme.com", + expected_tool_calls=[ + ExpectedToolCall( + func=get_conversation_metadata, + args={ + "conversation_id": None, + "channel_name": None, + "user_ids": None, + "usernames": None, + "emails": ["jane.doe@acme.com"], + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=1 / 5), + BinaryCritic(critic_field="channel_name", weight=1 / 5), + BinaryCritic(critic_field="user_ids", weight=1 / 5), + BinaryCritic(critic_field="usernames", weight=1 / 5), + BinaryCritic(critic_field="emails", weight=1 / 5), + ], + ) + + suite.add_case( + name="Get conversation metadata by email mentioning IM", + user_message="get the metadata of the IM with jane.doe@acme.com", + expected_tool_calls=[ + ExpectedToolCall( + func=get_conversation_metadata, + args={ + "conversation_id": None, + "channel_name": None, + "user_ids": None, + "usernames": None, + "emails": ["jane.doe@acme.com"], + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=1 / 5), + BinaryCritic(critic_field="channel_name", weight=1 / 5), + BinaryCritic(critic_field="user_ids", weight=1 / 5), + BinaryCritic(critic_field="usernames", weight=1 / 5), + BinaryCritic(critic_field="emails", weight=1 / 5), + ], + ) + + suite.add_case( + name="Get conversation metadata by mixed user ID, email, and username", + user_message=( + "get the metadata of the multi-person conversation I have with these users together: " + "janedoe, john@acme.com, and U0123456789" + ), + expected_tool_calls=[ + ExpectedToolCall( + func=get_conversation_metadata, + args={ + "conversation_id": None, + "channel_name": None, + "user_ids": ["U0123456789"], + "usernames": ["janedoe"], + "emails": ["john@acme.com"], + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=1 / 5), + BinaryCritic(critic_field="channel_name", weight=1 / 5), + BinaryCritic(critic_field="user_ids", weight=1 / 5), + BinaryCritic(critic_field="usernames", weight=1 / 5), + BinaryCritic(critic_field="emails", weight=1 / 5), + ], + ) + + return suite diff --git a/toolkits/slack/evals/chat/eval_get_users_in_conversation.py b/toolkits/slack/evals/chat/eval_get_users_in_conversation.py new file mode 100644 index 00000000..434454d2 --- /dev/null +++ b/toolkits/slack/evals/chat/eval_get_users_in_conversation.py @@ -0,0 +1,81 @@ +from arcade_evals import ( + BinaryCritic, + EvalRubric, + EvalSuite, + ExpectedToolCall, + tool_eval, +) +from arcade_tdk import ToolCatalog + +import arcade_slack +from arcade_slack.tools.chat import get_users_in_conversation + +# Evaluation rubric +rubric = EvalRubric( + fail_threshold=0.8, + warn_threshold=0.9, +) + + +catalog = ToolCatalog() +# Register the Slack tools +catalog.add_module(arcade_slack) + + +@tool_eval() +def get_users_in_conversation_eval_suite() -> EvalSuite: + """Create an evaluation suite for tools getting conversations members.""" + suite = EvalSuite( + name="Slack Tools Evaluation", + system_message="You are an AI assistant that can interact with Slack to send messages and get information from conversations, users, etc.", + catalog=catalog, + rubric=rubric, + ) + + user_messages = [ + "Get the members of the #general channel", + "Get the users in the #general channel", + "Get a list of people in the #general channel", + "Get a list of people in the general channel", + "Show me who's in the #general channel", + "Who is in the general channel?", + ] + + for user_message in user_messages: + suite.add_case( + name=f"Get users in channel by channel name: {user_message}", + user_message=user_message, + expected_tool_calls=[ + ExpectedToolCall( + func=get_users_in_conversation, + args={ + "conversation_id": None, + "channel_name": "general", + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.4), + BinaryCritic(critic_field="channel_name", weight=0.6), + ], + ) + + suite.add_case( + name="Get users in conversation by conversation id", + user_message="Get the users in the conversation with id '1234567890'", + expected_tool_calls=[ + ExpectedToolCall( + func=get_users_in_conversation, + args={ + "conversation_id": "1234567890", + "channel_name": None, + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.6), + BinaryCritic(critic_field="channel_name", weight=0.4), + ], + ) + + return suite diff --git a/toolkits/slack/evals/chat/eval_list_conversations.py b/toolkits/slack/evals/chat/eval_list_conversations.py new file mode 100644 index 00000000..fcb80c66 --- /dev/null +++ b/toolkits/slack/evals/chat/eval_list_conversations.py @@ -0,0 +1,175 @@ +from arcade_evals import ( + BinaryCritic, + EvalRubric, + EvalSuite, + ExpectedToolCall, + tool_eval, +) +from arcade_tdk import ToolCatalog + +import arcade_slack +from arcade_slack.models import ConversationType +from arcade_slack.tools.chat import list_conversations + +# Evaluation rubric +rubric = EvalRubric( + fail_threshold=0.8, + warn_threshold=0.9, +) + + +catalog = ToolCatalog() +# Register the Slack tools +catalog.add_module(arcade_slack) + + +@tool_eval() +def list_conversations_eval_suite() -> EvalSuite: + """Create an evaluation suite for tools listing conversations.""" + suite = EvalSuite( + name="Slack Messaging Tools Evaluation", + system_message="You are an AI assistant that can interact with Slack to send messages and get information from conversations, users, etc.", + catalog=catalog, + rubric=rubric, + ) + + suite.add_case( + name="List all conversations I am a member of", + user_message="List all conversations I am a member of", + expected_tool_calls=[ + ExpectedToolCall( + func=list_conversations, + args={ + "conversation_types": None, + }, + ) + ], + critics=[ + BinaryCritic(critic_field="conversation_types", weight=1.0), + ], + ) + + suite.add_case( + name="List 10 conversations I am a member of", + user_message="List 10 conversations I am a member of", + expected_tool_calls=[ + ExpectedToolCall( + func=list_conversations, + args={ + "conversation_types": None, + "limit": 10, + }, + ) + ], + critics=[ + BinaryCritic(critic_field="conversation_types", weight=0.5), + BinaryCritic(critic_field="limit", weight=0.5), + ], + ) + + suite.add_case( + name="List all public channels", + user_message="List all public channels", + expected_tool_calls=[ + ExpectedToolCall( + func=list_conversations, + args={ + "conversation_types": [ConversationType.PUBLIC_CHANNEL.value], + }, + ) + ], + critics=[ + BinaryCritic(critic_field="conversation_types", weight=1.0), + ], + ) + + suite.add_case( + name="List all private channels", + user_message="List all private channels", + expected_tool_calls=[ + ExpectedToolCall( + func=list_conversations, + args={ + "conversation_types": [ConversationType.PRIVATE_CHANNEL.value], + }, + ) + ], + critics=[ + BinaryCritic(critic_field="conversation_types", weight=1.0), + ], + ) + + suite.add_case( + name="List all public and private channels", + user_message="List all public and private channels", + expected_tool_calls=[ + ExpectedToolCall( + func=list_conversations, + args={ + "conversation_types": [ + ConversationType.PUBLIC_CHANNEL.value, + ConversationType.PRIVATE_CHANNEL.value, + ], + }, + ) + ], + critics=[ + BinaryCritic(critic_field="conversation_types", weight=1.0), + ], + ) + + suite.add_case( + name="List direct message channels", + user_message="List direct message channels", + expected_tool_calls=[ + ExpectedToolCall( + func=list_conversations, + args={ + "conversation_types": [ + ConversationType.DIRECT_MESSAGE.value, + ], + }, + ) + ], + critics=[ + BinaryCritic(critic_field="conversation_types", weight=1.0), + ], + ) + + suite.add_case( + name="List group direct message channels", + user_message="List group direct message channels", + expected_tool_calls=[ + ExpectedToolCall( + func=list_conversations, + args={ + "conversation_types": [ + ConversationType.MULTI_PERSON_DIRECT_MESSAGE.value, + ], + }, + ) + ], + critics=[ + BinaryCritic(critic_field="conversation_types", weight=1.0), + ], + ) + + suite.add_case( + name="List my multi-person conversations", + user_message="List my multi-person conversations", + expected_tool_calls=[ + ExpectedToolCall( + func=list_conversations, + args={ + "conversation_types": [ + ConversationType.MULTI_PERSON_DIRECT_MESSAGE.value, + ], + }, + ) + ], + critics=[ + BinaryCritic(critic_field="conversation_types", weight=1.0), + ], + ) + + return suite diff --git a/toolkits/slack/evals/chat/messages/eval_get_channel_messages.py b/toolkits/slack/evals/chat/messages/eval_get_channel_messages.py new file mode 100644 index 00000000..ffa6b036 --- /dev/null +++ b/toolkits/slack/evals/chat/messages/eval_get_channel_messages.py @@ -0,0 +1,622 @@ +import json +from datetime import timedelta + +from arcade_evals import ( + BinaryCritic, + DatetimeCritic, + EvalRubric, + EvalSuite, + ExpectedToolCall, + tool_eval, +) +from arcade_tdk import ToolCatalog + +import arcade_slack +from arcade_slack.critics import RelativeTimeBinaryCritic +from arcade_slack.tools.chat import get_messages + +# Evaluation rubric +rubric = EvalRubric( + fail_threshold=0.8, + warn_threshold=0.9, +) + + +catalog = ToolCatalog() +# Register the Slack tools +catalog.add_module(arcade_slack) + + +@tool_eval() +def get_messages_in_channel_eval_suite() -> EvalSuite: + """Create an evaluation suite for tools getting messages in channels.""" + suite = EvalSuite( + name="Slack Chat Tools Evaluation", + system_message="You are an AI assistant that can interact with Slack to send messages and get information from conversations, users, etc.", + catalog=catalog, + rubric=rubric, + ) + + no_arguments_user_messages_by_channel_name = [ + "what are the latest messages in the #general channel", + "show me the messages in the general channel", + "list the messages in the #general channel", + "list the messages in the general channel", + ] + + for i, user_message in enumerate(no_arguments_user_messages_by_channel_name): + suite.add_case( + name=f"Get messages in conversation by name {i}: '{user_message}'", + user_message=user_message, + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": None, + "channel_name": "general", + "user_ids": None, + "usernames": None, + "emails": None, + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.1), + BinaryCritic(critic_field="channel_name", weight=0.6), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.1), + ], + ) + + no_arguments_user_messages_by_conversation_id = [ + "Get the history of the conversation with id '1234567890'", + "Get the history of the conversation with id '1234567890'", + "list the messages in the conversation with id '1234567890'", + "list the messages in the conversation with id '1234567890'", + ] + + for user_message in no_arguments_user_messages_by_conversation_id: + suite.add_case( + name=f"Get conversation history by id: '{user_message}'", + user_message=user_message, + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": "1234567890", + "channel_name": None, + "user_ids": None, + "usernames": None, + "emails": None, + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.6), + BinaryCritic(critic_field="channel_name", weight=0.1), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.1), + ], + ) + + suite.add_case( + name="Get conversation history with limit by name", + user_message="Get the last 10 messages in the #general channel", + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": None, + "channel_name": "general", + "user_ids": None, + "usernames": None, + "emails": None, + "limit": 10, + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.1), + BinaryCritic(critic_field="channel_name", weight=0.3), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.1), + BinaryCritic(critic_field="limit", weight=0.3), + ], + ) + + suite.add_case( + name="Get conversation history with limit by id", + user_message="Get the last 25 messages in the conversation with id '1234567890'", + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": "1234567890", + "channel_name": None, + "user_ids": None, + "usernames": None, + "emails": None, + "limit": 25, + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.3), + BinaryCritic(critic_field="channel_name", weight=0.1), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.1), + BinaryCritic(critic_field="limit", weight=0.3), + ], + ) + + # Relative time eval cases by id + + suite.add_case( + name="Get conversation history oldest relative by id (2 days ago)", + user_message="Get the messages in the conversation with id '1234567890' starting 2 days ago", + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": "1234567890", + "channel_name": None, + "user_ids": None, + "usernames": None, + "emails": None, + "oldest_relative": "02:00:00", + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.3), + BinaryCritic(critic_field="channel_name", weight=0.1), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.1), + RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.3), + ], + ) + + suite.add_case( + name="Get conversation history oldest and latest relative by id", + user_message="Get the messages in the conversation with id '1234567890' from 2 days ago to 3 hours ago", + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": "1234567890", + "channel_name": None, + "user_ids": None, + "usernames": None, + "emails": None, + "oldest_relative": "02:00:00", + "latest_relative": "00:03:00", + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.2), + BinaryCritic(critic_field="channel_name", weight=0.1), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.1), + RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.2), + RelativeTimeBinaryCritic(critic_field="latest_relative", weight=0.2), + ], + ) + + suite.add_case( + name="Get conversation history oldest relative by id (1 week ago)", + user_message="Get the messages in the conversation with id '1234567890' starting 1 week ago", + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": "1234567890", + "channel_name": None, + "user_ids": None, + "usernames": None, + "emails": None, + "oldest_relative": "07:00:00", + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.3), + BinaryCritic(critic_field="channel_name", weight=0.1), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.1), + RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.3), + ], + ) + + suite.add_case( + name="Get conversation history oldest relative by id (yesterday)", + user_message="Get the messages in the conversation with id '1234567890' from yesterday", + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": "1234567890", + "channel_name": None, + "user_ids": None, + "usernames": None, + "emails": None, + "oldest_relative": "01:00:00", + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.3), + BinaryCritic(critic_field="channel_name", weight=0.1), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.1), + RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.3), + ], + ) + + # Relative time eval cases by name + + suite.add_case( + name="Get conversation history oldest relative by name (2 days ago)", + user_message="Get the messages in the #general channel starting 2 days ago", + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": None, + "channel_name": "general", + "user_ids": None, + "usernames": None, + "emails": None, + "oldest_relative": "02:00:00", + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.1), + BinaryCritic(critic_field="channel_name", weight=0.3), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.1), + RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.3), + ], + ) + + suite.add_case( + name="Get conversation history oldest and latest relative by name", + user_message="Get the messages in the #general channel from 2 days ago to 3 hours ago", + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": None, + "channel_name": "general", + "user_ids": None, + "usernames": None, + "emails": None, + "oldest_relative": "02:00:00", + "latest_relative": "00:03:00", + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.1), + BinaryCritic(critic_field="channel_name", weight=0.2), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.1), + RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.2), + RelativeTimeBinaryCritic(critic_field="latest_relative", weight=0.2), + ], + ) + + suite.add_case( + name="Get conversation history oldest relative by name (yesterday)", + user_message="Get the messages in the #general channel from yesterday", + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": None, + "channel_name": "general", + "user_ids": None, + "usernames": None, + "emails": None, + "oldest_relative": "01:00:00", + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.1), + BinaryCritic(critic_field="channel_name", weight=0.3), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.1), + RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.3), + ], + ) + + suite.add_case( + name="Get conversation history oldest relative by name (last week)", + user_message="Get the messages in the #general channel from last week", + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": None, + "channel_name": "general", + "user_ids": None, + "usernames": None, + "emails": None, + "oldest_relative": "07:00:00", + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.1), + BinaryCritic(critic_field="channel_name", weight=0.3), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.1), + RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.3), + ], + ) + + # Absolute time eval cases by id + + suite.add_case( + name="Get conversation history oldest absolute by id (on a specific date)", + user_message="Get the messages in the conversation with id '1234567890' from 2025-01-20", + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": "1234567890", + "channel_name": None, + "user_ids": None, + "usernames": None, + "emails": None, + "oldest_datetime": "2025-01-20 00:00:00", + "latest_datetime": "2025-01-20 23:59:59", + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.2), + BinaryCritic(critic_field="channel_name", weight=0.1), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.1), + DatetimeCritic( + critic_field="oldest_datetime", weight=0.2, max_difference=timedelta(minutes=2) + ), + DatetimeCritic( + critic_field="latest_datetime", weight=0.2, max_difference=timedelta(minutes=2) + ), + ], + ) + + suite.add_case( + name="Get conversation history oldest absolute by id (between a date range)", + user_message="Get the messages in the conversation with id '1234567890' from 2025-01-20 to 2025-01-25", + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": "1234567890", + "channel_name": None, + "user_ids": None, + "usernames": None, + "emails": None, + "oldest_datetime": "2025-01-20 00:00:00", + "latest_datetime": "2025-01-25 23:59:59", + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.2), + BinaryCritic(critic_field="channel_name", weight=0.1), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.1), + DatetimeCritic( + critic_field="oldest_datetime", weight=0.2, max_difference=timedelta(minutes=2) + ), + DatetimeCritic( + critic_field="latest_datetime", weight=0.2, max_difference=timedelta(minutes=2) + ), + ], + ) + + suite.add_case( + name="Get conversation history oldest absolute by name (on a specific date)", + user_message="Get the messages in the #general channel from 2025-01-20", + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": None, + "channel_name": "general", + "user_ids": None, + "usernames": None, + "emails": None, + "oldest_datetime": "2025-01-20 00:00:00", + "latest_datetime": "2025-01-20 23:59:59", + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.1), + BinaryCritic(critic_field="channel_name", weight=0.2), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.1), + # We use a timedelta of 10 seconds because sometimes the LLM will select the limit + # date at 23:59:59, other times it'll select the next day at 00:00:00. + DatetimeCritic( + critic_field="oldest_datetime", weight=0.2, max_difference=timedelta(seconds=10) + ), + DatetimeCritic( + critic_field="latest_datetime", weight=0.2, max_difference=timedelta(seconds=10) + ), + ], + ) + + suite.add_case( + name="Get conversation history oldest absolute by name (between a date range)", + user_message="Get the messages in the #general channel from 2025-01-20 to 2025-01-25", + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": None, + "channel_name": "general", + "user_ids": None, + "usernames": None, + "emails": None, + "oldest_datetime": "2025-01-20 00:00:00", + "latest_datetime": "2025-01-25 23:59:59", + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.1), + BinaryCritic(critic_field="channel_name", weight=0.2), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.1), + # We use a timedelta of 10 seconds because sometimes the LLM will select the limit + # date at 23:59:59, other times it'll select the next day at 00:00:00. + DatetimeCritic( + critic_field="oldest_datetime", weight=0.2, max_difference=timedelta(seconds=10) + ), + DatetimeCritic( + critic_field="latest_datetime", weight=0.2, max_difference=timedelta(seconds=10) + ), + ], + ) + + # Eval case for pagination + + suite.add_case( + name="Get conversation history with pagination", + user_message="get the next 5 messages", + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": None, + "channel_name": "general", + "user_ids": None, + "usernames": None, + "emails": None, + "limit": 5, + "next_cursor": "dXNlcjpVsDjzOTZGVDlQRA==", + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.1), + BinaryCritic(critic_field="channel_name", weight=0.2), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.1), + BinaryCritic(critic_field="next_cursor", weight=0.2), + BinaryCritic(critic_field="limit", weight=0.2), + ], + additional_messages=[ + {"role": "user", "content": "Get the last 2 messages on the general channel"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": { + "name": "Slack_GetConversationHistoryByName", + "arguments": json.dumps({ + "conversation_name": "general", + "limit": 2, + }), + }, + } + ], + }, + { + "role": "tool", + "content": json.dumps({ + "messages": [ + { + "blocks": [ + { + "block_id": "abc123", + "elements": [ + { + "elements": [ + { + "text": "Almost there, Boss, need to get some evals in!", + "type": "text", + } + ], + "type": "rich_text_section", + } + ], + "type": "rich_text", + } + ], + "client_msg_id": "msg_id_0", + "datetime_timestamp": "2025-01-21 16:59:55", + "team": "617263616465207465616D20697320617420626F7373206C6576656C", + "text": "Almost there, Boss, need to get some evals in!", + "ts": "1737507595.598529", + "type": "message", + "user": "77686F2069732074686520626F73733F", + }, + { + "blocks": [ + { + "block_id": "xyz456", + "elements": [ + { + "elements": [ + { + "text": "hey, are the Slack Tools ready yet?", + "type": "text", + } + ], + "type": "rich_text_section", + } + ], + "type": "rich_text", + } + ], + "client_msg_id": "msg_id_1", + "datetime_timestamp": "2025-01-21 16:57:35", + "team": "617263616465207465616D20697320617420626F7373206C6576656C", + "text": "hey, are the Slack Tools ready yet?", + "ts": "1737507595.598529", + "type": "message", + "user": "73616D2069732074686520626F7373", + }, + ], + "next_cursor": "dXNlcjpVsDjzOTZGVDlQRA==", + }), + "tool_call_id": "call_1", + "name": "Slack_GetConversationHistoryByName", + }, + { + "role": "assistant", + "content": 'Here are the last 2 messages from the general channel:\n\n1. **User:** 77686F2069732074686520626F73733F \n **Message:** "Almost there, Boss, need to get some evals in!" \n **Timestamp:** 2025-01-21 16:59:55\n\n2. **User:** 73616D2069732074686520626F7373 \n **Message:** "hey, are the Slack Tools ready yet?" \n **Timestamp:** 2025-01-21 16:57:35', + }, + ], + ) + + return suite diff --git a/toolkits/slack/evals/chat/messages/eval_get_dm_messages.py b/toolkits/slack/evals/chat/messages/eval_get_dm_messages.py new file mode 100644 index 00000000..88156592 --- /dev/null +++ b/toolkits/slack/evals/chat/messages/eval_get_dm_messages.py @@ -0,0 +1,191 @@ +from datetime import timedelta + +from arcade_evals import ( + BinaryCritic, + DatetimeCritic, + EvalRubric, + EvalSuite, + ExpectedToolCall, + tool_eval, +) +from arcade_tdk import ToolCatalog + +import arcade_slack +from arcade_slack.critics import RelativeTimeBinaryCritic +from arcade_slack.tools.chat import get_messages + +# Evaluation rubric +rubric = EvalRubric( + fail_threshold=0.8, + warn_threshold=0.9, +) + + +catalog = ToolCatalog() +# Register the Slack tools +catalog.add_module(arcade_slack) + + +@tool_eval() +def get_messages_in_direct_message_eval_suite() -> EvalSuite: + """Create an evaluation suite for tools getting messages in direct messages.""" + suite = EvalSuite( + name="Slack Chat Tools Evaluation", + system_message="You are an AI assistant that can interact with Slack to send messages and get information from conversations, users, etc.", + catalog=catalog, + rubric=rubric, + ) + + no_arguments_user_messages_by_username = [ + "what are the latest messages I exchanged with janedoe", + "show my messages with janedoe on Slack", + "get the messages I exchanged with janedoe", + "get the message history with janedoe", + ] + + for i, user_message in enumerate(no_arguments_user_messages_by_username): + suite.add_case( + name=f"{user_message} [{i}]", + user_message=user_message, + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": None, + "channel_name": None, + "user_ids": None, + "usernames": ["janedoe"], + "emails": None, + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.1), + BinaryCritic(critic_field="channel_name", weight=0.1), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.6), + BinaryCritic(critic_field="emails", weight=0.1), + ], + ) + + no_arguments_user_messages_by_email = [ + "what are the latest messages I exchanged with jane.doe@acme.com", + "show my messages with jane.doe@acme.com on Slack", + ] + + for i, user_message in enumerate(no_arguments_user_messages_by_email): + suite.add_case( + name=f"{user_message} [{i}]", + user_message=user_message, + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": None, + "channel_name": None, + "user_ids": None, + "usernames": None, + "emails": ["jane.doe@acme.com"], + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.1), + BinaryCritic(critic_field="channel_name", weight=0.1), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.6), + ], + ) + + suite.add_case( + name="get messages in direct conversation by username (on a specific date)", + user_message="get the messages I exchanged with janedoe on 2025-01-31", + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": None, + "channel_name": None, + "user_ids": None, + "usernames": ["janedoe"], + "emails": None, + "oldest_datetime": "2025-01-31 00:00:00", + "latest_datetime": "2025-01-31 23:59:59", + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.1), + BinaryCritic(critic_field="channel_name", weight=0.1), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.2), + BinaryCritic(critic_field="emails", weight=0.1), + DatetimeCritic( + critic_field="oldest_datetime", weight=0.2, max_difference=timedelta(minutes=2) + ), + DatetimeCritic( + critic_field="latest_datetime", weight=0.2, max_difference=timedelta(minutes=2) + ), + ], + ) + + suite.add_case( + name="get messages in direct conversation by email (on a specific date)", + user_message="get the messages I exchanged with jane.doe@acme.com on 2025-01-31", + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": None, + "channel_name": None, + "user_ids": None, + "usernames": None, + "emails": ["jane.doe@acme.com"], + "oldest_datetime": "2025-01-31 00:00:00", + "latest_datetime": "2025-01-31 23:59:59", + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.1), + BinaryCritic(critic_field="channel_name", weight=0.1), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.2), + DatetimeCritic( + critic_field="oldest_datetime", weight=0.2, max_difference=timedelta(minutes=2) + ), + DatetimeCritic( + critic_field="latest_datetime", weight=0.2, max_difference=timedelta(minutes=2) + ), + ], + ) + + suite.add_case( + name="Get conversation history oldest relative by username (2 days ago)", + user_message="Get the messages I exchanged with janedoe starting 2 days ago", + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": None, + "channel_name": None, + "user_ids": None, + "usernames": ["janedoe"], + "emails": None, + "oldest_relative": "02:00:00", + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.1), + BinaryCritic(critic_field="channel_name", weight=0.1), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.3), + BinaryCritic(critic_field="emails", weight=0.1), + RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.3), + ], + ) + + return suite diff --git a/toolkits/slack/evals/chat/messages/eval_get_mpim_messages.py b/toolkits/slack/evals/chat/messages/eval_get_mpim_messages.py new file mode 100644 index 00000000..028e4ad1 --- /dev/null +++ b/toolkits/slack/evals/chat/messages/eval_get_mpim_messages.py @@ -0,0 +1,165 @@ +from datetime import timedelta + +from arcade_evals import ( + BinaryCritic, + DatetimeCritic, + EvalRubric, + EvalSuite, + ExpectedToolCall, + tool_eval, +) +from arcade_tdk import ToolCatalog + +import arcade_slack +from arcade_slack.critics import RelativeTimeBinaryCritic +from arcade_slack.tools.chat import get_messages + +# Evaluation rubric +rubric = EvalRubric( + fail_threshold=0.8, + warn_threshold=0.9, +) + + +catalog = ToolCatalog() +# Register the Slack tools +catalog.add_module(arcade_slack) + + +@tool_eval() +def get_messages_in_multi_person_direct_message_eval_suite() -> EvalSuite: + """Create an evaluation suite for tools getting messages in multi-person direct messages.""" + suite = EvalSuite( + name="Slack Chat Tools Evaluation", + system_message="You are an AI assistant that can interact with Slack to send messages and get information from conversations, users, etc.", + catalog=catalog, + rubric=rubric, + ) + + no_arguments_user_messages_by_username = [ + ( + "what are the latest messages I exchanged in the MPIM " + "with the usernames john, ryan, and jennifer" + ), + ("show the messages in the MPIM with the usernames john, ryan, and jennifer on Slack"), + ("list the messages I exchanged in the MPIM with the usernames john, ryan, and jennifer"), + ("list the message history in the MPIM with the usernames john, ryan, and jennifer"), + ] + + for i, user_message in enumerate(no_arguments_user_messages_by_username): + suite.add_case( + name=f"{user_message} [{i}]", + user_message=user_message, + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": None, + "channel_name": None, + "user_ids": None, + "usernames": ["john", "ryan", "jennifer"], + "emails": None, + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.1), + BinaryCritic(critic_field="channel_name", weight=0.1), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.6), + BinaryCritic(critic_field="emails", weight=0.1), + ], + ) + + suite.add_case( + name="get messages in multi person direct conversation with mixed usernames and emails", + user_message=( + "get the messages I exchanged in the mpim with " + "the usernames john, ryan, and jennifer@acme.com" + ), + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": None, + "channel_name": None, + "user_ids": None, + "usernames": ["john", "ryan"], + "emails": ["jennifer@acme.com"], + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.1), + BinaryCritic(critic_field="channel_name", weight=0.1), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.35), + BinaryCritic(critic_field="emails", weight=0.35), + ], + ) + + suite.add_case( + name="get messages in direct conversation by username (on a specific date)", + user_message=( + "get the messages I exchanged in the mpim with " + "the usernames john, ryan, and jennifer on 2025-01-31" + ), + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": None, + "channel_name": None, + "user_ids": None, + "usernames": ["john", "ryan", "jennifer"], + "emails": None, + "oldest_datetime": "2025-01-31 00:00:00", + "latest_datetime": "2025-01-31 23:59:59", + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.1), + BinaryCritic(critic_field="channel_name", weight=0.1), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.2), + BinaryCritic(critic_field="emails", weight=0.1), + DatetimeCritic( + critic_field="oldest_datetime", weight=0.2, max_difference=timedelta(minutes=2) + ), + DatetimeCritic( + critic_field="latest_datetime", weight=0.2, max_difference=timedelta(minutes=2) + ), + ], + ) + + suite.add_case( + name="Get conversation history oldest relative by username (2 days ago)", + user_message=( + "Get the messages I exchanged in the MPIM with " + "the usernames john, ryan, and jennifer starting 2 days ago" + ), + expected_tool_calls=[ + ExpectedToolCall( + func=get_messages, + args={ + "conversation_id": None, + "channel_name": None, + "user_ids": None, + "usernames": ["john", "ryan", "jennifer"], + "emails": None, + "oldest_relative": "02:00:00", + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.1), + BinaryCritic(critic_field="channel_name", weight=0.1), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.3), + BinaryCritic(critic_field="emails", weight=0.1), + RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.3), + ], + ) + + return suite diff --git a/toolkits/slack/evals/chat/messages/eval_send_messages.py b/toolkits/slack/evals/chat/messages/eval_send_messages.py new file mode 100644 index 00000000..70a40688 --- /dev/null +++ b/toolkits/slack/evals/chat/messages/eval_send_messages.py @@ -0,0 +1,279 @@ +import json + +from arcade_evals import ( + BinaryCritic, + EvalRubric, + EvalSuite, + ExpectedToolCall, + SimilarityCritic, + tool_eval, +) +from arcade_tdk import ToolCatalog + +import arcade_slack +from arcade_slack.tools.chat import send_message + +# Evaluation rubric +rubric = EvalRubric( + fail_threshold=0.8, + warn_threshold=0.9, +) + + +catalog = ToolCatalog() +# Register the Slack tools +catalog.add_module(arcade_slack) + + +@tool_eval() +def send_message_eval_suite() -> EvalSuite: + """Create an evaluation suite for Slack messaging tools.""" + suite = EvalSuite( + name="Slack Messaging Tools Evaluation", + system_message="You are an AI assistant that can send direct messages and post messages to channels in Slack using the provided tools.", + catalog=catalog, + rubric=rubric, + ) + + # Send DM to User Scenarios + suite.add_case( + name="Send DM to user with clear username", + user_message="Send a direct message to johndoe saying 'Hello, can we meet at 3 PM?'", + expected_tool_calls=[ + ExpectedToolCall( + func=send_message, + args={ + "message": "Hello, can we meet at 3 PM?", + "conversation_id": None, + "channel_name": None, + "user_ids": None, + "emails": None, + "usernames": ["johndoe"], + }, + ) + ], + critics=[ + SimilarityCritic(critic_field="message", weight=0.3, similarity_threshold=0.9), + BinaryCritic(critic_field="conversation_id", weight=0.1), + BinaryCritic(critic_field="channel_name", weight=0.1), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.3), + ], + ) + + suite.add_case( + name="Send message to channel with clear name", + user_message="Post 'The new feature is now live!' in the #announcements channel", + expected_tool_calls=[ + ExpectedToolCall( + func=send_message, + args={ + "message": "The new feature is now live!", + "conversation_id": None, + "channel_name": "announcements", + "user_ids": None, + "emails": None, + "usernames": None, + }, + ) + ], + critics=[ + SimilarityCritic(critic_field="message", weight=0.3, similarity_threshold=0.9), + BinaryCritic(critic_field="conversation_id", weight=0.1), + BinaryCritic(critic_field="channel_name", weight=0.3), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.1), + ], + ) + + suite.add_case( + name="Send message to channel with ambiguous name", + user_message="Inform the team in the general channel about the upcoming maintenance", + expected_tool_calls=[ + ExpectedToolCall( + func=send_message, + args={ + "message": "Attention team: There will be upcoming maintenance. Please save your work and expect some downtime.", + "conversation_id": None, + "channel_name": "general", + "user_ids": None, + "emails": None, + "usernames": None, + }, + ) + ], + critics=[ + SimilarityCritic(critic_field="message", weight=0.3, similarity_threshold=0.9), + BinaryCritic(critic_field="conversation_id", weight=0.1), + BinaryCritic(critic_field="channel_name", weight=0.3), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.1), + ], + ) + + # Adversarial Scenarios + suite.add_case( + name="Ambiguous between DM and channel message", + user_message="general", + expected_tool_calls=[ + ExpectedToolCall( + func=send_message, + args={ + "message": "Great job on the presentation!", + "conversation_id": None, + "channel_name": "general", + "user_ids": None, + "emails": None, + "usernames": None, + }, + ) + ], + critics=[ + SimilarityCritic(critic_field="message", weight=0.3, similarity_threshold=0.9), + BinaryCritic(critic_field="conversation_id", weight=0.1), + BinaryCritic(critic_field="channel_name", weight=0.3), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.1), + ], + additional_messages=[ + {"role": "user", "content": "Send 'Great job on the presentation!' to the team"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": { + "name": "Slack_ListConversationsMetadata", + "arguments": '{"limit":20, "conversation_types":["public_channel"]}', + }, + } + ], + }, + { + "role": "tool", + "content": json.dumps({ + "conversations": [ + { + "conversation_type": "public_channel", + "id": "channel1", + "is_archived": False, + "is_member": True, + "is_private": False, + "name": "random", + "num_members": 999, + "purpose": "Random stuff", + }, + { + "conversation_type": "public_channel", + "id": "channel2", + "is_archived": False, + "is_member": True, + "is_private": False, + "name": "general", + "num_members": 999, + "purpose": "Just a general channel", + }, + ], + "next_cursor": "", + }), + "tool_call_id": "call_1", + "name": "Slack_ListPublicChannelsMetadata", + }, + { + "role": "assistant", + "content": 'To send the message "Great job on the presentation!" to the team, please let me know which Slack channel you\'d like to use:\n\n1. #random\n2. #general\n\nPlease let me know your choice!', + }, + ], + ) + + suite.add_case( + name="Multiple recipients in DM request", + user_message="Send DMs to the users 'alice' and 'bob' about rescheduling our meeting tomorrow. I have too much work to do.", + expected_tool_calls=[ + ExpectedToolCall( + func=send_message, + args={ + "conversation_id": None, + "channel_name": None, + "user_ids": None, + "emails": None, + "usernames": ["alice"], + }, + ), + ExpectedToolCall( + func=send_message, + args={ + "conversation_id": None, + "channel_name": None, + "user_ids": None, + "emails": None, + "usernames": ["bob"], + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.1), + BinaryCritic(critic_field="channel_name", weight=0.1), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.6), + ], + ) + + suite.add_case( + name="Multiple recipients in MPIM request", + user_message="Send a message to the users 'alice' and 'bob' about rescheduling our meeting tomorrow. I have too much work to do.", + expected_tool_calls=[ + ExpectedToolCall( + func=send_message, + args={ + "conversation_id": None, + "channel_name": None, + "user_ids": None, + "emails": None, + "usernames": ["alice", "bob"], + }, + ), + ], + critics=[ + BinaryCritic(critic_field="conversation_id", weight=0.1), + BinaryCritic(critic_field="channel_name", weight=0.1), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.6), + ], + ) + + suite.add_case( + name="Channel name similar to username", + user_message="Post 'sounds great!' in john-project channel", + expected_tool_calls=[ + ExpectedToolCall( + func=send_message, + args={ + "conversation_id": None, + "channel_name": "john-project", + "user_ids": None, + "emails": None, + "usernames": None, + "message": "Sounds great!", + }, + ) + ], + critics=[ + SimilarityCritic(critic_field="message", weight=0.3, similarity_threshold=0.9), + BinaryCritic(critic_field="conversation_id", weight=0.1), + BinaryCritic(critic_field="channel_name", weight=0.3), + BinaryCritic(critic_field="user_ids", weight=0.1), + BinaryCritic(critic_field="emails", weight=0.1), + BinaryCritic(critic_field="usernames", weight=0.1), + ], + ) + + return suite diff --git a/toolkits/slack/evals/eval_chat.py b/toolkits/slack/evals/eval_chat.py deleted file mode 100644 index 122aca95..00000000 --- a/toolkits/slack/evals/eval_chat.py +++ /dev/null @@ -1,1168 +0,0 @@ -import json -from datetime import timedelta - -from arcade_evals import ( - BinaryCritic, - DatetimeCritic, - EvalRubric, - EvalSuite, - ExpectedToolCall, - SimilarityCritic, - tool_eval, -) -from arcade_tdk import ToolCatalog - -import arcade_slack -from arcade_slack.critics import RelativeTimeBinaryCritic -from arcade_slack.tools.chat import ( - get_channel_metadata_by_name, - get_conversation_metadata_by_id, - get_direct_message_conversation_metadata_by_username, - get_members_in_channel_by_name, - get_members_in_conversation_by_id, - get_messages_in_channel_by_name, - get_messages_in_conversation_by_id, - get_messages_in_direct_message_conversation_by_username, - get_messages_in_multi_person_dm_conversation_by_usernames, - list_conversations_metadata, - list_direct_message_conversations_metadata, - list_group_direct_message_conversations_metadata, - list_private_channels_metadata, - list_public_channels_metadata, - send_dm_to_user, - send_message_to_channel, -) - -# Evaluation rubric -rubric = EvalRubric( - fail_threshold=0.8, - warn_threshold=0.9, -) - - -catalog = ToolCatalog() -# Register the Slack tools -catalog.add_module(arcade_slack) - - -@tool_eval() -def send_message_eval_suite() -> EvalSuite: - """Create an evaluation suite for Slack messaging tools.""" - suite = EvalSuite( - name="Slack Messaging Tools Evaluation", - system_message="You are an AI assistant that can send direct messages and post messages to channels in Slack using the provided tools.", - catalog=catalog, - rubric=rubric, - ) - - # Send DM to User Scenarios - suite.add_case( - name="Send DM to user with clear username", - user_message="Send a direct message to johndoe saying 'Hello, can we meet at 3 PM?'", - expected_tool_calls=[ - ExpectedToolCall( - func=send_dm_to_user, - args={ - "user_name": "johndoe", - "message": "Hello, can we meet at 3 PM?", - }, - ) - ], - critics=[ - BinaryCritic(critic_field="user_name", weight=0.5), - SimilarityCritic(critic_field="message", weight=0.5, similarity_threshold=0.9), - ], - ) - - suite.add_case( - name="Send DM with ambiguous username", - user_message="ask him for an update on the project", - expected_tool_calls=[ - ExpectedToolCall( - func=send_dm_to_user, - args={ - "user_name": "john", - "message": "Hi John, could you please provide an update on the Acme project?", - }, - ) - ], - critics=[ - BinaryCritic(critic_field="user_name", weight=0.75), - SimilarityCritic(critic_field="message", weight=0.25, similarity_threshold=0.6), - ], - additional_messages=[ - {"role": "user", "content": "Message John about the Acme project deadline"}, - { - "role": "assistant", - "content": "", - "tool_calls": [ - { - "id": "call_1", - "type": "function", - "function": { - "name": "Slack_ListUsers", - "arguments": '{"exclude_bots":true}', - }, - } - ], - }, - { - "role": "tool", - "content": json.dumps({ - "users": [ - { - "display_name": "john", - "email": "john@randomtech.com", - "id": "abc123", - "is_bot": False, - "name": "john", - "real_name": "John Doe", - "timezone": "America/Los_Angeles", - }, - { - "display_name": "jack", - "email": "jack@randomtech.com", - "id": "def456", - "is_bot": False, - "name": "jack", - "real_name": "Jack Doe", - "timezone": "America/Los_Angeles", - }, - ] - }), - "tool_call_id": "call_1", - "name": "Slack_ListUsers", - }, - { - "role": "assistant", - "content": "What would you like to include in the message to John about the Acme project deadline?", - }, - ], - ) - - suite.add_case( - name="Send DM with username in different format", - user_message="yes, send it", - expected_tool_calls=[ - ExpectedToolCall( - func=send_dm_to_user, - args={ - "user_name": "jane.doe", - "message": "Hi Jane, I need to reschedule our meeting. When are you available?", - }, - ) - ], - critics=[ - BinaryCritic(critic_field="user_name", weight=0.75), - SimilarityCritic(critic_field="message", weight=0.25, similarity_threshold=0.6), - ], - additional_messages=[ - {"role": "user", "content": "Message Jane.Doe asking to reschedule our meeting"}, - { - "role": "assistant", - "content": "", - "tool_calls": [ - { - "id": "call_1", - "type": "function", - "function": { - "name": "Slack_ListUsers", - "arguments": '{"exclude_bots":true}', - }, - } - ], - }, - { - "role": "tool", - "content": json.dumps({ - "users": [ - { - "display_name": "jane.doe", - "email": "jane@randomtech.com", - "id": "abc123", - "is_bot": False, - "name": "jane.doe", - "real_name": "Jane Doe", - "timezone": "America/Los_Angeles", - }, - { - "display_name": "jack", - "email": "jack@randomtech.com", - "id": "def456", - "is_bot": False, - "name": "jack", - "real_name": "Jack Doe", - "timezone": "America/Los_Angeles", - }, - ] - }), - "tool_call_id": "call_1", - "name": "Slack_ListUsers", - }, - { - "role": "assistant", - "content": "I found a user with the name 'jane.doe'. Would you like to send a message to them?", - }, - ], - ) - - # Send Message to Channel Scenarios - suite.add_case( - name="Send message to channel with clear name", - user_message="Post 'The new feature is now live!' in the #announcements channel", - expected_tool_calls=[ - ExpectedToolCall( - func=send_message_to_channel, - args={ - "channel_name": "announcements", - "message": "The new feature is now live!", - }, - ) - ], - critics=[ - BinaryCritic(critic_field="channel_name", weight=0.5), - SimilarityCritic(critic_field="message", weight=0.5), - ], - ) - - suite.add_case( - name="Send message to channel with ambiguous name", - user_message="Inform the team in the general channel about the upcoming maintenance", - expected_tool_calls=[ - ExpectedToolCall( - func=send_message_to_channel, - args={ - "channel_name": "general", - "message": "Attention team: There will be upcoming maintenance. Please save your work and expect some downtime.", - }, - ) - ], - critics=[ - SimilarityCritic(critic_field="channel_name", weight=0.8), - SimilarityCritic(critic_field="message", weight=0.2, similarity_threshold=0.6), - ], - ) - - # Adversarial Scenarios - suite.add_case( - name="Ambiguous between DM and channel message", - user_message="general", - expected_tool_calls=[ - ExpectedToolCall( - func=send_message_to_channel, - args={ - "channel_name": "general", - "message": "Great job on the presentation!", - }, - ) - ], - critics=[ - SimilarityCritic(critic_field="channel_name", weight=0.4), - SimilarityCritic(critic_field="message", weight=0.6), - ], - additional_messages=[ - {"role": "user", "content": "Send 'Great job on the presentation!' to the team"}, - { - "role": "assistant", - "content": "", - "tool_calls": [ - { - "id": "call_1", - "type": "function", - "function": { - "name": "Slack_ListPublicChannelsMetadata", - "arguments": '{"limit":20}', - }, - } - ], - }, - { - "role": "tool", - "content": json.dumps({ - "conversations": [ - { - "conversation_type": "public_channel", - "id": "channel1", - "is_archived": False, - "is_member": True, - "is_private": False, - "name": "random", - "num_members": 999, - "purpose": "Random stuff", - }, - { - "conversation_type": "public_channel", - "id": "channel2", - "is_archived": False, - "is_member": True, - "is_private": False, - "name": "general", - "num_members": 999, - "purpose": "Just a general channel", - }, - ], - "next_cursor": "", - }), - "tool_call_id": "call_1", - "name": "Slack_ListPublicChannelsMetadata", - }, - { - "role": "assistant", - "content": 'To send the message "Great job on the presentation!" to the team, please let me know which Slack channel you\'d like to use:\n\n1. #random\n2. #general\n\nPlease let me know your choice!', - }, - ], - ) - - # Multiple recipients in DM request - suite.add_case( - name="Multiple recipients in DM request", - user_message="Send DMs to the users 'alice' and 'bob' about pushing the meeting tomorrow. I have to much work to do.", - expected_tool_calls=[ - ExpectedToolCall( - func=send_dm_to_user, - args={ - "user_name": "alice", - "message": "Hi Alice, about our meeting tomorrow, let's reschedule? I am swamped with work.", - }, - ), - ExpectedToolCall( - func=send_dm_to_user, - args={ - "user_name": "bob", - "message": "Hi Bob, about our meeting tomorrow, let's reschedule? I am swamped with work.", - }, - ), - ], - critics=[ - SimilarityCritic(critic_field="user_name", weight=0.75), - SimilarityCritic(critic_field="message", weight=0.25, similarity_threshold=0.5), - ], - ) - - suite.add_case( - name="Channel name similar to username", - user_message="Post 'sounds great!' in john-project channel", - expected_tool_calls=[ - ExpectedToolCall( - func=send_message_to_channel, - args={ - "channel_name": "john-project", - "message": "Sounds great!", - }, - ) - ], - critics=[ - BinaryCritic(critic_field="channel_name", weight=0.5), - SimilarityCritic(critic_field="message", weight=0.5), - ], - ) - - return suite - - -@tool_eval() -def list_conversations_eval_suite() -> EvalSuite: - """Create an evaluation suite for tools listing conversations.""" - suite = EvalSuite( - name="Slack Messaging Tools Evaluation", - system_message="You are an AI assistant that can interact with Slack to send messages and get information from conversations, users, etc.", - catalog=catalog, - rubric=rubric, - ) - - cases = [ - ( - "List my conversations", - "List all conversations I am a member of", - list_conversations_metadata, - ), - ( - "List public channels", - "List all public channels", - list_public_channels_metadata, - ), - ( - "List private channels", - "List all private channels", - list_private_channels_metadata, - ), - ( - "List group direct message channels", - "List all group direct message channels", - list_group_direct_message_conversations_metadata, - ), - ( - "List individual direct message channels", - "List all individual direct message channels", - list_direct_message_conversations_metadata, - ), - ( - "List direct message channels", - "List all direct message channels", - list_direct_message_conversations_metadata, - ), - ( - "List public and private channels", - "List public and private channels I am a member of", - list_public_channels_metadata, - list_private_channels_metadata, - ), - ( - "List public channels and direct message conversations", - "List public channels and direct message conversations I am a member of", - list_public_channels_metadata, - list_direct_message_conversations_metadata, - ), - ] - - for name, user_message, *expect_called_tool_functions in cases: - suite.add_case( - name=name, - user_message=user_message, - expected_tool_calls=[ - ExpectedToolCall( - func=tool_function, - args={}, - ) - for tool_function in expect_called_tool_functions - ], - ) - - return suite - - -@tool_eval() -def get_conversations_metadata_eval_suite() -> EvalSuite: - """Create an evaluation suite for tools getting conversations metadata.""" - suite = EvalSuite( - name="Slack Tools Evaluation", - system_message="You are an AI assistant that can interact with Slack to get information from conversations, users, etc.", - catalog=catalog, - rubric=rubric, - ) - - suite.add_case( - name="Get channel metadata by name", - user_message="Get the metadata of the #general channel", - expected_tool_calls=[ - ExpectedToolCall( - func=get_channel_metadata_by_name, - args={ - "channel_name": "general", - }, - ), - ], - critics=[ - BinaryCritic(critic_field="conversation_name", weight=1.0), - ], - ) - - suite.add_case( - name="Get conversation metadata by id", - user_message="Get the metadata of the conversation with id '1234567890'", - expected_tool_calls=[ - ExpectedToolCall( - func=get_conversation_metadata_by_id, - args={ - "conversation_id": "1234567890", - }, - ), - ], - critics=[ - BinaryCritic(critic_field="conversation_id", weight=1.0), - ], - ) - - get_metadata_by_username_user_messages = [ - "get the metadata of the direct message conversation with the user 'jane.doe'" - "get data about my private conversation with the user 'jane.doe'", - "get data about my IM conversation with the 'jane.doe'", - ] - - for i, user_message in enumerate(get_metadata_by_username_user_messages): - suite.add_case( - name=f"Get direct message conversation metadata by username {i}", - user_message=user_message, - expected_tool_calls=[ - ExpectedToolCall( - func=get_direct_message_conversation_metadata_by_username, - args={ - "username": "jane.doe", - }, - ), - ], - ) - - return suite - - -@tool_eval() -def get_conversations_members_eval_suite() -> EvalSuite: - """Create an evaluation suite for tools getting conversations members.""" - suite = EvalSuite( - name="Slack Tools Evaluation", - system_message="You are an AI assistant that can interact with Slack to send messages and get information from conversations, users, etc.", - catalog=catalog, - rubric=rubric, - ) - - user_messages = [ - "Get the members of the #general channel", - "Get the members of the general channel", - "Get a list of people in the #general channel", - "Get a list of people in the general channel", - "Show me who's in the #general channel", - "Show me who's in the general channel", - "Who is in the #general channel?", - "Who is in the general channel?", - ] - - for user_message in user_messages: - suite.add_case( - name=f"Get channel members by name: {user_message}", - user_message=user_message, - expected_tool_calls=[ - ExpectedToolCall( - func=get_members_in_channel_by_name, - args={ - "channel_name": "general", - }, - ), - ], - critics=[ - BinaryCritic(critic_field="conversation_name", weight=1.0), - ], - ) - - suite.add_case( - name="Get conversation members by id", - user_message="Get the members of the conversation with id '1234567890'", - expected_tool_calls=[ - ExpectedToolCall( - func=get_members_in_conversation_by_id, - args={ - "conversation_id": "1234567890", - }, - ), - ], - critics=[ - BinaryCritic(critic_field="conversation_id", weight=1.0), - ], - ) - - return suite - - -@tool_eval() -def get_messages_in_channel_eval_suite() -> EvalSuite: - """Create an evaluation suite for tools getting messages in channels.""" - suite = EvalSuite( - name="Slack Chat Tools Evaluation", - system_message="You are an AI assistant that can interact with Slack to send messages and get information from conversations, users, etc.", - catalog=catalog, - rubric=rubric, - ) - - no_arguments_user_messages_by_channel_name = [ - "what are the latest messages in the #general channel", - "show me the messages in the general channel", - "list the messages in the #general channel", - "list the messages in the general channel", - ] - - for i, user_message in enumerate(no_arguments_user_messages_by_channel_name): - suite.add_case( - name=f"Get messages in conversation by name {i}: '{user_message}'", - user_message=user_message, - expected_tool_calls=[ - ExpectedToolCall( - func=get_messages_in_channel_by_name, - args={ - "conversation_name": "general", - }, - ), - ], - critics=[ - BinaryCritic(critic_field="conversation_name", weight=1.0), - ], - ) - - no_arguments_user_messages_by_conversation_id = [ - "Get the history of the conversation with id '1234567890'", - "Get the history of the conversation with id '1234567890'", - "list the messages in the conversation with id '1234567890'", - "list the messages in the conversation with id '1234567890'", - ] - - for user_message in no_arguments_user_messages_by_conversation_id: - suite.add_case( - name=f"Get conversation history by id: '{user_message}'", - user_message=user_message, - expected_tool_calls=[ - ExpectedToolCall( - func=get_messages_in_conversation_by_id, - args={ - "conversation_id": "1234567890", - }, - ), - ], - critics=[ - BinaryCritic(critic_field="conversation_id", weight=1.0), - ], - ) - - suite.add_case( - name="Get conversation history with limit by name", - user_message="Get the last 10 messages in the #general channel", - expected_tool_calls=[ - ExpectedToolCall( - func=get_messages_in_channel_by_name, - args={ - "conversation_name": "general", - "limit": 10, - }, - ), - ], - critics=[ - BinaryCritic(critic_field="conversation_name", weight=1.0), - ], - ) - - suite.add_case( - name="Get conversation history with limit by id", - user_message="Get the last 25 messages in the conversation with id '1234567890'", - expected_tool_calls=[ - ExpectedToolCall( - func=get_messages_in_conversation_by_id, - args={ - "conversation_id": "1234567890", - "limit": 25, - }, - ), - ], - critics=[ - BinaryCritic(critic_field="conversation_id", weight=1.0), - ], - ) - - # Relative time eval cases by id - - suite.add_case( - name="Get conversation history oldest relative by id (2 days ago)", - user_message="Get the messages in the conversation with id '1234567890' starting 2 days ago", - expected_tool_calls=[ - ExpectedToolCall( - func=get_messages_in_conversation_by_id, - args={ - "conversation_id": "1234567890", - "oldest_relative": "02:00:00", - }, - ), - ], - critics=[ - BinaryCritic(critic_field="conversation_id", weight=0.5), - RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.5), - ], - ) - - suite.add_case( - name="Get conversation history oldest and latest relative by id", - user_message="Get the messages in the conversation with id '1234567890' from 2 days ago to 3 hours ago", - expected_tool_calls=[ - ExpectedToolCall( - func=get_messages_in_conversation_by_id, - args={ - "conversation_id": "1234567890", - "oldest_relative": "02:00:00", - "latest_relative": "00:03:00", - }, - ), - ], - critics=[ - BinaryCritic(critic_field="conversation_id", weight=1 / 3), - RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=1 / 3), - RelativeTimeBinaryCritic(critic_field="latest_relative", weight=1 / 3), - ], - ) - - suite.add_case( - name="Get conversation history oldest relative by id (1 week ago)", - user_message="Get the messages in the conversation with id '1234567890' starting 1 week ago", - expected_tool_calls=[ - ExpectedToolCall( - func=get_messages_in_conversation_by_id, - args={ - "conversation_id": "1234567890", - "oldest_relative": "07:00:00", - }, - ), - ], - critics=[ - BinaryCritic(critic_field="conversation_id", weight=0.5), - RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.5), - ], - ) - - suite.add_case( - name="Get conversation history oldest relative by id (yesterday)", - user_message="Get the messages in the conversation with id '1234567890' from yesterday", - expected_tool_calls=[ - ExpectedToolCall( - func=get_messages_in_conversation_by_id, - args={ - "conversation_id": "1234567890", - "oldest_relative": "01:00:00", - }, - ), - ], - critics=[ - BinaryCritic(critic_field="conversation_id", weight=0.5), - RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.5), - ], - ) - - # Relative time eval cases by name - - suite.add_case( - name="Get conversation history oldest relative by name (2 days ago)", - user_message="Get the messages in the #general channel starting 2 days ago", - expected_tool_calls=[ - ExpectedToolCall( - func=get_messages_in_channel_by_name, - args={ - "conversation_name": "general", - "oldest_relative": "02:00:00", - }, - ), - ], - critics=[ - BinaryCritic(critic_field="conversation_name", weight=0.5), - RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.5), - ], - ) - - suite.add_case( - name="Get conversation history oldest and latest relative by name", - user_message="Get the messages in the #general channel from 2 days ago to 3 hours ago", - expected_tool_calls=[ - ExpectedToolCall( - func=get_messages_in_channel_by_name, - args={ - "conversation_name": "general", - "oldest_relative": "02:00:00", - "latest_relative": "00:03:00", - }, - ), - ], - critics=[ - BinaryCritic(critic_field="conversation_name", weight=1 / 3), - RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=1 / 3), - RelativeTimeBinaryCritic(critic_field="latest_relative", weight=1 / 3), - ], - ) - - suite.add_case( - name="Get conversation history oldest relative by name (yesterday)", - user_message="Get the messages in the #general channel from yesterday", - expected_tool_calls=[ - ExpectedToolCall( - func=get_messages_in_channel_by_name, - args={ - "conversation_name": "general", - "oldest_relative": "01:00:00", - }, - ), - ], - critics=[ - BinaryCritic(critic_field="conversation_name", weight=0.5), - RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.5), - ], - ) - - suite.add_case( - name="Get conversation history oldest relative by name (last week)", - user_message="Get the messages in the #general channel from last week", - expected_tool_calls=[ - ExpectedToolCall( - func=get_messages_in_channel_by_name, - args={ - "conversation_name": "general", - "oldest_relative": "07:00:00", - }, - ), - ], - critics=[ - BinaryCritic(critic_field="conversation_name", weight=0.5), - RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.5), - ], - ) - - # Absolute time eval cases by id - - suite.add_case( - name="Get conversation history oldest absolute by id (on a specific date)", - user_message="Get the messages in the conversation with id '1234567890' from 2025-01-20", - expected_tool_calls=[ - ExpectedToolCall( - func=get_messages_in_conversation_by_id, - args={ - "conversation_id": "1234567890", - "oldest_datetime": "2025-01-20 00:00:00", - "latest_datetime": "2025-01-20 23:59:59", - }, - ), - ], - critics=[ - BinaryCritic(critic_field="conversation_id", weight=1 / 3), - DatetimeCritic( - critic_field="oldest_datetime", weight=1 / 3, max_difference=timedelta(minutes=2) - ), - DatetimeCritic( - critic_field="latest_datetime", weight=1 / 3, max_difference=timedelta(minutes=2) - ), - ], - ) - - suite.add_case( - name="Get conversation history oldest absolute by id (between a date range)", - user_message="Get the messages in the conversation with id '1234567890' from 2025-01-20 to 2025-01-25", - expected_tool_calls=[ - ExpectedToolCall( - func=get_messages_in_conversation_by_id, - args={ - "conversation_id": "1234567890", - "oldest_datetime": "2025-01-20 00:00:00", - "latest_datetime": "2025-01-25 23:59:59", - }, - ), - ], - critics=[ - BinaryCritic(critic_field="conversation_id", weight=1 / 3), - DatetimeCritic( - critic_field="oldest_datetime", weight=1 / 3, max_difference=timedelta(minutes=2) - ), - DatetimeCritic( - critic_field="latest_datetime", weight=1 / 3, max_difference=timedelta(minutes=2) - ), - ], - ) - - suite.add_case( - name="Get conversation history oldest absolute by name (on a specific date)", - user_message="Get the messages in the #general channel from 2025-01-20", - expected_tool_calls=[ - ExpectedToolCall( - func=get_messages_in_channel_by_name, - args={ - "conversation_name": "general", - "oldest_datetime": "2025-01-20 00:00:00", - "latest_datetime": "2025-01-20 23:59:59", - }, - ), - ], - critics=[ - BinaryCritic(critic_field="conversation_name", weight=1 / 3), - # We use a timedelta of 10 seconds because sometimes the LLM will select the limit - # date at 23:59:59, other times it'll select the next day at 00:00:00. - DatetimeCritic( - critic_field="oldest_datetime", weight=1 / 3, max_difference=timedelta(seconds=10) - ), - DatetimeCritic( - critic_field="latest_datetime", weight=1 / 3, max_difference=timedelta(seconds=10) - ), - ], - ) - - suite.add_case( - name="Get conversation history oldest absolute by name (between a date range)", - user_message="Get the messages in the #general channel from 2025-01-20 to 2025-01-25", - expected_tool_calls=[ - ExpectedToolCall( - func=get_messages_in_channel_by_name, - args={ - "conversation_name": "general", - "oldest_datetime": "2025-01-20 00:00:00", - "latest_datetime": "2025-01-25 23:59:59", - }, - ), - ], - critics=[ - BinaryCritic(critic_field="conversation_name", weight=1 / 3), - # We use a timedelta of 10 seconds because sometimes the LLM will select the limit - # date at 23:59:59, other times it'll select the next day at 00:00:00. - DatetimeCritic( - critic_field="oldest_datetime", weight=1 / 3, max_difference=timedelta(seconds=10) - ), - DatetimeCritic( - critic_field="latest_datetime", weight=1 / 3, max_difference=timedelta(seconds=10) - ), - ], - ) - - # Eval case for pagination - - suite.add_case( - name="Get conversation history with pagination", - user_message="get the next 5 messages", - expected_tool_calls=[ - ExpectedToolCall( - func=get_messages_in_channel_by_name, - args={ - "conversation_name": "general", - "limit": 5, - "cursor": "dXNlcjpVsDjzOTZGVDlQRA==", - }, - ), - ], - critics=[ - BinaryCritic(critic_field="conversation_name", weight=1 / 3), - BinaryCritic(critic_field="cursor", weight=1 / 3), - BinaryCritic(critic_field="limit", weight=1 / 3), - ], - additional_messages=[ - {"role": "user", "content": "Get the last 2 messages on the general channel"}, - { - "role": "assistant", - "content": "", - "tool_calls": [ - { - "id": "call_1", - "type": "function", - "function": { - "name": "Slack_GetConversationHistoryByName", - "arguments": json.dumps({ - "conversation_name": "general", - "limit": 2, - }), - }, - } - ], - }, - { - "role": "tool", - "content": json.dumps({ - "messages": [ - { - "blocks": [ - { - "block_id": "abc123", - "elements": [ - { - "elements": [ - { - "text": "Almost there, Boss, need to get some evals in!", - "type": "text", - } - ], - "type": "rich_text_section", - } - ], - "type": "rich_text", - } - ], - "client_msg_id": "msg_id_0", - "datetime_timestamp": "2025-01-21 16:59:55", - "team": "617263616465207465616D20697320617420626F7373206C6576656C", - "text": "Almost there, Boss, need to get some evals in!", - "ts": "1737507595.598529", - "type": "message", - "user": "77686F2069732074686520626F73733F", - }, - { - "blocks": [ - { - "block_id": "xyz456", - "elements": [ - { - "elements": [ - { - "text": "hey, are the Slack Tools ready yet?", - "type": "text", - } - ], - "type": "rich_text_section", - } - ], - "type": "rich_text", - } - ], - "client_msg_id": "msg_id_1", - "datetime_timestamp": "2025-01-21 16:57:35", - "team": "617263616465207465616D20697320617420626F7373206C6576656C", - "text": "hey, are the Slack Tools ready yet?", - "ts": "1737507595.598529", - "type": "message", - "user": "73616D2069732074686520626F7373", - }, - ], - "next_cursor": "dXNlcjpVsDjzOTZGVDlQRA==", - }), - "tool_call_id": "call_1", - "name": "Slack_GetConversationHistoryByName", - }, - { - "role": "assistant", - "content": 'Here are the last 2 messages from the general channel:\n\n1. **User:** 77686F2069732074686520626F73733F \n **Message:** "Almost there, Boss, need to get some evals in!" \n **Timestamp:** 2025-01-21 16:59:55\n\n2. **User:** 73616D2069732074686520626F7373 \n **Message:** "hey, are the Slack Tools ready yet?" \n **Timestamp:** 2025-01-21 16:57:35', - }, - ], - ) - - return suite - - -@tool_eval() -def get_messages_in_direct_message_eval_suite() -> EvalSuite: - """Create an evaluation suite for tools getting messages in direct messages.""" - suite = EvalSuite( - name="Slack Chat Tools Evaluation", - system_message="You are an AI assistant that can interact with Slack to send messages and get information from conversations, users, etc.", - catalog=catalog, - rubric=rubric, - ) - - no_arguments_user_messages_by_username = [ - "what are the latest messages I exchanged with jane.doe", - "show my messages with jane.doe on Slack", - "list the messages I exchanged with jane.doe", - "list the message history with jane.doe", - ] - - for i, user_message in enumerate(no_arguments_user_messages_by_username): - suite.add_case( - name=f"{user_message} [{i}]", - user_message=user_message, - expected_tool_calls=[ - ExpectedToolCall( - func=get_messages_in_direct_message_conversation_by_username, - args={ - "username": "jane.doe", - }, - ), - ], - critics=[ - BinaryCritic(critic_field="username", weight=1.0), - ], - ) - - suite.add_case( - name="get messages in direct conversation by username (on a specific date)", - user_message="get the messages I exchanged with jane.doe on 2025-01-31", - expected_tool_calls=[ - ExpectedToolCall( - func=get_messages_in_direct_message_conversation_by_username, - args={ - "username": "jane.doe", - "oldest_datetime": "2025-01-31 00:00:00", - "latest_datetime": "2025-01-31 23:59:59", - }, - ), - ], - critics=[ - BinaryCritic(critic_field="username", weight=1 / 3), - DatetimeCritic( - critic_field="oldest_datetime", weight=1 / 3, max_difference=timedelta(minutes=2) - ), - DatetimeCritic( - critic_field="latest_datetime", weight=1 / 3, max_difference=timedelta(minutes=2) - ), - ], - ) - - suite.add_case( - name="Get conversation history oldest relative by username (2 days ago)", - user_message="Get the messages I exchanged with jane.doe starting 2 days ago", - expected_tool_calls=[ - ExpectedToolCall( - func=get_messages_in_direct_message_conversation_by_username, - args={ - "username": "jane.doe", - "oldest_relative": "02:00:00", - }, - ), - ], - critics=[ - BinaryCritic(critic_field="username", weight=0.5), - RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.5), - ], - ) - - return suite - - -@tool_eval() -def get_messages_in_multi_person_direct_message_eval_suite() -> EvalSuite: - """Create an evaluation suite for tools getting messages in multi-person direct messages.""" - suite = EvalSuite( - name="Slack Chat Tools Evaluation", - system_message="You are an AI assistant that can interact with Slack to send messages and get information from conversations, users, etc.", - catalog=catalog, - rubric=rubric, - ) - - no_arguments_user_messages_by_username = [ - "what are the latest messages I exchanged together with the usernames john, ryan, and jennifer", - "show the messages in the multi person dm with the usernames john, ryan, and jennifer on Slack", - "list the messages I exchanged together with the usernames john, ryan, and jennifer", - "list the message history together with the usernames john, ryan, and jennifer", - ] - - for i, user_message in enumerate(no_arguments_user_messages_by_username): - suite.add_case( - name=f"{user_message} [{i}]", - user_message=user_message, - expected_tool_calls=[ - ExpectedToolCall( - func=get_messages_in_multi_person_dm_conversation_by_usernames, - args={ - "usernames": ["john", "ryan", "jennifer"], - }, - ), - ], - critics=[ - BinaryCritic(critic_field="usernames", weight=1.0), - ], - ) - - suite.add_case( - name="get messages in direct conversation by username (on a specific date)", - user_message="get the messages I exchanged together with the usernames john, ryan, and jennifer on 2025-01-31", - expected_tool_calls=[ - ExpectedToolCall( - func=get_messages_in_multi_person_dm_conversation_by_usernames, - args={ - "usernames": ["john", "ryan", "jennifer"], - "oldest_datetime": "2025-01-31 00:00:00", - "latest_datetime": "2025-01-31 23:59:59", - }, - ), - ], - critics=[ - BinaryCritic(critic_field="usernames", weight=1 / 3), - DatetimeCritic( - critic_field="oldest_datetime", weight=1 / 3, max_difference=timedelta(minutes=2) - ), - DatetimeCritic( - critic_field="latest_datetime", weight=1 / 3, max_difference=timedelta(minutes=2) - ), - ], - ) - - suite.add_case( - name="Get conversation history oldest relative by username (2 days ago)", - user_message="Get the messages I exchanged together with the usernames john, ryan, and jennifer starting 2 days ago", - expected_tool_calls=[ - ExpectedToolCall( - func=get_messages_in_multi_person_dm_conversation_by_usernames, - args={ - "usernames": ["john", "ryan", "jennifer"], - "oldest_relative": "02:00:00", - }, - ), - ], - critics=[ - BinaryCritic(critic_field="usernames", weight=0.5), - RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.5), - ], - ) - - return suite diff --git a/toolkits/slack/evals/eval_users.py b/toolkits/slack/evals/eval_users.py index b1f3caa1..9b724145 100644 --- a/toolkits/slack/evals/eval_users.py +++ b/toolkits/slack/evals/eval_users.py @@ -10,7 +10,10 @@ from arcade_evals import ( from arcade_tdk import ToolCatalog import arcade_slack -from arcade_slack.tools.users import get_user_info_by_id, list_users +from arcade_slack.tools.users import ( + get_users_info, + list_users, +) # Evaluation rubric rubric = EvalRubric( @@ -25,7 +28,7 @@ catalog.add_module(arcade_slack) @tool_eval() -def get_user_info_by_id_eval_suite() -> EvalSuite: +def get_user_info_eval_suite() -> EvalSuite: """Create an evaluation suite for tools getting user info by id.""" suite = EvalSuite( name="Slack Users Tools Evaluation", @@ -53,13 +56,101 @@ def get_user_info_by_id_eval_suite() -> EvalSuite: user_message=user_message, expected_tool_calls=[ ExpectedToolCall( - func=get_user_info_by_id, - args={"user_id": expected_user_id}, + func=get_users_info, + args={ + "user_ids": [expected_user_id], + "usernames": None, + "emails": None, + }, ) ], - critics=[BinaryCritic(critic_field="user_id", weight=1.0)], + critics=[ + BinaryCritic(critic_field="user_ids", weight=1 / 3), + BinaryCritic(critic_field="usernames", weight=1 / 3), + BinaryCritic(critic_field="emails", weight=1 / 3), + ], ) + suite.add_case( + name="get user by username", + user_message="get the user 'johndoe'", + expected_tool_calls=[ + ExpectedToolCall( + func=get_users_info, + args={ + "usernames": ["johndoe"], + "user_ids": None, + "emails": None, + }, + ), + ], + critics=[ + BinaryCritic(critic_field="usernames", weight=1 / 3), + BinaryCritic(critic_field="user_ids", weight=1 / 3), + BinaryCritic(critic_field="emails", weight=1 / 3), + ], + ) + + suite.add_case( + name="get user by email", + user_message="get the user 'john.doe@acme.com'", + expected_tool_calls=[ + ExpectedToolCall( + func=get_users_info, + args={ + "usernames": None, + "user_ids": None, + "emails": ["john.doe@acme.com"], + }, + ), + ], + critics=[ + BinaryCritic(critic_field="usernames", weight=1 / 3), + BinaryCritic(critic_field="user_ids", weight=1 / 3), + BinaryCritic(critic_field="emails", weight=1 / 3), + ], + ) + + suite.add_case( + name="get multiple users by username", + user_message="get the users with the usernames 'johndoe' and 'foobar'", + expected_tool_calls=[ + ExpectedToolCall( + func=get_users_info, + args={ + "usernames": ["johndoe", "foobar"], + "user_ids": None, + "emails": None, + }, + ), + ], + critics=[ + BinaryCritic(critic_field="usernames", weight=1 / 3), + BinaryCritic(critic_field="user_ids", weight=1 / 3), + BinaryCritic(critic_field="emails", weight=1 / 3), + ], + ) + + suite.add_case( + name="get multiple users by email", + user_message="get the users with the emails 'john.doe@acme.com' and 'jane.doe@acme.com'", + expected_tool_calls=[ + ExpectedToolCall( + func=get_users_info, + args={ + "usernames": None, + "user_ids": None, + "emails": ["john.doe@acme.com", "jane.doe@acme.com"], + }, + ), + ], + critics=[ + BinaryCritic(critic_field="usernames", weight=1 / 3), + BinaryCritic(critic_field="user_ids", weight=1 / 3), + BinaryCritic(critic_field="emails", weight=1 / 3), + ], + ) + return suite diff --git a/toolkits/slack/pyproject.toml b/toolkits/slack/pyproject.toml index e2231396..9a630b44 100644 --- a/toolkits/slack/pyproject.toml +++ b/toolkits/slack/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "arcade_slack" -version = "0.4.6" +version = "0.5.0" description = "Arcade.dev LLM tools for Slack" requires-python = ">=3.10" dependencies = [ "aiodns>=1.0,<2.0.0", "typing; python_version < '3.7'", "aiohttp>=3.7.3,<4.0.0", "arcade-tdk>=2.0.0,<3.0.0", "slack-sdk>=3.31.0,<4.0.0",] diff --git a/toolkits/slack/tests/test_chat.py b/toolkits/slack/tests/test_chat.py index f10d00d2..05558d3a 100644 --- a/toolkits/slack/tests/test_chat.py +++ b/toolkits/slack/tests/test_chat.py @@ -1,4 +1,4 @@ -import copy +import json from datetime import datetime, timezone from unittest.mock import Mock, call, patch @@ -10,30 +10,18 @@ from slack_sdk.web.async_slack_response import AsyncSlackResponse from arcade_slack.constants import MAX_PAGINATION_SIZE_LIMIT from arcade_slack.models import ConversationType, ConversationTypeSlackName from arcade_slack.tools.chat import ( - get_channel_metadata_by_name, - get_conversation_metadata_by_id, - get_direct_message_conversation_metadata_by_username, - get_members_in_channel_by_name, - get_members_in_conversation_by_id, - get_messages_in_channel_by_name, - get_messages_in_conversation_by_id, - get_messages_in_direct_message_conversation_by_username, - get_messages_in_multi_person_dm_conversation_by_usernames, - get_multi_person_dm_conversation_metadata_by_usernames, - list_conversations_metadata, - list_direct_message_conversations_metadata, - list_group_direct_message_conversations_metadata, - list_private_channels_metadata, - list_public_channels_metadata, - send_dm_to_user, - send_message_to_channel, + get_conversation_metadata, + get_messages, + get_users_in_conversation, + list_conversations, + send_message, ) -from arcade_slack.utils import extract_basic_user_info, extract_conversation_metadata +from arcade_slack.utils import cast_user_dict, extract_conversation_metadata @pytest.fixture -def mock_list_conversations_metadata(mocker): - return mocker.patch("arcade_slack.tools.chat.list_conversations_metadata", autospec=True) +def mock_list_conversations(mocker): + return mocker.patch("arcade_slack.tools.chat.list_conversations", autospec=True) @pytest.fixture @@ -42,72 +30,147 @@ def mock_channel_info() -> dict: @pytest.mark.asyncio -async def test_send_dm_to_user(mock_context, mock_chat_slack_client): - mock_chat_slack_client.users_list.return_value = { - "ok": True, - "members": [{"name": "testuser", "id": "U12345"}], - } +async def test_send_message_to_conversation_id( + mock_context, + mock_chat_slack_client, +): + mock_slack_response = Mock(spec=AsyncSlackResponse) + mock_slack_response.data = {"ok": True} + mock_chat_slack_client.chat_postMessage.return_value = mock_slack_response + + response = await send_message(mock_context, conversation_id="abc123", message="Hello!") + + assert response["success"] is True + assert response["data"]["ok"] is True + mock_chat_slack_client.chat_postMessage.assert_called_once_with(channel="abc123", text="Hello!") + + +@pytest.mark.asyncio +async def test_send_message_to_username( + mock_context, + mock_chat_slack_client, + mock_user_retrieval_slack_client, +): + mock_chat_slack_client.auth_test.return_value = {"ok": True, "user_id": "current_user_id"} + mock_user_retrieval_slack_client.users_list.side_effect = [ + { + "ok": True, + "members": [{"name": "foo", "id": "bar"}], + "response_metadata": {"next_cursor": "123"}, + }, + { + "ok": True, + "members": [{"name": "foobar", "id": "foobar_user_id"}], + }, + ] mock_chat_slack_client.conversations_open.return_value = { "ok": True, - "channel": {"id": "D12345"}, + "channel": { + "id": "conversation_id", + "is_im": True, + }, } mock_slack_response = Mock(spec=AsyncSlackResponse) mock_slack_response.data = {"ok": True} mock_chat_slack_client.chat_postMessage.return_value = mock_slack_response - response = await send_dm_to_user(mock_context, "testuser", "Hello!") + response = await send_message( + context=mock_context, + usernames=["foobar"], + message="Hello, world!", + ) - assert response["response"]["ok"] is True - mock_chat_slack_client.users_list.assert_called_once() - mock_chat_slack_client.conversations_open.assert_called_once_with(users=["U12345"]) - mock_chat_slack_client.chat_postMessage.assert_called_once_with(channel="D12345", text="Hello!") + assert response["success"] is True + assert response["data"]["ok"] is True + + mock_chat_slack_client.auth_test.assert_called_once() + assert mock_user_retrieval_slack_client.users_list.call_count == 2 + mock_chat_slack_client.conversations_open.assert_called_once_with( + users=[ + "current_user_id", + "foobar_user_id", + ], + return_im=True, + ) + mock_chat_slack_client.chat_postMessage.assert_called_once_with( + channel="conversation_id", + text="Hello, world!", + ) @pytest.mark.asyncio -async def test_send_dm_to_inexistent_user(mock_context, mock_chat_slack_client): - mock_chat_slack_client.users_list.return_value = { +async def test_send_dm_to_inexistent_user( + mock_context, + mock_chat_slack_client, + mock_user_retrieval_slack_client, +): + mock_chat_slack_client.auth_test.return_value = {"ok": True, "user_id": "current_user_id"} + mock_user_retrieval_slack_client.users_list.return_value = { "ok": True, - "members": [{"name": "testuser", "id": "U12345"}], + "members": [{"name": "foo", "id": "bar"}], } - with pytest.raises(RetryableToolError): - await send_dm_to_user(mock_context, "inexistent_user", "Hello!") + with pytest.raises(RetryableToolError) as error: + await send_message(mock_context, usernames=["inexistent_user"], message="Hello!") - mock_chat_slack_client.users_list.assert_called_once() + assert "inexistent_user" in error.value.message + assert "foo" in error.value.additional_prompt_content + assert "bar" in error.value.additional_prompt_content + mock_user_retrieval_slack_client.users_list.assert_called_once() mock_chat_slack_client.conversations_open.assert_not_called() mock_chat_slack_client.chat_postMessage.assert_not_called() @pytest.mark.asyncio -async def test_send_message_to_channel(mock_context, mock_chat_slack_client): - mock_chat_slack_client.conversations_list.return_value = { +async def test_send_message_to_channel_success( + mock_context, + mock_chat_slack_client, + mock_conversation_retrieval_slack_client, +): + mock_conversation_retrieval_slack_client.conversations_list.return_value = { "ok": True, - "channels": [{"id": "C12345", "name": "general", "is_member": True, "is_group": True}], + "channels": [{"id": "channel_id", "name": "general", "is_member": True, "is_group": True}], } mock_slack_response = Mock(spec=AsyncSlackResponse) mock_slack_response.data = {"ok": True} mock_chat_slack_client.chat_postMessage.return_value = mock_slack_response - response = await send_message_to_channel(mock_context, "general", "Hello, channel!") + response = await send_message(mock_context, channel_name="general", message="Hello, channel!") - assert response["response"]["ok"] is True - mock_chat_slack_client.conversations_list.assert_called_once() + assert response["success"] is True + assert response["data"]["ok"] is True + mock_conversation_retrieval_slack_client.conversations_list.assert_called_once() mock_chat_slack_client.chat_postMessage.assert_called_once_with( - channel="C12345", text="Hello, channel!" + channel="channel_id", text="Hello, channel!" ) @pytest.mark.asyncio -async def test_send_message_to_inexistent_channel(mock_context, mock_chat_slack_client): - mock_chat_slack_client.conversations_list.return_value = { +async def test_send_message_to_inexistent_channel( + mock_context, + mock_chat_slack_client, + mock_conversation_retrieval_slack_client, +): + mock_conversation_retrieval_slack_client.conversations_list.return_value = { "ok": True, - "channels": [], + "channels": [ + { + "id": "another_channel_id", + "name": "another_channel", + "is_member": True, + "is_group": True, + } + ], } - with pytest.raises(RetryableToolError): - await send_message_to_channel(mock_context, "inexistent_channel", "Hello!") + with pytest.raises(RetryableToolError) as error: + await send_message(mock_context, channel_name="inexistent_channel", message="Hello!") - mock_chat_slack_client.conversations_list.assert_called_once() + assert "inexistent_channel" in error.value.message + assert "another_channel" in error.value.additional_prompt_content + assert "another_channel_id" in error.value.additional_prompt_content + + mock_conversation_retrieval_slack_client.conversations_list.assert_called_once() mock_chat_slack_client.chat_postMessage.assert_not_called() @@ -120,19 +183,44 @@ async def test_list_conversations_metadata_with_default_args( "channels": [mock_channel_info], } - response = await list_conversations_metadata(mock_context) + response = await list_conversations(mock_context) assert response["conversations"] == [extract_conversation_metadata(mock_channel_info)] assert response["next_cursor"] is None mock_chat_slack_client.conversations_list.assert_called_once_with( - types=",".join([conv_type.value for conv_type in ConversationTypeSlackName]), + types=None, exclude_archived=True, limit=MAX_PAGINATION_SIZE_LIMIT, cursor=None, ) +@pytest.mark.asyncio +async def test_list_conversations_metadata_with_more_pages( + mock_context, mock_chat_slack_client, dummy_channel_factory, random_str_factory +): + channel1 = dummy_channel_factory(is_channel=True) + channel2 = dummy_channel_factory(is_im=True) + channel3 = dummy_channel_factory(is_mpim=True) + next_cursor = random_str_factory() + + mock_chat_slack_client.conversations_list.return_value = { + "ok": True, + "channels": [channel1, channel2, channel3], + "response_metadata": {"next_cursor": next_cursor}, + } + + response = await list_conversations(mock_context, limit=3) + + assert response["conversations"] == [ + extract_conversation_metadata(channel1), + extract_conversation_metadata(channel2), + extract_conversation_metadata(channel3), + ] + assert response["next_cursor"] == next_cursor + + @pytest.mark.asyncio async def test_list_conversations_metadata_filtering_single_conversation_type( mock_context, mock_chat_slack_client, mock_channel_info @@ -142,8 +230,8 @@ async def test_list_conversations_metadata_filtering_single_conversation_type( "channels": [mock_channel_info], } - response = await list_conversations_metadata( - mock_context, conversation_types=ConversationType.PUBLIC_CHANNEL + response = await list_conversations( + mock_context, conversation_types=[ConversationType.PUBLIC_CHANNEL] ) assert response["conversations"] == [extract_conversation_metadata(mock_channel_info)] @@ -166,11 +254,11 @@ async def test_list_conversations_metadata_filtering_multiple_conversation_types "channels": [mock_channel_info], } - response = await list_conversations_metadata( + response = await list_conversations( mock_context, conversation_types=[ - ConversationTypeSlackName.PUBLIC_CHANNEL, - ConversationTypeSlackName.PRIVATE_CHANNEL, + ConversationType.PUBLIC_CHANNEL, + ConversationType.PRIVATE_CHANNEL, ], ) @@ -195,7 +283,7 @@ async def test_list_conversations_metadata_with_custom_pagination_args( "response_metadata": {"next_cursor": "456"}, } - response = await list_conversations_metadata(mock_context, limit=3, next_cursor="123") + response = await list_conversations(mock_context, limit=3, next_cursor="123") assert response["conversations"] == [ extract_conversation_metadata(mock_channel_info) for _ in range(3) @@ -203,7 +291,7 @@ async def test_list_conversations_metadata_with_custom_pagination_args( assert response["next_cursor"] == "456" mock_chat_slack_client.conversations_list.assert_called_once_with( - types=",".join([conv_type.value for conv_type in ConversationTypeSlackName]), + types=None, exclude_archived=True, limit=3, cursor="123", @@ -214,64 +302,37 @@ async def test_list_conversations_metadata_with_custom_pagination_args( @pytest.mark.parametrize( "faulty_slack_function_name, tool_function, tool_args", [ - ("users_list", send_dm_to_user, ("testuser", "Hello!")), - ("conversations_list", send_message_to_channel, ("general", "Hello!")), + ("users_list", send_message, {"usernames": ["testuser"], "message": "Hello!"}), + ("conversations_list", send_message, {"channel_name": "general", "message": "Hello!"}), ], ) async def test_tools_with_slack_error( mock_context, mock_chat_slack_client, faulty_slack_function_name, tool_function, tool_args ): + mock_chat_slack_client.auth_test.return_value = {"ok": True, "user_id": "current_user_id"} getattr(mock_chat_slack_client, faulty_slack_function_name).side_effect = SlackApiError( message="test_slack_error", response={"ok": False, "error": "test_slack_error"}, ) with pytest.raises(ToolExecutionError) as e: - await tool_function(mock_context, *tool_args) + await tool_function(mock_context, **tool_args) assert "test_slack_error" in str(e.value) -@pytest.mark.asyncio -@pytest.mark.parametrize( - "tool_function, conversation_type", - [ - (list_public_channels_metadata, ConversationType.PUBLIC_CHANNEL), - (list_private_channels_metadata, ConversationType.PRIVATE_CHANNEL), - ( - list_group_direct_message_conversations_metadata, - ConversationType.MULTI_PERSON_DIRECT_MESSAGE, - ), - (list_direct_message_conversations_metadata, ConversationType.DIRECT_MESSAGE), - ], -) -async def test_list_channels_metadata( - mock_context, - mock_list_conversations_metadata, - tool_function, - conversation_type, -): - response = await tool_function(mock_context, limit=3) - - mock_list_conversations_metadata.assert_called_once_with( - mock_context, conversation_types=[conversation_type], limit=3 - ) - - assert response == mock_list_conversations_metadata.return_value - - @pytest.mark.asyncio async def test_get_conversation_metadata_by_id( - mock_context, mock_chat_slack_client, mock_channel_info + mock_context, mock_conversation_retrieval_slack_client, mock_channel_info ): - mock_chat_slack_client.conversations_info.return_value = { + mock_conversation_retrieval_slack_client.conversations_info.return_value = { "ok": True, "channel": mock_channel_info, } - response = await get_conversation_metadata_by_id(mock_context, "C12345") + response = await get_conversation_metadata(mock_context, conversation_id="C12345") assert response == extract_conversation_metadata(mock_channel_info) - mock_chat_slack_client.conversations_info.assert_called_once_with( + mock_conversation_retrieval_slack_client.conversations_info.assert_called_once_with( channel="C12345", include_locale=True, include_num_members=True, @@ -279,304 +340,618 @@ async def test_get_conversation_metadata_by_id( @pytest.mark.asyncio -@patch("arcade_slack.tools.chat.list_conversations_metadata") async def test_get_conversation_metadata_by_id_slack_api_error( - mock_list_conversations_metadata, mock_context, mock_chat_slack_client, mock_channel_info + mock_context, + mock_conversation_retrieval_slack_client, + mock_channel_info, ): - mock_channel_info["name"] = "whatever_conversation_should_be_present_in_additional_prompt" - mock_list_conversations_metadata.return_value = { - "conversations": [extract_conversation_metadata(mock_channel_info)], - "response_metadata": {"next_cursor": None}, - } - mock_chat_slack_client.conversations_info.side_effect = SlackApiError( + mock_conversation_retrieval_slack_client.conversations_info.side_effect = SlackApiError( message="channel_not_found", response={"ok": False, "error": "channel_not_found"}, ) - with pytest.raises(RetryableToolError) as e: - await get_conversation_metadata_by_id(mock_context, "C12345") + with pytest.raises(ToolExecutionError) as e: + await get_conversation_metadata(mock_context, conversation_id="C12345") - assert ( - "whatever_conversation_should_be_present_in_additional_prompt" - in e.additional_prompt_content - ) - - mock_chat_slack_client.conversations_info.assert_called_once_with( - channel="C12345", - include_locale=True, - include_num_members=True, - ) - mock_list_conversations_metadata.assert_called_once_with(mock_context) + assert "C12345" in e.value.message + assert "not found" in e.value.message @pytest.mark.asyncio -async def test_get_conversation_metadata_by_name( - mock_context, mock_list_conversations_metadata, mock_channel_info +async def test_get_conversation_metadata_by_channel_name( + mock_context, + mock_conversation_retrieval_slack_client, + dummy_channel_factory, + random_str_factory, ): - sample_conversation = extract_conversation_metadata(mock_channel_info) - mock_list_conversations_metadata.return_value = { - "conversations": [sample_conversation], - "next_cursor": None, + channel_name = random_str_factory() + channel1 = dummy_channel_factory(is_channel=True, name=f"{channel_name}_another_channel") + channel2 = dummy_channel_factory(is_channel=True, name=channel_name) + + mock_conversation_retrieval_slack_client.conversations_list.return_value = { + "ok": True, + "channels": [channel1, channel2], } - response = await get_channel_metadata_by_name(mock_context, sample_conversation["name"]) + response = await get_conversation_metadata(mock_context, channel_name=channel_name) - assert response == sample_conversation - mock_list_conversations_metadata.assert_called_once_with( - context=mock_context, - conversation_types=[ - ConversationType.PUBLIC_CHANNEL, - ConversationType.PRIVATE_CHANNEL, - ], - next_cursor=None, + assert response == extract_conversation_metadata(channel2) + mock_conversation_retrieval_slack_client.conversations_list.assert_called_once_with( + types=f"{ConversationTypeSlackName.PUBLIC_CHANNEL.value},{ConversationTypeSlackName.PRIVATE_CHANNEL.value}", + exclude_archived=True, + limit=MAX_PAGINATION_SIZE_LIMIT, + cursor=None, ) @pytest.mark.asyncio -async def test_get_channel_metadata_by_name_triggering_pagination( - mock_context, mock_list_conversations_metadata, mock_channel_info +async def test_get_conversation_metadata_by_channel_name_triggering_pagination( + mock_context, + mock_conversation_retrieval_slack_client, + dummy_channel_factory, + random_str_factory, ): - target_channel = extract_conversation_metadata(mock_channel_info) - another_channel = extract_conversation_metadata(mock_channel_info) - another_channel["name"] = "another_channel" + target_channel_name = random_str_factory() + target_channel = dummy_channel_factory(is_channel=True, name=target_channel_name) + another_channel = dummy_channel_factory( + is_channel=True, name=f"{target_channel_name}_another_channel" + ) - mock_list_conversations_metadata.side_effect = [ + mock_conversation_retrieval_slack_client.conversations_list.side_effect = [ { - "conversations": [another_channel], - "next_cursor": "123", + "ok": True, + "channels": [another_channel], + "response_metadata": {"next_cursor": "123"}, }, { - "conversations": [target_channel], - "next_cursor": None, + "ok": True, + "channels": [target_channel], + "response_metadata": {"next_cursor": None}, }, ] - response = await get_channel_metadata_by_name(mock_context, target_channel["name"]) + response = await get_conversation_metadata(mock_context, channel_name=target_channel_name) - assert response == target_channel - assert mock_list_conversations_metadata.call_count == 2 - mock_list_conversations_metadata.assert_has_calls([ + assert response == extract_conversation_metadata(target_channel) + assert mock_conversation_retrieval_slack_client.conversations_list.call_count == 2 + mock_conversation_retrieval_slack_client.conversations_list.assert_has_calls([ call( - context=mock_context, - conversation_types=[ConversationType.PUBLIC_CHANNEL, ConversationType.PRIVATE_CHANNEL], - next_cursor=None, + types=f"{ConversationTypeSlackName.PUBLIC_CHANNEL.value},{ConversationTypeSlackName.PRIVATE_CHANNEL.value}", + exclude_archived=True, + limit=MAX_PAGINATION_SIZE_LIMIT, + cursor=None, ), call( - context=mock_context, - conversation_types=[ConversationType.PUBLIC_CHANNEL, ConversationType.PRIVATE_CHANNEL], - next_cursor="123", + types=f"{ConversationTypeSlackName.PUBLIC_CHANNEL.value},{ConversationTypeSlackName.PRIVATE_CHANNEL.value}", + exclude_archived=True, + limit=MAX_PAGINATION_SIZE_LIMIT, + cursor="123", ), ]) @pytest.mark.asyncio -async def test_get_channel_metadata_by_name_not_found( - mock_context, mock_list_conversations_metadata, mock_channel_info +async def test_get_conversation_metadata_by_channel_name_not_found( + mock_context, + mock_conversation_retrieval_slack_client, + dummy_channel_factory, + random_str_factory, ): - first_channel = extract_conversation_metadata(mock_channel_info) - second_channel = extract_conversation_metadata(mock_channel_info) - second_channel["name"] = "second_channel" + not_found_name = random_str_factory() + channel1 = dummy_channel_factory(is_channel=True, name=f"{not_found_name}_first") + channel2 = dummy_channel_factory(is_channel=True, name=f"{not_found_name}_second") - mock_list_conversations_metadata.side_effect = [ + mock_conversation_retrieval_slack_client.conversations_list.side_effect = [ { - "conversations": [second_channel], - "next_cursor": "123", + "ok": True, + "channels": [channel1], + "response_metadata": {"next_cursor": "123"}, }, { - "conversations": [first_channel], - "next_cursor": None, + "ok": True, + "channels": [channel2], + "response_metadata": {"next_cursor": None}, }, ] - with pytest.raises(RetryableToolError): - await get_channel_metadata_by_name(mock_context, "inexistent_channel") + with pytest.raises(RetryableToolError) as error: + await get_conversation_metadata(mock_context, channel_name=not_found_name) - assert mock_list_conversations_metadata.call_count == 2 - mock_list_conversations_metadata.assert_has_calls([ + assert "not found" in error.value.message + assert not_found_name in error.value.message + assert ( + json.dumps([ + {"id": channel1["id"], "name": channel1["name"]}, + {"id": channel2["id"], "name": channel2["name"]}, + ]) + in error.value.additional_prompt_content + ) + + assert mock_conversation_retrieval_slack_client.conversations_list.call_count == 2 + mock_conversation_retrieval_slack_client.conversations_list.assert_has_calls([ call( - context=mock_context, - conversation_types=[ConversationType.PUBLIC_CHANNEL, ConversationType.PRIVATE_CHANNEL], - next_cursor=None, + types=f"{ConversationTypeSlackName.PUBLIC_CHANNEL.value},{ConversationTypeSlackName.PRIVATE_CHANNEL.value}", + exclude_archived=True, + limit=MAX_PAGINATION_SIZE_LIMIT, + cursor=None, ), call( - context=mock_context, - conversation_types=[ConversationType.PUBLIC_CHANNEL, ConversationType.PRIVATE_CHANNEL], - next_cursor="123", + types=f"{ConversationTypeSlackName.PUBLIC_CHANNEL.value},{ConversationTypeSlackName.PRIVATE_CHANNEL.value}", + exclude_archived=True, + limit=MAX_PAGINATION_SIZE_LIMIT, + cursor="123", ), ]) @pytest.mark.asyncio -@patch("arcade_slack.tools.chat.async_paginate") -@patch("arcade_slack.tools.chat.get_user_info_by_id") -async def test_get_members_from_conversation_id( - mock_get_user_info_by_id, mock_async_paginate, mock_context, mock_chat_slack_client +async def test_get_conversation_metadata_by_username( + mock_context, + mock_chat_slack_client, + mock_user_retrieval_slack_client, + dummy_user_factory, + dummy_channel_factory, ): - member1 = {"id": "U123", "name": "testuser123"} - member1_info = extract_basic_user_info(member1) - member2 = {"id": "U456", "name": "testuser456"} - member2_info = extract_basic_user_info(member2) + current_user = dummy_user_factory(id_="U1", name="current_user") + other_user = dummy_user_factory(id_="U2", name="other_user") + conversation = dummy_channel_factory(is_im=True) - mock_async_paginate.return_value = [member1["id"], member2["id"]], "token123" - mock_get_user_info_by_id.side_effect = [member1_info, member2_info] + mock_chat_slack_client.auth_test.return_value = { + "ok": True, + "user_id": current_user["id"], + } - response = await get_members_in_conversation_by_id( - mock_context, conversation_id="C12345", limit=2 + mock_user_retrieval_slack_client.users_list.return_value = { + "ok": True, + "members": [current_user, other_user], + "response_metadata": {"next_cursor": None}, + } + + mock_chat_slack_client.conversations_open.return_value = { + "ok": True, + "channel": conversation, + } + + response = await get_conversation_metadata(mock_context, usernames=[other_user["name"]]) + + assert response == extract_conversation_metadata(conversation) + + +@pytest.mark.asyncio +async def test_get_dm_conversation_metadata_by_username_not_found( + mock_context, + mock_chat_slack_client, + mock_user_retrieval_slack_client, + dummy_user_factory, + dummy_channel_factory, + random_str_factory, +): + current_user = dummy_user_factory(id_="U1", name="current_user") + other_user = dummy_user_factory(id_="U2", name="other_user") + conversation = dummy_channel_factory(is_im=True) + not_found_user_name = random_str_factory() + + mock_chat_slack_client.auth_test.return_value = { + "ok": True, + "user_id": current_user["id"], + } + + mock_user_retrieval_slack_client.users_list.return_value = { + "ok": True, + "members": [current_user, other_user], + "response_metadata": {"next_cursor": None}, + } + + mock_chat_slack_client.conversations_open.return_value = { + "ok": True, + "channel": conversation, + } + + with pytest.raises(RetryableToolError) as error: + await get_conversation_metadata(mock_context, usernames=[not_found_user_name]) + + assert "not found" in error.value.message + assert not_found_user_name in error.value.message + assert other_user["id"] in error.value.additional_prompt_content + assert other_user["name"] in error.value.additional_prompt_content + + mock_chat_slack_client.conversations_open.assert_not_called() + + +@pytest.mark.asyncio +async def test_get_mpim_conversation_metadata_by_usernames( + mock_context, + mock_chat_slack_client, + mock_user_retrieval_slack_client, + dummy_user_factory, + dummy_channel_factory, +): + current_user = dummy_user_factory(id_="U1", name="current_user") + other_user1 = dummy_user_factory(id_="U2", name="other_user1") + other_user2 = dummy_user_factory(id_="U3", name="other_user2") + other_user3 = dummy_user_factory(id_="U4", name="other_user3") + other_user4 = dummy_user_factory(id_="U5", name="other_user4") + + conversation = dummy_channel_factory(is_mpim=True) + + mock_chat_slack_client.auth_test.return_value = { + "ok": True, + "user_id": current_user["id"], + } + + mock_user_retrieval_slack_client.users_list.side_effect = [ + { + "ok": True, + "members": [current_user, other_user1, other_user3], + "response_metadata": {"next_cursor": "users_list_cursor1"}, + }, + { + "ok": True, + "members": [current_user, other_user2, other_user4], + "response_metadata": {"next_cursor": None}, + }, + ] + + mock_chat_slack_client.conversations_open.return_value = { + "ok": True, + "channel": conversation, + } + + response = await get_conversation_metadata( + mock_context, + usernames=[other_user1["name"], other_user2["name"]], + ) + + assert response == extract_conversation_metadata(conversation) + + mock_chat_slack_client.conversations_open.assert_called_once_with( + users=[current_user["id"], other_user1["id"], other_user2["id"]], + return_im=True, + ) + + +@pytest.mark.asyncio +async def test_get_mpim_conversation_metadata_by_user_ids_and_usernames( + mock_context, + mock_chat_slack_client, + mock_user_retrieval_slack_client, + dummy_user_factory, + dummy_channel_factory, +): + current_user = dummy_user_factory(id_="U1", name="current_user") + other_user1 = dummy_user_factory(id_="U2", name="other_user1") + other_user2 = dummy_user_factory(id_="U3", name="other_user2") + other_user3 = dummy_user_factory(id_="U4", name="other_user3") + other_user4 = dummy_user_factory(id_="U5", name="other_user4") + + conversation = dummy_channel_factory(is_mpim=True) + + mock_chat_slack_client.auth_test.return_value = { + "ok": True, + "user_id": current_user["id"], + } + + mock_user_retrieval_slack_client.users_list.side_effect = [ + { + "ok": True, + "members": [current_user, other_user1, other_user3], + "response_metadata": {"next_cursor": "users_list_cursor1"}, + }, + { + "ok": True, + "members": [current_user, other_user2, other_user4], + "response_metadata": {"next_cursor": None}, + }, + ] + + mock_chat_slack_client.conversations_open.return_value = { + "ok": True, + "channel": conversation, + } + + response = await get_conversation_metadata( + mock_context, + user_ids=[other_user3["id"]], + usernames=[other_user1["name"], other_user2["name"]], + ) + + assert response == extract_conversation_metadata(conversation) + + mock_chat_slack_client.conversations_open.assert_called_once_with( + users=[other_user3["id"], current_user["id"], other_user1["id"], other_user2["id"]], + return_im=True, + ) + + +@pytest.mark.asyncio +async def test_get_mpim_conversation_metadata_by_user_ids_usernames_and_emails( + mock_context, + mock_chat_slack_client, + mock_user_retrieval_slack_client, + dummy_user_factory, + dummy_channel_factory, +): + current_user = dummy_user_factory(id_="U1", name="current_user") + other_user1 = dummy_user_factory(id_="U2", name="other_user1") + other_user2 = dummy_user_factory(id_="U3", name="other_user2") + other_user3 = dummy_user_factory(id_="U4", name="other_user3") + other_user4 = dummy_user_factory(id_="U5", name="other_user4") + other_user5 = dummy_user_factory(id_="U6", name="other_user5") + + conversation = dummy_channel_factory(is_mpim=True) + + mock_chat_slack_client.auth_test.return_value = { + "ok": True, + "user_id": current_user["id"], + } + + mock_user_retrieval_slack_client.users_list.side_effect = [ + { + "ok": True, + "members": [current_user, other_user1, other_user3], + "response_metadata": {"next_cursor": "users_list_cursor1"}, + }, + { + "ok": True, + "members": [current_user, other_user2, other_user4], + "response_metadata": {"next_cursor": None}, + }, + ] + + mock_user_retrieval_slack_client.users_lookupByEmail.side_effect = [ + { + "ok": True, + "user": other_user5, + }, + ] + + mock_chat_slack_client.conversations_open.return_value = { + "ok": True, + "channel": conversation, + } + + response = await get_conversation_metadata( + mock_context, + user_ids=[other_user3["id"]], + usernames=[other_user1["name"], other_user2["name"]], + emails=[other_user5["profile"]["email"]], + ) + + assert response == extract_conversation_metadata(conversation) + + mock_chat_slack_client.conversations_open.assert_called_once_with( + users=[ + other_user3["id"], + current_user["id"], + other_user1["id"], + other_user2["id"], + other_user5["id"], + ], + return_im=True, + ) + + +@pytest.mark.asyncio +async def test_get_users_in_conversation_by_id_with_conversation_and_user_paginations( + mock_context, + mock_chat_slack_client, + mock_user_retrieval_slack_client, + dummy_user_factory, + random_str_factory, +): + conversation_id = random_str_factory() + user1 = dummy_user_factory(id_="1") + user2 = dummy_user_factory(id_="2") + user3 = dummy_user_factory(id_="3") + + mock_chat_slack_client.conversations_members.side_effect = [ + { + "ok": True, + "members": [user1["id"], user2["id"]], + "response_metadata": {"next_cursor": "conversations_members_cursor1"}, + }, + { + "ok": True, + "members": [user3["id"]], + "response_metadata": {"next_cursor": "conversations_members_cursor2"}, + }, + ] + + mock_user_retrieval_slack_client.users_list.side_effect = [ + { + "ok": True, + "members": [user1, user2], + "response_metadata": {"next_cursor": "users_list_cursor"}, + }, + { + "ok": True, + "members": [user3], + "response_metadata": {"next_cursor": None}, + }, + ] + + response = await get_users_in_conversation( + context=mock_context, + conversation_id=conversation_id, + limit=3, ) assert response == { - "members": [member1_info, member2_info], - "next_cursor": "token123", + "users": [ + cast_user_dict(user1), + cast_user_dict(user2), + cast_user_dict(user3), + ], + "next_cursor": "conversations_members_cursor2", } - mock_async_paginate.assert_called_once_with( - mock_chat_slack_client.conversations_members, - "members", - limit=2, - next_cursor=None, - channel="C12345", - ) - mock_get_user_info_by_id.assert_has_calls([ - call(mock_context, member1["id"]), - call(mock_context, member2["id"]), + + mock_chat_slack_client.conversations_members.assert_has_calls([ + call( + channel=conversation_id, + limit=3, + cursor=None, + ), + call( + channel=conversation_id, + limit=1, + cursor="conversations_members_cursor1", + ), + ]) + + mock_user_retrieval_slack_client.users_list.assert_has_calls([ + call( + limit=MAX_PAGINATION_SIZE_LIMIT, + cursor=None, + ), + call( + limit=MAX_PAGINATION_SIZE_LIMIT, + cursor="users_list_cursor", + ), ]) @pytest.mark.asyncio -@patch("arcade_slack.tools.chat.async_paginate") -@patch("arcade_slack.tools.chat.get_user_info_by_id") -@patch("arcade_slack.tools.chat.list_conversations_metadata") -async def test_get_members_from_conversation_id_channel_not_found( - mock_list_conversations_metadata, - mock_get_user_info_by_id, - mock_async_paginate, +async def test_get_users_in_conversation_by_channel_name( mock_context, mock_chat_slack_client, - mock_channel_info, + mock_conversation_retrieval_slack_client, + mock_user_retrieval_slack_client, + dummy_channel_factory, + dummy_user_factory, + random_str_factory, ): - conversations = [extract_conversation_metadata(mock_channel_info)] * 2 - mock_list_conversations_metadata.return_value = { - "conversations": conversations, - "next_cursor": None, - } + channel_name = random_str_factory() + channel1 = dummy_channel_factory(is_channel=True, name=f"{channel_name}_another_channel") + channel2 = dummy_channel_factory(is_channel=True, name=channel_name) - member1 = {"id": "U123", "name": "testuser123"} - member1_info = extract_basic_user_info(member1) - member2 = {"id": "U456", "name": "testuser456"} - member2_info = extract_basic_user_info(member2) - - mock_async_paginate.side_effect = SlackApiError( - message="channel_not_found", - response={"ok": False, "error": "channel_not_found"}, - ) - mock_get_user_info_by_id.side_effect = [member1_info, member2_info] - - with pytest.raises(RetryableToolError): - await get_members_in_conversation_by_id(mock_context, conversation_id="C12345", limit=2) - - mock_async_paginate.assert_called_once_with( - mock_chat_slack_client.conversations_members, - "members", - limit=2, - next_cursor=None, - channel="C12345", - ) - mock_get_user_info_by_id.assert_not_called() - - -@pytest.mark.asyncio -@patch("arcade_slack.tools.chat.list_conversations_metadata") -@patch("arcade_slack.tools.chat.get_members_in_conversation_by_id") -async def test_get_members_in_channel_by_name( - mock_get_members_in_conversation_by_id, - mock_list_conversations_metadata, - mock_context, - mock_channel_info, -): - mock_list_conversations_metadata.return_value = { - "conversations": [extract_conversation_metadata(mock_channel_info)], - "next_cursor": None, - } - - response = await get_members_in_channel_by_name( - mock_context, mock_channel_info["name"], limit=2 - ) - - assert response == mock_get_members_in_conversation_by_id.return_value - mock_list_conversations_metadata.assert_called_once_with( - context=mock_context, - conversation_types=[ - ConversationType.PUBLIC_CHANNEL, - ConversationType.PRIVATE_CHANNEL, - ], - next_cursor=None, - ) - mock_get_members_in_conversation_by_id.assert_called_once_with( - context=mock_context, - conversation_id="C12345", - limit=2, - next_cursor=None, - ) - - -@pytest.mark.asyncio -@patch("arcade_slack.tools.chat.list_conversations_metadata") -@patch("arcade_slack.tools.chat.get_members_in_conversation_by_id") -async def test_get_members_in_channel_by_name_triggering_pagination( - mock_get_members_in_conversation_by_id, - mock_list_conversations_metadata, - mock_context, - mock_channel_info, -): - conversation1 = copy.deepcopy(mock_channel_info) - conversation1["name"] = "conversation1" - conversation2 = copy.deepcopy(mock_channel_info) - conversation2["name"] = "conversation2" - - mock_list_conversations_metadata.side_effect = [ + mock_conversation_retrieval_slack_client.conversations_list.side_effect = [ { - "conversations": [extract_conversation_metadata(conversation1)], - "next_cursor": "123", + "ok": True, + "channels": [channel1], + "response_metadata": {"next_cursor": "123"}, }, { - "conversations": [extract_conversation_metadata(conversation2)], - "next_cursor": None, + "ok": True, + "channels": [channel2], + "response_metadata": {"next_cursor": None}, }, ] - response = await get_members_in_channel_by_name(mock_context, conversation2["name"], limit=2) + user1 = dummy_user_factory(id_="1") + user2 = dummy_user_factory(id_="2") - assert response == mock_get_members_in_conversation_by_id.return_value - mock_list_conversations_metadata.assert_has_calls([ - call( - context=mock_context, - conversation_types=[ConversationType.PUBLIC_CHANNEL, ConversationType.PRIVATE_CHANNEL], - next_cursor=None, - ), - call( - context=mock_context, - conversation_types=[ConversationType.PUBLIC_CHANNEL, ConversationType.PRIVATE_CHANNEL], - next_cursor="123", - ), - ]) - mock_get_members_in_conversation_by_id.assert_called_once_with( - context=mock_context, - conversation_id="C12345", - limit=2, - next_cursor=None, - ) + mock_chat_slack_client.conversations_members.side_effect = [ + { + "ok": True, + "members": [user1["id"], user2["id"]], + "response_metadata": {"next_cursor": None}, + }, + ] + + mock_user_retrieval_slack_client.users_list.side_effect = [ + { + "ok": True, + "members": [user1, user2], + "response_metadata": {"next_cursor": None}, + }, + ] + + response = await get_users_in_conversation(mock_context, channel_name=channel_name) + + assert response == { + "users": [ + cast_user_dict(user1), + cast_user_dict(user2), + ], + "next_cursor": None, + } @pytest.mark.asyncio -async def test_get_conversation_history_by_id(mock_context, mock_chat_slack_client): - mock_chat_slack_client.conversations_history.return_value = { +async def test_get_users_in_conversation_by_channel_name_not_found( + mock_context, + mock_conversation_retrieval_slack_client, + dummy_channel_factory, + random_str_factory, +): + not_found_channel_name = random_str_factory() + channel1 = dummy_channel_factory(is_channel=True, name=f"{not_found_channel_name}_first") + channel2 = dummy_channel_factory(is_channel=True, name=f"{not_found_channel_name}_second") + + mock_conversation_retrieval_slack_client.conversations_list.side_effect = [ + { + "ok": True, + "channels": [channel1], + "response_metadata": {"next_cursor": "123"}, + }, + { + "ok": True, + "channels": [channel2], + "response_metadata": {"next_cursor": None}, + }, + ] + + with pytest.raises(RetryableToolError) as error: + await get_users_in_conversation(mock_context, channel_name=not_found_channel_name) + + assert "not found" in error.value.message + assert not_found_channel_name in error.value.message + assert ( + json.dumps([ + {"id": channel1["id"], "name": channel1["name"]}, + {"id": channel2["id"], "name": channel2["name"]}, + ]) + in error.value.additional_prompt_content + ) + + assert mock_conversation_retrieval_slack_client.conversations_list.call_count == 2 + mock_conversation_retrieval_slack_client.conversations_list.assert_has_calls([ + call( + types=f"{ConversationTypeSlackName.PUBLIC_CHANNEL.value},{ConversationTypeSlackName.PRIVATE_CHANNEL.value}", + exclude_archived=True, + limit=MAX_PAGINATION_SIZE_LIMIT, + cursor=None, + ), + call( + types=f"{ConversationTypeSlackName.PUBLIC_CHANNEL.value},{ConversationTypeSlackName.PRIVATE_CHANNEL.value}", + exclude_archived=True, + limit=MAX_PAGINATION_SIZE_LIMIT, + cursor="123", + ), + ]) + + +@pytest.mark.asyncio +async def test_get_messages_by_conversation_id( + mock_context, + mock_message_retrieval_slack_client, + mock_user_retrieval_slack_client, + dummy_user_factory, + dummy_message_factory, +): + user = dummy_user_factory() + message = dummy_message_factory(user_id=user["id"]) + + mock_message_retrieval_slack_client.conversations_history.return_value = { "ok": True, - "messages": [{"text": "Hello, world!"}], + "messages": [message], + "response_metadata": {"next_cursor": "cursor"}, } - response = await get_messages_in_conversation_by_id(mock_context, "C12345", limit=1) + mock_user_retrieval_slack_client.users_info.return_value = { + "ok": True, + "user": user, + } - assert response == {"messages": [{"text": "Hello, world!"}], "next_cursor": None} - mock_chat_slack_client.conversations_history.assert_called_once_with( + response = await get_messages(mock_context, "C12345", limit=1) + + assert response["next_cursor"] == "cursor" + assert len(response["messages"]) == 1 + returned_message = response["messages"][0] + assert returned_message["user"] == {"id": user["id"], "name": user["name"]} + assert returned_message["text"] == message["text"] + + mock_message_retrieval_slack_client.conversations_history.assert_called_once_with( channel="C12345", include_all_metadata=True, inclusive=True, @@ -585,28 +960,34 @@ async def test_get_conversation_history_by_id(mock_context, mock_chat_slack_clie ) -# TODO: pass a current unix timestamp to the tool, instead of mocking the datetime -# conversion. Have to wait until arcade.core.annotations.Inferrable is implemented. @pytest.mark.asyncio -@patch("arcade_slack.tools.chat.convert_relative_datetime_to_unix_timestamp") -@patch("arcade_slack.tools.chat.datetime") -async def test_get_conversation_history_by_id_with_relative_datetime_args( +@patch("arcade_slack.message_retrieval.convert_relative_datetime_to_unix_timestamp") +@patch("arcade_slack.message_retrieval.datetime") +async def test_get_messages_by_conversation_id_with_relative_datetime_args( mock_datetime, mock_convert_relative_datetime_to_unix_timestamp, mock_context, - mock_chat_slack_client, + mock_message_retrieval_slack_client, + mock_user_retrieval_slack_client, + dummy_user_factory, + dummy_message_factory, ): - mock_chat_slack_client.conversations_history.return_value = { + user = dummy_user_factory() + message = dummy_message_factory(user_id=user["id"]) + + mock_message_retrieval_slack_client.conversations_history.return_value = { "ok": True, - "messages": [{"text": "Hello, world!"}], + "messages": [message], + } + + mock_user_retrieval_slack_client.users_info.return_value = { + "ok": True, + "user": user, } expected_oldest_timestamp = 1716489600 expected_latest_timestamp = 1716403200 - # Ideally we'd pass the current unix timestamp to the function, instead of mocking, but - # currently there's no way to have a tool argument that is not exposed to the LLM. We - # should have that soon, though. mock_datetime.now.return_value = datetime(2025, 1, 1, 0, 0, 0, tzinfo=timezone.utc) expected_current_unix_timestamp = int(mock_datetime.now.return_value.timestamp()) mock_convert_relative_datetime_to_unix_timestamp.side_effect = [ @@ -614,16 +995,25 @@ async def test_get_conversation_history_by_id_with_relative_datetime_args( expected_oldest_timestamp, ] - response = await get_messages_in_conversation_by_id( - mock_context, "C12345", oldest_relative="02:00:00", latest_relative="01:00:00", limit=1 + response = await get_messages( + context=mock_context, + conversation_id="C12345", + oldest_relative="02:00:00", + latest_relative="01:00:00", + limit=1, ) - assert response == {"messages": [{"text": "Hello, world!"}], "next_cursor": None} + assert response["next_cursor"] is None + assert len(response["messages"]) == 1 + returned_message = response["messages"][0] + assert returned_message["user"] == {"id": user["id"], "name": user["name"]} + assert returned_message["text"] == message["text"] + mock_convert_relative_datetime_to_unix_timestamp.assert_has_calls([ call("01:00:00", expected_current_unix_timestamp), call("02:00:00", expected_current_unix_timestamp), ]) - mock_chat_slack_client.conversations_history.assert_called_once_with( + mock_message_retrieval_slack_client.conversations_history.assert_called_once_with( channel="C12345", include_all_metadata=True, inclusive=True, @@ -634,43 +1024,56 @@ async def test_get_conversation_history_by_id_with_relative_datetime_args( ) -# TODO: pass a current unix timestamp to the tool, instead of mocking the datetime -# conversion. Have to wait until arcade.core.annotations.Inferrable is implemented. @pytest.mark.asyncio -@patch("arcade_slack.tools.chat.convert_datetime_to_unix_timestamp") -async def test_get_conversation_history_by_id_with_absolute_datetime_args( - mock_convert_datetime_to_unix_timestamp, mock_context, mock_chat_slack_client +@patch("arcade_slack.message_retrieval.convert_datetime_to_unix_timestamp") +async def test_get_messages_by_conversation_id_with_absolute_datetime_args( + mock_convert_datetime_to_unix_timestamp, + mock_context, + mock_message_retrieval_slack_client, + mock_user_retrieval_slack_client, + dummy_user_factory, + dummy_message_factory, ): - mock_chat_slack_client.conversations_history.return_value = { + user = dummy_user_factory() + message = dummy_message_factory(user_id=user["id"]) + + mock_message_retrieval_slack_client.conversations_history.return_value = { "ok": True, - "messages": [{"text": "Hello, world!"}], + "messages": [message], + } + + mock_user_retrieval_slack_client.users_info.return_value = { + "ok": True, + "user": user, } expected_latest_timestamp = 1716403200 expected_oldest_timestamp = 1716489600 - # Ideally we'd pass the current unix timestamp to the function, instead of mocking, but - # currently there's no way to have a tool argument that is not exposed to the LLM. We - # should have that soon, though. mock_convert_datetime_to_unix_timestamp.side_effect = [ expected_latest_timestamp, expected_oldest_timestamp, ] - response = await get_messages_in_conversation_by_id( - mock_context, - "C12345", + response = await get_messages( + context=mock_context, + conversation_id="C12345", oldest_datetime="2025-01-01 00:00:00", latest_datetime="2025-01-02 00:00:00", limit=1, ) - assert response == {"messages": [{"text": "Hello, world!"}], "next_cursor": None} + assert response["next_cursor"] is None + assert len(response["messages"]) == 1 + returned_message = response["messages"][0] + assert returned_message["user"] == {"id": user["id"], "name": user["name"]} + assert returned_message["text"] == message["text"] + mock_convert_datetime_to_unix_timestamp.assert_has_calls([ call("2025-01-02 00:00:00"), call("2025-01-01 00:00:00"), ]) - mock_chat_slack_client.conversations_history.assert_called_once_with( + mock_message_retrieval_slack_client.conversations_history.assert_called_once_with( channel="C12345", include_all_metadata=True, inclusive=True, @@ -682,34 +1085,45 @@ async def test_get_conversation_history_by_id_with_absolute_datetime_args( @pytest.mark.asyncio -async def test_get_conversation_history_by_id_with_messed_oldest_args( - mock_context, mock_chat_slack_client +async def test_get_messages_by_conversation_id_with_messed_oldest_args( + mock_context, mock_message_retrieval_slack_client ): with pytest.raises(ToolExecutionError): - await get_messages_in_conversation_by_id( - mock_context, - "C12345", + await get_messages( + context=mock_context, + conversation_id="C12345", oldest_datetime="2025-01-01 00:00:00", oldest_relative="01:00:00", ) + mock_message_retrieval_slack_client.conversations_history.assert_not_called() + @pytest.mark.asyncio -async def test_get_conversation_history_by_id_with_messed_latest_args( - mock_context, mock_chat_slack_client +async def test_get_messages_by_conversation_id_with_messed_latest_args( + mock_context, mock_message_retrieval_slack_client ): with pytest.raises(ToolExecutionError): - await get_messages_in_conversation_by_id( - mock_context, - "C12345", + await get_messages( + context=mock_context, + conversation_id="C12345", latest_datetime="2025-01-01 00:00:00", latest_relative="01:00:00", ) + mock_message_retrieval_slack_client.conversations_history.assert_not_called() + @pytest.mark.asyncio -async def test_get_conversation_history_by_name(mock_context, mock_chat_slack_client): - mock_chat_slack_client.conversations_list.return_value = { +async def test_get_messages_by_channel_name( + mock_context, + mock_message_retrieval_slack_client, + mock_conversation_retrieval_slack_client, + mock_user_retrieval_slack_client, + dummy_message_factory, + dummy_user_factory, +): + mock_conversation_retrieval_slack_client.conversations_list.return_value = { "ok": True, "channels": [ { @@ -720,273 +1134,35 @@ async def test_get_conversation_history_by_name(mock_context, mock_chat_slack_cl } ], } - mock_chat_slack_client.conversations_history.return_value = { + + user = dummy_user_factory() + message = dummy_message_factory(user_id=user["id"]) + mock_message_retrieval_slack_client.conversations_history.return_value = { "ok": True, - "messages": [{"text": "Hello, world!"}], + "messages": [message], } - response = await get_messages_in_channel_by_name(mock_context, "general", limit=1) - - assert response == {"messages": [{"text": "Hello, world!"}], "next_cursor": None} - mock_chat_slack_client.conversations_history.assert_called_once_with( - channel="C12345", include_all_metadata=True, inclusive=True, limit=1, cursor=None - ) - - -@pytest.mark.asyncio -@patch("arcade_slack.tools.chat.retrieve_conversations_by_user_ids") -async def test_get_direct_message_conversation_metadata_by_username( - mock_retrieve_conversations_by_user_ids, - mock_context, - mock_chat_slack_client, - mock_users_slack_client, -): - mock_chat_slack_client.auth_test.return_value = { + mock_user_retrieval_slack_client.users_info.return_value = { "ok": True, - "user_id": "U1", - "team_id": "T1", - "user": "user1", + "user": user, } - mock_users_slack_client.users_list.return_value = { - "ok": True, - "members": [ - {"id": "U1", "name": "user1"}, - {"id": "U2", "name": "user2"}, - ], - "response_metadata": {"next_cursor": None}, - } - - conversation = { - "id": "C12345", - "type": ConversationTypeSlackName.IM.value, - "is_im": True, - "members": ["U1", "U2"], - } - - mock_retrieve_conversations_by_user_ids.return_value = [conversation] - - response = await get_direct_message_conversation_metadata_by_username( - context=mock_context, username="user2" - ) - - assert response == conversation - mock_retrieve_conversations_by_user_ids.assert_called_once_with( - list_conversations_func=list_conversations_metadata, - get_members_in_conversation_func=get_members_in_conversation_by_id, + response = await get_messages( context=mock_context, - conversation_types=[ConversationType.DIRECT_MESSAGE], - user_ids=["U1", "U2"], - exact_match=True, + channel_name="general", limit=1, - next_cursor=None, ) + assert response["next_cursor"] is None + assert len(response["messages"]) == 1 + returned_message = response["messages"][0] + assert returned_message["user"] == {"id": user["id"], "name": user["name"]} + assert returned_message["text"] == message["text"] -@pytest.mark.asyncio -@patch("arcade_slack.tools.chat.retrieve_conversations_by_user_ids") -async def test_get_direct_message_conversation_metadata_by_username_username_not_found( - mock_retrieve_conversations_by_user_ids, - mock_context, - mock_chat_slack_client, - mock_users_slack_client, -): - mock_chat_slack_client.users_identity.return_value = { - "ok": True, - "user": {"id": "U1", "name": "user1"}, - "team": {"id": "T1", "name": "team1"}, - } - - mock_users_slack_client.users_list.return_value = { - "ok": True, - "members": [ - {"id": "U1", "name": "user1"}, - {"id": "U2", "name": "user2"}, - ], - "response_metadata": {"next_cursor": None}, - } - - mock_retrieve_conversations_by_user_ids.side_effect = TimeoutError() - - with pytest.raises(RetryableToolError): - await get_direct_message_conversation_metadata_by_username( - context=mock_context, username="user999" - ) - - -@pytest.mark.asyncio -@patch("arcade_slack.tools.chat.get_messages_in_conversation_by_id") -@patch("arcade_slack.tools.chat.get_direct_message_conversation_metadata_by_username") -async def test_get_messages_in_direct_conversation_by_username( - mock_get_direct_message_conversation_metadata_by_username, - mock_get_messages_in_conversation_by_id, - mock_context, -): - mock_get_direct_message_conversation_metadata_by_username.return_value = { - "id": "C12345", - } - - response = await get_messages_in_direct_message_conversation_by_username( - context=mock_context, username="user2" - ) - - assert response == mock_get_messages_in_conversation_by_id.return_value - mock_get_direct_message_conversation_metadata_by_username.assert_called_once_with( - context=mock_context, username="user2" - ) - mock_get_messages_in_conversation_by_id.assert_called_once_with( - context=mock_context, - conversation_id="C12345", - oldest_relative=None, - latest_relative=None, - oldest_datetime=None, - latest_datetime=None, - limit=None, - next_cursor=None, - ) - - -@pytest.mark.asyncio -@patch("arcade_slack.tools.chat.get_direct_message_conversation_metadata_by_username") -async def test_get_messages_in_direct_conversation_by_username_not_found( - mock_get_direct_message_conversation_metadata_by_username, - mock_context, -): - mock_get_direct_message_conversation_metadata_by_username.return_value = None - - with pytest.raises(ToolExecutionError): - await get_messages_in_direct_message_conversation_by_username( - context=mock_context, username="user2" - ) - - -@pytest.mark.asyncio -@patch("arcade_slack.tools.chat.retrieve_conversations_by_user_ids") -async def test_get_multi_person_direct_message_conversation_metadata_by_username( - mock_retrieve_conversations_by_user_ids, - mock_context, - mock_chat_slack_client, - mock_users_slack_client, -): - mock_chat_slack_client.auth_test.return_value = { - "ok": True, - "user_id": "U1", - "team_id": "T1", - "user": "user1", - } - - mock_users_slack_client.users_list.return_value = { - "ok": True, - "members": [ - {"id": "U1", "name": "user1"}, - {"id": "U2", "name": "user2"}, - {"id": "U3", "name": "user3"}, - {"id": "U4", "name": "user4"}, - {"id": "U5", "name": "user5"}, - ], - "response_metadata": {"next_cursor": None}, - } - - conversation = { - "id": "C12345", - "type": ConversationTypeSlackName.MPIM.value, - "is_mpim": True, - "members": ["U1", "U4", "U5"], - } - - mock_retrieve_conversations_by_user_ids.return_value = [conversation] - - response = await get_multi_person_dm_conversation_metadata_by_usernames( - context=mock_context, usernames=["user1", "user4", "user5"] - ) - - assert response == conversation - mock_retrieve_conversations_by_user_ids.assert_called_once_with( - list_conversations_func=list_conversations_metadata, - get_members_in_conversation_func=get_members_in_conversation_by_id, - context=mock_context, - conversation_types=[ConversationType.MULTI_PERSON_DIRECT_MESSAGE], - user_ids=["U1", "U4", "U5"], - exact_match=True, + mock_message_retrieval_slack_client.conversations_history.assert_called_once_with( + channel="C12345", + include_all_metadata=True, + inclusive=True, limit=1, - next_cursor=None, + cursor=None, ) - - -@pytest.mark.asyncio -@patch("arcade_slack.tools.chat.retrieve_conversations_by_user_ids") -async def test_get_multi_person_direct_message_conversation_metadata_by_username_username_not_found( - mock_retrieve_conversations_by_user_ids, - mock_context, - mock_chat_slack_client, - mock_users_slack_client, -): - mock_chat_slack_client.users_identity.return_value = { - "ok": True, - "user": {"id": "U1", "name": "user1"}, - "team": {"id": "T1", "name": "team1"}, - } - - mock_users_slack_client.users_list.return_value = { - "ok": True, - "members": [ - {"id": "U1", "name": "user1"}, - {"id": "U2", "name": "user2"}, - ], - "response_metadata": {"next_cursor": None}, - } - - mock_retrieve_conversations_by_user_ids.side_effect = TimeoutError() - - with pytest.raises(RetryableToolError): - await get_multi_person_dm_conversation_metadata_by_usernames( - context=mock_context, usernames=["user999", "user1", "user2"] - ) - - -@pytest.mark.asyncio -@patch("arcade_slack.tools.chat.get_messages_in_conversation_by_id") -@patch("arcade_slack.tools.chat.get_multi_person_dm_conversation_metadata_by_usernames") -async def test_get_messages_in_multi_person_dm_conversation_by_usernames( - mock_get_multi_person_dm_conversation_metadata_by_usernames, - mock_get_messages_in_conversation_by_id, - mock_context, -): - mock_get_multi_person_dm_conversation_metadata_by_usernames.return_value = { - "id": "C12345", - } - - response = await get_messages_in_multi_person_dm_conversation_by_usernames( - context=mock_context, usernames=["user1", "user4", "user5"] - ) - - assert response == mock_get_messages_in_conversation_by_id.return_value - - mock_get_multi_person_dm_conversation_metadata_by_usernames.assert_called_once_with( - context=mock_context, usernames=["user1", "user4", "user5"] - ) - - mock_get_messages_in_conversation_by_id.assert_called_once_with( - context=mock_context, - conversation_id="C12345", - oldest_relative=None, - latest_relative=None, - oldest_datetime=None, - latest_datetime=None, - limit=None, - next_cursor=None, - ) - - -@pytest.mark.asyncio -@patch("arcade_slack.tools.chat.get_multi_person_dm_conversation_metadata_by_usernames") -async def test_get_messages_in_multi_person_dm_conversation_by_usernames_not_found( - mock_get_multi_person_dm_conversation_metadata_by_usernames, - mock_context, -): - mock_get_multi_person_dm_conversation_metadata_by_usernames.return_value = None - - with pytest.raises(ToolExecutionError): - await get_messages_in_direct_message_conversation_by_username( - context=mock_context, username="user2" - ) diff --git a/toolkits/slack/tests/test_models.py b/toolkits/slack/tests/test_models.py new file mode 100644 index 00000000..208473a0 --- /dev/null +++ b/toolkits/slack/tests/test_models.py @@ -0,0 +1,107 @@ +import asyncio +from unittest.mock import AsyncMock + +import pytest +from arcade_tdk.errors import ToolExecutionError +from slack_sdk.errors import SlackApiError + +from arcade_slack.models import ( + ConcurrencySafeCoroutineCaller, + FindChannelByNameSentinel, + FindMultipleUsersByUsernameSentinel, + FindUserByUsernameSentinel, + GetUserByEmailCaller, +) + + +def test_find_user_by_username_sentinel(): + sentinel = FindUserByUsernameSentinel(username="jenifer") + assert sentinel(last_result=[{"name": "jack"}]) is False + assert sentinel(last_result=[{"name": "john"}, {"name": "jack"}]) is False + assert sentinel(last_result=[{"name": "hello"}, {"name": "jenifer"}]) is True + assert sentinel(last_result=[{"name": "JENIFER"}]) is True + + +def test_find_multiple_users_by_username_sentinel(): + sentinel = FindMultipleUsersByUsernameSentinel(usernames=["jenifer", "jack"]) + assert sentinel(last_result=[{"name": "jack"}]) is False + assert sentinel(last_result=[{"name": "john"}, {"name": "jack"}]) is False + assert sentinel(last_result=[{"name": "hello"}, {"name": "JENIFER"}]) is True + assert sentinel(last_result=[{"name": "world"}]) is True + + +def test_find_channel_by_name_sentinel(): + sentinel = FindChannelByNameSentinel(channel_name="foobar") + assert sentinel(last_result=[{"name": "foo"}]) is False + assert sentinel(last_result=[{"name": "foo"}, {"name": "bar"}]) is False + assert sentinel(last_result=[{"name": "foo"}, {"name": "foobar"}]) is True + assert sentinel(last_result=[{"name": "FOObar"}]) is True + + +@pytest.mark.asyncio +async def test_concurrency_safe_coroutine_caller(): + mock_func = AsyncMock() + mock_semaphore = AsyncMock(spec=asyncio.Semaphore) + + caller = ConcurrencySafeCoroutineCaller(mock_func, "arg1", "arg2", kwarg1="kwarg1") + response = await caller(semaphore=mock_semaphore) + + assert response == mock_func.return_value + mock_func.assert_awaited_once_with("arg1", "arg2", kwarg1="kwarg1") + mock_semaphore.__aenter__.assert_awaited_once() + mock_semaphore.__aexit__.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_user_by_email_caller_success(): + mock_func = AsyncMock() + mock_func.return_value = {"user": {"id": "U1234567890", "name": "John Doe"}} + mock_semaphore = AsyncMock(spec=asyncio.Semaphore) + + caller = GetUserByEmailCaller(mock_func, "test@example.com") + response = await caller(semaphore=mock_semaphore) + + assert response == { + "user": {"id": "U1234567890", "name": "John Doe"}, + "email": "test@example.com", + } + mock_func.assert_awaited_once_with(email="test@example.com") + mock_semaphore.__aenter__.assert_awaited_once() + mock_semaphore.__aexit__.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_user_by_email_caller_not_found_error(): + mock_func = AsyncMock( + side_effect=SlackApiError(message="User not found", response={"error": "user_not_found"}) + ) + mock_semaphore = AsyncMock(spec=asyncio.Semaphore) + + caller = GetUserByEmailCaller(mock_func, "test@example.com") + response = await caller(semaphore=mock_semaphore) + + assert response == { + "user": None, + "email": "test@example.com", + } + mock_func.assert_awaited_once_with(email="test@example.com") + mock_semaphore.__aenter__.assert_awaited_once() + mock_semaphore.__aexit__.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_get_user_by_email_caller_unknown_slack_api_error(): + mock_func = AsyncMock( + side_effect=SlackApiError(message="Unknown error", response={"error": "unknown_error"}) + ) + mock_semaphore = AsyncMock(spec=asyncio.Semaphore) + + caller = GetUserByEmailCaller(mock_func, "test@example.com") + with pytest.raises(ToolExecutionError) as error: + await caller(semaphore=mock_semaphore) + + assert error.value.message == "Error getting user by email" + assert error.value.developer_message == "Error getting user by email: unknown_error" + mock_func.assert_awaited_once_with(email="test@example.com") + mock_semaphore.__aenter__.assert_awaited_once() + mock_semaphore.__aexit__.assert_awaited_once() diff --git a/toolkits/slack/tests/test_user_retrieval.py b/toolkits/slack/tests/test_user_retrieval.py new file mode 100644 index 00000000..c584dd4b --- /dev/null +++ b/toolkits/slack/tests/test_user_retrieval.py @@ -0,0 +1,313 @@ +import json + +import pytest +from arcade_tdk.errors import RetryableToolError, ToolExecutionError +from slack_sdk.errors import SlackApiError + +from arcade_slack.user_retrieval import ( + get_single_user_by_id, + get_users_by_id, + get_users_by_id_username_or_email, +) +from arcade_slack.utils import ( + cast_user_dict, + extract_basic_user_info, + short_user_info, +) + + +@pytest.mark.asyncio +async def test_get_multiple_users_by_emails_success( + mock_context, mock_user_retrieval_slack_client, dummy_user_factory +): + user1 = dummy_user_factory() + user2 = dummy_user_factory() + + emails = [user1["profile"]["email"], user2["profile"]["email"]] + + mock_user_retrieval_slack_client.users_lookupByEmail.side_effect = [ + {"ok": True, "user": user1}, + {"ok": True, "user": user2}, + ] + + response = await get_users_by_id_username_or_email(context=mock_context, emails=emails) + + assert response == [extract_basic_user_info(user1), extract_basic_user_info(user2)] + + +@pytest.mark.asyncio +async def test_get_multiple_users_by_usernames_or_emails_with_emails_not_found( + mock_context, + mock_user_retrieval_slack_client, + mock_users_slack_client, + dummy_user_factory, +): + user1 = dummy_user_factory(email="user1@example.com") + + emails = [user1["profile"]["email"], "not_found@example.com"] + + async def lookup_by_email_side_effect(*, email): + if email == "user1@example.com": + return {"ok": True, "user": user1} + raise SlackApiError( + message="User not found", + response={"error": "user_not_found"}, + ) + + mock_user_retrieval_slack_client.users_lookupByEmail.side_effect = lookup_by_email_side_effect + + mock_users_slack_client.users_list.return_value = { + "ok": True, + "members": [user1], + } + + with pytest.raises(RetryableToolError) as error: + await get_users_by_id_username_or_email(context=mock_context, emails=emails) + + assert "not_found@example.com" in error.value.message + assert json.dumps(short_user_info(user1)) in error.value.additional_prompt_content + + +@pytest.mark.asyncio +async def test_get_multiple_users_by_usernames_or_emails_with_usernames_success( + mock_context, mock_user_retrieval_slack_client, dummy_user_factory +): + user1 = dummy_user_factory() + user2 = dummy_user_factory() + + usernames = [user1["name"], user2["name"]] + + mock_user_retrieval_slack_client.users_list.return_value = { + "ok": True, + "members": [user1, user2], + } + + response = await get_users_by_id_username_or_email(context=mock_context, usernames=usernames) + + assert response == [extract_basic_user_info(user1), extract_basic_user_info(user2)] + + +@pytest.mark.asyncio +async def test_get_multiple_users_by_usernames_or_emails_with_usernames_not_found( + mock_context, mock_user_retrieval_slack_client, dummy_user_factory +): + user1 = dummy_user_factory() + user2 = dummy_user_factory() + user3 = dummy_user_factory() + + usernames = [user1["name"], "username_not_found"] + + mock_user_retrieval_slack_client.users_list.return_value = { + "ok": True, + "members": [user1, user2, user3], + } + + with pytest.raises(RetryableToolError) as error: + await get_users_by_id_username_or_email(context=mock_context, usernames=usernames) + + assert "username_not_found" in error.value.message + assert json.dumps(short_user_info(user1)) in error.value.additional_prompt_content + assert json.dumps(short_user_info(user2)) in error.value.additional_prompt_content + assert json.dumps(short_user_info(user3)) in error.value.additional_prompt_content + + +@pytest.mark.asyncio +async def test_get_multiple_users_by_mixed_usernames_and_emails_success( + mock_context, mock_user_retrieval_slack_client, dummy_user_factory +): + user1 = dummy_user_factory() + user2 = dummy_user_factory() + user3 = dummy_user_factory() + user4 = dummy_user_factory() + + mock_user_retrieval_slack_client.users_list.return_value = { + "ok": True, + "members": [user1, user2], + } + mock_user_retrieval_slack_client.users_lookupByEmail.side_effect = [ + {"ok": True, "user": user3}, + {"ok": True, "user": user4}, + ] + + response = await get_users_by_id_username_or_email( + context=mock_context, + usernames=[user1["name"], user2["name"]], + emails=[user3["profile"]["email"], user4["profile"]["email"]], + ) + + assert response == [ + extract_basic_user_info(user1), + extract_basic_user_info(user2), + extract_basic_user_info(user3), + extract_basic_user_info(user4), + ] + + +@pytest.mark.asyncio +async def test_get_single_user_by_id_success( + mock_context, mock_user_retrieval_slack_client, dummy_user_factory +): + user = dummy_user_factory() + + mock_user_retrieval_slack_client.users_info.return_value = {"ok": True, "user": user} + + response = await get_single_user_by_id( + auth_token=mock_context.get_auth_token_or_empty(), + user_id=user["id"], + ) + + assert response == cast_user_dict(user) + + +@pytest.mark.asyncio +async def test_get_single_user_by_id_not_found( + mock_context, mock_user_retrieval_slack_client, dummy_user_factory +): + user = dummy_user_factory() + + mock_user_retrieval_slack_client.users_info.side_effect = SlackApiError( + message="User not found", + response={"error": "user_not_found"}, + ) + + response = await get_single_user_by_id( + auth_token=mock_context.get_auth_token_or_empty(), + user_id=user["id"], + ) + + assert response is None + + +@pytest.mark.asyncio +async def test_get_single_user_by_id_not_ok( + mock_context, mock_user_retrieval_slack_client, dummy_user_factory +): + user = dummy_user_factory() + + mock_user_retrieval_slack_client.users_info.return_value = {"ok": False, "error": "not_ok"} + + response = await get_single_user_by_id( + auth_token=mock_context.get_auth_token_or_empty(), + user_id=user["id"], + ) + + assert response is None + + +@pytest.mark.asyncio +async def test_get_single_user_by_id_unknown_error( + mock_context, mock_user_retrieval_slack_client, dummy_user_factory +): + user = dummy_user_factory() + + mock_user_retrieval_slack_client.users_info.side_effect = SlackApiError( + message="Unknown error", + response={"error": "unknown_error_string"}, + ) + + with pytest.raises(ToolExecutionError) as error: + await get_single_user_by_id( + auth_token=mock_context.get_auth_token_or_empty(), + user_id=user["id"], + ) + + assert user["id"] in error.value.message + assert "unknown_error_string" in error.value.developer_message + + +@pytest.mark.asyncio +async def test_get_users_by_id_with_one_user_id_success( + mock_context, mock_user_retrieval_slack_client, dummy_user_factory +): + user1 = dummy_user_factory() + + mock_user_retrieval_slack_client.users_info.return_value = {"ok": True, "user": user1} + + response = await get_users_by_id( + auth_token=mock_context.get_auth_token_or_empty(), + user_ids=[user1["id"]], + ) + + assert response == {"users": [cast_user_dict(user1)], "not_found": []} + + mock_user_retrieval_slack_client.users_list.assert_not_called() + + +@pytest.mark.asyncio +async def test_get_users_by_id_with_one_user_id_not_found( + mock_context, mock_user_retrieval_slack_client, dummy_user_factory +): + user1 = dummy_user_factory() + + mock_user_retrieval_slack_client.users_info.side_effect = SlackApiError( + message="User not found", + response={"error": "user_not_found"}, + ) + + response = await get_users_by_id( + auth_token=mock_context.get_auth_token_or_empty(), + user_ids=[user1["id"]], + ) + + assert response == {"users": [], "not_found": [user1["id"]]} + + mock_user_retrieval_slack_client.users_list.assert_not_called() + + +@pytest.mark.asyncio +async def test_get_users_by_id_with_multiple_user_ids_success( + mock_context, mock_user_retrieval_slack_client, dummy_user_factory +): + user1 = dummy_user_factory() + user2 = dummy_user_factory() + user3 = dummy_user_factory() + user4 = dummy_user_factory() + + mock_user_retrieval_slack_client.users_list.side_effect = [ + { + "ok": True, + "members": [user1, user2], + "response_metadata": {"next_cursor": "next_cursor"}, + }, + {"ok": True, "members": [user3, user4]}, + ] + + response = await get_users_by_id( + auth_token=mock_context.get_auth_token_or_empty(), + user_ids=[user1["id"], user4["id"]], + ) + + assert response == {"users": [cast_user_dict(user1), cast_user_dict(user4)], "not_found": []} + + +@pytest.mark.asyncio +async def test_get_users_by_id_with_multiple_user_ids_some_not_found( + mock_context, mock_user_retrieval_slack_client, dummy_user_factory +): + user1 = dummy_user_factory(id_="U1") + user2 = dummy_user_factory(id_="U2") + user3 = dummy_user_factory(id_="U3") + user4 = dummy_user_factory(id_="U4") + + mock_user_retrieval_slack_client.users_list.side_effect = [ + { + "ok": True, + "members": [user1, user2], + "response_metadata": {"next_cursor": "next_cursor"}, + }, + { + "ok": True, + "members": [user3, user4], + "response_metadata": {"next_cursor": None}, + }, + ] + + response = await get_users_by_id( + auth_token=mock_context.get_auth_token_or_empty(), + user_ids=[user1["id"], user4["id"], "user_not_exists"], + ) + + assert response == { + "users": [cast_user_dict(user1), cast_user_dict(user4)], + "not_found": ["user_not_exists"], + } diff --git a/toolkits/slack/tests/test_users.py b/toolkits/slack/tests/test_users.py index a9641e72..d2849bb7 100644 --- a/toolkits/slack/tests/test_users.py +++ b/toolkits/slack/tests/test_users.py @@ -1,21 +1,19 @@ +import json from unittest.mock import patch import pytest -from arcade_tdk.errors import RetryableToolError +from arcade_tdk.errors import RetryableToolError, ToolExecutionError from slack_sdk.errors import SlackApiError -from arcade_slack.tools.users import get_user_info_by_id, list_users -from arcade_slack.utils import extract_basic_user_info - - -@pytest.fixture -def mock_slack_client(mocker): - mock_client = mocker.patch("arcade_slack.tools.users.AsyncWebClient", autospec=True) - return mock_client.return_value +from arcade_slack.tools.users import ( + get_users_info, + list_users, +) +from arcade_slack.utils import extract_basic_user_info, short_user_info @pytest.mark.asyncio -async def test_get_user_info_by_id_success(mock_context, mock_slack_client): +async def test_get_user_info_by_id_success(mock_context, mock_user_retrieval_slack_client): # Mock the response from slackClient.users_info mock_user = { "id": "U12345", @@ -23,44 +21,46 @@ async def test_get_user_info_by_id_success(mock_context, mock_slack_client): "real_name": "Test User", "profile": {"email": "testuser@example.com"}, } - mock_slack_client.users_info.return_value = {"ok": True, "user": mock_user} + mock_user_retrieval_slack_client.users_info.return_value = {"ok": True, "user": mock_user} - # Call the function - response = await get_user_info_by_id(mock_context, user_id="U12345") + response = await get_users_info(mock_context, user_ids=["U12345"]) - # Verify that the correct Slack API method was called - mock_slack_client.users_info.assert_called_once_with(user="U12345") + mock_user_retrieval_slack_client.users_info.assert_called_once_with(user="U12345") - # Verify the response expected_response = extract_basic_user_info(mock_user) - assert response == expected_response + assert response == {"users": [expected_response]} @pytest.mark.asyncio @patch("arcade_slack.tools.users.list_users") -async def test_get_user_info_by_id_user_not_found(mock_list_users, mock_context, mock_slack_client): +async def test_get_user_info_by_id_user_not_found( + mock_list_users, mock_context, mock_user_retrieval_slack_client +): error_response = {"ok": False, "error": "user_not_found"} - mock_slack_client.users_info.side_effect = SlackApiError( + mock_user_retrieval_slack_client.users_info.side_effect = SlackApiError( message="User not found", response=error_response, ) existing_user = {"id": "U12345", "name": "testuser"} mock_list_users.return_value = {"users": [existing_user]} + mock_list_users.__tool_name__ = list_users.__tool_name__ with pytest.raises(RetryableToolError) as e: - await get_user_info_by_id(mock_context, user_id="U99999") + await get_users_info(mock_context, user_ids=["U99999"]) assert existing_user["id"] in e.value.additional_prompt_content - assert existing_user["name"] in e.value.additional_prompt_content - mock_slack_client.users_info.assert_called_once_with(user="U99999") - mock_list_users.assert_called_once_with(mock_context) + mock_user_retrieval_slack_client.users_info.assert_called_once_with(user="U99999") + mock_list_users.assert_called_once_with(mock_context, limit=100, exclude_bots=True) @pytest.mark.asyncio -async def test_list_users_success(mock_context, mock_slack_client): - mock_slack_client.users_list.return_value = {"ok": True, "members": [{"id": "U12345"}]} +async def test_list_users_success(mock_context, mock_users_slack_client): + mock_users_slack_client.users_list.return_value = { + "ok": True, + "members": [{"id": "U12345"}], + } response = await list_users(mock_context) assert response == { "users": [extract_basic_user_info({"id": "U12345"})], @@ -69,24 +69,306 @@ async def test_list_users_success(mock_context, mock_slack_client): @pytest.mark.asyncio -async def test_list_users_with_pagination_success(mock_context, mock_slack_client): - mock_slack_client.users_list.side_effect = [ +async def test_list_users_with_pagination_success( + dummy_user_factory, mock_context, mock_users_slack_client +): + user1 = dummy_user_factory() + user2 = dummy_user_factory() + user3 = dummy_user_factory() + user4 = dummy_user_factory() + + mock_users_slack_client.users_list.side_effect = [ { "ok": True, - "members": [{"id": "U12345"}], + "members": [user1, user2], "response_metadata": {"next_cursor": "cursor_xyz"}, }, { "ok": True, - "members": [{"id": "U123456"}], + "members": [user3, user4], "response_metadata": {"next_cursor": None}, }, ] response = await list_users(mock_context, limit=3) assert response == { "users": [ - extract_basic_user_info({"id": "U12345"}), - extract_basic_user_info({"id": "U123456"}), + extract_basic_user_info(user1), + extract_basic_user_info(user2), + extract_basic_user_info(user3), + extract_basic_user_info(user4), ], "next_cursor": None, } + + assert mock_users_slack_client.users_list.call_count == 2 + + +@pytest.mark.asyncio +async def test_get_user_by_username_success( + mock_context, + mock_user_retrieval_slack_client, + dummy_user_factory, +): + user1 = dummy_user_factory() + user2 = dummy_user_factory() + + mock_user_retrieval_slack_client.users_list.return_value = { + "ok": True, + "members": [user1, user2], + } + + response = await get_users_info(mock_context, usernames=[user1["name"]]) + + assert response == {"users": [extract_basic_user_info(user1)]} + + +@pytest.mark.asyncio +async def test_get_user_by_username_with_pagination_success( + mock_context, mock_user_retrieval_slack_client, dummy_user_factory +): + user1 = dummy_user_factory() + user2 = dummy_user_factory() + user3 = dummy_user_factory() + user4 = dummy_user_factory() + user5 = dummy_user_factory() + + mock_user_retrieval_slack_client.users_list.side_effect = [ + { + "ok": True, + "members": [user1, user2], + "response_metadata": {"next_cursor": "cursor1"}, + }, + { + "ok": True, + "members": [user3, user4], + "response_metadata": {"next_cursor": "cursor2"}, + }, + { + "ok": True, + "members": [user5], + "response_metadata": {"next_cursor": None}, + }, + ] + + response = await get_users_info(mock_context, usernames=[user3["name"]]) + + assert response == {"users": [extract_basic_user_info(user3)]} + + assert mock_user_retrieval_slack_client.users_list.call_count == 2 + + +@pytest.mark.asyncio +async def test_get_user_by_username_not_found( + mock_context, mock_user_retrieval_slack_client, dummy_user_factory +): + user1 = dummy_user_factory() + user2 = dummy_user_factory() + user3 = dummy_user_factory(is_bot=True) + + mock_user_retrieval_slack_client.users_list.return_value = { + "ok": True, + "members": [user1, user2, user3], + } + + with pytest.raises(RetryableToolError) as e: + await get_users_info(mock_context, usernames=[user1["name"] + "not_found"]) + + # Check that the error message contains the available users + assert user1["id"] in e.value.additional_prompt_content + assert user2["id"] in e.value.additional_prompt_content + assert user3["id"] not in e.value.additional_prompt_content + + +@pytest.mark.asyncio +async def test_get_multiple_users_by_username_success( + mock_context, mock_user_retrieval_slack_client, dummy_user_factory +): + user1 = dummy_user_factory() + user2 = dummy_user_factory() + user3 = dummy_user_factory() + + mock_user_retrieval_slack_client.users_list.return_value = { + "ok": True, + "members": [user1, user2, user3], + } + + response = await get_users_info( + mock_context, + usernames=[user1["name"], user2["name"]], + ) + + assert response == { + "users": [ + extract_basic_user_info(user1), + extract_basic_user_info(user2), + ] + } + + +@pytest.mark.asyncio +async def test_get_multiple_users_by_username_with_pagination_success( + mock_context, mock_user_retrieval_slack_client, dummy_user_factory +): + user1 = dummy_user_factory() + user2 = dummy_user_factory() + user3 = dummy_user_factory() + user4 = dummy_user_factory() + user5 = dummy_user_factory() + + mock_user_retrieval_slack_client.users_list.side_effect = [ + {"ok": True, "members": [user1, user2], "response_metadata": {"next_cursor": "cursor1"}}, + {"ok": True, "members": [user3, user4], "response_metadata": {"next_cursor": "cursor2"}}, + {"ok": True, "members": [user5], "response_metadata": {"next_cursor": None}}, + ] + + response = await get_users_info(mock_context, usernames=[user1["name"], user3["name"]]) + + assert response == {"users": [extract_basic_user_info(user1), extract_basic_user_info(user3)]} + assert mock_user_retrieval_slack_client.users_list.call_count == 2 + + +@pytest.mark.asyncio +async def test_get_multiple_users_by_username_not_found( + mock_context, mock_user_retrieval_slack_client, dummy_user_factory +): + user1 = dummy_user_factory() + user2 = dummy_user_factory() + user3 = dummy_user_factory(is_bot=True) + + mock_user_retrieval_slack_client.users_list.return_value = { + "ok": True, + "members": [user1, user2, user3], + } + + not_found_username = f"{user2['name']} not_found" + + with pytest.raises(RetryableToolError) as e: + await get_users_info(mock_context, usernames=[user1["name"], not_found_username]) + + assert user1["id"] in e.value.additional_prompt_content + assert user2["id"] in e.value.additional_prompt_content + assert user3["id"] not in e.value.additional_prompt_content + + +@pytest.mark.asyncio +async def test_get_user_by_email_success( + mock_context, mock_user_retrieval_slack_client, dummy_user_factory +): + user = dummy_user_factory() + mock_user_retrieval_slack_client.users_lookupByEmail.return_value = { + "ok": True, + "user": user, + } + response = await get_users_info(mock_context, emails=[user["profile"]["email"]]) + assert response == {"users": [extract_basic_user_info(user)]} + + +@pytest.mark.asyncio +async def test_get_user_by_email_not_found( + mock_context, mock_users_slack_client, mock_user_retrieval_slack_client, dummy_user_factory +): + additional_user = dummy_user_factory(email="additional_user@example.com") + + async def lookup_by_email_side_effect(*, email): + if email == "additional_user@example.com": + return {"ok": True, "user": additional_user} + raise SlackApiError( + message="User not found", + response={"ok": False, "error": "user_not_found"}, + ) + + mock_user_retrieval_slack_client.users_lookupByEmail.side_effect = lookup_by_email_side_effect + mock_users_slack_client.users_list.return_value = { + "ok": True, + "members": [additional_user], + } + + with pytest.raises(RetryableToolError) as e: + await get_users_info(mock_context, emails=["not_found@example.com"]) + + assert "not_found@example.com" in e.value.message + assert json.dumps(short_user_info(additional_user)) in e.value.additional_prompt_content + + +@pytest.mark.asyncio +async def test_get_multiple_users_by_email_success( + mock_context, mock_user_retrieval_slack_client, dummy_user_factory +): + user1 = dummy_user_factory() + user2 = dummy_user_factory() + user3 = dummy_user_factory() + + mock_user_retrieval_slack_client.users_lookupByEmail.side_effect = [ + {"ok": True, "user": user1}, + {"ok": True, "user": user2}, + {"ok": True, "user": user3}, + ] + + response = await get_users_info( + mock_context, + emails=[ + user1["profile"]["email"], + user2["profile"]["email"], + user3["profile"]["email"], + ], + ) + + assert response == { + "users": [ + extract_basic_user_info(user1), + extract_basic_user_info(user2), + extract_basic_user_info(user3), + ] + } + + +@pytest.mark.asyncio +async def test_get_multiple_users_by_email_with_invalid_address( + mock_context, mock_user_retrieval_slack_client +): + with pytest.raises(ToolExecutionError) as e: + await get_users_info( + mock_context, + emails=["amyra@jades.com", "invalid_address"], + ) + + mock_user_retrieval_slack_client.assert_not_called() + assert e.value.message == "Invalid email address: invalid_address" + + +@pytest.mark.asyncio +async def test_get_multiple_users_by_email_not_found( + mock_context, mock_user_retrieval_slack_client, mock_users_slack_client, dummy_user_factory +): + user1 = dummy_user_factory() + user2 = dummy_user_factory() + + async def lookup_by_email_side_effect(*, email): + if email == user1["profile"]["email"]: + return {"ok": True, "user": user1} + if email == user2["profile"]["email"]: + return {"ok": True, "user": user2} + raise SlackApiError( + message="User not found", + response={"ok": False, "error": "user_not_found"}, + ) + + mock_user_retrieval_slack_client.users_lookupByEmail.side_effect = lookup_by_email_side_effect + mock_users_slack_client.users_list.return_value = { + "ok": True, + "members": [user1, user2], + } + + with pytest.raises(RetryableToolError) as e: + await get_users_info( + mock_context, + emails=[ + user1["profile"]["email"], + user2["profile"]["email"], + "not_found@example.com", + ], + ) + + assert "not_found@example.com" in e.value.message + assert json.dumps(short_user_info(user1)) in e.value.additional_prompt_content + assert json.dumps(short_user_info(user2)) in e.value.additional_prompt_content diff --git a/toolkits/slack/tests/test_utils.py b/toolkits/slack/tests/test_utils.py index c095f7b8..dc4aec85 100644 --- a/toolkits/slack/tests/test_utils.py +++ b/toolkits/slack/tests/test_utils.py @@ -1,25 +1,26 @@ import asyncio -from unittest.mock import AsyncMock, MagicMock, call, patch +import copy +import json +from unittest.mock import AsyncMock, call, patch import pytest -from arcade_tdk import ToolContext +from arcade_tdk.errors import RetryableToolError from slack_sdk.errors import SlackApiError from slack_sdk.web.async_client import AsyncWebClient from arcade_slack.exceptions import PaginationTimeoutError from arcade_slack.models import ( - ConversationType, - ConversationTypeSlackName, -) -from arcade_slack.tools.chat import ( - get_members_in_conversation_by_id, - list_conversations_metadata, + ConcurrencySafeCoroutineCaller, + FindMultipleUsersByUsernameSentinel, + FindUserByUsernameSentinel, ) from arcade_slack.utils import ( async_paginate, - convert_conversation_type_to_slack_name, + build_multiple_users_retrieval_response, filter_conversations_by_user_ids, - retrieve_conversations_by_user_ids, + gather_with_concurrency_limit, + is_valid_email, + populate_users_in_messages, ) @@ -41,6 +42,170 @@ async def test_async_paginate(): assert next_cursor is None +@pytest.mark.asyncio +async def test_async_paginate_with_find_user_sentinel(): + mock_slack_client = AsyncMock() + mock_slack_client.users_list.side_effect = [ + { + "ok": True, + "members": [ + {"id": "123", "name": "Jack"}, + {"id": "456", "name": "John"}, + ], + "response_metadata": {"next_cursor": "cursor1"}, + }, + { + "ok": True, + "members": [{"id": "789", "name": "Jenifer"}], + "response_metadata": {"next_cursor": "cursor2"}, + }, + { + "ok": True, + "members": [{"id": "007", "name": "James"}], + "response_metadata": {"next_cursor": None}, + }, + ] + + results, next_cursor = await async_paginate( + func=mock_slack_client.users_list, + response_key="members", + sentinel=FindUserByUsernameSentinel(username="jenifer"), + ) + + assert results == [ + {"id": "123", "name": "Jack"}, + {"id": "456", "name": "John"}, + {"id": "789", "name": "Jenifer"}, + ] + assert next_cursor == "cursor2" + + +@pytest.mark.asyncio +async def test_async_paginate_with_find_user_sentinel_not_found(): + mock_slack_client = AsyncMock() + mock_slack_client.users_list.side_effect = [ + { + "ok": True, + "members": [ + {"id": "123", "name": "Jack"}, + {"id": "456", "name": "John"}, + ], + "response_metadata": {"next_cursor": "cursor1"}, + }, + { + "ok": True, + "members": [{"id": "789", "name": "Jenifer"}], + "response_metadata": {"next_cursor": "cursor2"}, + }, + { + "ok": True, + "members": [{"id": "007", "name": "James"}], + "response_metadata": {"next_cursor": None}, + }, + ] + + results, next_cursor = await async_paginate( + func=mock_slack_client.users_list, + response_key="members", + sentinel=FindUserByUsernameSentinel(username="Do not find me"), + ) + + assert results == [ + {"id": "123", "name": "Jack"}, + {"id": "456", "name": "John"}, + {"id": "789", "name": "Jenifer"}, + {"id": "007", "name": "James"}, + ] + assert next_cursor is None + + +@pytest.mark.asyncio +async def test_async_paginate_with_find_multiple_users_sentinel(): + mock_slack_client = AsyncMock() + mock_slack_client.users_list.side_effect = [ + { + "ok": True, + "members": [ + {"id": "123", "name": "Jack"}, + {"id": "456", "name": "John"}, + ], + "response_metadata": {"next_cursor": "cursor1"}, + }, + { + "ok": True, + "members": [ + {"id": "789", "name": "Jenifer"}, + {"id": "101", "name": "Janis"}, + ], + "response_metadata": {"next_cursor": "cursor2"}, + }, + { + "ok": True, + "members": [{"id": "007", "name": "James"}], + "response_metadata": {"next_cursor": None}, + }, + ] + + results, next_cursor = await async_paginate( + func=mock_slack_client.users_list, + response_key="members", + sentinel=FindMultipleUsersByUsernameSentinel(usernames=["jenifer", "jack"]), + ) + + assert results == [ + {"id": "123", "name": "Jack"}, + {"id": "456", "name": "John"}, + {"id": "789", "name": "Jenifer"}, + {"id": "101", "name": "Janis"}, + ] + assert next_cursor == "cursor2" + + +@pytest.mark.asyncio +async def test_async_paginate_with_find_multiple_users_sentinel_not_found(): + mock_slack_client = AsyncMock() + mock_slack_client.users_list.side_effect = [ + { + "ok": True, + "members": [ + {"id": "123", "name": "Jack"}, + {"id": "456", "name": "John"}, + ], + "response_metadata": {"next_cursor": "cursor1"}, + }, + { + "ok": True, + "members": [ + {"id": "789", "name": "Jenifer"}, + {"id": "101", "name": "Janis"}, + ], + "response_metadata": {"next_cursor": "cursor2"}, + }, + { + "ok": True, + "members": [{"id": "007", "name": "James"}], + "response_metadata": {"next_cursor": None}, + }, + ] + + results, next_cursor = await async_paginate( + func=mock_slack_client.users_list, + response_key="members", + sentinel=FindMultipleUsersByUsernameSentinel( + usernames=["jenifer", "jack", "do not find me"] + ), + ) + + assert results == [ + {"id": "123", "name": "Jack"}, + {"id": "456", "name": "John"}, + {"id": "789", "name": "Jenifer"}, + {"id": "101", "name": "Janis"}, + {"id": "007", "name": "James"}, + ] + assert next_cursor is None + + @pytest.mark.asyncio async def test_async_paginate_with_response_error(): mock_slack_client = AsyncMock() @@ -278,407 +443,247 @@ def test_filter_conversations_by_user_ids_exact_match_empty_response(): assert response == [] -@pytest.mark.asyncio @pytest.mark.parametrize( - "search_user_ids, conversation_types, exact_match, limit, expected_conversation_ids", + "users_by_email, users_by_username, expected_response", [ - (["U1", "U2"], [ConversationType.DIRECT_MESSAGE], False, 1, ["C1"]), - (["U1", "U2"], [ConversationType.DIRECT_MESSAGE], True, 1, ["C1"]), - (["U1", "U2", "U3"], [ConversationType.DIRECT_MESSAGE], False, 1, []), ( - ["U1", "U2"], - [ConversationType.DIRECT_MESSAGE, ConversationType.PUBLIC_CHANNEL], - False, - 10, - ["C1", "C3", "C4"], + {"users": [{"id": "U1", "name": "user1"}]}, + {"users": [{"id": "U2", "name": "user2"}]}, + [{"id": "U1", "name": "user1"}, {"id": "U2", "name": "user2"}], ), ( - ["U1", "U2"], - [ConversationType.DIRECT_MESSAGE, ConversationType.PUBLIC_CHANNEL], - True, - 10, - ["C1", "C3"], + {"users": [{"id": "U1", "name": "user1"}]}, + {"users": []}, + [{"id": "U1", "name": "user1"}], + ), + ( + {"users": []}, + {"users": [{"id": "U2", "name": "user2"}]}, + [{"id": "U2", "name": "user2"}], + ), + ( + {"users": []}, + {"users": []}, + [], ), ], ) -async def test_retrieve_conversations_by_user_ids( - mock_chat_slack_client, - mock_users_slack_client, - search_user_ids, - conversation_types, - exact_match, - limit, - expected_conversation_ids, +@pytest.mark.asyncio +async def test_build_multiple_users_retrieval_response_success( + users_by_email, + users_by_username, + expected_response, + mock_context, ): - context = MagicMock(spec=ToolContext) - context.authorization = MagicMock() - context.authorization.token = MagicMock() + response = await build_multiple_users_retrieval_response( + context=mock_context, + users_responses=[users_by_email, users_by_username], + ) + assert response == expected_response - conversation_types_slack_name_str = [ - convert_conversation_type_to_slack_name(conv_type).value - for conv_type in conversation_types or ConversationType - ] - conversations = [ - { - "conversation": { - "id": "C1", - "type": ConversationTypeSlackName.IM.value, - "name": "im-1", - "is_channel": False, - "is_im": True, - "is_member": True, +@pytest.mark.parametrize( + "users_by_email, users_by_username", + [ + # Both emails and usernames not found + ( + { + "users": [{"id": "U1", "name": "user1"}], + "not_found": ["email_not_found"], }, - "members": { - "ok": True, - "members": ["U1", "U2"], - "response_metadata": {"next_cursor": None}, + { + "users": [{"id": "U2", "name": "user2"}], + "not_found": ["username_not_found"], + "available_users": [{"id": "U3", "name": "user3"}], }, - "users": [ - {"ok": True, "user": {"id": "U1", "team_id": "T123", "name": "user1"}}, - {"ok": True, "user": {"id": "U2", "team_id": "T123", "name": "user2"}}, - ], - }, - { - "conversation": { - "id": "C2", - "type": ConversationTypeSlackName.IM.value, - "name": "im-2", - "is_channel": False, - "is_im": True, - "is_member": True, + ), + # Email not found, usernames found + ( + { + "users": [{"id": "U1", "name": "user1"}], + "not_found": ["email_not_found"], }, - "members": { - "ok": True, - "members": ["U2", "U3"], - "response_metadata": {"next_cursor": None}, + { + "users": [{"id": "U2", "name": "user2"}], + "not_found": [], }, - "users": [ - {"ok": True, "user": {"id": "U2", "team_id": "T123", "name": "user2"}}, - {"ok": True, "user": {"id": "U3", "team_id": "T123", "name": "user3"}}, - ], - }, - { - "conversation": { - "id": "C3", - "type": ConversationTypeSlackName.PUBLIC_CHANNEL.value, - "name": "general", - "is_channel": True, - "is_im": False, - "is_member": True, + ), + # Email found, username not found + ( + { + "users": [{"id": "U1", "name": "user1"}], + "not_found": [], }, - "members": { - "ok": True, - "members": ["U1", "U2"], - "response_metadata": {"next_cursor": None}, + { + "users": [{"id": "U2", "name": "user2"}], + "not_found": ["username_not_found"], + "available_users": [{"id": "U3", "name": "user3"}], }, - "users": [ - {"ok": True, "user": {"id": "U1", "team_id": "T123", "name": "user1"}}, - {"ok": True, "user": {"id": "U2", "team_id": "T123", "name": "user2"}}, - ], - }, - { - "conversation": { - "id": "C4", - "type": ConversationTypeSlackName.PUBLIC_CHANNEL.value, - "name": "random", - "is_channel": True, - "is_im": False, - "is_member": True, - }, - "members": { - "ok": True, - "members": ["U1", "U2", "U3", "U4"], - "response_metadata": {"next_cursor": None}, - }, - "users": [ - {"ok": True, "user": {"id": "U1", "team_id": "T123", "name": "user1"}}, - {"ok": True, "user": {"id": "U2", "team_id": "T123", "name": "user2"}}, - {"ok": True, "user": {"id": "U3", "team_id": "T123", "name": "user3"}}, - {"ok": True, "user": {"id": "U4", "team_id": "T123", "name": "user4"}}, - ], - }, - ] + ), + ], +) +@pytest.mark.asyncio +async def test_build_multiple_users_retrieval_response_not_found( + users_by_email, + users_by_username, + mock_context, +): + with pytest.raises(RetryableToolError) as error: + await build_multiple_users_retrieval_response( + context=mock_context, + users_responses=[users_by_email, users_by_username], + ) - conversations_listed = [ - conversation - for conversation in conversations - if conversation["conversation"]["type"] in conversation_types_slack_name_str - ] + emails_not_found = users_by_email.get("not_found", []) + usernames_not_found = users_by_username.get("not_found", []) + available_users = users_by_username.get("available_users", []) - mock_chat_slack_client.conversations_list.return_value = { - "ok": True, - "channels": [conversation["conversation"] for conversation in conversations_listed], - "response_metadata": {"next_cursor": None}, - } + for email in emails_not_found: + assert email in error.value.message + for username in usernames_not_found: + assert username in error.value.message + for user in available_users: + assert json.dumps(user) in error.value.additional_prompt_content - mock_chat_slack_client.conversations_members.side_effect = [ - conversation["members"] for conversation in conversations_listed - ] - mock_users_slack_client.users_info.side_effect = [ - user for conversation in conversations_listed for user in conversation["users"] - ] +def test_is_valid_email(): + assert is_valid_email("test@example.com") + assert is_valid_email("test+123@example.com") + assert is_valid_email("test-123@example.com") + assert is_valid_email("test_123@example.com") + assert is_valid_email("test.123@example.com") + assert is_valid_email("test123@example.com") + assert is_valid_email("test@example.co") + assert is_valid_email("test@example.com.co") + assert not is_valid_email("test123@example") + assert not is_valid_email("test@example") + assert not is_valid_email("test@example.") + assert not is_valid_email("test@.com") + assert not is_valid_email("test@example.c") + assert not is_valid_email("test@example.com.") + assert not is_valid_email("test@example.com.c") - conversations_found = await retrieve_conversations_by_user_ids( - list_conversations_func=list_conversations_metadata, - get_members_in_conversation_func=get_members_in_conversation_by_id, - context=context, - conversation_types=conversation_types, - user_ids=search_user_ids, - exact_match=exact_match, - limit=limit, - next_cursor=None, + +@pytest.mark.asyncio +async def test_gather_with_concurrency_limit(): + mock_func1 = AsyncMock() + mock_func2 = AsyncMock() + + caller1 = ConcurrencySafeCoroutineCaller(mock_func1, "arg1", "arg2", kwarg1="kwarg1") + caller2 = ConcurrencySafeCoroutineCaller(mock_func2, "arg1", "arg2", kwarg1="kwarg1") + + mock_semaphore = AsyncMock(spec=asyncio.Semaphore) + + response = await gather_with_concurrency_limit( + coroutine_callers=[caller1, caller2], + semaphore=mock_semaphore, ) - assert [conversation["id"] for conversation in conversations_found] == expected_conversation_ids + response = tuple(response) + + assert len(response) == 2 + assert response[0] == mock_func1.return_value + assert response[1] == mock_func2.return_value + + mock_func1.assert_awaited_once_with("arg1", "arg2", kwarg1="kwarg1") + mock_func2.assert_awaited_once_with("arg1", "arg2", kwarg1="kwarg1") + + assert mock_semaphore.__aenter__.await_count == 2 + assert mock_semaphore.__aexit__.await_count == 2 @pytest.mark.asyncio -@pytest.mark.parametrize( - ( - "search_user_ids, conversation_types, exact_match, limit, " - "expected_conversation_ids, expected_conversations_list_calls" - ), - [ - ( - ["U1", "U2", "U3"], - [ConversationType.MULTI_PERSON_DIRECT_MESSAGE], - False, - None, - ["C1", "C3"], - 2, - ), - ( - ["U1", "U2", "U3"], - [ConversationType.MULTI_PERSON_DIRECT_MESSAGE], - True, - None, - ["C1"], - 2, - ), - (["U1", "U2", "U99"], [ConversationType.MULTI_PERSON_DIRECT_MESSAGE], False, None, [], 2), - ( - ["U1", "U2"], - [ConversationType.MULTI_PERSON_DIRECT_MESSAGE, ConversationType.PUBLIC_CHANNEL], - False, - None, - ["C1", "C3", "C4", "C6"], - 2, - ), - ( - ["U1", "U2"], - [ConversationType.MULTI_PERSON_DIRECT_MESSAGE, ConversationType.PUBLIC_CHANNEL], - False, - 1, - ["C1"], - 2, - ), - ( - ["U1", "U2"], - [ConversationType.MULTI_PERSON_DIRECT_MESSAGE, ConversationType.PUBLIC_CHANNEL], - False, - 3, - ["C1", "C3", "C4"], - 2, - ), - ( - ["U1", "U2"], - [ConversationType.MULTI_PERSON_DIRECT_MESSAGE, ConversationType.PUBLIC_CHANNEL], - True, - None, - ["C4"], - 2, - ), - ], -) -async def test_retrieve_conversations_by_user_ids_with_pagination( - mock_chat_slack_client, - mock_users_slack_client, - search_user_ids, - conversation_types, - exact_match, - limit, - expected_conversation_ids, - expected_conversations_list_calls, +async def test_populate_users_in_messages( + mock_context, + mock_user_retrieval_slack_client, + dummy_message_factory, + dummy_reaction_factory, + dummy_user_factory, ): - context = MagicMock(spec=ToolContext) - context.authorization = MagicMock() - context.authorization.token = MagicMock() + user1 = dummy_user_factory(id_="U1", name="user1") + user2 = dummy_user_factory(id_="U2", name="user2") + user3 = dummy_user_factory(id_="U3", name="user3") + user4 = dummy_user_factory(id_="U4", name="user4") + user5 = dummy_user_factory(id_="U5", name="user5") - conversation_types_slack_name_str = [ - convert_conversation_type_to_slack_name(conv_type).value - for conv_type in conversation_types or ConversationType + user1_short = {"id": user1["id"], "name": user1["name"]} + user2_short = {"id": user2["id"], "name": user2["name"]} + user3_short = {"id": user3["id"], "name": user3["name"]} + user4_short = {"id": user4["id"], "name": user4["name"]} + + user2_mention = f"<@{user2['name']} (id:{user2['id']})>" + user5_mention = f"<@{user5['name']} (id:{user5['id']})>" + + reactions = [ + dummy_reaction_factory(name="thumbsup", user_ids=[user1["id"], user2["id"]]), + dummy_reaction_factory(name="partyparrot", user_ids=[user3["id"], user4["id"]]), ] - conversations = [ - { - "conversation": { - "id": "C1", - "type": ConversationTypeSlackName.MPIM.value, - "name": "mpim-1", - "is_channel": False, - "is_im": False, - "is_mpim": True, - "is_member": True, - }, - "members": { - "ok": True, - "members": ["U1", "U2", "U3"], - "response_metadata": {"next_cursor": None}, - }, - "users": [ - {"ok": True, "user": {"id": "U1", "team_id": "T123", "name": "user1"}}, - {"ok": True, "user": {"id": "U2", "team_id": "T123", "name": "user2"}}, - {"ok": True, "user": {"id": "U3", "team_id": "T123", "name": "user3"}}, - ], - }, - { - "conversation": { - "id": "C2", - "type": ConversationTypeSlackName.MPIM.value, - "name": "mpim-2", - "is_channel": False, - "is_im": False, - "is_mpim": True, - "is_member": True, - }, - "members": { - "ok": True, - "members": ["U2", "U3"], - "response_metadata": {"next_cursor": None}, - }, - "users": [ - {"ok": True, "user": {"id": "U2", "team_id": "T123", "name": "user2"}}, - {"ok": True, "user": {"id": "U3", "team_id": "T123", "name": "user3"}}, - ], - }, - { - "conversation": { - "id": "C3", - "type": ConversationTypeSlackName.MPIM.value, - "name": "mpim-3", - "is_channel": False, - "is_im": False, - "is_mpim": True, - "is_member": True, - }, - "members": { - "ok": True, - "members": ["U1", "U2", "U3", "U4"], - "response_metadata": {"next_cursor": None}, - }, - "users": [ - {"ok": True, "user": {"id": "U1", "team_id": "T123", "name": "user1"}}, - {"ok": True, "user": {"id": "U2", "team_id": "T123", "name": "user2"}}, - {"ok": True, "user": {"id": "U3", "team_id": "T123", "name": "user3"}}, - {"ok": True, "user": {"id": "U4", "team_id": "T123", "name": "user4"}}, - ], - }, - { - "conversation": { - "id": "C4", - "type": ConversationTypeSlackName.PUBLIC_CHANNEL.value, - "name": "channel-4", - "is_channel": True, - "is_im": False, - "is_member": True, - }, - "members": { - "ok": True, - "members": ["U1", "U2"], - "response_metadata": {"next_cursor": None}, - }, - "users": [ - {"ok": True, "user": {"id": "U1", "team_id": "T123", "name": "user1"}}, - {"ok": True, "user": {"id": "U2", "team_id": "T123", "name": "user2"}}, - ], - }, - { - "conversation": { - "id": "C5", - "type": ConversationTypeSlackName.PUBLIC_CHANNEL.value, - "name": "channel-5", - "is_channel": True, - "is_im": False, - "is_member": True, - }, - "members": { - "ok": True, - "members": ["U2", "U3", "U4"], - "response_metadata": {"next_cursor": None}, - }, - "users": [ - {"ok": True, "user": {"id": "U2", "team_id": "T123", "name": "user2"}}, - {"ok": True, "user": {"id": "U3", "team_id": "T123", "name": "user3"}}, - {"ok": True, "user": {"id": "U4", "team_id": "T123", "name": "user4"}}, - ], - }, - { - "conversation": { - "id": "C6", - "type": ConversationTypeSlackName.PUBLIC_CHANNEL.value, - "name": "channel-6", - "is_channel": True, - "is_im": False, - "is_member": True, - }, - "members": { - "ok": True, - "members": ["U1", "U2", "U3", "U4"], - "response_metadata": {"next_cursor": None}, - }, - "users": [ - {"ok": True, "user": {"id": "U1", "team_id": "T123", "name": "user1"}}, - {"ok": True, "user": {"id": "U2", "team_id": "T123", "name": "user2"}}, - {"ok": True, "user": {"id": "U3", "team_id": "T123", "name": "user3"}}, - {"ok": True, "user": {"id": "U4", "team_id": "T123", "name": "user4"}}, - ], - }, + messages = [ + dummy_message_factory( + user_id=user1["id"], + text=f"Hello <@{user2['id']}>", + ), + dummy_message_factory( + user_id=user2["id"], + text="foobar", + reactions=copy.deepcopy(reactions[:1]), + ), + dummy_message_factory( + user_id=user3["id"], + text=f"Is this @{user5['id']} a user mention?", + ), + dummy_message_factory( + user_id=user4["id"], + text="hello", + reactions=copy.deepcopy(reactions), + ), ] - conversations_listed = [ - conversation - for conversation in conversations - if conversation["conversation"]["type"] in conversation_types_slack_name_str - ] - - split_size = len(conversations_listed) // 2 - - conversations_listed_1 = conversations_listed[:split_size] - conversations_listed_2 = conversations_listed[split_size:] - - mock_chat_slack_client.conversations_list.side_effect = [ + mock_user_retrieval_slack_client.users_list.side_effect = [ { "ok": True, - "channels": [conversation["conversation"] for conversation in conversations_listed_1], - "response_metadata": {"next_cursor": "cursor_1"}, + "members": [user1, user2, user3], + "response_metadata": {"next_cursor": "cursor1"}, }, { "ok": True, - "channels": [conversation["conversation"] for conversation in conversations_listed_2], + "members": [user4, user5], "response_metadata": {"next_cursor": None}, }, ] - mock_chat_slack_client.conversations_members.side_effect = [ - conversation["members"] for conversation in conversations_listed - ] - - mock_users_slack_client.users_info.side_effect = [ - user for conversation in conversations_listed for user in conversation["users"] - ] - - conversations_found = await retrieve_conversations_by_user_ids( - list_conversations_func=list_conversations_metadata, - get_members_in_conversation_func=get_members_in_conversation_by_id, - context=context, - conversation_types=conversation_types, - user_ids=search_user_ids, - exact_match=exact_match, - limit=limit, - next_cursor=None, + response = await populate_users_in_messages( + auth_token=mock_context.get_auth_token_or_empty(), + messages=messages, ) - assert [conversation["id"] for conversation in conversations_found] == expected_conversation_ids - assert mock_chat_slack_client.conversations_list.call_count == expected_conversations_list_calls + msg1 = response[0] + msg2 = response[1] + msg3 = response[2] + msg4 = response[3] + + assert msg1["user"] == user1_short + assert msg1["text"] == f"Hello {user2_mention}" + assert "reactions" not in msg1 + + assert msg2["user"] == user2_short + assert msg2["text"] == "foobar" + assert "reactions" in msg2 + assert len(msg2["reactions"]) == 1 + assert msg2["reactions"][0]["name"] == "thumbsup" + assert msg2["reactions"][0]["users"] == [user1_short, user2_short] + + assert msg3["user"] == user3_short + assert msg3["text"] == f"Is this @{user5['id']} a user mention?" + assert "reactions" not in msg3 + assert user5_mention not in msg3["text"] + + assert msg4["user"] == user4_short + assert msg4["text"] == "hello" + assert "reactions" in msg4 + assert len(msg4["reactions"]) == 2 + assert msg4["reactions"][0]["name"] == "thumbsup" + assert msg4["reactions"][0]["users"] == [user1_short, user2_short] + assert msg4["reactions"][1]["name"] == "partyparrot" + assert msg4["reactions"][1]["users"] == [user3_short, user4_short]