Slack Toolkit Refactoring (#453)
# Backwards-compatible refactoring of the Slack toolkit Several performance improvements, streamlined tool set, and easier to understand tool interfaces. All "old" tools were kept for backwards compatibility, with the same interfaces and response structure (but using the new and more performant tools under the hood). Full revision of unit-tests and evals. ## Streamlined tool set Multiple groups of tools were streamlined into a single one: * "get conversation metadata" from 5 tools to one; * "send message" from 2 tools to one; * "get users in conversation" from 3 tools to one; * "get messages" from 4 tools to one ## New capabilities * Messages retrieved are now populated with the users' names, apart from ID: makes it easier for LLMs to reference who sent a message, were mentioned, or reacted to a message * Retrieve users by username, email, and/or ID (before we only supported ID) * Retrieve multiple users in a single tool call ## Concurrency controls All operations issuing multiple requests concurrently now have a `Semaphore` to limit the concurrency level. The limit can be controlled through the `SLACK_MAX_CONCURRENT_REQUESTS` env var (defaults to 3). ## Networking performance improvement Various operations that used to make multiple API calls are now executed more efficiently: ### Find users by username * Before: a full scan of `users_list` was required (potentially multiple pages for large workspaces); * Now it stops as soon as we have all users needed (yes, it was dumb before) ### Get multiple users by their IDs * Before: for each user ID, we made one API call to the `users_info` endpoint * Now: we call `list_users` and scan the results to match the user IDs (an estimate of 99.5% of Slack workspaces have < 200 users; for large workspaces, we may need to paginate `list_users`) ### Get a conversation by its users * Before: * Call to `list_conversations` (potentially paginating) * For each conversation, one call to `conversations_members` (potentially paginating) * Then loop and find which conversation matches the users' IDs * Now: * A single call to `conversations_open`
This commit is contained in:
parent
07c52100f3
commit
ed89af4b4d
27 changed files with 5385 additions and 2896 deletions
|
|
@ -49,6 +49,6 @@ def generate_google_file_picker_url(
|
|||
"url": url,
|
||||
"llm_instructions": (
|
||||
"Instruct the user to click the following link to open the Google Drive File Picker. "
|
||||
"This will allow them to select files and grant access permissions: {url}"
|
||||
f"This will allow them to select files and grant access permissions: {url}"
|
||||
),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,13 +1,18 @@
|
|||
import os
|
||||
|
||||
from arcade_slack.custom_types import PositiveInt
|
||||
from arcade_slack.custom_types import PositiveNonZeroInt
|
||||
|
||||
MAX_PAGINATION_SIZE_LIMIT = 200
|
||||
|
||||
MAX_PAGINATION_TIMEOUT_SECONDS = PositiveInt(
|
||||
MAX_PAGINATION_TIMEOUT_SECONDS = PositiveNonZeroInt(
|
||||
os.environ.get(
|
||||
"MAX_PAGINATION_TIMEOUT_SECONDS",
|
||||
os.environ.get("MAX_SLACK_PAGINATION_TIMEOUT_SECONDS", 30),
|
||||
),
|
||||
name="MAX_PAGINATION_TIMEOUT_SECONDS or MAX_SLACK_PAGINATION_TIMEOUT_SECONDS",
|
||||
)
|
||||
|
||||
MAX_CONCURRENT_REQUESTS = PositiveNonZeroInt(
|
||||
os.environ.get("SLACK_MAX_CONCURRENT_REQUESTS", 3),
|
||||
name="SLACK_MAX_CONCURRENT_REQUESTS",
|
||||
)
|
||||
|
|
|
|||
74
toolkits/slack/arcade_slack/conversation_retrieval.py
Normal file
74
toolkits/slack/arcade_slack/conversation_retrieval.py
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
import json
|
||||
from typing import cast
|
||||
|
||||
from arcade_tdk.errors import RetryableToolError, ToolExecutionError
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from slack_sdk.web.async_client import AsyncWebClient
|
||||
|
||||
from arcade_slack.models import (
|
||||
ConversationType,
|
||||
FindChannelByNameSentinel,
|
||||
)
|
||||
from arcade_slack.utils import (
|
||||
async_paginate,
|
||||
extract_conversation_metadata,
|
||||
)
|
||||
|
||||
|
||||
async def get_conversation_by_id(
|
||||
auth_token: str,
|
||||
conversation_id: str,
|
||||
) -> dict:
|
||||
"""Get metadata of a conversation in Slack by the conversation_id."""
|
||||
try:
|
||||
slack_client = AsyncWebClient(token=auth_token)
|
||||
response = await slack_client.conversations_info(
|
||||
channel=conversation_id,
|
||||
include_locale=True,
|
||||
include_num_members=True,
|
||||
)
|
||||
return dict(**extract_conversation_metadata(response["channel"]))
|
||||
|
||||
except SlackApiError as e:
|
||||
slack_error = cast(str, e.response.get("error", ""))
|
||||
if "not_found" in slack_error.lower():
|
||||
message = f"Conversation with ID '{conversation_id}' not found."
|
||||
raise ToolExecutionError(message=message, developer_message=message)
|
||||
raise
|
||||
|
||||
|
||||
async def get_channel_by_name(
|
||||
auth_token: str,
|
||||
channel_name: str,
|
||||
) -> dict:
|
||||
channel_name_casefolded = channel_name.lstrip("#").casefold()
|
||||
|
||||
slack_client = AsyncWebClient(token=auth_token)
|
||||
|
||||
results, _ = await async_paginate(
|
||||
func=slack_client.conversations_list,
|
||||
response_key="channels",
|
||||
types=",".join([
|
||||
ConversationType.PUBLIC_CHANNEL.value,
|
||||
ConversationType.PRIVATE_CHANNEL.value,
|
||||
]),
|
||||
exclude_archived=True,
|
||||
sentinel=FindChannelByNameSentinel(channel_name_casefolded),
|
||||
)
|
||||
|
||||
available_channels = []
|
||||
|
||||
for channel in results:
|
||||
if channel["name"].casefold() == channel_name_casefolded:
|
||||
return dict(**extract_conversation_metadata(channel))
|
||||
else:
|
||||
available_channels.append({"id": channel["id"], "name": channel["name"]})
|
||||
|
||||
error_message = f"Channel with name '{channel_name}' not found."
|
||||
|
||||
raise RetryableToolError(
|
||||
message=error_message,
|
||||
developer_message=error_message,
|
||||
additional_prompt_content=f"Available channels: {json.dumps(available_channels)}",
|
||||
retry_after_ms=500,
|
||||
)
|
||||
|
|
@ -1,11 +1,11 @@
|
|||
from typing import NewType
|
||||
|
||||
|
||||
class PositiveInt(int):
|
||||
def __new__(cls, value: str | int, name: str = "value") -> "PositiveInt":
|
||||
class PositiveNonZeroInt(int):
|
||||
def __new__(cls, value: str | int, name: str = "value") -> "PositiveNonZeroInt":
|
||||
def validate(val: int) -> int:
|
||||
if val <= 0:
|
||||
raise ValueError(f"{name} must be positive, got {val}")
|
||||
if val < 1:
|
||||
raise ValueError(f"{name} must be a positive non-zero integer, got {val}")
|
||||
return val
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -8,23 +8,3 @@ class PaginationTimeoutError(SlackToolkitError):
|
|||
def __init__(self, timeout_seconds: int):
|
||||
self.timeout_seconds = timeout_seconds
|
||||
super().__init__(f"The pagination process timed out after {timeout_seconds} seconds.")
|
||||
|
||||
|
||||
class ItemNotFoundError(SlackToolkitError):
|
||||
"""Raised when an item is not found."""
|
||||
|
||||
|
||||
class UsernameNotFoundError(SlackToolkitError):
|
||||
"""Raised when a user is not found by the username searched"""
|
||||
|
||||
def __init__(self, usernames_found: list[str], username_not_found: str) -> None:
|
||||
self.usernames_found = usernames_found
|
||||
self.username_not_found = username_not_found
|
||||
|
||||
|
||||
class ConversationNotFoundError(SlackToolkitError):
|
||||
"""Raised when a conversation is not found"""
|
||||
|
||||
|
||||
class DirectMessageConversationNotFoundError(ConversationNotFoundError):
|
||||
"""Raised when a direct message conversation searched is not found"""
|
||||
|
|
|
|||
76
toolkits/slack/arcade_slack/message_retrieval.py
Normal file
76
toolkits/slack/arcade_slack/message_retrieval.py
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from arcade_tdk.errors import ToolExecutionError
|
||||
from slack_sdk.web.async_client import AsyncWebClient
|
||||
|
||||
from arcade_slack.utils import (
|
||||
async_paginate,
|
||||
convert_datetime_to_unix_timestamp,
|
||||
convert_relative_datetime_to_unix_timestamp,
|
||||
enrich_message_datetime,
|
||||
)
|
||||
|
||||
|
||||
async def retrieve_messages_in_conversation(
|
||||
conversation_id: str,
|
||||
auth_token: str | None = None,
|
||||
oldest_relative: str | None = None,
|
||||
latest_relative: str | None = None,
|
||||
oldest_datetime: str | None = None,
|
||||
latest_datetime: str | None = None,
|
||||
limit: int | None = None,
|
||||
next_cursor: str | None = None,
|
||||
) -> dict:
|
||||
error_message = None
|
||||
if oldest_datetime and oldest_relative:
|
||||
error_message = "Cannot specify both 'oldest_datetime' and 'oldest_relative'."
|
||||
|
||||
if latest_datetime and latest_relative:
|
||||
error_message = "Cannot specify both 'latest_datetime' and 'latest_relative'."
|
||||
|
||||
if error_message:
|
||||
raise ToolExecutionError(error_message, developer_message=error_message)
|
||||
|
||||
current_unix_timestamp = int(datetime.now(timezone.utc).timestamp())
|
||||
|
||||
if latest_relative:
|
||||
latest_timestamp = convert_relative_datetime_to_unix_timestamp(
|
||||
latest_relative, current_unix_timestamp
|
||||
)
|
||||
elif latest_datetime:
|
||||
latest_timestamp = convert_datetime_to_unix_timestamp(latest_datetime)
|
||||
else:
|
||||
latest_timestamp = None
|
||||
|
||||
if oldest_relative:
|
||||
oldest_timestamp = convert_relative_datetime_to_unix_timestamp(
|
||||
oldest_relative, current_unix_timestamp
|
||||
)
|
||||
elif oldest_datetime:
|
||||
oldest_timestamp = convert_datetime_to_unix_timestamp(oldest_datetime)
|
||||
else:
|
||||
oldest_timestamp = None
|
||||
|
||||
datetime_args: dict[str, Any] = {}
|
||||
if oldest_timestamp:
|
||||
datetime_args["oldest"] = oldest_timestamp
|
||||
if latest_timestamp:
|
||||
datetime_args["latest"] = latest_timestamp
|
||||
|
||||
slackClient = AsyncWebClient(token=auth_token)
|
||||
|
||||
response, next_cursor = await async_paginate(
|
||||
slackClient.conversations_history,
|
||||
"messages",
|
||||
limit=limit,
|
||||
next_cursor=next_cursor,
|
||||
channel=conversation_id,
|
||||
include_all_metadata=True,
|
||||
inclusive=True, # Include messages at the start and end of the time range
|
||||
**datetime_args,
|
||||
)
|
||||
|
||||
messages = [enrich_message_datetime(message) for message in response]
|
||||
|
||||
return {"messages": messages, "next_cursor": next_cursor}
|
||||
|
|
@ -1,5 +1,12 @@
|
|||
import asyncio
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Awaitable, Callable
|
||||
from contextlib import suppress
|
||||
from enum import Enum
|
||||
from typing import Literal, TypedDict
|
||||
from typing import Any, Literal, TypedDict
|
||||
|
||||
from arcade_tdk.errors import ToolExecutionError
|
||||
from slack_sdk.errors import SlackApiError
|
||||
|
||||
from arcade_slack.custom_types import (
|
||||
SlackOffsetSecondsFromUTC,
|
||||
|
|
@ -24,6 +31,16 @@ class ConversationType(str, Enum):
|
|||
MULTI_PERSON_DIRECT_MESSAGE = "multi_person_direct_message"
|
||||
DIRECT_MESSAGE = "direct_message"
|
||||
|
||||
def to_slack_name_str(self) -> str:
|
||||
mapping = {
|
||||
ConversationType.PUBLIC_CHANNEL: ConversationTypeSlackName.PUBLIC_CHANNEL.value,
|
||||
ConversationType.PRIVATE_CHANNEL: ConversationTypeSlackName.PRIVATE_CHANNEL.value,
|
||||
ConversationType.MULTI_PERSON_DIRECT_MESSAGE: ConversationTypeSlackName.MPIM.value,
|
||||
ConversationType.DIRECT_MESSAGE: ConversationTypeSlackName.IM.value,
|
||||
}
|
||||
|
||||
return mapping[self]
|
||||
|
||||
|
||||
"""
|
||||
About Slack dictionaries: Slack does not guarantee the presence of all fields for a given
|
||||
|
|
@ -204,3 +221,150 @@ class SlackConversationsToolResponse(TypedDict, total=True):
|
|||
|
||||
conversations: list[ConversationMetadata]
|
||||
next_cursor: SlackPaginationNextCursor | None
|
||||
|
||||
|
||||
class PaginationSentinel(ABC):
|
||||
"""Base class for pagination sentinel classes."""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
self.kwargs = kwargs
|
||||
|
||||
@abstractmethod
|
||||
def __call__(self, last_result: Any) -> bool:
|
||||
"""Determine if the pagination should stop."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class FindUserByUsernameSentinel(PaginationSentinel):
|
||||
"""Sentinel class for finding a user by username."""
|
||||
|
||||
def __call__(self, last_result: Any) -> bool:
|
||||
for user in last_result:
|
||||
if not isinstance(user.get("name"), str):
|
||||
continue
|
||||
if user.get("name").casefold() == self.kwargs["username"].casefold():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class FindMultipleUsersByUsernameSentinel(PaginationSentinel):
|
||||
"""Sentinel class for finding multiple users by username."""
|
||||
|
||||
def __init__(self, usernames: list[str]) -> None:
|
||||
if not usernames:
|
||||
raise ValueError("usernames must be a non-empty list of strings")
|
||||
super().__init__(usernames=usernames)
|
||||
self.usernames_pending = {username.casefold() for username in usernames}
|
||||
|
||||
def _flag_username_found(self, username: str) -> None:
|
||||
with suppress(KeyError):
|
||||
self.usernames_pending.remove(username.casefold())
|
||||
|
||||
def _all_usernames_found(self) -> bool:
|
||||
return not self.usernames_pending
|
||||
|
||||
def __call__(self, last_result: Any) -> bool:
|
||||
if not self.usernames_pending:
|
||||
return True
|
||||
for user in last_result:
|
||||
username = user.get("name")
|
||||
if not isinstance(username, str):
|
||||
continue
|
||||
if username.casefold() in self.usernames_pending:
|
||||
self._flag_username_found(username)
|
||||
if self._all_usernames_found():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class FindMultipleUsersByIdSentinel(PaginationSentinel):
|
||||
"""Sentinel class for finding multiple users by ID."""
|
||||
|
||||
def __init__(self, user_ids: list[str]) -> None:
|
||||
if not user_ids:
|
||||
raise ValueError("user_ids must be a non-empty list of strings")
|
||||
super().__init__(user_ids=user_ids)
|
||||
self.user_ids_pending = set(user_ids)
|
||||
|
||||
def _flag_user_id_found(self, user_id: str) -> None:
|
||||
with suppress(KeyError):
|
||||
self.user_ids_pending.remove(user_id.casefold())
|
||||
|
||||
def _all_user_ids_found(self) -> bool:
|
||||
return not self.user_ids_pending
|
||||
|
||||
def __call__(self, last_result: Any) -> bool:
|
||||
if not self.user_ids_pending:
|
||||
return True
|
||||
for user in last_result:
|
||||
user_id = user.get("id")
|
||||
if user_id in self.user_ids_pending:
|
||||
self._flag_user_id_found(user_id)
|
||||
if self._all_user_ids_found():
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class FindChannelByNameSentinel(PaginationSentinel):
|
||||
"""Sentinel class for finding a channel by name."""
|
||||
|
||||
def __init__(self, channel_name: str) -> None:
|
||||
super().__init__(channel_name=channel_name)
|
||||
self.channel_name_casefold = channel_name.casefold()
|
||||
|
||||
def __call__(self, last_result: Any) -> bool:
|
||||
for channel in last_result:
|
||||
channel_name = channel.get("name")
|
||||
if not isinstance(channel_name, str):
|
||||
continue
|
||||
if channel_name.casefold() == self.channel_name_casefold:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
class AbstractConcurrencySafeCoroutineCaller(ABC):
|
||||
"""Abstract base class for concurrency-safe coroutine callers."""
|
||||
|
||||
def __init__(self, func: Callable[..., Awaitable[Any]], *args: Any, **kwargs: Any) -> None:
|
||||
self.func = func
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
|
||||
@abstractmethod
|
||||
async def __call__(self, semaphore: asyncio.Semaphore) -> Any:
|
||||
"""Call a coroutine with a semaphore."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ConcurrencySafeCoroutineCaller(AbstractConcurrencySafeCoroutineCaller):
|
||||
"""Calls a coroutine with an asyncio semaphore."""
|
||||
|
||||
async def __call__(self, semaphore: asyncio.Semaphore) -> Any:
|
||||
async with semaphore:
|
||||
return await self.func(*self.args, **self.kwargs)
|
||||
|
||||
|
||||
class GetUserByEmailCaller(AbstractConcurrencySafeCoroutineCaller):
|
||||
"""Call Slack's lookupByEmail method with an asyncio semaphore while handling API errors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
func: Callable[..., Awaitable[Any]],
|
||||
email: str,
|
||||
) -> None:
|
||||
super().__init__(func)
|
||||
self.email = email
|
||||
|
||||
async def __call__(self, semaphore: asyncio.Semaphore) -> dict[str, Any]:
|
||||
async with semaphore:
|
||||
try:
|
||||
user = await self.func(email=self.email)
|
||||
return {"user": user["user"], "email": self.email}
|
||||
except SlackApiError as e:
|
||||
if e.response.get("error") in ["user_not_found", "users_not_found"]:
|
||||
return {"user": None, "email": self.email}
|
||||
else:
|
||||
raise ToolExecutionError(
|
||||
message="Error getting user by email",
|
||||
developer_message=f"Error getting user by email: {e.response.get('error')}",
|
||||
)
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -2,12 +2,13 @@ from typing import Annotated, Any, cast
|
|||
|
||||
from arcade_tdk import ToolContext, tool
|
||||
from arcade_tdk.auth import Slack
|
||||
from arcade_tdk.errors import RetryableToolError
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from slack_sdk.web.async_client import AsyncWebClient
|
||||
|
||||
from arcade_slack.constants import MAX_PAGINATION_TIMEOUT_SECONDS
|
||||
from arcade_slack.models import SlackPaginationNextCursor, SlackUser
|
||||
from arcade_slack.models import (
|
||||
SlackPaginationNextCursor,
|
||||
)
|
||||
from arcade_slack.user_retrieval import get_users_by_id_username_or_email
|
||||
from arcade_slack.utils import (
|
||||
async_paginate,
|
||||
extract_basic_user_info,
|
||||
|
|
@ -16,62 +17,58 @@ from arcade_slack.utils import (
|
|||
)
|
||||
|
||||
|
||||
@tool(
|
||||
requires_auth=Slack(
|
||||
scopes=["users:read", "users:read.email"],
|
||||
)
|
||||
)
|
||||
async def get_user_info_by_id(
|
||||
@tool(requires_auth=Slack(scopes=["users:read", "users:read.email"]))
|
||||
async def get_users_info(
|
||||
context: ToolContext,
|
||||
user_id: Annotated[str, "The ID of the user to get"],
|
||||
) -> Annotated[dict[str, Any], "The user's information"]:
|
||||
"""Get the information of a user in Slack."""
|
||||
user_ids: Annotated[list[str] | None, "The IDs of the users to get"] = None,
|
||||
usernames: Annotated[
|
||||
list[str] | None,
|
||||
"The usernames of the users to get. Prefer retrieving by user_ids and/or emails, "
|
||||
"when available, since the performance is better.",
|
||||
] = None,
|
||||
emails: Annotated[list[str] | None, "The emails of the users to get"] = None,
|
||||
) -> Annotated[dict[str, Any], "The users' information"]:
|
||||
"""Get the information of one or more users in Slack by ID, username, and/or email.
|
||||
|
||||
token = (
|
||||
context.authorization.token if context.authorization and context.authorization.token else ""
|
||||
)
|
||||
slackClient = AsyncWebClient(token=token)
|
||||
Provide any combination of user_ids, usernames, and/or emails. If you need to retrieve
|
||||
data about multiple users, DO NOT CALL THE TOOL MULTIPLE TIMES. Instead, call it once
|
||||
with all the user_ids, usernames, and/or emails. IF YOU CALL THIS TOOL MULTIPLE TIMES
|
||||
UNNECESSARILY, YOU WILL RELEASE MORE CO2 IN THE ATMOSPHERE AND CONTRIBUTE TO GLOBAL WARMING.
|
||||
|
||||
try:
|
||||
response = await slackClient.users_info(user=user_id)
|
||||
except SlackApiError as e:
|
||||
if e.response.get("error") == "user_not_found":
|
||||
users = await list_users(context)
|
||||
available_users = ", ".join(f"{user['id']} ({user['name']})" for user in users["users"])
|
||||
|
||||
raise RetryableToolError(
|
||||
"User not found",
|
||||
developer_message=f"User with ID '{user_id}' not found.",
|
||||
additional_prompt_content=f"Available users: {available_users}",
|
||||
retry_after_ms=500,
|
||||
)
|
||||
|
||||
user_dict_raw: dict[str, Any] = response.get("user", {}) or {}
|
||||
user_dict = cast(SlackUser, user_dict_raw)
|
||||
user = SlackUser(**user_dict)
|
||||
return dict(**extract_basic_user_info(user))
|
||||
If you need to get metadata or messages of a conversation, use the
|
||||
`Slack.GetConversationMetadata` or `Slack.GetMessages` tool instead. These
|
||||
tools accept user_ids, usernames, and/or emails. Do not retrieve users' info first,
|
||||
as it is inefficient, releases more CO2 in the atmosphere, and contributes to climate change.
|
||||
"""
|
||||
users = await get_users_by_id_username_or_email(context, user_ids, usernames, emails)
|
||||
return {"users": users}
|
||||
|
||||
|
||||
@tool(
|
||||
requires_auth=Slack(
|
||||
scopes=["users:read", "users:read.email"],
|
||||
)
|
||||
)
|
||||
@tool(requires_auth=Slack(scopes=["users:read", "users:read.email"]))
|
||||
async def list_users(
|
||||
context: ToolContext,
|
||||
exclude_bots: Annotated[bool | None, "Whether to exclude bots from the results"] = True,
|
||||
limit: Annotated[int | None, "The maximum number of users to return."] = None,
|
||||
exclude_bots: Annotated[
|
||||
bool | None, "Whether to exclude bots from the results. Defaults to True."
|
||||
] = True,
|
||||
limit: Annotated[
|
||||
int,
|
||||
# The user object is relatively small, so we allow a higher limit than the default of 200.
|
||||
"The maximum number of users to return. Defaults to 200. Maximum is 500.",
|
||||
] = 200,
|
||||
next_cursor: Annotated[str | None, "The next cursor token to use for pagination."] = None,
|
||||
) -> Annotated[dict, "The users' info"]:
|
||||
"""List all users in the authenticated user's Slack team."""
|
||||
"""List all users in the authenticated user's Slack team.
|
||||
|
||||
token = (
|
||||
context.authorization.token if context.authorization and context.authorization.token else ""
|
||||
)
|
||||
slackClient = AsyncWebClient(token=token)
|
||||
If you need to get metadata or messages of a conversation, use the
|
||||
`Slack.GetConversationMetadata` tool or `Slack.GetMessages` tool instead. These
|
||||
tools accept a user_id, username, and/or email. Do not use this tool to first retrieve user(s),
|
||||
as it is inefficient and releases more CO2 in the atmosphere, contributing to climate change.
|
||||
"""
|
||||
limit = max(1, min(limit, 500))
|
||||
slack_client = AsyncWebClient(token=context.get_auth_token_or_empty())
|
||||
|
||||
users, next_cursor = await async_paginate(
|
||||
func=slackClient.users_list,
|
||||
func=slack_client.users_list,
|
||||
response_key="members",
|
||||
limit=limit,
|
||||
next_cursor=cast(SlackPaginationNextCursor, next_cursor),
|
||||
|
|
@ -85,3 +82,18 @@ async def list_users(
|
|||
]
|
||||
|
||||
return {"users": users, "next_cursor": next_cursor}
|
||||
|
||||
|
||||
# NOTE: This tool is kept here for backwards compatibility.
|
||||
# Use the `Slack.GetUsersInfo` tool instead.
|
||||
@tool(requires_auth=Slack(scopes=["users:read", "users:read.email"]))
|
||||
async def get_user_info_by_id(
|
||||
context: ToolContext,
|
||||
user_id: Annotated[str, "The ID of the user to get"],
|
||||
) -> Annotated[dict[str, Any], "The user's information"]:
|
||||
"""Get the information of a user in Slack.
|
||||
|
||||
This tool is deprecated. Use the `Slack.GetUsersInfo` tool instead.
|
||||
"""
|
||||
users = await get_users_info(context, user_ids=[user_id])
|
||||
return cast(dict[str, Any], users["users"][0])
|
||||
|
|
|
|||
214
toolkits/slack/arcade_slack/user_retrieval.py
Normal file
214
toolkits/slack/arcade_slack/user_retrieval.py
Normal file
|
|
@ -0,0 +1,214 @@
|
|||
import asyncio
|
||||
from typing import Any, cast
|
||||
|
||||
from arcade_tdk import ToolContext
|
||||
from arcade_tdk.errors import ToolExecutionError
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from slack_sdk.web.async_client import AsyncWebClient
|
||||
|
||||
from arcade_slack.constants import MAX_CONCURRENT_REQUESTS, MAX_PAGINATION_TIMEOUT_SECONDS
|
||||
from arcade_slack.models import (
|
||||
FindMultipleUsersByIdSentinel,
|
||||
FindMultipleUsersByUsernameSentinel,
|
||||
GetUserByEmailCaller,
|
||||
)
|
||||
from arcade_slack.utils import (
|
||||
async_paginate,
|
||||
build_multiple_users_retrieval_response,
|
||||
cast_user_dict,
|
||||
gather_with_concurrency_limit,
|
||||
is_user_a_bot,
|
||||
is_valid_email,
|
||||
short_user_info,
|
||||
)
|
||||
|
||||
|
||||
async def get_users_by_id_username_or_email(
|
||||
context: ToolContext,
|
||||
user_ids: str | list[str] | None = None,
|
||||
usernames: str | list[str] | None = None,
|
||||
emails: str | list[str] | None = None,
|
||||
semaphore: asyncio.Semaphore | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Get the metadata of a user by their ID, username, or email.
|
||||
|
||||
Provide any combination of user_ids, usernames, and/or emails. Always prefer providing user_ids
|
||||
and/or emails, when available, since the performance is better.
|
||||
"""
|
||||
if isinstance(user_ids, str):
|
||||
user_ids = [user_ids]
|
||||
if isinstance(usernames, str):
|
||||
usernames = [usernames]
|
||||
if isinstance(emails, str):
|
||||
emails = [emails]
|
||||
|
||||
if not any([user_ids, usernames, emails]):
|
||||
raise ToolExecutionError("At least one of user_ids, usernames, or emails must be provided")
|
||||
|
||||
if not semaphore:
|
||||
semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
|
||||
|
||||
user_retrieval_calls = []
|
||||
|
||||
auth_token = context.get_auth_token_or_empty()
|
||||
|
||||
if user_ids:
|
||||
user_retrieval_calls.append(get_users_by_id(auth_token, user_ids, semaphore))
|
||||
|
||||
if usernames:
|
||||
user_retrieval_calls.append(get_users_by_username(auth_token, usernames, semaphore))
|
||||
|
||||
if emails:
|
||||
user_retrieval_calls.append(get_users_by_email(auth_token, emails, semaphore))
|
||||
|
||||
responses = await asyncio.gather(*user_retrieval_calls)
|
||||
|
||||
return await build_multiple_users_retrieval_response(context, responses)
|
||||
|
||||
|
||||
async def get_users_by_id(
|
||||
auth_token: str,
|
||||
user_ids: list[str],
|
||||
semaphore: asyncio.Semaphore | None = None,
|
||||
) -> dict[str, list]:
|
||||
user_ids = list(set(user_ids))
|
||||
|
||||
if len(user_ids) == 1:
|
||||
user = await get_single_user_by_id(auth_token, user_ids[0])
|
||||
if not user:
|
||||
return {"users": [], "not_found": user_ids}
|
||||
else:
|
||||
return {"users": [user], "not_found": []}
|
||||
|
||||
if not semaphore:
|
||||
semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
|
||||
|
||||
async with semaphore:
|
||||
slack_client = AsyncWebClient(token=auth_token)
|
||||
response, _ = await async_paginate(
|
||||
func=slack_client.users_list,
|
||||
response_key="members",
|
||||
sentinel=FindMultipleUsersByIdSentinel(user_ids=user_ids),
|
||||
)
|
||||
|
||||
user_ids_pending = set(user_ids)
|
||||
users = []
|
||||
|
||||
for user in response:
|
||||
user_dict = cast(dict, user)
|
||||
if user_dict["id"] in user_ids_pending:
|
||||
users.append(cast_user_dict(user_dict))
|
||||
user_ids_pending.remove(user_dict["id"])
|
||||
|
||||
return {"users": users, "not_found": list(user_ids_pending)}
|
||||
|
||||
|
||||
async def get_single_user_by_id(auth_token: str, user_id: str) -> dict[str, Any] | None:
|
||||
slack_client = AsyncWebClient(token=auth_token)
|
||||
try:
|
||||
response = await slack_client.users_info(user=user_id)
|
||||
if not response.get("ok"):
|
||||
return None
|
||||
return cast_user_dict(response["user"])
|
||||
except SlackApiError as e:
|
||||
if "not_found" in e.response.get("error", ""):
|
||||
return None
|
||||
else:
|
||||
message = f"There was an error getting the user with ID {user_id}."
|
||||
slack_error_message = e.response.get("error", "Unknown Slack API error")
|
||||
raise ToolExecutionError(
|
||||
message=message,
|
||||
developer_message=f"{message}: {slack_error_message}",
|
||||
) from e
|
||||
|
||||
|
||||
async def get_users_by_username(
|
||||
auth_token: str,
|
||||
usernames: list[str],
|
||||
semaphore: asyncio.Semaphore | None = None,
|
||||
) -> dict[str, list[dict]]:
|
||||
usernames = list(set(usernames))
|
||||
|
||||
if not semaphore:
|
||||
semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
|
||||
|
||||
slack_client = AsyncWebClient(token=auth_token)
|
||||
|
||||
async with semaphore:
|
||||
users, _ = await async_paginate(
|
||||
func=slack_client.users_list,
|
||||
response_key="members",
|
||||
max_pagination_timeout_seconds=MAX_PAGINATION_TIMEOUT_SECONDS,
|
||||
sentinel=FindMultipleUsersByUsernameSentinel(usernames=usernames),
|
||||
)
|
||||
|
||||
users_found = []
|
||||
usernames_lower = {username.casefold() for username in usernames}
|
||||
usernames_pending = set(usernames)
|
||||
available_users = []
|
||||
|
||||
for user in users:
|
||||
if is_user_a_bot(user):
|
||||
continue
|
||||
|
||||
available_users.append(short_user_info(user))
|
||||
|
||||
if not isinstance(user.get("name"), str):
|
||||
continue
|
||||
|
||||
username_lower = user["name"].casefold()
|
||||
|
||||
if username_lower in usernames_lower:
|
||||
users_found.append(cast_user_dict(user))
|
||||
# Username/handle is unique in Slack, we can ignore it after finding a match
|
||||
for pending_username in usernames_pending:
|
||||
if pending_username.casefold() == username_lower:
|
||||
usernames_pending.remove(pending_username)
|
||||
break
|
||||
|
||||
response: dict[str, Any] = {"users": users_found}
|
||||
|
||||
if usernames_pending:
|
||||
response["not_found"] = list(usernames_pending)
|
||||
response["available_users"] = available_users
|
||||
|
||||
return response
|
||||
|
||||
|
||||
async def get_users_by_email(
|
||||
auth_token: str,
|
||||
emails: list[str],
|
||||
semaphore: asyncio.Semaphore | None = None,
|
||||
) -> dict[str, list[dict]]:
|
||||
emails = list(set(emails))
|
||||
|
||||
for email in emails:
|
||||
if not is_valid_email(email):
|
||||
raise ToolExecutionError(f"Invalid email address: {email}")
|
||||
|
||||
if not semaphore:
|
||||
semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS)
|
||||
|
||||
slack_client = AsyncWebClient(token=auth_token)
|
||||
callers = [GetUserByEmailCaller(slack_client.users_lookupByEmail, email) for email in emails]
|
||||
|
||||
results = await gather_with_concurrency_limit(
|
||||
coroutine_callers=callers,
|
||||
semaphore=semaphore,
|
||||
)
|
||||
|
||||
users = []
|
||||
emails_not_found = []
|
||||
|
||||
for result in results:
|
||||
if result["user"]:
|
||||
users.append(cast_user_dict(result["user"]))
|
||||
else:
|
||||
emails_not_found.append(result["email"])
|
||||
|
||||
response: dict[str, Any] = {"users": users}
|
||||
|
||||
if emails_not_found:
|
||||
response["not_found"] = emails_not_found
|
||||
|
||||
return response
|
||||
|
|
@ -1,23 +1,28 @@
|
|||
import asyncio
|
||||
from collections.abc import Callable
|
||||
import json
|
||||
import re
|
||||
from collections.abc import Callable, Sequence
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from arcade_tdk import ToolContext
|
||||
from arcade_tdk.errors import RetryableToolError
|
||||
|
||||
from arcade_slack.constants import MAX_PAGINATION_SIZE_LIMIT, MAX_PAGINATION_TIMEOUT_SECONDS
|
||||
from arcade_slack.custom_types import SlackPaginationNextCursor
|
||||
from arcade_slack.exceptions import (
|
||||
PaginationTimeoutError,
|
||||
UsernameNotFoundError,
|
||||
from arcade_slack.constants import (
|
||||
MAX_CONCURRENT_REQUESTS,
|
||||
MAX_PAGINATION_SIZE_LIMIT,
|
||||
MAX_PAGINATION_TIMEOUT_SECONDS,
|
||||
)
|
||||
from arcade_slack.custom_types import SlackPaginationNextCursor
|
||||
from arcade_slack.exceptions import PaginationTimeoutError
|
||||
from arcade_slack.models import (
|
||||
AbstractConcurrencySafeCoroutineCaller,
|
||||
BasicUserInfo,
|
||||
ConversationMetadata,
|
||||
ConversationType,
|
||||
ConversationTypeSlackName,
|
||||
Message,
|
||||
PaginationSentinel,
|
||||
SlackConversation,
|
||||
SlackConversationPurpose,
|
||||
SlackMessage,
|
||||
|
|
@ -75,7 +80,7 @@ def remove_none_values(params: dict) -> dict:
|
|||
return {k: v for k, v in params.items() if v is not None}
|
||||
|
||||
|
||||
def get_slack_conversation_type_as_str(channel: SlackConversation) -> str:
|
||||
def get_slack_conversation_type_as_str(channel: SlackConversation) -> str | None:
|
||||
"""Get the type of conversation from a Slack channel's dictionary.
|
||||
|
||||
Args:
|
||||
|
|
@ -92,19 +97,7 @@ def get_slack_conversation_type_as_str(channel: SlackConversation) -> str:
|
|||
return ConversationTypeSlackName.IM.value
|
||||
if channel.get("is_mpim"):
|
||||
return ConversationTypeSlackName.MPIM.value
|
||||
raise ValueError(f"Invalid conversation type in channel {channel.get('name')}")
|
||||
|
||||
|
||||
def get_user_by_username(username: str, users_list: list[dict]) -> SlackUser:
|
||||
usernames_found = []
|
||||
for user in users_list:
|
||||
if isinstance(user.get("name"), str):
|
||||
usernames_found.append(user["name"])
|
||||
username_found = user.get("name") or ""
|
||||
if username.lower() == username_found.lower():
|
||||
return SlackUser(**user)
|
||||
|
||||
raise UsernameNotFoundError(usernames_found=usernames_found, username_not_found=username)
|
||||
raise ValueError(f"Invalid conversation type in channel: {json.dumps(channel)}")
|
||||
|
||||
|
||||
def convert_conversation_type_to_slack_name(
|
||||
|
|
@ -191,62 +184,6 @@ def extract_basic_user_info(user_info: SlackUser) -> BasicUserInfo:
|
|||
)
|
||||
|
||||
|
||||
async def associate_members_of_multiple_conversations(
|
||||
get_members_in_conversation_func: Callable,
|
||||
conversations: list[dict],
|
||||
context: ToolContext,
|
||||
) -> list[dict]:
|
||||
"""Associate members to each conversation, returning the updated list."""
|
||||
return await asyncio.gather(*[ # type: ignore[no-any-return]
|
||||
associate_members_of_conversation(get_members_in_conversation_func, context, conv)
|
||||
for conv in conversations
|
||||
])
|
||||
|
||||
|
||||
async def associate_members_of_conversation(
|
||||
get_members_in_conversation_func: Callable,
|
||||
context: ToolContext,
|
||||
conversation: dict,
|
||||
) -> dict:
|
||||
response = await get_members_in_conversation_func(context, conversation["id"])
|
||||
conversation["members"] = response["members"]
|
||||
return conversation
|
||||
|
||||
|
||||
async def retrieve_conversations_by_user_ids(
|
||||
list_conversations_func: Callable,
|
||||
get_members_in_conversation_func: Callable,
|
||||
context: ToolContext,
|
||||
conversation_types: list[ConversationType],
|
||||
user_ids: list[str],
|
||||
exact_match: bool = False,
|
||||
limit: int | None = None,
|
||||
next_cursor: str | None = None,
|
||||
) -> list[dict]:
|
||||
"""
|
||||
Retrieve conversations filtered by the given user IDs. Includes pagination support
|
||||
and optionally limits the number of returned conversations.
|
||||
"""
|
||||
conversations_found: list[dict] = []
|
||||
|
||||
response = await list_conversations_func(
|
||||
context=context,
|
||||
conversation_types=conversation_types,
|
||||
next_cursor=next_cursor,
|
||||
)
|
||||
|
||||
# Associate members to each conversation
|
||||
conversations_with_members = await associate_members_of_multiple_conversations(
|
||||
get_members_in_conversation_func, response["conversations"], context
|
||||
)
|
||||
|
||||
conversations_found.extend(
|
||||
filter_conversations_by_user_ids(conversations_with_members, user_ids, exact_match)
|
||||
)
|
||||
|
||||
return conversations_found[:limit]
|
||||
|
||||
|
||||
def filter_conversations_by_user_ids(
|
||||
conversations: list[dict],
|
||||
user_ids: list[str],
|
||||
|
|
@ -317,6 +254,7 @@ async def async_paginate(
|
|||
limit: int | None = None,
|
||||
next_cursor: SlackPaginationNextCursor | None = None,
|
||||
max_pagination_timeout_seconds: int = MAX_PAGINATION_TIMEOUT_SECONDS,
|
||||
sentinel: PaginationSentinel | None = None,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> tuple[list, SlackPaginationNextCursor | None]:
|
||||
|
|
@ -332,6 +270,10 @@ async def async_paginate(
|
|||
not provided, the entire response dictionary is used.
|
||||
limit: The maximum number of items to retrieve (defaults to Slack's suggested limit).
|
||||
next_cursor: The cursor to use for pagination (optional).
|
||||
max_pagination_timeout_seconds: The maximum timeout for the pagination loop (defaults to
|
||||
MAX_PAGINATION_TIMEOUT_SECONDS).
|
||||
sentinel: Control whether the pagination should continue after each iteration (optional).
|
||||
If provided, the pagination will stop when the sentinel function returns True.
|
||||
*args: Positional arguments to pass to the Slack method.
|
||||
**kwargs: Keyword arguments to pass to the Slack method.
|
||||
|
||||
|
|
@ -358,13 +300,18 @@ async def async_paginate(
|
|||
response = await func(*args, **iteration_kwargs)
|
||||
|
||||
try:
|
||||
results.extend(dict(response.data) if not response_key else response[response_key])
|
||||
result = dict(response.data) if not response_key else response[response_key]
|
||||
results.extend(result)
|
||||
except KeyError:
|
||||
raise ValueError(f"Response key {response_key} not found in Slack response")
|
||||
|
||||
next_cursor = response.get("response_metadata", {}).get("next_cursor")
|
||||
|
||||
if (limit and len(results) >= limit) or not next_cursor:
|
||||
if (
|
||||
(sentinel and sentinel(last_result=result))
|
||||
or (limit and len(results) >= limit)
|
||||
or not next_cursor
|
||||
):
|
||||
should_continue = False
|
||||
|
||||
return results
|
||||
|
|
@ -445,3 +392,215 @@ def convert_relative_datetime_to_unix_timestamp(
|
|||
days, hours, minutes = map(int, relative_datetime.split(":"))
|
||||
seconds = days * 86400 + hours * 3600 + minutes * 60
|
||||
return int(current_unix_timestamp - seconds)
|
||||
|
||||
|
||||
def short_user_info(user: dict) -> dict[str, str | None]:
|
||||
data = {"id": user.get("id")}
|
||||
if user.get("name"):
|
||||
data["name"] = user["name"]
|
||||
if isinstance(user.get("profile"), dict) and user["profile"].get("email"):
|
||||
data["email"] = user["profile"]["email"]
|
||||
elif user.get("email"):
|
||||
data["email"] = user["email"]
|
||||
return data
|
||||
|
||||
|
||||
def short_human_users_info(users: list[dict]) -> list[dict[str, str | None]]:
|
||||
return [short_user_info(user) for user in users if not user.get("is_bot")]
|
||||
|
||||
|
||||
def is_valid_email(email: str) -> bool:
|
||||
"""Validate an email address using regex.
|
||||
|
||||
Args:
|
||||
email: The email address to validate.
|
||||
|
||||
Returns:
|
||||
True if the email is valid, False otherwise.
|
||||
"""
|
||||
email_pattern = r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$"
|
||||
return bool(re.match(email_pattern, email))
|
||||
|
||||
|
||||
async def build_multiple_users_retrieval_response(
|
||||
context: ToolContext,
|
||||
users_responses: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Builds response list for the get_multiple_users_by_usernames_or_emails function."""
|
||||
await raise_for_users_not_found(context, users_responses)
|
||||
|
||||
users = []
|
||||
|
||||
for users_response in users_responses:
|
||||
users.extend(users_response["users"])
|
||||
|
||||
return cast(list[dict[str, Any]], users)
|
||||
|
||||
|
||||
async def raise_for_users_not_found(
|
||||
context: ToolContext, users_responses: list[dict[str, Any]]
|
||||
) -> None:
|
||||
"""Raise an error if any user was not found in the responses."""
|
||||
users_not_found, available_users = collect_users_not_found_in_responses(users_responses)
|
||||
|
||||
if users_not_found:
|
||||
not_found_message = ", ".join(users_not_found)
|
||||
s = "" if len(users_not_found) == 1 else "s"
|
||||
message = f"User{s} not found: {not_found_message}"
|
||||
available_users_prompt = await get_available_users_prompt(context, available_users)
|
||||
|
||||
raise RetryableToolError(
|
||||
message=message,
|
||||
developer_message=message,
|
||||
additional_prompt_content=available_users_prompt,
|
||||
retry_after_ms=500,
|
||||
)
|
||||
|
||||
|
||||
def collect_users_not_found_in_responses(
|
||||
responses: list[dict[str, Any]],
|
||||
) -> tuple[list[str], list[dict[str, Any]]]:
|
||||
users_not_found = []
|
||||
available_users = []
|
||||
|
||||
for response in responses:
|
||||
if response.get("not_found"):
|
||||
users_not_found.extend(response["not_found"])
|
||||
if response.get("available_users"):
|
||||
available_users = response["available_users"]
|
||||
|
||||
return users_not_found, available_users
|
||||
|
||||
|
||||
async def get_available_users_prompt(
|
||||
context: ToolContext,
|
||||
available_users: list[dict] | None = None,
|
||||
limit: int = 100,
|
||||
) -> str:
|
||||
try:
|
||||
from arcade_slack.tools.users import list_users # Avoid circular import
|
||||
|
||||
if isinstance(available_users, list) and available_users:
|
||||
available_users = [
|
||||
user for user in available_users if not is_user_a_bot(SlackUser(**user))
|
||||
]
|
||||
available_users_str = json.dumps(short_human_users_info(available_users))
|
||||
next_cursor = None
|
||||
potentially_more_users = True
|
||||
else:
|
||||
users = await list_users(context, limit=limit, exclude_bots=True)
|
||||
next_cursor = users["next_cursor"]
|
||||
available_users_str = json.dumps(short_human_users_info(users["users"]))
|
||||
potentially_more_users = bool(next_cursor)
|
||||
|
||||
if not potentially_more_users:
|
||||
return f"The users available are: {available_users_str}"
|
||||
else:
|
||||
msg = (
|
||||
f"Some of the available users are: {available_users_str}. Potentially more users "
|
||||
f"can be retrieved by calling the 'Slack.{list_users.__tool_name__}' tool"
|
||||
)
|
||||
if next_cursor:
|
||||
msg += f" using the next cursor: '{next_cursor}' to continue pagination."
|
||||
return msg
|
||||
except Exception as e:
|
||||
return (
|
||||
"The tool tried to retrieve a list of available users, but failed with error: "
|
||||
f"{type(e).__name__}: {e!s}. Use the 'Slack.{list_users.__tool_name__}' tool "
|
||||
"to get a list of users."
|
||||
)
|
||||
|
||||
|
||||
async def gather_with_concurrency_limit(
|
||||
coroutine_callers: Sequence[AbstractConcurrencySafeCoroutineCaller],
|
||||
semaphore: asyncio.Semaphore | None = None,
|
||||
max_concurrent_requests: int = MAX_CONCURRENT_REQUESTS,
|
||||
) -> list[Any]:
|
||||
if not semaphore:
|
||||
semaphore = asyncio.Semaphore(max_concurrent_requests)
|
||||
|
||||
return await asyncio.gather(*[caller(semaphore) for caller in coroutine_callers]) # type: ignore[no-any-return]
|
||||
|
||||
|
||||
def cast_user_dict(user: dict[str, Any]) -> dict[str, Any]:
|
||||
slack_user = SlackUser(**cast(dict, user))
|
||||
return dict(**extract_basic_user_info(slack_user))
|
||||
|
||||
|
||||
async def populate_users_in_messages(auth_token: str, messages: list[dict]) -> list[dict]:
|
||||
if not messages:
|
||||
return messages
|
||||
|
||||
users = await get_users_from_messages(auth_token, messages)
|
||||
users_by_id = {user["id"]: {"id": user["id"], "name": user["name"]} for user in users}
|
||||
|
||||
for message in messages:
|
||||
if message.get("type") != "message":
|
||||
continue
|
||||
|
||||
# Message author
|
||||
message["user"] = users_by_id.get(
|
||||
message.get("user"), {"id": message["user"], "name": None}
|
||||
)
|
||||
|
||||
# User mentions in the message text
|
||||
text_mentions = re.findall(r"<@([A-Z0-9]+)>", message.get("text", ""))
|
||||
for user_id in text_mentions:
|
||||
if user_id in users_by_id:
|
||||
user = users_by_id.get(user_id, {"id": user_id, "name": None})
|
||||
name = user.get("name")
|
||||
message["text"] = message["text"].replace(
|
||||
f"<@{user_id}>", f"<@{name} (id:{user_id})>" if name else f"<@{user_id}>"
|
||||
)
|
||||
|
||||
# User mentions in reactions
|
||||
reactions = message.get("reactions")
|
||||
if isinstance(reactions, list):
|
||||
for reaction in reactions:
|
||||
reaction_users = []
|
||||
for user_id in reaction.get("users", []):
|
||||
reaction_users.append(users_by_id.get(user_id, {"id": user_id, "name": None}))
|
||||
reaction["users"] = reaction_users
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
async def get_users_from_messages(auth_token: str, messages: list[dict]) -> list[dict[str, Any]]:
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
from arcade_slack.user_retrieval import get_users_by_id # Avoid circular import
|
||||
|
||||
user_ids = get_user_ids_from_messages(messages)
|
||||
response = await get_users_by_id(auth_token, user_ids)
|
||||
print("\n\n\nresponse:", response, "\n\n\n")
|
||||
return response["users"]
|
||||
|
||||
|
||||
def get_user_ids_from_messages(messages: list[dict]) -> list[str]:
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
user_ids = []
|
||||
|
||||
for message in messages:
|
||||
if message.get("type") != "message":
|
||||
continue
|
||||
|
||||
# Message author
|
||||
user = message.get("user")
|
||||
if isinstance(user, str) and user:
|
||||
user_ids.append(user)
|
||||
|
||||
# User mentions in the message text
|
||||
text = message.get("text")
|
||||
if isinstance(text, str) and text:
|
||||
user_ids.extend(re.findall(r"<@([A-Z0-9]+)>", text))
|
||||
|
||||
# User mentions in reactions
|
||||
reactions = message.get("reactions")
|
||||
if isinstance(reactions, list):
|
||||
for reaction in reactions:
|
||||
user_ids.extend(reaction.get("users", []))
|
||||
|
||||
return user_ids
|
||||
|
|
|
|||
|
|
@ -1,3 +1,7 @@
|
|||
import random
|
||||
import string
|
||||
from collections.abc import Callable
|
||||
|
||||
import pytest
|
||||
from arcade_tdk import ToolAuthorizationContext, ToolContext
|
||||
|
||||
|
|
@ -18,3 +22,142 @@ def mock_chat_slack_client(mocker):
|
|||
def mock_users_slack_client(mocker):
|
||||
mock_client = mocker.patch("arcade_slack.tools.users.AsyncWebClient", autospec=True)
|
||||
return mock_client.return_value
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_retrieval_slack_client(mocker):
|
||||
mock_client = mocker.patch("arcade_slack.user_retrieval.AsyncWebClient", autospec=True)
|
||||
return mock_client.return_value
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_conversation_retrieval_slack_client(mocker):
|
||||
mock_client = mocker.patch("arcade_slack.conversation_retrieval.AsyncWebClient", autospec=True)
|
||||
return mock_client.return_value
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message_retrieval_slack_client(mocker):
|
||||
mock_client = mocker.patch("arcade_slack.message_retrieval.AsyncWebClient", autospec=True)
|
||||
return mock_client.return_value
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def random_str_factory():
|
||||
def random_str_factory(length: int = 10):
|
||||
return "".join(random.choices(string.ascii_letters + string.digits, k=length)) # noqa: S311
|
||||
|
||||
return random_str_factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def random_ts_factory():
|
||||
def random_ts_factory():
|
||||
return f"{random.uniform(1735689600.000000, 1751327999.999999)}" # noqa: S311
|
||||
|
||||
return random_ts_factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_channel_factory(random_str_factory: Callable[[int], str]):
|
||||
def dummy_channel_factory(
|
||||
id_: str | None = None,
|
||||
name: str | None = None,
|
||||
is_member: bool = True,
|
||||
is_private: bool = False,
|
||||
is_archived: bool = False,
|
||||
is_channel: bool = False,
|
||||
is_im: bool = False,
|
||||
is_mpim: bool = False,
|
||||
num_members: int | None = None,
|
||||
user: str | None = None,
|
||||
is_user_deleted: bool = False,
|
||||
):
|
||||
channel = {
|
||||
"id": id_ or f"channel_id_{random_str_factory()}",
|
||||
"is_member": is_member,
|
||||
"is_private": is_private,
|
||||
"is_archived": is_archived,
|
||||
}
|
||||
|
||||
if name or is_channel or is_mpim:
|
||||
channel["name"] = name or f"channel_name_{random_str_factory()}"
|
||||
|
||||
if is_channel:
|
||||
channel["is_channel"] = True
|
||||
if is_im:
|
||||
channel["is_im"] = True
|
||||
if is_mpim:
|
||||
channel["is_group"] = True
|
||||
if num_members:
|
||||
channel["num_members"] = num_members
|
||||
if user or is_im:
|
||||
channel["user"] = user or f"user_id_{random_str_factory()}"
|
||||
if is_user_deleted:
|
||||
channel["is_user_deleted"] = is_user_deleted
|
||||
|
||||
return channel
|
||||
|
||||
return dummy_channel_factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_user_factory(random_str_factory: Callable[[int], str]):
|
||||
def dummy_user_factory(
|
||||
id_: str | None = None,
|
||||
name: str | None = None,
|
||||
email: str | None = None,
|
||||
is_bot: bool = False,
|
||||
):
|
||||
return {
|
||||
"id": id_ or random_str_factory(),
|
||||
"name": name or random_str_factory(),
|
||||
"profile": {
|
||||
"email": email or f"{random_str_factory()}@{random_str_factory()}.com",
|
||||
},
|
||||
"is_bot": is_bot,
|
||||
}
|
||||
|
||||
return dummy_user_factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_reaction_factory(random_str_factory):
|
||||
def reaction_factory(
|
||||
name: str | None = None,
|
||||
user_ids: list[str] | None = None,
|
||||
count: int | None = None,
|
||||
):
|
||||
count = count or random.randint(1, 10) # noqa: S311
|
||||
if user_ids:
|
||||
count = len(user_ids)
|
||||
return {
|
||||
"count": count,
|
||||
"name": name or random_str_factory(),
|
||||
"users": user_ids or [random_str_factory() for _ in range(count)],
|
||||
}
|
||||
|
||||
return reaction_factory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dummy_message_factory(random_str_factory, random_ts_factory):
|
||||
def message_factory(
|
||||
user_id: str | None = None,
|
||||
text: str | None = None,
|
||||
reactions: list[dict] | None = None,
|
||||
type_: str = "message",
|
||||
ts: float | None = None,
|
||||
):
|
||||
message = {
|
||||
"user": user_id or random_str_factory(),
|
||||
"text": text or random_str_factory(),
|
||||
"type": type_,
|
||||
"ts": ts or random_ts_factory(),
|
||||
}
|
||||
|
||||
if reactions:
|
||||
message["reactions"] = reactions
|
||||
return message
|
||||
|
||||
return message_factory
|
||||
|
|
|
|||
206
toolkits/slack/evals/chat/eval_get_metadata.py
Normal file
206
toolkits/slack/evals/chat/eval_get_metadata.py
Normal file
|
|
@ -0,0 +1,206 @@
|
|||
from arcade_evals import (
|
||||
BinaryCritic,
|
||||
EvalRubric,
|
||||
EvalSuite,
|
||||
ExpectedToolCall,
|
||||
tool_eval,
|
||||
)
|
||||
from arcade_tdk import ToolCatalog
|
||||
|
||||
import arcade_slack
|
||||
from arcade_slack.tools.chat import get_conversation_metadata
|
||||
|
||||
# Evaluation rubric
|
||||
rubric = EvalRubric(
|
||||
fail_threshold=0.8,
|
||||
warn_threshold=0.9,
|
||||
)
|
||||
|
||||
|
||||
catalog = ToolCatalog()
|
||||
# Register the Slack tools
|
||||
catalog.add_module(arcade_slack)
|
||||
|
||||
|
||||
@tool_eval()
|
||||
def get_conversations_metadata_eval_suite() -> EvalSuite:
|
||||
"""Create an evaluation suite for tools getting conversations metadata."""
|
||||
suite = EvalSuite(
|
||||
name="Slack Tools Evaluation",
|
||||
system_message="You are an AI assistant that can interact with Slack to get information from conversations, users, etc.",
|
||||
catalog=catalog,
|
||||
rubric=rubric,
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Get channel metadata by name",
|
||||
user_message="Get the metadata of the #general channel",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_conversation_metadata,
|
||||
args={
|
||||
"conversation_id": None,
|
||||
"channel_name": "general",
|
||||
"user_ids": None,
|
||||
"usernames": None,
|
||||
"emails": None,
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_id", weight=1 / 5),
|
||||
BinaryCritic(critic_field="channel_name", weight=1 / 5),
|
||||
BinaryCritic(critic_field="user_ids", weight=1 / 5),
|
||||
BinaryCritic(critic_field="usernames", weight=1 / 5),
|
||||
BinaryCritic(critic_field="emails", weight=1 / 5),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Get conversation metadata by id",
|
||||
user_message="Get the metadata of the conversation with id '1234567890'",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_conversation_metadata,
|
||||
args={
|
||||
"conversation_id": "1234567890",
|
||||
"channel_name": None,
|
||||
"user_ids": None,
|
||||
"usernames": None,
|
||||
"emails": None,
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_id", weight=1 / 5),
|
||||
BinaryCritic(critic_field="channel_name", weight=1 / 5),
|
||||
BinaryCritic(critic_field="user_ids", weight=1 / 5),
|
||||
BinaryCritic(critic_field="usernames", weight=1 / 5),
|
||||
BinaryCritic(critic_field="emails", weight=1 / 5),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Get conversation metadata by username mentioning DM",
|
||||
user_message="get the metadata of the DM with janedoe",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_conversation_metadata,
|
||||
args={
|
||||
"conversation_id": None,
|
||||
"channel_name": None,
|
||||
"user_ids": None,
|
||||
"usernames": ["janedoe"],
|
||||
"emails": None,
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_id", weight=1 / 5),
|
||||
BinaryCritic(critic_field="channel_name", weight=1 / 5),
|
||||
BinaryCritic(critic_field="user_ids", weight=1 / 5),
|
||||
BinaryCritic(critic_field="usernames", weight=1 / 5),
|
||||
BinaryCritic(critic_field="emails", weight=1 / 5),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Get conversation metadata by username mentioning IM",
|
||||
user_message="get metadata about my IM with janedoe",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_conversation_metadata,
|
||||
args={
|
||||
"conversation_id": None,
|
||||
"channel_name": None,
|
||||
"user_ids": None,
|
||||
"usernames": ["janedoe"],
|
||||
"emails": None,
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_id", weight=1 / 5),
|
||||
BinaryCritic(critic_field="channel_name", weight=1 / 5),
|
||||
BinaryCritic(critic_field="user_ids", weight=1 / 5),
|
||||
BinaryCritic(critic_field="usernames", weight=1 / 5),
|
||||
BinaryCritic(critic_field="emails", weight=1 / 5),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Get conversation metadata by email mentioning DM",
|
||||
user_message="get the metadata of the DM with jane.doe@acme.com",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_conversation_metadata,
|
||||
args={
|
||||
"conversation_id": None,
|
||||
"channel_name": None,
|
||||
"user_ids": None,
|
||||
"usernames": None,
|
||||
"emails": ["jane.doe@acme.com"],
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_id", weight=1 / 5),
|
||||
BinaryCritic(critic_field="channel_name", weight=1 / 5),
|
||||
BinaryCritic(critic_field="user_ids", weight=1 / 5),
|
||||
BinaryCritic(critic_field="usernames", weight=1 / 5),
|
||||
BinaryCritic(critic_field="emails", weight=1 / 5),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Get conversation metadata by email mentioning IM",
|
||||
user_message="get the metadata of the IM with jane.doe@acme.com",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_conversation_metadata,
|
||||
args={
|
||||
"conversation_id": None,
|
||||
"channel_name": None,
|
||||
"user_ids": None,
|
||||
"usernames": None,
|
||||
"emails": ["jane.doe@acme.com"],
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_id", weight=1 / 5),
|
||||
BinaryCritic(critic_field="channel_name", weight=1 / 5),
|
||||
BinaryCritic(critic_field="user_ids", weight=1 / 5),
|
||||
BinaryCritic(critic_field="usernames", weight=1 / 5),
|
||||
BinaryCritic(critic_field="emails", weight=1 / 5),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Get conversation metadata by mixed user ID, email, and username",
|
||||
user_message=(
|
||||
"get the metadata of the multi-person conversation I have with these users together: "
|
||||
"janedoe, john@acme.com, and U0123456789"
|
||||
),
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_conversation_metadata,
|
||||
args={
|
||||
"conversation_id": None,
|
||||
"channel_name": None,
|
||||
"user_ids": ["U0123456789"],
|
||||
"usernames": ["janedoe"],
|
||||
"emails": ["john@acme.com"],
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_id", weight=1 / 5),
|
||||
BinaryCritic(critic_field="channel_name", weight=1 / 5),
|
||||
BinaryCritic(critic_field="user_ids", weight=1 / 5),
|
||||
BinaryCritic(critic_field="usernames", weight=1 / 5),
|
||||
BinaryCritic(critic_field="emails", weight=1 / 5),
|
||||
],
|
||||
)
|
||||
|
||||
return suite
|
||||
81
toolkits/slack/evals/chat/eval_get_users_in_conversation.py
Normal file
81
toolkits/slack/evals/chat/eval_get_users_in_conversation.py
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
from arcade_evals import (
|
||||
BinaryCritic,
|
||||
EvalRubric,
|
||||
EvalSuite,
|
||||
ExpectedToolCall,
|
||||
tool_eval,
|
||||
)
|
||||
from arcade_tdk import ToolCatalog
|
||||
|
||||
import arcade_slack
|
||||
from arcade_slack.tools.chat import get_users_in_conversation
|
||||
|
||||
# Evaluation rubric
|
||||
rubric = EvalRubric(
|
||||
fail_threshold=0.8,
|
||||
warn_threshold=0.9,
|
||||
)
|
||||
|
||||
|
||||
catalog = ToolCatalog()
|
||||
# Register the Slack tools
|
||||
catalog.add_module(arcade_slack)
|
||||
|
||||
|
||||
@tool_eval()
|
||||
def get_users_in_conversation_eval_suite() -> EvalSuite:
|
||||
"""Create an evaluation suite for tools getting conversations members."""
|
||||
suite = EvalSuite(
|
||||
name="Slack Tools Evaluation",
|
||||
system_message="You are an AI assistant that can interact with Slack to send messages and get information from conversations, users, etc.",
|
||||
catalog=catalog,
|
||||
rubric=rubric,
|
||||
)
|
||||
|
||||
user_messages = [
|
||||
"Get the members of the #general channel",
|
||||
"Get the users in the #general channel",
|
||||
"Get a list of people in the #general channel",
|
||||
"Get a list of people in the general channel",
|
||||
"Show me who's in the #general channel",
|
||||
"Who is in the general channel?",
|
||||
]
|
||||
|
||||
for user_message in user_messages:
|
||||
suite.add_case(
|
||||
name=f"Get users in channel by channel name: {user_message}",
|
||||
user_message=user_message,
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_users_in_conversation,
|
||||
args={
|
||||
"conversation_id": None,
|
||||
"channel_name": "general",
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_id", weight=0.4),
|
||||
BinaryCritic(critic_field="channel_name", weight=0.6),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Get users in conversation by conversation id",
|
||||
user_message="Get the users in the conversation with id '1234567890'",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_users_in_conversation,
|
||||
args={
|
||||
"conversation_id": "1234567890",
|
||||
"channel_name": None,
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_id", weight=0.6),
|
||||
BinaryCritic(critic_field="channel_name", weight=0.4),
|
||||
],
|
||||
)
|
||||
|
||||
return suite
|
||||
175
toolkits/slack/evals/chat/eval_list_conversations.py
Normal file
175
toolkits/slack/evals/chat/eval_list_conversations.py
Normal file
|
|
@ -0,0 +1,175 @@
|
|||
from arcade_evals import (
|
||||
BinaryCritic,
|
||||
EvalRubric,
|
||||
EvalSuite,
|
||||
ExpectedToolCall,
|
||||
tool_eval,
|
||||
)
|
||||
from arcade_tdk import ToolCatalog
|
||||
|
||||
import arcade_slack
|
||||
from arcade_slack.models import ConversationType
|
||||
from arcade_slack.tools.chat import list_conversations
|
||||
|
||||
# Evaluation rubric
|
||||
rubric = EvalRubric(
|
||||
fail_threshold=0.8,
|
||||
warn_threshold=0.9,
|
||||
)
|
||||
|
||||
|
||||
catalog = ToolCatalog()
|
||||
# Register the Slack tools
|
||||
catalog.add_module(arcade_slack)
|
||||
|
||||
|
||||
@tool_eval()
|
||||
def list_conversations_eval_suite() -> EvalSuite:
|
||||
"""Create an evaluation suite for tools listing conversations."""
|
||||
suite = EvalSuite(
|
||||
name="Slack Messaging Tools Evaluation",
|
||||
system_message="You are an AI assistant that can interact with Slack to send messages and get information from conversations, users, etc.",
|
||||
catalog=catalog,
|
||||
rubric=rubric,
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="List all conversations I am a member of",
|
||||
user_message="List all conversations I am a member of",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=list_conversations,
|
||||
args={
|
||||
"conversation_types": None,
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_types", weight=1.0),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="List 10 conversations I am a member of",
|
||||
user_message="List 10 conversations I am a member of",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=list_conversations,
|
||||
args={
|
||||
"conversation_types": None,
|
||||
"limit": 10,
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_types", weight=0.5),
|
||||
BinaryCritic(critic_field="limit", weight=0.5),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="List all public channels",
|
||||
user_message="List all public channels",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=list_conversations,
|
||||
args={
|
||||
"conversation_types": [ConversationType.PUBLIC_CHANNEL.value],
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_types", weight=1.0),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="List all private channels",
|
||||
user_message="List all private channels",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=list_conversations,
|
||||
args={
|
||||
"conversation_types": [ConversationType.PRIVATE_CHANNEL.value],
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_types", weight=1.0),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="List all public and private channels",
|
||||
user_message="List all public and private channels",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=list_conversations,
|
||||
args={
|
||||
"conversation_types": [
|
||||
ConversationType.PUBLIC_CHANNEL.value,
|
||||
ConversationType.PRIVATE_CHANNEL.value,
|
||||
],
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_types", weight=1.0),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="List direct message channels",
|
||||
user_message="List direct message channels",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=list_conversations,
|
||||
args={
|
||||
"conversation_types": [
|
||||
ConversationType.DIRECT_MESSAGE.value,
|
||||
],
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_types", weight=1.0),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="List group direct message channels",
|
||||
user_message="List group direct message channels",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=list_conversations,
|
||||
args={
|
||||
"conversation_types": [
|
||||
ConversationType.MULTI_PERSON_DIRECT_MESSAGE.value,
|
||||
],
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_types", weight=1.0),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="List my multi-person conversations",
|
||||
user_message="List my multi-person conversations",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=list_conversations,
|
||||
args={
|
||||
"conversation_types": [
|
||||
ConversationType.MULTI_PERSON_DIRECT_MESSAGE.value,
|
||||
],
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_types", weight=1.0),
|
||||
],
|
||||
)
|
||||
|
||||
return suite
|
||||
622
toolkits/slack/evals/chat/messages/eval_get_channel_messages.py
Normal file
622
toolkits/slack/evals/chat/messages/eval_get_channel_messages.py
Normal file
|
|
@ -0,0 +1,622 @@
|
|||
import json
|
||||
from datetime import timedelta
|
||||
|
||||
from arcade_evals import (
|
||||
BinaryCritic,
|
||||
DatetimeCritic,
|
||||
EvalRubric,
|
||||
EvalSuite,
|
||||
ExpectedToolCall,
|
||||
tool_eval,
|
||||
)
|
||||
from arcade_tdk import ToolCatalog
|
||||
|
||||
import arcade_slack
|
||||
from arcade_slack.critics import RelativeTimeBinaryCritic
|
||||
from arcade_slack.tools.chat import get_messages
|
||||
|
||||
# Evaluation rubric
|
||||
rubric = EvalRubric(
|
||||
fail_threshold=0.8,
|
||||
warn_threshold=0.9,
|
||||
)
|
||||
|
||||
|
||||
catalog = ToolCatalog()
|
||||
# Register the Slack tools
|
||||
catalog.add_module(arcade_slack)
|
||||
|
||||
|
||||
@tool_eval()
|
||||
def get_messages_in_channel_eval_suite() -> EvalSuite:
|
||||
"""Create an evaluation suite for tools getting messages in channels."""
|
||||
suite = EvalSuite(
|
||||
name="Slack Chat Tools Evaluation",
|
||||
system_message="You are an AI assistant that can interact with Slack to send messages and get information from conversations, users, etc.",
|
||||
catalog=catalog,
|
||||
rubric=rubric,
|
||||
)
|
||||
|
||||
no_arguments_user_messages_by_channel_name = [
|
||||
"what are the latest messages in the #general channel",
|
||||
"show me the messages in the general channel",
|
||||
"list the messages in the #general channel",
|
||||
"list the messages in the general channel",
|
||||
]
|
||||
|
||||
for i, user_message in enumerate(no_arguments_user_messages_by_channel_name):
|
||||
suite.add_case(
|
||||
name=f"Get messages in conversation by name {i}: '{user_message}'",
|
||||
user_message=user_message,
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_messages,
|
||||
args={
|
||||
"conversation_id": None,
|
||||
"channel_name": "general",
|
||||
"user_ids": None,
|
||||
"usernames": None,
|
||||
"emails": None,
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_id", weight=0.1),
|
||||
BinaryCritic(critic_field="channel_name", weight=0.6),
|
||||
BinaryCritic(critic_field="user_ids", weight=0.1),
|
||||
BinaryCritic(critic_field="usernames", weight=0.1),
|
||||
BinaryCritic(critic_field="emails", weight=0.1),
|
||||
],
|
||||
)
|
||||
|
||||
no_arguments_user_messages_by_conversation_id = [
|
||||
"Get the history of the conversation with id '1234567890'",
|
||||
"Get the history of the conversation with id '1234567890'",
|
||||
"list the messages in the conversation with id '1234567890'",
|
||||
"list the messages in the conversation with id '1234567890'",
|
||||
]
|
||||
|
||||
for user_message in no_arguments_user_messages_by_conversation_id:
|
||||
suite.add_case(
|
||||
name=f"Get conversation history by id: '{user_message}'",
|
||||
user_message=user_message,
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_messages,
|
||||
args={
|
||||
"conversation_id": "1234567890",
|
||||
"channel_name": None,
|
||||
"user_ids": None,
|
||||
"usernames": None,
|
||||
"emails": None,
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_id", weight=0.6),
|
||||
BinaryCritic(critic_field="channel_name", weight=0.1),
|
||||
BinaryCritic(critic_field="user_ids", weight=0.1),
|
||||
BinaryCritic(critic_field="usernames", weight=0.1),
|
||||
BinaryCritic(critic_field="emails", weight=0.1),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Get conversation history with limit by name",
|
||||
user_message="Get the last 10 messages in the #general channel",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_messages,
|
||||
args={
|
||||
"conversation_id": None,
|
||||
"channel_name": "general",
|
||||
"user_ids": None,
|
||||
"usernames": None,
|
||||
"emails": None,
|
||||
"limit": 10,
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_id", weight=0.1),
|
||||
BinaryCritic(critic_field="channel_name", weight=0.3),
|
||||
BinaryCritic(critic_field="user_ids", weight=0.1),
|
||||
BinaryCritic(critic_field="usernames", weight=0.1),
|
||||
BinaryCritic(critic_field="emails", weight=0.1),
|
||||
BinaryCritic(critic_field="limit", weight=0.3),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Get conversation history with limit by id",
|
||||
user_message="Get the last 25 messages in the conversation with id '1234567890'",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_messages,
|
||||
args={
|
||||
"conversation_id": "1234567890",
|
||||
"channel_name": None,
|
||||
"user_ids": None,
|
||||
"usernames": None,
|
||||
"emails": None,
|
||||
"limit": 25,
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_id", weight=0.3),
|
||||
BinaryCritic(critic_field="channel_name", weight=0.1),
|
||||
BinaryCritic(critic_field="user_ids", weight=0.1),
|
||||
BinaryCritic(critic_field="usernames", weight=0.1),
|
||||
BinaryCritic(critic_field="emails", weight=0.1),
|
||||
BinaryCritic(critic_field="limit", weight=0.3),
|
||||
],
|
||||
)
|
||||
|
||||
# Relative time eval cases by id
|
||||
|
||||
suite.add_case(
|
||||
name="Get conversation history oldest relative by id (2 days ago)",
|
||||
user_message="Get the messages in the conversation with id '1234567890' starting 2 days ago",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_messages,
|
||||
args={
|
||||
"conversation_id": "1234567890",
|
||||
"channel_name": None,
|
||||
"user_ids": None,
|
||||
"usernames": None,
|
||||
"emails": None,
|
||||
"oldest_relative": "02:00:00",
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_id", weight=0.3),
|
||||
BinaryCritic(critic_field="channel_name", weight=0.1),
|
||||
BinaryCritic(critic_field="user_ids", weight=0.1),
|
||||
BinaryCritic(critic_field="usernames", weight=0.1),
|
||||
BinaryCritic(critic_field="emails", weight=0.1),
|
||||
RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.3),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Get conversation history oldest and latest relative by id",
|
||||
user_message="Get the messages in the conversation with id '1234567890' from 2 days ago to 3 hours ago",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_messages,
|
||||
args={
|
||||
"conversation_id": "1234567890",
|
||||
"channel_name": None,
|
||||
"user_ids": None,
|
||||
"usernames": None,
|
||||
"emails": None,
|
||||
"oldest_relative": "02:00:00",
|
||||
"latest_relative": "00:03:00",
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_id", weight=0.2),
|
||||
BinaryCritic(critic_field="channel_name", weight=0.1),
|
||||
BinaryCritic(critic_field="user_ids", weight=0.1),
|
||||
BinaryCritic(critic_field="usernames", weight=0.1),
|
||||
BinaryCritic(critic_field="emails", weight=0.1),
|
||||
RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.2),
|
||||
RelativeTimeBinaryCritic(critic_field="latest_relative", weight=0.2),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Get conversation history oldest relative by id (1 week ago)",
|
||||
user_message="Get the messages in the conversation with id '1234567890' starting 1 week ago",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_messages,
|
||||
args={
|
||||
"conversation_id": "1234567890",
|
||||
"channel_name": None,
|
||||
"user_ids": None,
|
||||
"usernames": None,
|
||||
"emails": None,
|
||||
"oldest_relative": "07:00:00",
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_id", weight=0.3),
|
||||
BinaryCritic(critic_field="channel_name", weight=0.1),
|
||||
BinaryCritic(critic_field="user_ids", weight=0.1),
|
||||
BinaryCritic(critic_field="usernames", weight=0.1),
|
||||
BinaryCritic(critic_field="emails", weight=0.1),
|
||||
RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.3),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Get conversation history oldest relative by id (yesterday)",
|
||||
user_message="Get the messages in the conversation with id '1234567890' from yesterday",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_messages,
|
||||
args={
|
||||
"conversation_id": "1234567890",
|
||||
"channel_name": None,
|
||||
"user_ids": None,
|
||||
"usernames": None,
|
||||
"emails": None,
|
||||
"oldest_relative": "01:00:00",
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_id", weight=0.3),
|
||||
BinaryCritic(critic_field="channel_name", weight=0.1),
|
||||
BinaryCritic(critic_field="user_ids", weight=0.1),
|
||||
BinaryCritic(critic_field="usernames", weight=0.1),
|
||||
BinaryCritic(critic_field="emails", weight=0.1),
|
||||
RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.3),
|
||||
],
|
||||
)
|
||||
|
||||
# Relative time eval cases by name
|
||||
|
||||
suite.add_case(
|
||||
name="Get conversation history oldest relative by name (2 days ago)",
|
||||
user_message="Get the messages in the #general channel starting 2 days ago",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_messages,
|
||||
args={
|
||||
"conversation_id": None,
|
||||
"channel_name": "general",
|
||||
"user_ids": None,
|
||||
"usernames": None,
|
||||
"emails": None,
|
||||
"oldest_relative": "02:00:00",
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_id", weight=0.1),
|
||||
BinaryCritic(critic_field="channel_name", weight=0.3),
|
||||
BinaryCritic(critic_field="user_ids", weight=0.1),
|
||||
BinaryCritic(critic_field="usernames", weight=0.1),
|
||||
BinaryCritic(critic_field="emails", weight=0.1),
|
||||
RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.3),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Get conversation history oldest and latest relative by name",
|
||||
user_message="Get the messages in the #general channel from 2 days ago to 3 hours ago",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_messages,
|
||||
args={
|
||||
"conversation_id": None,
|
||||
"channel_name": "general",
|
||||
"user_ids": None,
|
||||
"usernames": None,
|
||||
"emails": None,
|
||||
"oldest_relative": "02:00:00",
|
||||
"latest_relative": "00:03:00",
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_id", weight=0.1),
|
||||
BinaryCritic(critic_field="channel_name", weight=0.2),
|
||||
BinaryCritic(critic_field="user_ids", weight=0.1),
|
||||
BinaryCritic(critic_field="usernames", weight=0.1),
|
||||
BinaryCritic(critic_field="emails", weight=0.1),
|
||||
RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.2),
|
||||
RelativeTimeBinaryCritic(critic_field="latest_relative", weight=0.2),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Get conversation history oldest relative by name (yesterday)",
|
||||
user_message="Get the messages in the #general channel from yesterday",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_messages,
|
||||
args={
|
||||
"conversation_id": None,
|
||||
"channel_name": "general",
|
||||
"user_ids": None,
|
||||
"usernames": None,
|
||||
"emails": None,
|
||||
"oldest_relative": "01:00:00",
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_id", weight=0.1),
|
||||
BinaryCritic(critic_field="channel_name", weight=0.3),
|
||||
BinaryCritic(critic_field="user_ids", weight=0.1),
|
||||
BinaryCritic(critic_field="usernames", weight=0.1),
|
||||
BinaryCritic(critic_field="emails", weight=0.1),
|
||||
RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.3),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Get conversation history oldest relative by name (last week)",
|
||||
user_message="Get the messages in the #general channel from last week",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_messages,
|
||||
args={
|
||||
"conversation_id": None,
|
||||
"channel_name": "general",
|
||||
"user_ids": None,
|
||||
"usernames": None,
|
||||
"emails": None,
|
||||
"oldest_relative": "07:00:00",
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_id", weight=0.1),
|
||||
BinaryCritic(critic_field="channel_name", weight=0.3),
|
||||
BinaryCritic(critic_field="user_ids", weight=0.1),
|
||||
BinaryCritic(critic_field="usernames", weight=0.1),
|
||||
BinaryCritic(critic_field="emails", weight=0.1),
|
||||
RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.3),
|
||||
],
|
||||
)
|
||||
|
||||
# Absolute time eval cases by id
|
||||
|
||||
suite.add_case(
|
||||
name="Get conversation history oldest absolute by id (on a specific date)",
|
||||
user_message="Get the messages in the conversation with id '1234567890' from 2025-01-20",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_messages,
|
||||
args={
|
||||
"conversation_id": "1234567890",
|
||||
"channel_name": None,
|
||||
"user_ids": None,
|
||||
"usernames": None,
|
||||
"emails": None,
|
||||
"oldest_datetime": "2025-01-20 00:00:00",
|
||||
"latest_datetime": "2025-01-20 23:59:59",
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_id", weight=0.2),
|
||||
BinaryCritic(critic_field="channel_name", weight=0.1),
|
||||
BinaryCritic(critic_field="user_ids", weight=0.1),
|
||||
BinaryCritic(critic_field="usernames", weight=0.1),
|
||||
BinaryCritic(critic_field="emails", weight=0.1),
|
||||
DatetimeCritic(
|
||||
critic_field="oldest_datetime", weight=0.2, max_difference=timedelta(minutes=2)
|
||||
),
|
||||
DatetimeCritic(
|
||||
critic_field="latest_datetime", weight=0.2, max_difference=timedelta(minutes=2)
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Get conversation history oldest absolute by id (between a date range)",
|
||||
user_message="Get the messages in the conversation with id '1234567890' from 2025-01-20 to 2025-01-25",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_messages,
|
||||
args={
|
||||
"conversation_id": "1234567890",
|
||||
"channel_name": None,
|
||||
"user_ids": None,
|
||||
"usernames": None,
|
||||
"emails": None,
|
||||
"oldest_datetime": "2025-01-20 00:00:00",
|
||||
"latest_datetime": "2025-01-25 23:59:59",
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_id", weight=0.2),
|
||||
BinaryCritic(critic_field="channel_name", weight=0.1),
|
||||
BinaryCritic(critic_field="user_ids", weight=0.1),
|
||||
BinaryCritic(critic_field="usernames", weight=0.1),
|
||||
BinaryCritic(critic_field="emails", weight=0.1),
|
||||
DatetimeCritic(
|
||||
critic_field="oldest_datetime", weight=0.2, max_difference=timedelta(minutes=2)
|
||||
),
|
||||
DatetimeCritic(
|
||||
critic_field="latest_datetime", weight=0.2, max_difference=timedelta(minutes=2)
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Get conversation history oldest absolute by name (on a specific date)",
|
||||
user_message="Get the messages in the #general channel from 2025-01-20",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_messages,
|
||||
args={
|
||||
"conversation_id": None,
|
||||
"channel_name": "general",
|
||||
"user_ids": None,
|
||||
"usernames": None,
|
||||
"emails": None,
|
||||
"oldest_datetime": "2025-01-20 00:00:00",
|
||||
"latest_datetime": "2025-01-20 23:59:59",
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_id", weight=0.1),
|
||||
BinaryCritic(critic_field="channel_name", weight=0.2),
|
||||
BinaryCritic(critic_field="user_ids", weight=0.1),
|
||||
BinaryCritic(critic_field="usernames", weight=0.1),
|
||||
BinaryCritic(critic_field="emails", weight=0.1),
|
||||
# We use a timedelta of 10 seconds because sometimes the LLM will select the limit
|
||||
# date at 23:59:59, other times it'll select the next day at 00:00:00.
|
||||
DatetimeCritic(
|
||||
critic_field="oldest_datetime", weight=0.2, max_difference=timedelta(seconds=10)
|
||||
),
|
||||
DatetimeCritic(
|
||||
critic_field="latest_datetime", weight=0.2, max_difference=timedelta(seconds=10)
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Get conversation history oldest absolute by name (between a date range)",
|
||||
user_message="Get the messages in the #general channel from 2025-01-20 to 2025-01-25",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_messages,
|
||||
args={
|
||||
"conversation_id": None,
|
||||
"channel_name": "general",
|
||||
"user_ids": None,
|
||||
"usernames": None,
|
||||
"emails": None,
|
||||
"oldest_datetime": "2025-01-20 00:00:00",
|
||||
"latest_datetime": "2025-01-25 23:59:59",
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_id", weight=0.1),
|
||||
BinaryCritic(critic_field="channel_name", weight=0.2),
|
||||
BinaryCritic(critic_field="user_ids", weight=0.1),
|
||||
BinaryCritic(critic_field="usernames", weight=0.1),
|
||||
BinaryCritic(critic_field="emails", weight=0.1),
|
||||
# We use a timedelta of 10 seconds because sometimes the LLM will select the limit
|
||||
# date at 23:59:59, other times it'll select the next day at 00:00:00.
|
||||
DatetimeCritic(
|
||||
critic_field="oldest_datetime", weight=0.2, max_difference=timedelta(seconds=10)
|
||||
),
|
||||
DatetimeCritic(
|
||||
critic_field="latest_datetime", weight=0.2, max_difference=timedelta(seconds=10)
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# Eval case for pagination
|
||||
|
||||
suite.add_case(
|
||||
name="Get conversation history with pagination",
|
||||
user_message="get the next 5 messages",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_messages,
|
||||
args={
|
||||
"conversation_id": None,
|
||||
"channel_name": "general",
|
||||
"user_ids": None,
|
||||
"usernames": None,
|
||||
"emails": None,
|
||||
"limit": 5,
|
||||
"next_cursor": "dXNlcjpVsDjzOTZGVDlQRA==",
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_id", weight=0.1),
|
||||
BinaryCritic(critic_field="channel_name", weight=0.2),
|
||||
BinaryCritic(critic_field="user_ids", weight=0.1),
|
||||
BinaryCritic(critic_field="usernames", weight=0.1),
|
||||
BinaryCritic(critic_field="emails", weight=0.1),
|
||||
BinaryCritic(critic_field="next_cursor", weight=0.2),
|
||||
BinaryCritic(critic_field="limit", weight=0.2),
|
||||
],
|
||||
additional_messages=[
|
||||
{"role": "user", "content": "Get the last 2 messages on the general channel"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "Slack_GetConversationHistoryByName",
|
||||
"arguments": json.dumps({
|
||||
"conversation_name": "general",
|
||||
"limit": 2,
|
||||
}),
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"content": json.dumps({
|
||||
"messages": [
|
||||
{
|
||||
"blocks": [
|
||||
{
|
||||
"block_id": "abc123",
|
||||
"elements": [
|
||||
{
|
||||
"elements": [
|
||||
{
|
||||
"text": "Almost there, Boss, need to get some evals in!",
|
||||
"type": "text",
|
||||
}
|
||||
],
|
||||
"type": "rich_text_section",
|
||||
}
|
||||
],
|
||||
"type": "rich_text",
|
||||
}
|
||||
],
|
||||
"client_msg_id": "msg_id_0",
|
||||
"datetime_timestamp": "2025-01-21 16:59:55",
|
||||
"team": "617263616465207465616D20697320617420626F7373206C6576656C",
|
||||
"text": "Almost there, Boss, need to get some evals in!",
|
||||
"ts": "1737507595.598529",
|
||||
"type": "message",
|
||||
"user": "77686F2069732074686520626F73733F",
|
||||
},
|
||||
{
|
||||
"blocks": [
|
||||
{
|
||||
"block_id": "xyz456",
|
||||
"elements": [
|
||||
{
|
||||
"elements": [
|
||||
{
|
||||
"text": "hey, are the Slack Tools ready yet?",
|
||||
"type": "text",
|
||||
}
|
||||
],
|
||||
"type": "rich_text_section",
|
||||
}
|
||||
],
|
||||
"type": "rich_text",
|
||||
}
|
||||
],
|
||||
"client_msg_id": "msg_id_1",
|
||||
"datetime_timestamp": "2025-01-21 16:57:35",
|
||||
"team": "617263616465207465616D20697320617420626F7373206C6576656C",
|
||||
"text": "hey, are the Slack Tools ready yet?",
|
||||
"ts": "1737507595.598529",
|
||||
"type": "message",
|
||||
"user": "73616D2069732074686520626F7373",
|
||||
},
|
||||
],
|
||||
"next_cursor": "dXNlcjpVsDjzOTZGVDlQRA==",
|
||||
}),
|
||||
"tool_call_id": "call_1",
|
||||
"name": "Slack_GetConversationHistoryByName",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": 'Here are the last 2 messages from the general channel:\n\n1. **User:** 77686F2069732074686520626F73733F \n **Message:** "Almost there, Boss, need to get some evals in!" \n **Timestamp:** 2025-01-21 16:59:55\n\n2. **User:** 73616D2069732074686520626F7373 \n **Message:** "hey, are the Slack Tools ready yet?" \n **Timestamp:** 2025-01-21 16:57:35',
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
return suite
|
||||
191
toolkits/slack/evals/chat/messages/eval_get_dm_messages.py
Normal file
191
toolkits/slack/evals/chat/messages/eval_get_dm_messages.py
Normal file
|
|
@ -0,0 +1,191 @@
|
|||
from datetime import timedelta
|
||||
|
||||
from arcade_evals import (
|
||||
BinaryCritic,
|
||||
DatetimeCritic,
|
||||
EvalRubric,
|
||||
EvalSuite,
|
||||
ExpectedToolCall,
|
||||
tool_eval,
|
||||
)
|
||||
from arcade_tdk import ToolCatalog
|
||||
|
||||
import arcade_slack
|
||||
from arcade_slack.critics import RelativeTimeBinaryCritic
|
||||
from arcade_slack.tools.chat import get_messages
|
||||
|
||||
# Evaluation rubric
|
||||
rubric = EvalRubric(
|
||||
fail_threshold=0.8,
|
||||
warn_threshold=0.9,
|
||||
)
|
||||
|
||||
|
||||
catalog = ToolCatalog()
|
||||
# Register the Slack tools
|
||||
catalog.add_module(arcade_slack)
|
||||
|
||||
|
||||
@tool_eval()
|
||||
def get_messages_in_direct_message_eval_suite() -> EvalSuite:
|
||||
"""Create an evaluation suite for tools getting messages in direct messages."""
|
||||
suite = EvalSuite(
|
||||
name="Slack Chat Tools Evaluation",
|
||||
system_message="You are an AI assistant that can interact with Slack to send messages and get information from conversations, users, etc.",
|
||||
catalog=catalog,
|
||||
rubric=rubric,
|
||||
)
|
||||
|
||||
no_arguments_user_messages_by_username = [
|
||||
"what are the latest messages I exchanged with janedoe",
|
||||
"show my messages with janedoe on Slack",
|
||||
"get the messages I exchanged with janedoe",
|
||||
"get the message history with janedoe",
|
||||
]
|
||||
|
||||
for i, user_message in enumerate(no_arguments_user_messages_by_username):
|
||||
suite.add_case(
|
||||
name=f"{user_message} [{i}]",
|
||||
user_message=user_message,
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_messages,
|
||||
args={
|
||||
"conversation_id": None,
|
||||
"channel_name": None,
|
||||
"user_ids": None,
|
||||
"usernames": ["janedoe"],
|
||||
"emails": None,
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_id", weight=0.1),
|
||||
BinaryCritic(critic_field="channel_name", weight=0.1),
|
||||
BinaryCritic(critic_field="user_ids", weight=0.1),
|
||||
BinaryCritic(critic_field="usernames", weight=0.6),
|
||||
BinaryCritic(critic_field="emails", weight=0.1),
|
||||
],
|
||||
)
|
||||
|
||||
no_arguments_user_messages_by_email = [
|
||||
"what are the latest messages I exchanged with jane.doe@acme.com",
|
||||
"show my messages with jane.doe@acme.com on Slack",
|
||||
]
|
||||
|
||||
for i, user_message in enumerate(no_arguments_user_messages_by_email):
|
||||
suite.add_case(
|
||||
name=f"{user_message} [{i}]",
|
||||
user_message=user_message,
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_messages,
|
||||
args={
|
||||
"conversation_id": None,
|
||||
"channel_name": None,
|
||||
"user_ids": None,
|
||||
"usernames": None,
|
||||
"emails": ["jane.doe@acme.com"],
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_id", weight=0.1),
|
||||
BinaryCritic(critic_field="channel_name", weight=0.1),
|
||||
BinaryCritic(critic_field="user_ids", weight=0.1),
|
||||
BinaryCritic(critic_field="usernames", weight=0.1),
|
||||
BinaryCritic(critic_field="emails", weight=0.6),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="get messages in direct conversation by username (on a specific date)",
|
||||
user_message="get the messages I exchanged with janedoe on 2025-01-31",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_messages,
|
||||
args={
|
||||
"conversation_id": None,
|
||||
"channel_name": None,
|
||||
"user_ids": None,
|
||||
"usernames": ["janedoe"],
|
||||
"emails": None,
|
||||
"oldest_datetime": "2025-01-31 00:00:00",
|
||||
"latest_datetime": "2025-01-31 23:59:59",
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_id", weight=0.1),
|
||||
BinaryCritic(critic_field="channel_name", weight=0.1),
|
||||
BinaryCritic(critic_field="user_ids", weight=0.1),
|
||||
BinaryCritic(critic_field="usernames", weight=0.2),
|
||||
BinaryCritic(critic_field="emails", weight=0.1),
|
||||
DatetimeCritic(
|
||||
critic_field="oldest_datetime", weight=0.2, max_difference=timedelta(minutes=2)
|
||||
),
|
||||
DatetimeCritic(
|
||||
critic_field="latest_datetime", weight=0.2, max_difference=timedelta(minutes=2)
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="get messages in direct conversation by email (on a specific date)",
|
||||
user_message="get the messages I exchanged with jane.doe@acme.com on 2025-01-31",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_messages,
|
||||
args={
|
||||
"conversation_id": None,
|
||||
"channel_name": None,
|
||||
"user_ids": None,
|
||||
"usernames": None,
|
||||
"emails": ["jane.doe@acme.com"],
|
||||
"oldest_datetime": "2025-01-31 00:00:00",
|
||||
"latest_datetime": "2025-01-31 23:59:59",
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_id", weight=0.1),
|
||||
BinaryCritic(critic_field="channel_name", weight=0.1),
|
||||
BinaryCritic(critic_field="user_ids", weight=0.1),
|
||||
BinaryCritic(critic_field="usernames", weight=0.1),
|
||||
BinaryCritic(critic_field="emails", weight=0.2),
|
||||
DatetimeCritic(
|
||||
critic_field="oldest_datetime", weight=0.2, max_difference=timedelta(minutes=2)
|
||||
),
|
||||
DatetimeCritic(
|
||||
critic_field="latest_datetime", weight=0.2, max_difference=timedelta(minutes=2)
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Get conversation history oldest relative by username (2 days ago)",
|
||||
user_message="Get the messages I exchanged with janedoe starting 2 days ago",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_messages,
|
||||
args={
|
||||
"conversation_id": None,
|
||||
"channel_name": None,
|
||||
"user_ids": None,
|
||||
"usernames": ["janedoe"],
|
||||
"emails": None,
|
||||
"oldest_relative": "02:00:00",
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_id", weight=0.1),
|
||||
BinaryCritic(critic_field="channel_name", weight=0.1),
|
||||
BinaryCritic(critic_field="user_ids", weight=0.1),
|
||||
BinaryCritic(critic_field="usernames", weight=0.3),
|
||||
BinaryCritic(critic_field="emails", weight=0.1),
|
||||
RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.3),
|
||||
],
|
||||
)
|
||||
|
||||
return suite
|
||||
165
toolkits/slack/evals/chat/messages/eval_get_mpim_messages.py
Normal file
165
toolkits/slack/evals/chat/messages/eval_get_mpim_messages.py
Normal file
|
|
@ -0,0 +1,165 @@
|
|||
from datetime import timedelta
|
||||
|
||||
from arcade_evals import (
|
||||
BinaryCritic,
|
||||
DatetimeCritic,
|
||||
EvalRubric,
|
||||
EvalSuite,
|
||||
ExpectedToolCall,
|
||||
tool_eval,
|
||||
)
|
||||
from arcade_tdk import ToolCatalog
|
||||
|
||||
import arcade_slack
|
||||
from arcade_slack.critics import RelativeTimeBinaryCritic
|
||||
from arcade_slack.tools.chat import get_messages
|
||||
|
||||
# Evaluation rubric
|
||||
rubric = EvalRubric(
|
||||
fail_threshold=0.8,
|
||||
warn_threshold=0.9,
|
||||
)
|
||||
|
||||
|
||||
catalog = ToolCatalog()
|
||||
# Register the Slack tools
|
||||
catalog.add_module(arcade_slack)
|
||||
|
||||
|
||||
@tool_eval()
|
||||
def get_messages_in_multi_person_direct_message_eval_suite() -> EvalSuite:
|
||||
"""Create an evaluation suite for tools getting messages in multi-person direct messages."""
|
||||
suite = EvalSuite(
|
||||
name="Slack Chat Tools Evaluation",
|
||||
system_message="You are an AI assistant that can interact with Slack to send messages and get information from conversations, users, etc.",
|
||||
catalog=catalog,
|
||||
rubric=rubric,
|
||||
)
|
||||
|
||||
no_arguments_user_messages_by_username = [
|
||||
(
|
||||
"what are the latest messages I exchanged in the MPIM "
|
||||
"with the usernames john, ryan, and jennifer"
|
||||
),
|
||||
("show the messages in the MPIM with the usernames john, ryan, and jennifer on Slack"),
|
||||
("list the messages I exchanged in the MPIM with the usernames john, ryan, and jennifer"),
|
||||
("list the message history in the MPIM with the usernames john, ryan, and jennifer"),
|
||||
]
|
||||
|
||||
for i, user_message in enumerate(no_arguments_user_messages_by_username):
|
||||
suite.add_case(
|
||||
name=f"{user_message} [{i}]",
|
||||
user_message=user_message,
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_messages,
|
||||
args={
|
||||
"conversation_id": None,
|
||||
"channel_name": None,
|
||||
"user_ids": None,
|
||||
"usernames": ["john", "ryan", "jennifer"],
|
||||
"emails": None,
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_id", weight=0.1),
|
||||
BinaryCritic(critic_field="channel_name", weight=0.1),
|
||||
BinaryCritic(critic_field="user_ids", weight=0.1),
|
||||
BinaryCritic(critic_field="usernames", weight=0.6),
|
||||
BinaryCritic(critic_field="emails", weight=0.1),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="get messages in multi person direct conversation with mixed usernames and emails",
|
||||
user_message=(
|
||||
"get the messages I exchanged in the mpim with "
|
||||
"the usernames john, ryan, and jennifer@acme.com"
|
||||
),
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_messages,
|
||||
args={
|
||||
"conversation_id": None,
|
||||
"channel_name": None,
|
||||
"user_ids": None,
|
||||
"usernames": ["john", "ryan"],
|
||||
"emails": ["jennifer@acme.com"],
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_id", weight=0.1),
|
||||
BinaryCritic(critic_field="channel_name", weight=0.1),
|
||||
BinaryCritic(critic_field="user_ids", weight=0.1),
|
||||
BinaryCritic(critic_field="usernames", weight=0.35),
|
||||
BinaryCritic(critic_field="emails", weight=0.35),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="get messages in direct conversation by username (on a specific date)",
|
||||
user_message=(
|
||||
"get the messages I exchanged in the mpim with "
|
||||
"the usernames john, ryan, and jennifer on 2025-01-31"
|
||||
),
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_messages,
|
||||
args={
|
||||
"conversation_id": None,
|
||||
"channel_name": None,
|
||||
"user_ids": None,
|
||||
"usernames": ["john", "ryan", "jennifer"],
|
||||
"emails": None,
|
||||
"oldest_datetime": "2025-01-31 00:00:00",
|
||||
"latest_datetime": "2025-01-31 23:59:59",
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_id", weight=0.1),
|
||||
BinaryCritic(critic_field="channel_name", weight=0.1),
|
||||
BinaryCritic(critic_field="user_ids", weight=0.1),
|
||||
BinaryCritic(critic_field="usernames", weight=0.2),
|
||||
BinaryCritic(critic_field="emails", weight=0.1),
|
||||
DatetimeCritic(
|
||||
critic_field="oldest_datetime", weight=0.2, max_difference=timedelta(minutes=2)
|
||||
),
|
||||
DatetimeCritic(
|
||||
critic_field="latest_datetime", weight=0.2, max_difference=timedelta(minutes=2)
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Get conversation history oldest relative by username (2 days ago)",
|
||||
user_message=(
|
||||
"Get the messages I exchanged in the MPIM with "
|
||||
"the usernames john, ryan, and jennifer starting 2 days ago"
|
||||
),
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_messages,
|
||||
args={
|
||||
"conversation_id": None,
|
||||
"channel_name": None,
|
||||
"user_ids": None,
|
||||
"usernames": ["john", "ryan", "jennifer"],
|
||||
"emails": None,
|
||||
"oldest_relative": "02:00:00",
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_id", weight=0.1),
|
||||
BinaryCritic(critic_field="channel_name", weight=0.1),
|
||||
BinaryCritic(critic_field="user_ids", weight=0.1),
|
||||
BinaryCritic(critic_field="usernames", weight=0.3),
|
||||
BinaryCritic(critic_field="emails", weight=0.1),
|
||||
RelativeTimeBinaryCritic(critic_field="oldest_relative", weight=0.3),
|
||||
],
|
||||
)
|
||||
|
||||
return suite
|
||||
279
toolkits/slack/evals/chat/messages/eval_send_messages.py
Normal file
279
toolkits/slack/evals/chat/messages/eval_send_messages.py
Normal file
|
|
@ -0,0 +1,279 @@
|
|||
import json
|
||||
|
||||
from arcade_evals import (
|
||||
BinaryCritic,
|
||||
EvalRubric,
|
||||
EvalSuite,
|
||||
ExpectedToolCall,
|
||||
SimilarityCritic,
|
||||
tool_eval,
|
||||
)
|
||||
from arcade_tdk import ToolCatalog
|
||||
|
||||
import arcade_slack
|
||||
from arcade_slack.tools.chat import send_message
|
||||
|
||||
# Evaluation rubric
|
||||
rubric = EvalRubric(
|
||||
fail_threshold=0.8,
|
||||
warn_threshold=0.9,
|
||||
)
|
||||
|
||||
|
||||
catalog = ToolCatalog()
|
||||
# Register the Slack tools
|
||||
catalog.add_module(arcade_slack)
|
||||
|
||||
|
||||
@tool_eval()
|
||||
def send_message_eval_suite() -> EvalSuite:
|
||||
"""Create an evaluation suite for Slack messaging tools."""
|
||||
suite = EvalSuite(
|
||||
name="Slack Messaging Tools Evaluation",
|
||||
system_message="You are an AI assistant that can send direct messages and post messages to channels in Slack using the provided tools.",
|
||||
catalog=catalog,
|
||||
rubric=rubric,
|
||||
)
|
||||
|
||||
# Send DM to User Scenarios
|
||||
suite.add_case(
|
||||
name="Send DM to user with clear username",
|
||||
user_message="Send a direct message to johndoe saying 'Hello, can we meet at 3 PM?'",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=send_message,
|
||||
args={
|
||||
"message": "Hello, can we meet at 3 PM?",
|
||||
"conversation_id": None,
|
||||
"channel_name": None,
|
||||
"user_ids": None,
|
||||
"emails": None,
|
||||
"usernames": ["johndoe"],
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
SimilarityCritic(critic_field="message", weight=0.3, similarity_threshold=0.9),
|
||||
BinaryCritic(critic_field="conversation_id", weight=0.1),
|
||||
BinaryCritic(critic_field="channel_name", weight=0.1),
|
||||
BinaryCritic(critic_field="user_ids", weight=0.1),
|
||||
BinaryCritic(critic_field="emails", weight=0.1),
|
||||
BinaryCritic(critic_field="usernames", weight=0.3),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Send message to channel with clear name",
|
||||
user_message="Post 'The new feature is now live!' in the #announcements channel",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=send_message,
|
||||
args={
|
||||
"message": "The new feature is now live!",
|
||||
"conversation_id": None,
|
||||
"channel_name": "announcements",
|
||||
"user_ids": None,
|
||||
"emails": None,
|
||||
"usernames": None,
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
SimilarityCritic(critic_field="message", weight=0.3, similarity_threshold=0.9),
|
||||
BinaryCritic(critic_field="conversation_id", weight=0.1),
|
||||
BinaryCritic(critic_field="channel_name", weight=0.3),
|
||||
BinaryCritic(critic_field="user_ids", weight=0.1),
|
||||
BinaryCritic(critic_field="emails", weight=0.1),
|
||||
BinaryCritic(critic_field="usernames", weight=0.1),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Send message to channel with ambiguous name",
|
||||
user_message="Inform the team in the general channel about the upcoming maintenance",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=send_message,
|
||||
args={
|
||||
"message": "Attention team: There will be upcoming maintenance. Please save your work and expect some downtime.",
|
||||
"conversation_id": None,
|
||||
"channel_name": "general",
|
||||
"user_ids": None,
|
||||
"emails": None,
|
||||
"usernames": None,
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
SimilarityCritic(critic_field="message", weight=0.3, similarity_threshold=0.9),
|
||||
BinaryCritic(critic_field="conversation_id", weight=0.1),
|
||||
BinaryCritic(critic_field="channel_name", weight=0.3),
|
||||
BinaryCritic(critic_field="user_ids", weight=0.1),
|
||||
BinaryCritic(critic_field="emails", weight=0.1),
|
||||
BinaryCritic(critic_field="usernames", weight=0.1),
|
||||
],
|
||||
)
|
||||
|
||||
# Adversarial Scenarios
|
||||
suite.add_case(
|
||||
name="Ambiguous between DM and channel message",
|
||||
user_message="general",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=send_message,
|
||||
args={
|
||||
"message": "Great job on the presentation!",
|
||||
"conversation_id": None,
|
||||
"channel_name": "general",
|
||||
"user_ids": None,
|
||||
"emails": None,
|
||||
"usernames": None,
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
SimilarityCritic(critic_field="message", weight=0.3, similarity_threshold=0.9),
|
||||
BinaryCritic(critic_field="conversation_id", weight=0.1),
|
||||
BinaryCritic(critic_field="channel_name", weight=0.3),
|
||||
BinaryCritic(critic_field="user_ids", weight=0.1),
|
||||
BinaryCritic(critic_field="emails", weight=0.1),
|
||||
BinaryCritic(critic_field="usernames", weight=0.1),
|
||||
],
|
||||
additional_messages=[
|
||||
{"role": "user", "content": "Send 'Great job on the presentation!' to the team"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_1",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "Slack_ListConversationsMetadata",
|
||||
"arguments": '{"limit":20, "conversation_types":["public_channel"]}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"content": json.dumps({
|
||||
"conversations": [
|
||||
{
|
||||
"conversation_type": "public_channel",
|
||||
"id": "channel1",
|
||||
"is_archived": False,
|
||||
"is_member": True,
|
||||
"is_private": False,
|
||||
"name": "random",
|
||||
"num_members": 999,
|
||||
"purpose": "Random stuff",
|
||||
},
|
||||
{
|
||||
"conversation_type": "public_channel",
|
||||
"id": "channel2",
|
||||
"is_archived": False,
|
||||
"is_member": True,
|
||||
"is_private": False,
|
||||
"name": "general",
|
||||
"num_members": 999,
|
||||
"purpose": "Just a general channel",
|
||||
},
|
||||
],
|
||||
"next_cursor": "",
|
||||
}),
|
||||
"tool_call_id": "call_1",
|
||||
"name": "Slack_ListPublicChannelsMetadata",
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": 'To send the message "Great job on the presentation!" to the team, please let me know which Slack channel you\'d like to use:\n\n1. #random\n2. #general\n\nPlease let me know your choice!',
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Multiple recipients in DM request",
|
||||
user_message="Send DMs to the users 'alice' and 'bob' about rescheduling our meeting tomorrow. I have too much work to do.",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=send_message,
|
||||
args={
|
||||
"conversation_id": None,
|
||||
"channel_name": None,
|
||||
"user_ids": None,
|
||||
"emails": None,
|
||||
"usernames": ["alice"],
|
||||
},
|
||||
),
|
||||
ExpectedToolCall(
|
||||
func=send_message,
|
||||
args={
|
||||
"conversation_id": None,
|
||||
"channel_name": None,
|
||||
"user_ids": None,
|
||||
"emails": None,
|
||||
"usernames": ["bob"],
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_id", weight=0.1),
|
||||
BinaryCritic(critic_field="channel_name", weight=0.1),
|
||||
BinaryCritic(critic_field="user_ids", weight=0.1),
|
||||
BinaryCritic(critic_field="emails", weight=0.1),
|
||||
BinaryCritic(critic_field="usernames", weight=0.6),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Multiple recipients in MPIM request",
|
||||
user_message="Send a message to the users 'alice' and 'bob' about rescheduling our meeting tomorrow. I have too much work to do.",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=send_message,
|
||||
args={
|
||||
"conversation_id": None,
|
||||
"channel_name": None,
|
||||
"user_ids": None,
|
||||
"emails": None,
|
||||
"usernames": ["alice", "bob"],
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="conversation_id", weight=0.1),
|
||||
BinaryCritic(critic_field="channel_name", weight=0.1),
|
||||
BinaryCritic(critic_field="user_ids", weight=0.1),
|
||||
BinaryCritic(critic_field="emails", weight=0.1),
|
||||
BinaryCritic(critic_field="usernames", weight=0.6),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="Channel name similar to username",
|
||||
user_message="Post 'sounds great!' in john-project channel",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=send_message,
|
||||
args={
|
||||
"conversation_id": None,
|
||||
"channel_name": "john-project",
|
||||
"user_ids": None,
|
||||
"emails": None,
|
||||
"usernames": None,
|
||||
"message": "Sounds great!",
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[
|
||||
SimilarityCritic(critic_field="message", weight=0.3, similarity_threshold=0.9),
|
||||
BinaryCritic(critic_field="conversation_id", weight=0.1),
|
||||
BinaryCritic(critic_field="channel_name", weight=0.3),
|
||||
BinaryCritic(critic_field="user_ids", weight=0.1),
|
||||
BinaryCritic(critic_field="emails", weight=0.1),
|
||||
BinaryCritic(critic_field="usernames", weight=0.1),
|
||||
],
|
||||
)
|
||||
|
||||
return suite
|
||||
File diff suppressed because it is too large
Load diff
|
|
@ -10,7 +10,10 @@ from arcade_evals import (
|
|||
from arcade_tdk import ToolCatalog
|
||||
|
||||
import arcade_slack
|
||||
from arcade_slack.tools.users import get_user_info_by_id, list_users
|
||||
from arcade_slack.tools.users import (
|
||||
get_users_info,
|
||||
list_users,
|
||||
)
|
||||
|
||||
# Evaluation rubric
|
||||
rubric = EvalRubric(
|
||||
|
|
@ -25,7 +28,7 @@ catalog.add_module(arcade_slack)
|
|||
|
||||
|
||||
@tool_eval()
|
||||
def get_user_info_by_id_eval_suite() -> EvalSuite:
|
||||
def get_user_info_eval_suite() -> EvalSuite:
|
||||
"""Create an evaluation suite for tools getting user info by id."""
|
||||
suite = EvalSuite(
|
||||
name="Slack Users Tools Evaluation",
|
||||
|
|
@ -53,13 +56,101 @@ def get_user_info_by_id_eval_suite() -> EvalSuite:
|
|||
user_message=user_message,
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_user_info_by_id,
|
||||
args={"user_id": expected_user_id},
|
||||
func=get_users_info,
|
||||
args={
|
||||
"user_ids": [expected_user_id],
|
||||
"usernames": None,
|
||||
"emails": None,
|
||||
},
|
||||
)
|
||||
],
|
||||
critics=[BinaryCritic(critic_field="user_id", weight=1.0)],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="user_ids", weight=1 / 3),
|
||||
BinaryCritic(critic_field="usernames", weight=1 / 3),
|
||||
BinaryCritic(critic_field="emails", weight=1 / 3),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="get user by username",
|
||||
user_message="get the user 'johndoe'",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_users_info,
|
||||
args={
|
||||
"usernames": ["johndoe"],
|
||||
"user_ids": None,
|
||||
"emails": None,
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="usernames", weight=1 / 3),
|
||||
BinaryCritic(critic_field="user_ids", weight=1 / 3),
|
||||
BinaryCritic(critic_field="emails", weight=1 / 3),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="get user by email",
|
||||
user_message="get the user 'john.doe@acme.com'",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_users_info,
|
||||
args={
|
||||
"usernames": None,
|
||||
"user_ids": None,
|
||||
"emails": ["john.doe@acme.com"],
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="usernames", weight=1 / 3),
|
||||
BinaryCritic(critic_field="user_ids", weight=1 / 3),
|
||||
BinaryCritic(critic_field="emails", weight=1 / 3),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="get multiple users by username",
|
||||
user_message="get the users with the usernames 'johndoe' and 'foobar'",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_users_info,
|
||||
args={
|
||||
"usernames": ["johndoe", "foobar"],
|
||||
"user_ids": None,
|
||||
"emails": None,
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="usernames", weight=1 / 3),
|
||||
BinaryCritic(critic_field="user_ids", weight=1 / 3),
|
||||
BinaryCritic(critic_field="emails", weight=1 / 3),
|
||||
],
|
||||
)
|
||||
|
||||
suite.add_case(
|
||||
name="get multiple users by email",
|
||||
user_message="get the users with the emails 'john.doe@acme.com' and 'jane.doe@acme.com'",
|
||||
expected_tool_calls=[
|
||||
ExpectedToolCall(
|
||||
func=get_users_info,
|
||||
args={
|
||||
"usernames": None,
|
||||
"user_ids": None,
|
||||
"emails": ["john.doe@acme.com", "jane.doe@acme.com"],
|
||||
},
|
||||
),
|
||||
],
|
||||
critics=[
|
||||
BinaryCritic(critic_field="usernames", weight=1 / 3),
|
||||
BinaryCritic(critic_field="user_ids", weight=1 / 3),
|
||||
BinaryCritic(critic_field="emails", weight=1 / 3),
|
||||
],
|
||||
)
|
||||
|
||||
return suite
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ build-backend = "hatchling.build"
|
|||
|
||||
[project]
|
||||
name = "arcade_slack"
|
||||
version = "0.4.6"
|
||||
version = "0.5.0"
|
||||
description = "Arcade.dev LLM tools for Slack"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [ "aiodns>=1.0,<2.0.0", "typing; python_version < '3.7'", "aiohttp>=3.7.3,<4.0.0", "arcade-tdk>=2.0.0,<3.0.0", "slack-sdk>=3.31.0,<4.0.0",]
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
107
toolkits/slack/tests/test_models.py
Normal file
107
toolkits/slack/tests/test_models.py
Normal file
|
|
@ -0,0 +1,107 @@
|
|||
import asyncio
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from arcade_tdk.errors import ToolExecutionError
|
||||
from slack_sdk.errors import SlackApiError
|
||||
|
||||
from arcade_slack.models import (
|
||||
ConcurrencySafeCoroutineCaller,
|
||||
FindChannelByNameSentinel,
|
||||
FindMultipleUsersByUsernameSentinel,
|
||||
FindUserByUsernameSentinel,
|
||||
GetUserByEmailCaller,
|
||||
)
|
||||
|
||||
|
||||
def test_find_user_by_username_sentinel():
|
||||
sentinel = FindUserByUsernameSentinel(username="jenifer")
|
||||
assert sentinel(last_result=[{"name": "jack"}]) is False
|
||||
assert sentinel(last_result=[{"name": "john"}, {"name": "jack"}]) is False
|
||||
assert sentinel(last_result=[{"name": "hello"}, {"name": "jenifer"}]) is True
|
||||
assert sentinel(last_result=[{"name": "JENIFER"}]) is True
|
||||
|
||||
|
||||
def test_find_multiple_users_by_username_sentinel():
|
||||
sentinel = FindMultipleUsersByUsernameSentinel(usernames=["jenifer", "jack"])
|
||||
assert sentinel(last_result=[{"name": "jack"}]) is False
|
||||
assert sentinel(last_result=[{"name": "john"}, {"name": "jack"}]) is False
|
||||
assert sentinel(last_result=[{"name": "hello"}, {"name": "JENIFER"}]) is True
|
||||
assert sentinel(last_result=[{"name": "world"}]) is True
|
||||
|
||||
|
||||
def test_find_channel_by_name_sentinel():
|
||||
sentinel = FindChannelByNameSentinel(channel_name="foobar")
|
||||
assert sentinel(last_result=[{"name": "foo"}]) is False
|
||||
assert sentinel(last_result=[{"name": "foo"}, {"name": "bar"}]) is False
|
||||
assert sentinel(last_result=[{"name": "foo"}, {"name": "foobar"}]) is True
|
||||
assert sentinel(last_result=[{"name": "FOObar"}]) is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrency_safe_coroutine_caller():
|
||||
mock_func = AsyncMock()
|
||||
mock_semaphore = AsyncMock(spec=asyncio.Semaphore)
|
||||
|
||||
caller = ConcurrencySafeCoroutineCaller(mock_func, "arg1", "arg2", kwarg1="kwarg1")
|
||||
response = await caller(semaphore=mock_semaphore)
|
||||
|
||||
assert response == mock_func.return_value
|
||||
mock_func.assert_awaited_once_with("arg1", "arg2", kwarg1="kwarg1")
|
||||
mock_semaphore.__aenter__.assert_awaited_once()
|
||||
mock_semaphore.__aexit__.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_by_email_caller_success():
|
||||
mock_func = AsyncMock()
|
||||
mock_func.return_value = {"user": {"id": "U1234567890", "name": "John Doe"}}
|
||||
mock_semaphore = AsyncMock(spec=asyncio.Semaphore)
|
||||
|
||||
caller = GetUserByEmailCaller(mock_func, "test@example.com")
|
||||
response = await caller(semaphore=mock_semaphore)
|
||||
|
||||
assert response == {
|
||||
"user": {"id": "U1234567890", "name": "John Doe"},
|
||||
"email": "test@example.com",
|
||||
}
|
||||
mock_func.assert_awaited_once_with(email="test@example.com")
|
||||
mock_semaphore.__aenter__.assert_awaited_once()
|
||||
mock_semaphore.__aexit__.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_by_email_caller_not_found_error():
|
||||
mock_func = AsyncMock(
|
||||
side_effect=SlackApiError(message="User not found", response={"error": "user_not_found"})
|
||||
)
|
||||
mock_semaphore = AsyncMock(spec=asyncio.Semaphore)
|
||||
|
||||
caller = GetUserByEmailCaller(mock_func, "test@example.com")
|
||||
response = await caller(semaphore=mock_semaphore)
|
||||
|
||||
assert response == {
|
||||
"user": None,
|
||||
"email": "test@example.com",
|
||||
}
|
||||
mock_func.assert_awaited_once_with(email="test@example.com")
|
||||
mock_semaphore.__aenter__.assert_awaited_once()
|
||||
mock_semaphore.__aexit__.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_by_email_caller_unknown_slack_api_error():
|
||||
mock_func = AsyncMock(
|
||||
side_effect=SlackApiError(message="Unknown error", response={"error": "unknown_error"})
|
||||
)
|
||||
mock_semaphore = AsyncMock(spec=asyncio.Semaphore)
|
||||
|
||||
caller = GetUserByEmailCaller(mock_func, "test@example.com")
|
||||
with pytest.raises(ToolExecutionError) as error:
|
||||
await caller(semaphore=mock_semaphore)
|
||||
|
||||
assert error.value.message == "Error getting user by email"
|
||||
assert error.value.developer_message == "Error getting user by email: unknown_error"
|
||||
mock_func.assert_awaited_once_with(email="test@example.com")
|
||||
mock_semaphore.__aenter__.assert_awaited_once()
|
||||
mock_semaphore.__aexit__.assert_awaited_once()
|
||||
313
toolkits/slack/tests/test_user_retrieval.py
Normal file
313
toolkits/slack/tests/test_user_retrieval.py
Normal file
|
|
@ -0,0 +1,313 @@
|
|||
import json
|
||||
|
||||
import pytest
|
||||
from arcade_tdk.errors import RetryableToolError, ToolExecutionError
|
||||
from slack_sdk.errors import SlackApiError
|
||||
|
||||
from arcade_slack.user_retrieval import (
|
||||
get_single_user_by_id,
|
||||
get_users_by_id,
|
||||
get_users_by_id_username_or_email,
|
||||
)
|
||||
from arcade_slack.utils import (
|
||||
cast_user_dict,
|
||||
extract_basic_user_info,
|
||||
short_user_info,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multiple_users_by_emails_success(
|
||||
mock_context, mock_user_retrieval_slack_client, dummy_user_factory
|
||||
):
|
||||
user1 = dummy_user_factory()
|
||||
user2 = dummy_user_factory()
|
||||
|
||||
emails = [user1["profile"]["email"], user2["profile"]["email"]]
|
||||
|
||||
mock_user_retrieval_slack_client.users_lookupByEmail.side_effect = [
|
||||
{"ok": True, "user": user1},
|
||||
{"ok": True, "user": user2},
|
||||
]
|
||||
|
||||
response = await get_users_by_id_username_or_email(context=mock_context, emails=emails)
|
||||
|
||||
assert response == [extract_basic_user_info(user1), extract_basic_user_info(user2)]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multiple_users_by_usernames_or_emails_with_emails_not_found(
|
||||
mock_context,
|
||||
mock_user_retrieval_slack_client,
|
||||
mock_users_slack_client,
|
||||
dummy_user_factory,
|
||||
):
|
||||
user1 = dummy_user_factory(email="user1@example.com")
|
||||
|
||||
emails = [user1["profile"]["email"], "not_found@example.com"]
|
||||
|
||||
async def lookup_by_email_side_effect(*, email):
|
||||
if email == "user1@example.com":
|
||||
return {"ok": True, "user": user1}
|
||||
raise SlackApiError(
|
||||
message="User not found",
|
||||
response={"error": "user_not_found"},
|
||||
)
|
||||
|
||||
mock_user_retrieval_slack_client.users_lookupByEmail.side_effect = lookup_by_email_side_effect
|
||||
|
||||
mock_users_slack_client.users_list.return_value = {
|
||||
"ok": True,
|
||||
"members": [user1],
|
||||
}
|
||||
|
||||
with pytest.raises(RetryableToolError) as error:
|
||||
await get_users_by_id_username_or_email(context=mock_context, emails=emails)
|
||||
|
||||
assert "not_found@example.com" in error.value.message
|
||||
assert json.dumps(short_user_info(user1)) in error.value.additional_prompt_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multiple_users_by_usernames_or_emails_with_usernames_success(
|
||||
mock_context, mock_user_retrieval_slack_client, dummy_user_factory
|
||||
):
|
||||
user1 = dummy_user_factory()
|
||||
user2 = dummy_user_factory()
|
||||
|
||||
usernames = [user1["name"], user2["name"]]
|
||||
|
||||
mock_user_retrieval_slack_client.users_list.return_value = {
|
||||
"ok": True,
|
||||
"members": [user1, user2],
|
||||
}
|
||||
|
||||
response = await get_users_by_id_username_or_email(context=mock_context, usernames=usernames)
|
||||
|
||||
assert response == [extract_basic_user_info(user1), extract_basic_user_info(user2)]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multiple_users_by_usernames_or_emails_with_usernames_not_found(
|
||||
mock_context, mock_user_retrieval_slack_client, dummy_user_factory
|
||||
):
|
||||
user1 = dummy_user_factory()
|
||||
user2 = dummy_user_factory()
|
||||
user3 = dummy_user_factory()
|
||||
|
||||
usernames = [user1["name"], "username_not_found"]
|
||||
|
||||
mock_user_retrieval_slack_client.users_list.return_value = {
|
||||
"ok": True,
|
||||
"members": [user1, user2, user3],
|
||||
}
|
||||
|
||||
with pytest.raises(RetryableToolError) as error:
|
||||
await get_users_by_id_username_or_email(context=mock_context, usernames=usernames)
|
||||
|
||||
assert "username_not_found" in error.value.message
|
||||
assert json.dumps(short_user_info(user1)) in error.value.additional_prompt_content
|
||||
assert json.dumps(short_user_info(user2)) in error.value.additional_prompt_content
|
||||
assert json.dumps(short_user_info(user3)) in error.value.additional_prompt_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multiple_users_by_mixed_usernames_and_emails_success(
|
||||
mock_context, mock_user_retrieval_slack_client, dummy_user_factory
|
||||
):
|
||||
user1 = dummy_user_factory()
|
||||
user2 = dummy_user_factory()
|
||||
user3 = dummy_user_factory()
|
||||
user4 = dummy_user_factory()
|
||||
|
||||
mock_user_retrieval_slack_client.users_list.return_value = {
|
||||
"ok": True,
|
||||
"members": [user1, user2],
|
||||
}
|
||||
mock_user_retrieval_slack_client.users_lookupByEmail.side_effect = [
|
||||
{"ok": True, "user": user3},
|
||||
{"ok": True, "user": user4},
|
||||
]
|
||||
|
||||
response = await get_users_by_id_username_or_email(
|
||||
context=mock_context,
|
||||
usernames=[user1["name"], user2["name"]],
|
||||
emails=[user3["profile"]["email"], user4["profile"]["email"]],
|
||||
)
|
||||
|
||||
assert response == [
|
||||
extract_basic_user_info(user1),
|
||||
extract_basic_user_info(user2),
|
||||
extract_basic_user_info(user3),
|
||||
extract_basic_user_info(user4),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_single_user_by_id_success(
|
||||
mock_context, mock_user_retrieval_slack_client, dummy_user_factory
|
||||
):
|
||||
user = dummy_user_factory()
|
||||
|
||||
mock_user_retrieval_slack_client.users_info.return_value = {"ok": True, "user": user}
|
||||
|
||||
response = await get_single_user_by_id(
|
||||
auth_token=mock_context.get_auth_token_or_empty(),
|
||||
user_id=user["id"],
|
||||
)
|
||||
|
||||
assert response == cast_user_dict(user)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_single_user_by_id_not_found(
|
||||
mock_context, mock_user_retrieval_slack_client, dummy_user_factory
|
||||
):
|
||||
user = dummy_user_factory()
|
||||
|
||||
mock_user_retrieval_slack_client.users_info.side_effect = SlackApiError(
|
||||
message="User not found",
|
||||
response={"error": "user_not_found"},
|
||||
)
|
||||
|
||||
response = await get_single_user_by_id(
|
||||
auth_token=mock_context.get_auth_token_or_empty(),
|
||||
user_id=user["id"],
|
||||
)
|
||||
|
||||
assert response is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_single_user_by_id_not_ok(
|
||||
mock_context, mock_user_retrieval_slack_client, dummy_user_factory
|
||||
):
|
||||
user = dummy_user_factory()
|
||||
|
||||
mock_user_retrieval_slack_client.users_info.return_value = {"ok": False, "error": "not_ok"}
|
||||
|
||||
response = await get_single_user_by_id(
|
||||
auth_token=mock_context.get_auth_token_or_empty(),
|
||||
user_id=user["id"],
|
||||
)
|
||||
|
||||
assert response is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_single_user_by_id_unknown_error(
|
||||
mock_context, mock_user_retrieval_slack_client, dummy_user_factory
|
||||
):
|
||||
user = dummy_user_factory()
|
||||
|
||||
mock_user_retrieval_slack_client.users_info.side_effect = SlackApiError(
|
||||
message="Unknown error",
|
||||
response={"error": "unknown_error_string"},
|
||||
)
|
||||
|
||||
with pytest.raises(ToolExecutionError) as error:
|
||||
await get_single_user_by_id(
|
||||
auth_token=mock_context.get_auth_token_or_empty(),
|
||||
user_id=user["id"],
|
||||
)
|
||||
|
||||
assert user["id"] in error.value.message
|
||||
assert "unknown_error_string" in error.value.developer_message
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_users_by_id_with_one_user_id_success(
|
||||
mock_context, mock_user_retrieval_slack_client, dummy_user_factory
|
||||
):
|
||||
user1 = dummy_user_factory()
|
||||
|
||||
mock_user_retrieval_slack_client.users_info.return_value = {"ok": True, "user": user1}
|
||||
|
||||
response = await get_users_by_id(
|
||||
auth_token=mock_context.get_auth_token_or_empty(),
|
||||
user_ids=[user1["id"]],
|
||||
)
|
||||
|
||||
assert response == {"users": [cast_user_dict(user1)], "not_found": []}
|
||||
|
||||
mock_user_retrieval_slack_client.users_list.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_users_by_id_with_one_user_id_not_found(
|
||||
mock_context, mock_user_retrieval_slack_client, dummy_user_factory
|
||||
):
|
||||
user1 = dummy_user_factory()
|
||||
|
||||
mock_user_retrieval_slack_client.users_info.side_effect = SlackApiError(
|
||||
message="User not found",
|
||||
response={"error": "user_not_found"},
|
||||
)
|
||||
|
||||
response = await get_users_by_id(
|
||||
auth_token=mock_context.get_auth_token_or_empty(),
|
||||
user_ids=[user1["id"]],
|
||||
)
|
||||
|
||||
assert response == {"users": [], "not_found": [user1["id"]]}
|
||||
|
||||
mock_user_retrieval_slack_client.users_list.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_users_by_id_with_multiple_user_ids_success(
|
||||
mock_context, mock_user_retrieval_slack_client, dummy_user_factory
|
||||
):
|
||||
user1 = dummy_user_factory()
|
||||
user2 = dummy_user_factory()
|
||||
user3 = dummy_user_factory()
|
||||
user4 = dummy_user_factory()
|
||||
|
||||
mock_user_retrieval_slack_client.users_list.side_effect = [
|
||||
{
|
||||
"ok": True,
|
||||
"members": [user1, user2],
|
||||
"response_metadata": {"next_cursor": "next_cursor"},
|
||||
},
|
||||
{"ok": True, "members": [user3, user4]},
|
||||
]
|
||||
|
||||
response = await get_users_by_id(
|
||||
auth_token=mock_context.get_auth_token_or_empty(),
|
||||
user_ids=[user1["id"], user4["id"]],
|
||||
)
|
||||
|
||||
assert response == {"users": [cast_user_dict(user1), cast_user_dict(user4)], "not_found": []}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_users_by_id_with_multiple_user_ids_some_not_found(
|
||||
mock_context, mock_user_retrieval_slack_client, dummy_user_factory
|
||||
):
|
||||
user1 = dummy_user_factory(id_="U1")
|
||||
user2 = dummy_user_factory(id_="U2")
|
||||
user3 = dummy_user_factory(id_="U3")
|
||||
user4 = dummy_user_factory(id_="U4")
|
||||
|
||||
mock_user_retrieval_slack_client.users_list.side_effect = [
|
||||
{
|
||||
"ok": True,
|
||||
"members": [user1, user2],
|
||||
"response_metadata": {"next_cursor": "next_cursor"},
|
||||
},
|
||||
{
|
||||
"ok": True,
|
||||
"members": [user3, user4],
|
||||
"response_metadata": {"next_cursor": None},
|
||||
},
|
||||
]
|
||||
|
||||
response = await get_users_by_id(
|
||||
auth_token=mock_context.get_auth_token_or_empty(),
|
||||
user_ids=[user1["id"], user4["id"], "user_not_exists"],
|
||||
)
|
||||
|
||||
assert response == {
|
||||
"users": [cast_user_dict(user1), cast_user_dict(user4)],
|
||||
"not_found": ["user_not_exists"],
|
||||
}
|
||||
|
|
@ -1,21 +1,19 @@
|
|||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from arcade_tdk.errors import RetryableToolError
|
||||
from arcade_tdk.errors import RetryableToolError, ToolExecutionError
|
||||
from slack_sdk.errors import SlackApiError
|
||||
|
||||
from arcade_slack.tools.users import get_user_info_by_id, list_users
|
||||
from arcade_slack.utils import extract_basic_user_info
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_slack_client(mocker):
|
||||
mock_client = mocker.patch("arcade_slack.tools.users.AsyncWebClient", autospec=True)
|
||||
return mock_client.return_value
|
||||
from arcade_slack.tools.users import (
|
||||
get_users_info,
|
||||
list_users,
|
||||
)
|
||||
from arcade_slack.utils import extract_basic_user_info, short_user_info
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_info_by_id_success(mock_context, mock_slack_client):
|
||||
async def test_get_user_info_by_id_success(mock_context, mock_user_retrieval_slack_client):
|
||||
# Mock the response from slackClient.users_info
|
||||
mock_user = {
|
||||
"id": "U12345",
|
||||
|
|
@ -23,44 +21,46 @@ async def test_get_user_info_by_id_success(mock_context, mock_slack_client):
|
|||
"real_name": "Test User",
|
||||
"profile": {"email": "testuser@example.com"},
|
||||
}
|
||||
mock_slack_client.users_info.return_value = {"ok": True, "user": mock_user}
|
||||
mock_user_retrieval_slack_client.users_info.return_value = {"ok": True, "user": mock_user}
|
||||
|
||||
# Call the function
|
||||
response = await get_user_info_by_id(mock_context, user_id="U12345")
|
||||
response = await get_users_info(mock_context, user_ids=["U12345"])
|
||||
|
||||
# Verify that the correct Slack API method was called
|
||||
mock_slack_client.users_info.assert_called_once_with(user="U12345")
|
||||
mock_user_retrieval_slack_client.users_info.assert_called_once_with(user="U12345")
|
||||
|
||||
# Verify the response
|
||||
expected_response = extract_basic_user_info(mock_user)
|
||||
assert response == expected_response
|
||||
assert response == {"users": [expected_response]}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("arcade_slack.tools.users.list_users")
|
||||
async def test_get_user_info_by_id_user_not_found(mock_list_users, mock_context, mock_slack_client):
|
||||
async def test_get_user_info_by_id_user_not_found(
|
||||
mock_list_users, mock_context, mock_user_retrieval_slack_client
|
||||
):
|
||||
error_response = {"ok": False, "error": "user_not_found"}
|
||||
mock_slack_client.users_info.side_effect = SlackApiError(
|
||||
mock_user_retrieval_slack_client.users_info.side_effect = SlackApiError(
|
||||
message="User not found",
|
||||
response=error_response,
|
||||
)
|
||||
|
||||
existing_user = {"id": "U12345", "name": "testuser"}
|
||||
mock_list_users.return_value = {"users": [existing_user]}
|
||||
mock_list_users.__tool_name__ = list_users.__tool_name__
|
||||
|
||||
with pytest.raises(RetryableToolError) as e:
|
||||
await get_user_info_by_id(mock_context, user_id="U99999")
|
||||
await get_users_info(mock_context, user_ids=["U99999"])
|
||||
|
||||
assert existing_user["id"] in e.value.additional_prompt_content
|
||||
assert existing_user["name"] in e.value.additional_prompt_content
|
||||
|
||||
mock_slack_client.users_info.assert_called_once_with(user="U99999")
|
||||
mock_list_users.assert_called_once_with(mock_context)
|
||||
mock_user_retrieval_slack_client.users_info.assert_called_once_with(user="U99999")
|
||||
mock_list_users.assert_called_once_with(mock_context, limit=100, exclude_bots=True)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_users_success(mock_context, mock_slack_client):
|
||||
mock_slack_client.users_list.return_value = {"ok": True, "members": [{"id": "U12345"}]}
|
||||
async def test_list_users_success(mock_context, mock_users_slack_client):
|
||||
mock_users_slack_client.users_list.return_value = {
|
||||
"ok": True,
|
||||
"members": [{"id": "U12345"}],
|
||||
}
|
||||
response = await list_users(mock_context)
|
||||
assert response == {
|
||||
"users": [extract_basic_user_info({"id": "U12345"})],
|
||||
|
|
@ -69,24 +69,306 @@ async def test_list_users_success(mock_context, mock_slack_client):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_users_with_pagination_success(mock_context, mock_slack_client):
|
||||
mock_slack_client.users_list.side_effect = [
|
||||
async def test_list_users_with_pagination_success(
|
||||
dummy_user_factory, mock_context, mock_users_slack_client
|
||||
):
|
||||
user1 = dummy_user_factory()
|
||||
user2 = dummy_user_factory()
|
||||
user3 = dummy_user_factory()
|
||||
user4 = dummy_user_factory()
|
||||
|
||||
mock_users_slack_client.users_list.side_effect = [
|
||||
{
|
||||
"ok": True,
|
||||
"members": [{"id": "U12345"}],
|
||||
"members": [user1, user2],
|
||||
"response_metadata": {"next_cursor": "cursor_xyz"},
|
||||
},
|
||||
{
|
||||
"ok": True,
|
||||
"members": [{"id": "U123456"}],
|
||||
"members": [user3, user4],
|
||||
"response_metadata": {"next_cursor": None},
|
||||
},
|
||||
]
|
||||
response = await list_users(mock_context, limit=3)
|
||||
assert response == {
|
||||
"users": [
|
||||
extract_basic_user_info({"id": "U12345"}),
|
||||
extract_basic_user_info({"id": "U123456"}),
|
||||
extract_basic_user_info(user1),
|
||||
extract_basic_user_info(user2),
|
||||
extract_basic_user_info(user3),
|
||||
extract_basic_user_info(user4),
|
||||
],
|
||||
"next_cursor": None,
|
||||
}
|
||||
|
||||
assert mock_users_slack_client.users_list.call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_by_username_success(
|
||||
mock_context,
|
||||
mock_user_retrieval_slack_client,
|
||||
dummy_user_factory,
|
||||
):
|
||||
user1 = dummy_user_factory()
|
||||
user2 = dummy_user_factory()
|
||||
|
||||
mock_user_retrieval_slack_client.users_list.return_value = {
|
||||
"ok": True,
|
||||
"members": [user1, user2],
|
||||
}
|
||||
|
||||
response = await get_users_info(mock_context, usernames=[user1["name"]])
|
||||
|
||||
assert response == {"users": [extract_basic_user_info(user1)]}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_by_username_with_pagination_success(
|
||||
mock_context, mock_user_retrieval_slack_client, dummy_user_factory
|
||||
):
|
||||
user1 = dummy_user_factory()
|
||||
user2 = dummy_user_factory()
|
||||
user3 = dummy_user_factory()
|
||||
user4 = dummy_user_factory()
|
||||
user5 = dummy_user_factory()
|
||||
|
||||
mock_user_retrieval_slack_client.users_list.side_effect = [
|
||||
{
|
||||
"ok": True,
|
||||
"members": [user1, user2],
|
||||
"response_metadata": {"next_cursor": "cursor1"},
|
||||
},
|
||||
{
|
||||
"ok": True,
|
||||
"members": [user3, user4],
|
||||
"response_metadata": {"next_cursor": "cursor2"},
|
||||
},
|
||||
{
|
||||
"ok": True,
|
||||
"members": [user5],
|
||||
"response_metadata": {"next_cursor": None},
|
||||
},
|
||||
]
|
||||
|
||||
response = await get_users_info(mock_context, usernames=[user3["name"]])
|
||||
|
||||
assert response == {"users": [extract_basic_user_info(user3)]}
|
||||
|
||||
assert mock_user_retrieval_slack_client.users_list.call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_by_username_not_found(
|
||||
mock_context, mock_user_retrieval_slack_client, dummy_user_factory
|
||||
):
|
||||
user1 = dummy_user_factory()
|
||||
user2 = dummy_user_factory()
|
||||
user3 = dummy_user_factory(is_bot=True)
|
||||
|
||||
mock_user_retrieval_slack_client.users_list.return_value = {
|
||||
"ok": True,
|
||||
"members": [user1, user2, user3],
|
||||
}
|
||||
|
||||
with pytest.raises(RetryableToolError) as e:
|
||||
await get_users_info(mock_context, usernames=[user1["name"] + "not_found"])
|
||||
|
||||
# Check that the error message contains the available users
|
||||
assert user1["id"] in e.value.additional_prompt_content
|
||||
assert user2["id"] in e.value.additional_prompt_content
|
||||
assert user3["id"] not in e.value.additional_prompt_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multiple_users_by_username_success(
|
||||
mock_context, mock_user_retrieval_slack_client, dummy_user_factory
|
||||
):
|
||||
user1 = dummy_user_factory()
|
||||
user2 = dummy_user_factory()
|
||||
user3 = dummy_user_factory()
|
||||
|
||||
mock_user_retrieval_slack_client.users_list.return_value = {
|
||||
"ok": True,
|
||||
"members": [user1, user2, user3],
|
||||
}
|
||||
|
||||
response = await get_users_info(
|
||||
mock_context,
|
||||
usernames=[user1["name"], user2["name"]],
|
||||
)
|
||||
|
||||
assert response == {
|
||||
"users": [
|
||||
extract_basic_user_info(user1),
|
||||
extract_basic_user_info(user2),
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multiple_users_by_username_with_pagination_success(
|
||||
mock_context, mock_user_retrieval_slack_client, dummy_user_factory
|
||||
):
|
||||
user1 = dummy_user_factory()
|
||||
user2 = dummy_user_factory()
|
||||
user3 = dummy_user_factory()
|
||||
user4 = dummy_user_factory()
|
||||
user5 = dummy_user_factory()
|
||||
|
||||
mock_user_retrieval_slack_client.users_list.side_effect = [
|
||||
{"ok": True, "members": [user1, user2], "response_metadata": {"next_cursor": "cursor1"}},
|
||||
{"ok": True, "members": [user3, user4], "response_metadata": {"next_cursor": "cursor2"}},
|
||||
{"ok": True, "members": [user5], "response_metadata": {"next_cursor": None}},
|
||||
]
|
||||
|
||||
response = await get_users_info(mock_context, usernames=[user1["name"], user3["name"]])
|
||||
|
||||
assert response == {"users": [extract_basic_user_info(user1), extract_basic_user_info(user3)]}
|
||||
assert mock_user_retrieval_slack_client.users_list.call_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multiple_users_by_username_not_found(
|
||||
mock_context, mock_user_retrieval_slack_client, dummy_user_factory
|
||||
):
|
||||
user1 = dummy_user_factory()
|
||||
user2 = dummy_user_factory()
|
||||
user3 = dummy_user_factory(is_bot=True)
|
||||
|
||||
mock_user_retrieval_slack_client.users_list.return_value = {
|
||||
"ok": True,
|
||||
"members": [user1, user2, user3],
|
||||
}
|
||||
|
||||
not_found_username = f"{user2['name']} not_found"
|
||||
|
||||
with pytest.raises(RetryableToolError) as e:
|
||||
await get_users_info(mock_context, usernames=[user1["name"], not_found_username])
|
||||
|
||||
assert user1["id"] in e.value.additional_prompt_content
|
||||
assert user2["id"] in e.value.additional_prompt_content
|
||||
assert user3["id"] not in e.value.additional_prompt_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_by_email_success(
|
||||
mock_context, mock_user_retrieval_slack_client, dummy_user_factory
|
||||
):
|
||||
user = dummy_user_factory()
|
||||
mock_user_retrieval_slack_client.users_lookupByEmail.return_value = {
|
||||
"ok": True,
|
||||
"user": user,
|
||||
}
|
||||
response = await get_users_info(mock_context, emails=[user["profile"]["email"]])
|
||||
assert response == {"users": [extract_basic_user_info(user)]}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_user_by_email_not_found(
|
||||
mock_context, mock_users_slack_client, mock_user_retrieval_slack_client, dummy_user_factory
|
||||
):
|
||||
additional_user = dummy_user_factory(email="additional_user@example.com")
|
||||
|
||||
async def lookup_by_email_side_effect(*, email):
|
||||
if email == "additional_user@example.com":
|
||||
return {"ok": True, "user": additional_user}
|
||||
raise SlackApiError(
|
||||
message="User not found",
|
||||
response={"ok": False, "error": "user_not_found"},
|
||||
)
|
||||
|
||||
mock_user_retrieval_slack_client.users_lookupByEmail.side_effect = lookup_by_email_side_effect
|
||||
mock_users_slack_client.users_list.return_value = {
|
||||
"ok": True,
|
||||
"members": [additional_user],
|
||||
}
|
||||
|
||||
with pytest.raises(RetryableToolError) as e:
|
||||
await get_users_info(mock_context, emails=["not_found@example.com"])
|
||||
|
||||
assert "not_found@example.com" in e.value.message
|
||||
assert json.dumps(short_user_info(additional_user)) in e.value.additional_prompt_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multiple_users_by_email_success(
|
||||
mock_context, mock_user_retrieval_slack_client, dummy_user_factory
|
||||
):
|
||||
user1 = dummy_user_factory()
|
||||
user2 = dummy_user_factory()
|
||||
user3 = dummy_user_factory()
|
||||
|
||||
mock_user_retrieval_slack_client.users_lookupByEmail.side_effect = [
|
||||
{"ok": True, "user": user1},
|
||||
{"ok": True, "user": user2},
|
||||
{"ok": True, "user": user3},
|
||||
]
|
||||
|
||||
response = await get_users_info(
|
||||
mock_context,
|
||||
emails=[
|
||||
user1["profile"]["email"],
|
||||
user2["profile"]["email"],
|
||||
user3["profile"]["email"],
|
||||
],
|
||||
)
|
||||
|
||||
assert response == {
|
||||
"users": [
|
||||
extract_basic_user_info(user1),
|
||||
extract_basic_user_info(user2),
|
||||
extract_basic_user_info(user3),
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multiple_users_by_email_with_invalid_address(
|
||||
mock_context, mock_user_retrieval_slack_client
|
||||
):
|
||||
with pytest.raises(ToolExecutionError) as e:
|
||||
await get_users_info(
|
||||
mock_context,
|
||||
emails=["amyra@jades.com", "invalid_address"],
|
||||
)
|
||||
|
||||
mock_user_retrieval_slack_client.assert_not_called()
|
||||
assert e.value.message == "Invalid email address: invalid_address"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_multiple_users_by_email_not_found(
|
||||
mock_context, mock_user_retrieval_slack_client, mock_users_slack_client, dummy_user_factory
|
||||
):
|
||||
user1 = dummy_user_factory()
|
||||
user2 = dummy_user_factory()
|
||||
|
||||
async def lookup_by_email_side_effect(*, email):
|
||||
if email == user1["profile"]["email"]:
|
||||
return {"ok": True, "user": user1}
|
||||
if email == user2["profile"]["email"]:
|
||||
return {"ok": True, "user": user2}
|
||||
raise SlackApiError(
|
||||
message="User not found",
|
||||
response={"ok": False, "error": "user_not_found"},
|
||||
)
|
||||
|
||||
mock_user_retrieval_slack_client.users_lookupByEmail.side_effect = lookup_by_email_side_effect
|
||||
mock_users_slack_client.users_list.return_value = {
|
||||
"ok": True,
|
||||
"members": [user1, user2],
|
||||
}
|
||||
|
||||
with pytest.raises(RetryableToolError) as e:
|
||||
await get_users_info(
|
||||
mock_context,
|
||||
emails=[
|
||||
user1["profile"]["email"],
|
||||
user2["profile"]["email"],
|
||||
"not_found@example.com",
|
||||
],
|
||||
)
|
||||
|
||||
assert "not_found@example.com" in e.value.message
|
||||
assert json.dumps(short_user_info(user1)) in e.value.additional_prompt_content
|
||||
assert json.dumps(short_user_info(user2)) in e.value.additional_prompt_content
|
||||
|
|
|
|||
|
|
@ -1,25 +1,26 @@
|
|||
import asyncio
|
||||
from unittest.mock import AsyncMock, MagicMock, call, patch
|
||||
import copy
|
||||
import json
|
||||
from unittest.mock import AsyncMock, call, patch
|
||||
|
||||
import pytest
|
||||
from arcade_tdk import ToolContext
|
||||
from arcade_tdk.errors import RetryableToolError
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from slack_sdk.web.async_client import AsyncWebClient
|
||||
|
||||
from arcade_slack.exceptions import PaginationTimeoutError
|
||||
from arcade_slack.models import (
|
||||
ConversationType,
|
||||
ConversationTypeSlackName,
|
||||
)
|
||||
from arcade_slack.tools.chat import (
|
||||
get_members_in_conversation_by_id,
|
||||
list_conversations_metadata,
|
||||
ConcurrencySafeCoroutineCaller,
|
||||
FindMultipleUsersByUsernameSentinel,
|
||||
FindUserByUsernameSentinel,
|
||||
)
|
||||
from arcade_slack.utils import (
|
||||
async_paginate,
|
||||
convert_conversation_type_to_slack_name,
|
||||
build_multiple_users_retrieval_response,
|
||||
filter_conversations_by_user_ids,
|
||||
retrieve_conversations_by_user_ids,
|
||||
gather_with_concurrency_limit,
|
||||
is_valid_email,
|
||||
populate_users_in_messages,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -41,6 +42,170 @@ async def test_async_paginate():
|
|||
assert next_cursor is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_paginate_with_find_user_sentinel():
|
||||
mock_slack_client = AsyncMock()
|
||||
mock_slack_client.users_list.side_effect = [
|
||||
{
|
||||
"ok": True,
|
||||
"members": [
|
||||
{"id": "123", "name": "Jack"},
|
||||
{"id": "456", "name": "John"},
|
||||
],
|
||||
"response_metadata": {"next_cursor": "cursor1"},
|
||||
},
|
||||
{
|
||||
"ok": True,
|
||||
"members": [{"id": "789", "name": "Jenifer"}],
|
||||
"response_metadata": {"next_cursor": "cursor2"},
|
||||
},
|
||||
{
|
||||
"ok": True,
|
||||
"members": [{"id": "007", "name": "James"}],
|
||||
"response_metadata": {"next_cursor": None},
|
||||
},
|
||||
]
|
||||
|
||||
results, next_cursor = await async_paginate(
|
||||
func=mock_slack_client.users_list,
|
||||
response_key="members",
|
||||
sentinel=FindUserByUsernameSentinel(username="jenifer"),
|
||||
)
|
||||
|
||||
assert results == [
|
||||
{"id": "123", "name": "Jack"},
|
||||
{"id": "456", "name": "John"},
|
||||
{"id": "789", "name": "Jenifer"},
|
||||
]
|
||||
assert next_cursor == "cursor2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_paginate_with_find_user_sentinel_not_found():
|
||||
mock_slack_client = AsyncMock()
|
||||
mock_slack_client.users_list.side_effect = [
|
||||
{
|
||||
"ok": True,
|
||||
"members": [
|
||||
{"id": "123", "name": "Jack"},
|
||||
{"id": "456", "name": "John"},
|
||||
],
|
||||
"response_metadata": {"next_cursor": "cursor1"},
|
||||
},
|
||||
{
|
||||
"ok": True,
|
||||
"members": [{"id": "789", "name": "Jenifer"}],
|
||||
"response_metadata": {"next_cursor": "cursor2"},
|
||||
},
|
||||
{
|
||||
"ok": True,
|
||||
"members": [{"id": "007", "name": "James"}],
|
||||
"response_metadata": {"next_cursor": None},
|
||||
},
|
||||
]
|
||||
|
||||
results, next_cursor = await async_paginate(
|
||||
func=mock_slack_client.users_list,
|
||||
response_key="members",
|
||||
sentinel=FindUserByUsernameSentinel(username="Do not find me"),
|
||||
)
|
||||
|
||||
assert results == [
|
||||
{"id": "123", "name": "Jack"},
|
||||
{"id": "456", "name": "John"},
|
||||
{"id": "789", "name": "Jenifer"},
|
||||
{"id": "007", "name": "James"},
|
||||
]
|
||||
assert next_cursor is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_paginate_with_find_multiple_users_sentinel():
|
||||
mock_slack_client = AsyncMock()
|
||||
mock_slack_client.users_list.side_effect = [
|
||||
{
|
||||
"ok": True,
|
||||
"members": [
|
||||
{"id": "123", "name": "Jack"},
|
||||
{"id": "456", "name": "John"},
|
||||
],
|
||||
"response_metadata": {"next_cursor": "cursor1"},
|
||||
},
|
||||
{
|
||||
"ok": True,
|
||||
"members": [
|
||||
{"id": "789", "name": "Jenifer"},
|
||||
{"id": "101", "name": "Janis"},
|
||||
],
|
||||
"response_metadata": {"next_cursor": "cursor2"},
|
||||
},
|
||||
{
|
||||
"ok": True,
|
||||
"members": [{"id": "007", "name": "James"}],
|
||||
"response_metadata": {"next_cursor": None},
|
||||
},
|
||||
]
|
||||
|
||||
results, next_cursor = await async_paginate(
|
||||
func=mock_slack_client.users_list,
|
||||
response_key="members",
|
||||
sentinel=FindMultipleUsersByUsernameSentinel(usernames=["jenifer", "jack"]),
|
||||
)
|
||||
|
||||
assert results == [
|
||||
{"id": "123", "name": "Jack"},
|
||||
{"id": "456", "name": "John"},
|
||||
{"id": "789", "name": "Jenifer"},
|
||||
{"id": "101", "name": "Janis"},
|
||||
]
|
||||
assert next_cursor == "cursor2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_paginate_with_find_multiple_users_sentinel_not_found():
|
||||
mock_slack_client = AsyncMock()
|
||||
mock_slack_client.users_list.side_effect = [
|
||||
{
|
||||
"ok": True,
|
||||
"members": [
|
||||
{"id": "123", "name": "Jack"},
|
||||
{"id": "456", "name": "John"},
|
||||
],
|
||||
"response_metadata": {"next_cursor": "cursor1"},
|
||||
},
|
||||
{
|
||||
"ok": True,
|
||||
"members": [
|
||||
{"id": "789", "name": "Jenifer"},
|
||||
{"id": "101", "name": "Janis"},
|
||||
],
|
||||
"response_metadata": {"next_cursor": "cursor2"},
|
||||
},
|
||||
{
|
||||
"ok": True,
|
||||
"members": [{"id": "007", "name": "James"}],
|
||||
"response_metadata": {"next_cursor": None},
|
||||
},
|
||||
]
|
||||
|
||||
results, next_cursor = await async_paginate(
|
||||
func=mock_slack_client.users_list,
|
||||
response_key="members",
|
||||
sentinel=FindMultipleUsersByUsernameSentinel(
|
||||
usernames=["jenifer", "jack", "do not find me"]
|
||||
),
|
||||
)
|
||||
|
||||
assert results == [
|
||||
{"id": "123", "name": "Jack"},
|
||||
{"id": "456", "name": "John"},
|
||||
{"id": "789", "name": "Jenifer"},
|
||||
{"id": "101", "name": "Janis"},
|
||||
{"id": "007", "name": "James"},
|
||||
]
|
||||
assert next_cursor is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_paginate_with_response_error():
|
||||
mock_slack_client = AsyncMock()
|
||||
|
|
@ -278,407 +443,247 @@ def test_filter_conversations_by_user_ids_exact_match_empty_response():
|
|||
assert response == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"search_user_ids, conversation_types, exact_match, limit, expected_conversation_ids",
|
||||
"users_by_email, users_by_username, expected_response",
|
||||
[
|
||||
(["U1", "U2"], [ConversationType.DIRECT_MESSAGE], False, 1, ["C1"]),
|
||||
(["U1", "U2"], [ConversationType.DIRECT_MESSAGE], True, 1, ["C1"]),
|
||||
(["U1", "U2", "U3"], [ConversationType.DIRECT_MESSAGE], False, 1, []),
|
||||
(
|
||||
["U1", "U2"],
|
||||
[ConversationType.DIRECT_MESSAGE, ConversationType.PUBLIC_CHANNEL],
|
||||
False,
|
||||
10,
|
||||
["C1", "C3", "C4"],
|
||||
{"users": [{"id": "U1", "name": "user1"}]},
|
||||
{"users": [{"id": "U2", "name": "user2"}]},
|
||||
[{"id": "U1", "name": "user1"}, {"id": "U2", "name": "user2"}],
|
||||
),
|
||||
(
|
||||
["U1", "U2"],
|
||||
[ConversationType.DIRECT_MESSAGE, ConversationType.PUBLIC_CHANNEL],
|
||||
True,
|
||||
10,
|
||||
["C1", "C3"],
|
||||
{"users": [{"id": "U1", "name": "user1"}]},
|
||||
{"users": []},
|
||||
[{"id": "U1", "name": "user1"}],
|
||||
),
|
||||
(
|
||||
{"users": []},
|
||||
{"users": [{"id": "U2", "name": "user2"}]},
|
||||
[{"id": "U2", "name": "user2"}],
|
||||
),
|
||||
(
|
||||
{"users": []},
|
||||
{"users": []},
|
||||
[],
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_retrieve_conversations_by_user_ids(
|
||||
mock_chat_slack_client,
|
||||
mock_users_slack_client,
|
||||
search_user_ids,
|
||||
conversation_types,
|
||||
exact_match,
|
||||
limit,
|
||||
expected_conversation_ids,
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_multiple_users_retrieval_response_success(
|
||||
users_by_email,
|
||||
users_by_username,
|
||||
expected_response,
|
||||
mock_context,
|
||||
):
|
||||
context = MagicMock(spec=ToolContext)
|
||||
context.authorization = MagicMock()
|
||||
context.authorization.token = MagicMock()
|
||||
response = await build_multiple_users_retrieval_response(
|
||||
context=mock_context,
|
||||
users_responses=[users_by_email, users_by_username],
|
||||
)
|
||||
assert response == expected_response
|
||||
|
||||
conversation_types_slack_name_str = [
|
||||
convert_conversation_type_to_slack_name(conv_type).value
|
||||
for conv_type in conversation_types or ConversationType
|
||||
]
|
||||
|
||||
conversations = [
|
||||
{
|
||||
"conversation": {
|
||||
"id": "C1",
|
||||
"type": ConversationTypeSlackName.IM.value,
|
||||
"name": "im-1",
|
||||
"is_channel": False,
|
||||
"is_im": True,
|
||||
"is_member": True,
|
||||
@pytest.mark.parametrize(
|
||||
"users_by_email, users_by_username",
|
||||
[
|
||||
# Both emails and usernames not found
|
||||
(
|
||||
{
|
||||
"users": [{"id": "U1", "name": "user1"}],
|
||||
"not_found": ["email_not_found"],
|
||||
},
|
||||
"members": {
|
||||
"ok": True,
|
||||
"members": ["U1", "U2"],
|
||||
"response_metadata": {"next_cursor": None},
|
||||
{
|
||||
"users": [{"id": "U2", "name": "user2"}],
|
||||
"not_found": ["username_not_found"],
|
||||
"available_users": [{"id": "U3", "name": "user3"}],
|
||||
},
|
||||
"users": [
|
||||
{"ok": True, "user": {"id": "U1", "team_id": "T123", "name": "user1"}},
|
||||
{"ok": True, "user": {"id": "U2", "team_id": "T123", "name": "user2"}},
|
||||
],
|
||||
},
|
||||
{
|
||||
"conversation": {
|
||||
"id": "C2",
|
||||
"type": ConversationTypeSlackName.IM.value,
|
||||
"name": "im-2",
|
||||
"is_channel": False,
|
||||
"is_im": True,
|
||||
"is_member": True,
|
||||
),
|
||||
# Email not found, usernames found
|
||||
(
|
||||
{
|
||||
"users": [{"id": "U1", "name": "user1"}],
|
||||
"not_found": ["email_not_found"],
|
||||
},
|
||||
"members": {
|
||||
"ok": True,
|
||||
"members": ["U2", "U3"],
|
||||
"response_metadata": {"next_cursor": None},
|
||||
{
|
||||
"users": [{"id": "U2", "name": "user2"}],
|
||||
"not_found": [],
|
||||
},
|
||||
"users": [
|
||||
{"ok": True, "user": {"id": "U2", "team_id": "T123", "name": "user2"}},
|
||||
{"ok": True, "user": {"id": "U3", "team_id": "T123", "name": "user3"}},
|
||||
],
|
||||
},
|
||||
{
|
||||
"conversation": {
|
||||
"id": "C3",
|
||||
"type": ConversationTypeSlackName.PUBLIC_CHANNEL.value,
|
||||
"name": "general",
|
||||
"is_channel": True,
|
||||
"is_im": False,
|
||||
"is_member": True,
|
||||
),
|
||||
# Email found, username not found
|
||||
(
|
||||
{
|
||||
"users": [{"id": "U1", "name": "user1"}],
|
||||
"not_found": [],
|
||||
},
|
||||
"members": {
|
||||
"ok": True,
|
||||
"members": ["U1", "U2"],
|
||||
"response_metadata": {"next_cursor": None},
|
||||
{
|
||||
"users": [{"id": "U2", "name": "user2"}],
|
||||
"not_found": ["username_not_found"],
|
||||
"available_users": [{"id": "U3", "name": "user3"}],
|
||||
},
|
||||
"users": [
|
||||
{"ok": True, "user": {"id": "U1", "team_id": "T123", "name": "user1"}},
|
||||
{"ok": True, "user": {"id": "U2", "team_id": "T123", "name": "user2"}},
|
||||
],
|
||||
},
|
||||
{
|
||||
"conversation": {
|
||||
"id": "C4",
|
||||
"type": ConversationTypeSlackName.PUBLIC_CHANNEL.value,
|
||||
"name": "random",
|
||||
"is_channel": True,
|
||||
"is_im": False,
|
||||
"is_member": True,
|
||||
},
|
||||
"members": {
|
||||
"ok": True,
|
||||
"members": ["U1", "U2", "U3", "U4"],
|
||||
"response_metadata": {"next_cursor": None},
|
||||
},
|
||||
"users": [
|
||||
{"ok": True, "user": {"id": "U1", "team_id": "T123", "name": "user1"}},
|
||||
{"ok": True, "user": {"id": "U2", "team_id": "T123", "name": "user2"}},
|
||||
{"ok": True, "user": {"id": "U3", "team_id": "T123", "name": "user3"}},
|
||||
{"ok": True, "user": {"id": "U4", "team_id": "T123", "name": "user4"}},
|
||||
],
|
||||
},
|
||||
]
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_multiple_users_retrieval_response_not_found(
|
||||
users_by_email,
|
||||
users_by_username,
|
||||
mock_context,
|
||||
):
|
||||
with pytest.raises(RetryableToolError) as error:
|
||||
await build_multiple_users_retrieval_response(
|
||||
context=mock_context,
|
||||
users_responses=[users_by_email, users_by_username],
|
||||
)
|
||||
|
||||
conversations_listed = [
|
||||
conversation
|
||||
for conversation in conversations
|
||||
if conversation["conversation"]["type"] in conversation_types_slack_name_str
|
||||
]
|
||||
emails_not_found = users_by_email.get("not_found", [])
|
||||
usernames_not_found = users_by_username.get("not_found", [])
|
||||
available_users = users_by_username.get("available_users", [])
|
||||
|
||||
mock_chat_slack_client.conversations_list.return_value = {
|
||||
"ok": True,
|
||||
"channels": [conversation["conversation"] for conversation in conversations_listed],
|
||||
"response_metadata": {"next_cursor": None},
|
||||
}
|
||||
for email in emails_not_found:
|
||||
assert email in error.value.message
|
||||
for username in usernames_not_found:
|
||||
assert username in error.value.message
|
||||
for user in available_users:
|
||||
assert json.dumps(user) in error.value.additional_prompt_content
|
||||
|
||||
mock_chat_slack_client.conversations_members.side_effect = [
|
||||
conversation["members"] for conversation in conversations_listed
|
||||
]
|
||||
|
||||
mock_users_slack_client.users_info.side_effect = [
|
||||
user for conversation in conversations_listed for user in conversation["users"]
|
||||
]
|
||||
def test_is_valid_email():
|
||||
assert is_valid_email("test@example.com")
|
||||
assert is_valid_email("test+123@example.com")
|
||||
assert is_valid_email("test-123@example.com")
|
||||
assert is_valid_email("test_123@example.com")
|
||||
assert is_valid_email("test.123@example.com")
|
||||
assert is_valid_email("test123@example.com")
|
||||
assert is_valid_email("test@example.co")
|
||||
assert is_valid_email("test@example.com.co")
|
||||
assert not is_valid_email("test123@example")
|
||||
assert not is_valid_email("test@example")
|
||||
assert not is_valid_email("test@example.")
|
||||
assert not is_valid_email("test@.com")
|
||||
assert not is_valid_email("test@example.c")
|
||||
assert not is_valid_email("test@example.com.")
|
||||
assert not is_valid_email("test@example.com.c")
|
||||
|
||||
conversations_found = await retrieve_conversations_by_user_ids(
|
||||
list_conversations_func=list_conversations_metadata,
|
||||
get_members_in_conversation_func=get_members_in_conversation_by_id,
|
||||
context=context,
|
||||
conversation_types=conversation_types,
|
||||
user_ids=search_user_ids,
|
||||
exact_match=exact_match,
|
||||
limit=limit,
|
||||
next_cursor=None,
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gather_with_concurrency_limit():
|
||||
mock_func1 = AsyncMock()
|
||||
mock_func2 = AsyncMock()
|
||||
|
||||
caller1 = ConcurrencySafeCoroutineCaller(mock_func1, "arg1", "arg2", kwarg1="kwarg1")
|
||||
caller2 = ConcurrencySafeCoroutineCaller(mock_func2, "arg1", "arg2", kwarg1="kwarg1")
|
||||
|
||||
mock_semaphore = AsyncMock(spec=asyncio.Semaphore)
|
||||
|
||||
response = await gather_with_concurrency_limit(
|
||||
coroutine_callers=[caller1, caller2],
|
||||
semaphore=mock_semaphore,
|
||||
)
|
||||
|
||||
assert [conversation["id"] for conversation in conversations_found] == expected_conversation_ids
|
||||
response = tuple(response)
|
||||
|
||||
assert len(response) == 2
|
||||
assert response[0] == mock_func1.return_value
|
||||
assert response[1] == mock_func2.return_value
|
||||
|
||||
mock_func1.assert_awaited_once_with("arg1", "arg2", kwarg1="kwarg1")
|
||||
mock_func2.assert_awaited_once_with("arg1", "arg2", kwarg1="kwarg1")
|
||||
|
||||
assert mock_semaphore.__aenter__.await_count == 2
|
||||
assert mock_semaphore.__aexit__.await_count == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
(
|
||||
"search_user_ids, conversation_types, exact_match, limit, "
|
||||
"expected_conversation_ids, expected_conversations_list_calls"
|
||||
),
|
||||
[
|
||||
(
|
||||
["U1", "U2", "U3"],
|
||||
[ConversationType.MULTI_PERSON_DIRECT_MESSAGE],
|
||||
False,
|
||||
None,
|
||||
["C1", "C3"],
|
||||
2,
|
||||
),
|
||||
(
|
||||
["U1", "U2", "U3"],
|
||||
[ConversationType.MULTI_PERSON_DIRECT_MESSAGE],
|
||||
True,
|
||||
None,
|
||||
["C1"],
|
||||
2,
|
||||
),
|
||||
(["U1", "U2", "U99"], [ConversationType.MULTI_PERSON_DIRECT_MESSAGE], False, None, [], 2),
|
||||
(
|
||||
["U1", "U2"],
|
||||
[ConversationType.MULTI_PERSON_DIRECT_MESSAGE, ConversationType.PUBLIC_CHANNEL],
|
||||
False,
|
||||
None,
|
||||
["C1", "C3", "C4", "C6"],
|
||||
2,
|
||||
),
|
||||
(
|
||||
["U1", "U2"],
|
||||
[ConversationType.MULTI_PERSON_DIRECT_MESSAGE, ConversationType.PUBLIC_CHANNEL],
|
||||
False,
|
||||
1,
|
||||
["C1"],
|
||||
2,
|
||||
),
|
||||
(
|
||||
["U1", "U2"],
|
||||
[ConversationType.MULTI_PERSON_DIRECT_MESSAGE, ConversationType.PUBLIC_CHANNEL],
|
||||
False,
|
||||
3,
|
||||
["C1", "C3", "C4"],
|
||||
2,
|
||||
),
|
||||
(
|
||||
["U1", "U2"],
|
||||
[ConversationType.MULTI_PERSON_DIRECT_MESSAGE, ConversationType.PUBLIC_CHANNEL],
|
||||
True,
|
||||
None,
|
||||
["C4"],
|
||||
2,
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_retrieve_conversations_by_user_ids_with_pagination(
|
||||
mock_chat_slack_client,
|
||||
mock_users_slack_client,
|
||||
search_user_ids,
|
||||
conversation_types,
|
||||
exact_match,
|
||||
limit,
|
||||
expected_conversation_ids,
|
||||
expected_conversations_list_calls,
|
||||
async def test_populate_users_in_messages(
|
||||
mock_context,
|
||||
mock_user_retrieval_slack_client,
|
||||
dummy_message_factory,
|
||||
dummy_reaction_factory,
|
||||
dummy_user_factory,
|
||||
):
|
||||
context = MagicMock(spec=ToolContext)
|
||||
context.authorization = MagicMock()
|
||||
context.authorization.token = MagicMock()
|
||||
user1 = dummy_user_factory(id_="U1", name="user1")
|
||||
user2 = dummy_user_factory(id_="U2", name="user2")
|
||||
user3 = dummy_user_factory(id_="U3", name="user3")
|
||||
user4 = dummy_user_factory(id_="U4", name="user4")
|
||||
user5 = dummy_user_factory(id_="U5", name="user5")
|
||||
|
||||
conversation_types_slack_name_str = [
|
||||
convert_conversation_type_to_slack_name(conv_type).value
|
||||
for conv_type in conversation_types or ConversationType
|
||||
user1_short = {"id": user1["id"], "name": user1["name"]}
|
||||
user2_short = {"id": user2["id"], "name": user2["name"]}
|
||||
user3_short = {"id": user3["id"], "name": user3["name"]}
|
||||
user4_short = {"id": user4["id"], "name": user4["name"]}
|
||||
|
||||
user2_mention = f"<@{user2['name']} (id:{user2['id']})>"
|
||||
user5_mention = f"<@{user5['name']} (id:{user5['id']})>"
|
||||
|
||||
reactions = [
|
||||
dummy_reaction_factory(name="thumbsup", user_ids=[user1["id"], user2["id"]]),
|
||||
dummy_reaction_factory(name="partyparrot", user_ids=[user3["id"], user4["id"]]),
|
||||
]
|
||||
|
||||
conversations = [
|
||||
{
|
||||
"conversation": {
|
||||
"id": "C1",
|
||||
"type": ConversationTypeSlackName.MPIM.value,
|
||||
"name": "mpim-1",
|
||||
"is_channel": False,
|
||||
"is_im": False,
|
||||
"is_mpim": True,
|
||||
"is_member": True,
|
||||
},
|
||||
"members": {
|
||||
"ok": True,
|
||||
"members": ["U1", "U2", "U3"],
|
||||
"response_metadata": {"next_cursor": None},
|
||||
},
|
||||
"users": [
|
||||
{"ok": True, "user": {"id": "U1", "team_id": "T123", "name": "user1"}},
|
||||
{"ok": True, "user": {"id": "U2", "team_id": "T123", "name": "user2"}},
|
||||
{"ok": True, "user": {"id": "U3", "team_id": "T123", "name": "user3"}},
|
||||
],
|
||||
},
|
||||
{
|
||||
"conversation": {
|
||||
"id": "C2",
|
||||
"type": ConversationTypeSlackName.MPIM.value,
|
||||
"name": "mpim-2",
|
||||
"is_channel": False,
|
||||
"is_im": False,
|
||||
"is_mpim": True,
|
||||
"is_member": True,
|
||||
},
|
||||
"members": {
|
||||
"ok": True,
|
||||
"members": ["U2", "U3"],
|
||||
"response_metadata": {"next_cursor": None},
|
||||
},
|
||||
"users": [
|
||||
{"ok": True, "user": {"id": "U2", "team_id": "T123", "name": "user2"}},
|
||||
{"ok": True, "user": {"id": "U3", "team_id": "T123", "name": "user3"}},
|
||||
],
|
||||
},
|
||||
{
|
||||
"conversation": {
|
||||
"id": "C3",
|
||||
"type": ConversationTypeSlackName.MPIM.value,
|
||||
"name": "mpim-3",
|
||||
"is_channel": False,
|
||||
"is_im": False,
|
||||
"is_mpim": True,
|
||||
"is_member": True,
|
||||
},
|
||||
"members": {
|
||||
"ok": True,
|
||||
"members": ["U1", "U2", "U3", "U4"],
|
||||
"response_metadata": {"next_cursor": None},
|
||||
},
|
||||
"users": [
|
||||
{"ok": True, "user": {"id": "U1", "team_id": "T123", "name": "user1"}},
|
||||
{"ok": True, "user": {"id": "U2", "team_id": "T123", "name": "user2"}},
|
||||
{"ok": True, "user": {"id": "U3", "team_id": "T123", "name": "user3"}},
|
||||
{"ok": True, "user": {"id": "U4", "team_id": "T123", "name": "user4"}},
|
||||
],
|
||||
},
|
||||
{
|
||||
"conversation": {
|
||||
"id": "C4",
|
||||
"type": ConversationTypeSlackName.PUBLIC_CHANNEL.value,
|
||||
"name": "channel-4",
|
||||
"is_channel": True,
|
||||
"is_im": False,
|
||||
"is_member": True,
|
||||
},
|
||||
"members": {
|
||||
"ok": True,
|
||||
"members": ["U1", "U2"],
|
||||
"response_metadata": {"next_cursor": None},
|
||||
},
|
||||
"users": [
|
||||
{"ok": True, "user": {"id": "U1", "team_id": "T123", "name": "user1"}},
|
||||
{"ok": True, "user": {"id": "U2", "team_id": "T123", "name": "user2"}},
|
||||
],
|
||||
},
|
||||
{
|
||||
"conversation": {
|
||||
"id": "C5",
|
||||
"type": ConversationTypeSlackName.PUBLIC_CHANNEL.value,
|
||||
"name": "channel-5",
|
||||
"is_channel": True,
|
||||
"is_im": False,
|
||||
"is_member": True,
|
||||
},
|
||||
"members": {
|
||||
"ok": True,
|
||||
"members": ["U2", "U3", "U4"],
|
||||
"response_metadata": {"next_cursor": None},
|
||||
},
|
||||
"users": [
|
||||
{"ok": True, "user": {"id": "U2", "team_id": "T123", "name": "user2"}},
|
||||
{"ok": True, "user": {"id": "U3", "team_id": "T123", "name": "user3"}},
|
||||
{"ok": True, "user": {"id": "U4", "team_id": "T123", "name": "user4"}},
|
||||
],
|
||||
},
|
||||
{
|
||||
"conversation": {
|
||||
"id": "C6",
|
||||
"type": ConversationTypeSlackName.PUBLIC_CHANNEL.value,
|
||||
"name": "channel-6",
|
||||
"is_channel": True,
|
||||
"is_im": False,
|
||||
"is_member": True,
|
||||
},
|
||||
"members": {
|
||||
"ok": True,
|
||||
"members": ["U1", "U2", "U3", "U4"],
|
||||
"response_metadata": {"next_cursor": None},
|
||||
},
|
||||
"users": [
|
||||
{"ok": True, "user": {"id": "U1", "team_id": "T123", "name": "user1"}},
|
||||
{"ok": True, "user": {"id": "U2", "team_id": "T123", "name": "user2"}},
|
||||
{"ok": True, "user": {"id": "U3", "team_id": "T123", "name": "user3"}},
|
||||
{"ok": True, "user": {"id": "U4", "team_id": "T123", "name": "user4"}},
|
||||
],
|
||||
},
|
||||
messages = [
|
||||
dummy_message_factory(
|
||||
user_id=user1["id"],
|
||||
text=f"Hello <@{user2['id']}>",
|
||||
),
|
||||
dummy_message_factory(
|
||||
user_id=user2["id"],
|
||||
text="foobar",
|
||||
reactions=copy.deepcopy(reactions[:1]),
|
||||
),
|
||||
dummy_message_factory(
|
||||
user_id=user3["id"],
|
||||
text=f"Is this @{user5['id']} a user mention?",
|
||||
),
|
||||
dummy_message_factory(
|
||||
user_id=user4["id"],
|
||||
text="hello",
|
||||
reactions=copy.deepcopy(reactions),
|
||||
),
|
||||
]
|
||||
|
||||
conversations_listed = [
|
||||
conversation
|
||||
for conversation in conversations
|
||||
if conversation["conversation"]["type"] in conversation_types_slack_name_str
|
||||
]
|
||||
|
||||
split_size = len(conversations_listed) // 2
|
||||
|
||||
conversations_listed_1 = conversations_listed[:split_size]
|
||||
conversations_listed_2 = conversations_listed[split_size:]
|
||||
|
||||
mock_chat_slack_client.conversations_list.side_effect = [
|
||||
mock_user_retrieval_slack_client.users_list.side_effect = [
|
||||
{
|
||||
"ok": True,
|
||||
"channels": [conversation["conversation"] for conversation in conversations_listed_1],
|
||||
"response_metadata": {"next_cursor": "cursor_1"},
|
||||
"members": [user1, user2, user3],
|
||||
"response_metadata": {"next_cursor": "cursor1"},
|
||||
},
|
||||
{
|
||||
"ok": True,
|
||||
"channels": [conversation["conversation"] for conversation in conversations_listed_2],
|
||||
"members": [user4, user5],
|
||||
"response_metadata": {"next_cursor": None},
|
||||
},
|
||||
]
|
||||
|
||||
mock_chat_slack_client.conversations_members.side_effect = [
|
||||
conversation["members"] for conversation in conversations_listed
|
||||
]
|
||||
|
||||
mock_users_slack_client.users_info.side_effect = [
|
||||
user for conversation in conversations_listed for user in conversation["users"]
|
||||
]
|
||||
|
||||
conversations_found = await retrieve_conversations_by_user_ids(
|
||||
list_conversations_func=list_conversations_metadata,
|
||||
get_members_in_conversation_func=get_members_in_conversation_by_id,
|
||||
context=context,
|
||||
conversation_types=conversation_types,
|
||||
user_ids=search_user_ids,
|
||||
exact_match=exact_match,
|
||||
limit=limit,
|
||||
next_cursor=None,
|
||||
response = await populate_users_in_messages(
|
||||
auth_token=mock_context.get_auth_token_or_empty(),
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
assert [conversation["id"] for conversation in conversations_found] == expected_conversation_ids
|
||||
assert mock_chat_slack_client.conversations_list.call_count == expected_conversations_list_calls
|
||||
msg1 = response[0]
|
||||
msg2 = response[1]
|
||||
msg3 = response[2]
|
||||
msg4 = response[3]
|
||||
|
||||
assert msg1["user"] == user1_short
|
||||
assert msg1["text"] == f"Hello {user2_mention}"
|
||||
assert "reactions" not in msg1
|
||||
|
||||
assert msg2["user"] == user2_short
|
||||
assert msg2["text"] == "foobar"
|
||||
assert "reactions" in msg2
|
||||
assert len(msg2["reactions"]) == 1
|
||||
assert msg2["reactions"][0]["name"] == "thumbsup"
|
||||
assert msg2["reactions"][0]["users"] == [user1_short, user2_short]
|
||||
|
||||
assert msg3["user"] == user3_short
|
||||
assert msg3["text"] == f"Is this @{user5['id']} a user mention?"
|
||||
assert "reactions" not in msg3
|
||||
assert user5_mention not in msg3["text"]
|
||||
|
||||
assert msg4["user"] == user4_short
|
||||
assert msg4["text"] == "hello"
|
||||
assert "reactions" in msg4
|
||||
assert len(msg4["reactions"]) == 2
|
||||
assert msg4["reactions"][0]["name"] == "thumbsup"
|
||||
assert msg4["reactions"][0]["users"] == [user1_short, user2_short]
|
||||
assert msg4["reactions"][1]["name"] == "partyparrot"
|
||||
assert msg4["reactions"][1]["users"] == [user3_short, user4_short]
|
||||
|
|
|
|||
Loading…
Reference in a new issue