diff --git a/toolkits/jira/arcade_jira/cache.py b/toolkits/jira/arcade_jira/cache.py deleted file mode 100644 index 53ea304f..00000000 --- a/toolkits/jira/arcade_jira/cache.py +++ /dev/null @@ -1,105 +0,0 @@ -import asyncio -from collections import OrderedDict -from threading import Lock -from typing import Generic, TypeVar - -from arcade_jira.constants import JIRA_CACHE_MAX_ITEMS - -T = TypeVar("T") - - -class LRUCache(Generic[T]): - def __init__(self, max_size: int): - self.cache: OrderedDict[str, T] = OrderedDict() - self.max_size = max_size - self.thread_lock = Lock() - self.async_lock = asyncio.Lock() - - # Thread-safe synchronous methods - def get(self, key: str) -> T | None: - with self.thread_lock: - if key not in self.cache: - return None - - value = self.cache.pop(key) - self.cache[key] = value - return value - - def set(self, key: str, value: T) -> None: - with self.thread_lock: - if key in self.cache: - self.cache.pop(key) - elif len(self.cache) >= self.max_size: - self.cache.popitem(last=False) - self.cache[key] = value - - # Async-safe methods - async def async_get(self, key: str) -> T | None: - async with self.async_lock: - if key not in self.cache: - return None - - value = self.cache.pop(key) - self.cache[key] = value - return value - - async def async_set(self, key: str, value: T) -> None: - async with self.async_lock: - if key in self.cache: - self.cache.pop(key) - elif len(self.cache) >= self.max_size: - self.cache.popitem(last=False) - self.cache[key] = value - - -CLOUD_ID_CACHE = LRUCache[str](max_size=JIRA_CACHE_MAX_ITEMS) -CLOUD_NAME_CACHE = LRUCache[str](max_size=JIRA_CACHE_MAX_ITEMS) -CLIENT_SEMAPHORE_CACHE = LRUCache[asyncio.Semaphore](max_size=JIRA_CACHE_MAX_ITEMS) - - -def get_cloud_id(auth_token: str) -> str | None: - return CLOUD_ID_CACHE.get(auth_token) - - -def get_cloud_name(auth_token: str) -> str | None: - return CLOUD_NAME_CACHE.get(auth_token) - - -def set_cloud_id(auth_token: str, cloud_id: str) -> None: - CLOUD_ID_CACHE.set(auth_token, cloud_id) - - -def set_cloud_name(auth_token: str, cloud_name: str) -> None: - CLOUD_NAME_CACHE.set(auth_token, cloud_name) - - -def get_jira_client_semaphore(auth_token: str) -> asyncio.Semaphore | None: - return CLIENT_SEMAPHORE_CACHE.get(auth_token) - - -def set_jira_client_semaphore(auth_token: str, semaphore: asyncio.Semaphore) -> None: - CLIENT_SEMAPHORE_CACHE.set(auth_token, semaphore) - - -async def async_get_cloud_id(auth_token: str) -> str | None: - return await CLOUD_ID_CACHE.async_get(auth_token) - - -async def async_get_cloud_name(auth_token: str) -> str | None: - return await CLOUD_NAME_CACHE.async_get(auth_token) - - -async def async_set_cloud_id(auth_token: str, cloud_id: str) -> None: - await CLOUD_ID_CACHE.async_set(auth_token, cloud_id) - - -async def async_set_cloud_name(auth_token: str, cloud_name: str) -> None: - await CLOUD_NAME_CACHE.async_set(auth_token, cloud_name) - - -async def async_get_jira_client_semaphore(auth_token: str) -> asyncio.Semaphore | None: - return await CLIENT_SEMAPHORE_CACHE.async_get(auth_token) - - -async def async_set_jira_client_semaphore(auth_token: str, semaphore: asyncio.Semaphore) -> None: - await CLIENT_SEMAPHORE_CACHE.async_set(auth_token, semaphore) diff --git a/toolkits/jira/arcade_jira/client.py b/toolkits/jira/arcade_jira/client.py index 03475d2f..dc52150f 100644 --- a/toolkits/jira/arcade_jira/client.py +++ b/toolkits/jira/arcade_jira/client.py @@ -2,83 +2,51 @@ import asyncio import json import json.decoder from dataclasses import dataclass -from typing import Any, cast +from typing import cast import httpx +from arcade_tdk import ToolContext +from arcade_tdk.errors import ToolExecutionError -import arcade_jira.cache as cache from arcade_jira.constants import JIRA_API_VERSION, JIRA_BASE_URL, JIRA_MAX_CONCURRENT_REQUESTS -from arcade_jira.exceptions import JiraToolExecutionError, NotFoundError +from arcade_jira.exceptions import NotFoundError @dataclass class JiraClient: - auth_token: str + context: ToolContext + cloud_id: str | None base_url: str = JIRA_BASE_URL api_version: str = JIRA_API_VERSION max_concurrent_requests: int = JIRA_MAX_CONCURRENT_REQUESTS _semaphore: asyncio.Semaphore | None = None - _cloud_id: str | None = None + + @property + def auth_token(self) -> str | None: + return self.context.get_auth_token_or_empty() def __post_init__(self) -> None: if not self._semaphore: - cached_semaphore = cache.get_jira_client_semaphore(self.auth_token) + cached_semaphore = getattr(self.context, "_global_jira_client_semaphore", None) + # If a semaphore was already cached in the context, we use it. Some tools + # may call other tools. Each tool will instantiate its own JiraClient. + # This is necessary to ensure that all instances will respect the + # concurrency limit. if cached_semaphore: self._semaphore = cached_semaphore else: self._semaphore = asyncio.Semaphore(self.max_concurrent_requests) - cache.set_jira_client_semaphore(self.auth_token, self._semaphore) + self.context._global_jira_client_semaphore = self._semaphore # type: ignore[attr-defined] self.base_url = self.base_url.rstrip("/") self.api_version = self.api_version.strip("/") - async def get_cloud_id(self) -> str: - if self._cloud_id is None: - if (cloud_id := await cache.async_get_cloud_id(self.auth_token)) is not None: - self._cloud_id = cloud_id - else: - cloud = await self._get_cloud_data_from_available_resources() - self._cloud_id = cloud["id"] - await cache.async_set_cloud_id(self.auth_token, cloud["id"]) - await cache.async_set_cloud_name(self.auth_token, cloud["name"]) - - return self._cloud_id - async def _build_url(self, endpoint: str) -> str: - cloud_id = await self.get_cloud_id() - return f"{self.base_url}/{cloud_id}/rest/api/{self.api_version}/{endpoint.lstrip('/')}" - - async def _get_cloud_data_from_available_resources(self) -> dict[str, Any]: - async with httpx.AsyncClient() as client: - response = await client.get( - "https://api.atlassian.com/oauth/token/accessible-resources", - headers={"Authorization": f"Bearer {self.auth_token}"}, - ) - - available_resources = deduplicate_available_resources(response.json()) - - if len(available_resources) == 0: - raise JiraToolExecutionError( - message="No cloud ID returned by Atlassian, cannot make API calls" - ) - if len(available_resources) > 1: - cloud_ids_found = json.dumps([ - { - "id": resource["id"], - "name": resource["name"], - "url": resource["url"], - } - for resource in available_resources - ]) - raise JiraToolExecutionError( - message=( - "Multiple cloud IDs returned by Atlassian, cannot resolve which one " - "to use. Please revoke your authorization access and authorize a single " - f"Atlassian Cloud. Available cloud IDs: {cloud_ids_found}. " - ) - ) - return cast(dict[str, Any], available_resources[0]) + return ( + f"{self.base_url}/{self.cloud_id or ''}" + f"/rest/api/{self.api_version}/{endpoint.lstrip('/')}" + ) def _build_error_messages(self, response: httpx.Response) -> tuple[str, str | None]: try: @@ -108,7 +76,7 @@ class JiraClient: return error_message, developer_message - def _raise_for_status(self, response: httpx.Response) -> None: + async def _raise_for_status(self, response: httpx.Response) -> None: if response.status_code < 300: return @@ -117,7 +85,7 @@ class JiraClient: if response.status_code == 404: raise NotFoundError(error_message, developer_message) - raise JiraToolExecutionError(error_message, developer_message) + raise ToolExecutionError(error_message, developer_message) def _set_request_body(self, kwargs: dict, data: dict | None, json_data: dict | None) -> dict: if data and json_data: @@ -159,7 +127,7 @@ class JiraClient: async with self._semaphore, httpx.AsyncClient() as client: # type: ignore[union-attr] response = await client.get(**kwargs) # type: ignore[arg-type] - self._raise_for_status(response) + await self._raise_for_status(response) return self._format_response_dict(response) @@ -195,7 +163,7 @@ class JiraClient: async with self._semaphore, httpx.AsyncClient() as client: # type: ignore[union-attr] response = await client.post(**kwargs) # type: ignore[arg-type] - self._raise_for_status(response) + await self._raise_for_status(response) return self._format_response_dict(response) @@ -224,18 +192,6 @@ class JiraClient: async with self._semaphore, httpx.AsyncClient() as client: # type: ignore[union-attr] response = await client.put(**kwargs) # type: ignore[arg-type] - self._raise_for_status(response) + await self._raise_for_status(response) return self._format_response_dict(response) - - -def deduplicate_available_resources(available_resources: list[dict]) -> list[dict]: - account_ids_seen = set() - deduplicated = [] - - for item in available_resources: - if item["id"] not in account_ids_seen: - deduplicated.append(item) - account_ids_seen.add(item["id"]) - - return deduplicated diff --git a/toolkits/jira/arcade_jira/constants.py b/toolkits/jira/arcade_jira/constants.py index b44bd6e6..2d30f168 100644 --- a/toolkits/jira/arcade_jira/constants.py +++ b/toolkits/jira/arcade_jira/constants.py @@ -14,11 +14,6 @@ try: except Exception: JIRA_API_REQUEST_TIMEOUT = 30 -try: - JIRA_CACHE_MAX_ITEMS = max(1, int(os.getenv("JIRA_CACHE_MAX_ITEMS", 5000))) -except Exception: - JIRA_CACHE_MAX_ITEMS = 5000 - STOP_WORDS = [ "a", diff --git a/toolkits/jira/arcade_jira/tools/__init__.py b/toolkits/jira/arcade_jira/tools/__init__.py index 6d0a2b83..c7bebed7 100644 --- a/toolkits/jira/arcade_jira/tools/__init__.py +++ b/toolkits/jira/arcade_jira/tools/__init__.py @@ -4,6 +4,7 @@ from arcade_jira.tools.attachments import ( get_attachment_metadata, list_issue_attachments_metadata, ) +from arcade_jira.tools.cloud import get_available_atlassian_clouds from arcade_jira.tools.comments import ( add_comment_to_issue, get_comment_by_id, @@ -43,6 +44,8 @@ __all__ = [ "download_attachment", "get_attachment_metadata", "list_issue_attachments_metadata", + # Cloud tools + "get_available_atlassian_clouds", # Comments tools "add_comment_to_issue", "get_comment_by_id", diff --git a/toolkits/jira/arcade_jira/tools/attachments.py b/toolkits/jira/arcade_jira/tools/attachments.py index 7bddf0c5..aefb72ed 100644 --- a/toolkits/jira/arcade_jira/tools/attachments.py +++ b/toolkits/jira/arcade_jira/tools/attachments.py @@ -4,10 +4,9 @@ from arcade_tdk import ToolContext, tool from arcade_tdk.auth import Atlassian from arcade_tdk.errors import ToolExecutionError -import arcade_jira.cache as cache from arcade_jira.client import JiraClient from arcade_jira.exceptions import NotFoundError -from arcade_jira.utils import build_file_data, clean_attachment_dict +from arcade_jira.utils import build_file_data, clean_attachment_dict, resolve_cloud_id @tool(requires_auth=Atlassian(scopes=["write:jira-work"])) @@ -40,11 +39,17 @@ async def attach_file_to_issue( "If the filename is not recognized, it will attach the file without specifying a type. " "Defaults to None (infer from filename or attach without type).", ] = None, + atlassian_cloud_id: Annotated[ + str | None, + "The ID of the Atlassian Cloud to use (defaults to None). If not provided and the user has " + "a single cloud authorized, the tool will use that. Otherwise, an error will be raised.", + ] = None, ) -> Annotated[dict[str, Any], "Metadata about the attachment"]: """Add an attachment to an issue. Must provide exactly one of file_content_str or file_content_base64. """ + atlassian_cloud_id = await resolve_cloud_id(context, atlassian_cloud_id) file_contents = [file_content_str, file_content_base64] if not any(file_contents) or all(file_contents): @@ -55,7 +60,7 @@ async def attach_file_to_issue( if not filename: raise ToolExecutionError(message="Must provide a filename.") - client = JiraClient(context.get_auth_token_or_empty()) + client = JiraClient(context=context, cloud_id=atlassian_cloud_id) response = await client.post( f"/issue/{issue}/attachments", @@ -70,13 +75,13 @@ async def attach_file_to_issue( file_encoding=file_encoding, ), ) - cloud_name = cache.get_cloud_name(context.get_auth_token_or_empty()) + return { "status": { "success": True, "message": f"Attachment '{filename}' successfully added to the issue '{issue}'", }, - "attachment": clean_attachment_dict(response[0], cloud_name), + "attachment": clean_attachment_dict(response[0]), } @@ -84,6 +89,11 @@ async def attach_file_to_issue( async def list_issue_attachments_metadata( context: ToolContext, issue: Annotated[str, "The ID or key of the issue to retrieve"], + atlassian_cloud_id: Annotated[ + str | None, + "The ID of the Atlassian Cloud to use (defaults to None). If not provided and the user has " + "a single cloud authorized, the tool will use that. Otherwise, an error will be raised.", + ] = None, ) -> Annotated[dict, "Information about the issue"]: """Get the metadata about the files attached to an issue. @@ -92,7 +102,11 @@ async def list_issue_attachments_metadata( """ from arcade_jira.tools.issues import get_issue_by_id # Avoid circular imports - response = await get_issue_by_id(context, issue) + response = await get_issue_by_id( + context=context, + issue=issue, + atlassian_cloud_id=atlassian_cloud_id, + ) if response.get("error"): return cast(dict, response) return { @@ -108,26 +122,42 @@ async def list_issue_attachments_metadata( async def get_attachment_metadata( context: ToolContext, attachment_id: Annotated[str, "The ID of the attachment to retrieve"], + atlassian_cloud_id: Annotated[ + str | None, + "The ID of the Atlassian Cloud to use (defaults to None). If not provided and the user has " + "a single cloud authorized, the tool will use that. Otherwise, an error will be raised.", + ] = None, ) -> Annotated[dict[str, Any], "The metadata of the attachment"]: """Get the metadata of an attachment.""" - client = JiraClient(context.get_auth_token_or_empty()) + atlassian_cloud_id = await resolve_cloud_id(context, atlassian_cloud_id) + client = JiraClient(context=context, cloud_id=atlassian_cloud_id) try: response = await client.get(f"/attachment/{attachment_id}") except NotFoundError: return {"error": f"Attachment not found with ID '{attachment_id}'."} - cloud_name = cache.get_cloud_name(context.get_auth_token_or_empty()) - return {"attachment": clean_attachment_dict(response, cloud_name)} + + return {"attachment": clean_attachment_dict(response)} @tool(requires_auth=Atlassian(scopes=["read:jira-work"])) async def download_attachment( context: ToolContext, attachment_id: Annotated[str, "The ID of the attachment to download"], + atlassian_cloud_id: Annotated[ + str | None, + "The ID of the Atlassian Cloud to use (defaults to None). If not provided and the user has " + "a single cloud authorized, the tool will use that. Otherwise, an error will be raised.", + ] = None, ) -> Annotated[dict[str, Any], "The content of the attachment"]: """Download the contents of an attachment associated with an issue.""" - client = JiraClient(context.get_auth_token_or_empty()) + atlassian_cloud_id = await resolve_cloud_id(context, atlassian_cloud_id) + client = JiraClient(context=context, cloud_id=atlassian_cloud_id) - attachment = await get_attachment_metadata(context, attachment_id) + attachment = await get_attachment_metadata( + context=context, + attachment_id=attachment_id, + atlassian_cloud_id=atlassian_cloud_id, + ) if attachment.get("error"): return cast(dict, attachment) diff --git a/toolkits/jira/arcade_jira/tools/cloud.py b/toolkits/jira/arcade_jira/tools/cloud.py new file mode 100644 index 00000000..b54fd675 --- /dev/null +++ b/toolkits/jira/arcade_jira/tools/cloud.py @@ -0,0 +1,46 @@ +import asyncio +from typing import Annotated + +import httpx +from arcade_tdk import ToolContext, tool +from arcade_tdk.auth import Atlassian + +from arcade_jira.constants import JIRA_MAX_CONCURRENT_REQUESTS +from arcade_jira.utils import check_if_cloud_is_authorized + + +@tool(requires_auth=Atlassian(scopes=["read:jira-user"])) +async def get_available_atlassian_clouds( + context: ToolContext, +) -> Annotated[dict[str, list[dict[str, str]]], "Available Atlassian Clouds"]: + """Get available Atlassian Clouds.""" + async with httpx.AsyncClient() as client: + response = await client.get( + "https://api.atlassian.com/oauth/token/accessible-resources", + headers={"Authorization": f"Bearer {context.get_auth_token_or_empty()}"}, + ) + + verified_clouds = response.json() + cloud_ids_seen = set() + unique_clouds = [] + + for cloud in verified_clouds: + if cloud["id"] not in cloud_ids_seen: + unique_clouds.append({ + "atlassian_cloud_id": cloud["id"], + "atlassian_cloud_name": cloud["name"], + "atlassian_cloud_url": cloud["url"], + }) + cloud_ids_seen.add(cloud["id"]) + + semaphore = asyncio.Semaphore(JIRA_MAX_CONCURRENT_REQUESTS) + + verified_clouds = await asyncio.gather(*[ + check_if_cloud_is_authorized(context, cloud, semaphore) for cloud in unique_clouds + ]) + + return { + "clouds_available": [ + cloud_available for cloud_available in verified_clouds if cloud_available is not False + ] + } diff --git a/toolkits/jira/arcade_jira/tools/comments.py b/toolkits/jira/arcade_jira/tools/comments.py index c687e927..94940e15 100644 --- a/toolkits/jira/arcade_jira/tools/comments.py +++ b/toolkits/jira/arcade_jira/tools/comments.py @@ -13,6 +13,7 @@ from arcade_jira.utils import ( clean_comment_dict, find_multiple_unique_users, remove_none_values, + resolve_cloud_id, ) @@ -26,9 +27,15 @@ async def get_comment_by_id( "Whether to include the ADF (Atlassian Document Format) content of the comment in the " "response. Defaults to False (return only the HTML rendered content).", ] = False, + atlassian_cloud_id: Annotated[ + str | None, + "The ID of the Atlassian Cloud to use (defaults to None). If not provided and the user has " + "a single cloud authorized, the tool will use that. Otherwise, an error will be raised.", + ] = None, ) -> Annotated[dict[str, Any], "Information about the comment"]: """Get a comment by its ID.""" - client = JiraClient(context.get_auth_token_or_empty()) + atlassian_cloud_id = await resolve_cloud_id(context, atlassian_cloud_id) + client = JiraClient(context=context, cloud_id=atlassian_cloud_id) response = await client.get( f"issue/{issue_id}/comment/{comment_id}", params={"expand": "renderedBody"}, @@ -66,10 +73,16 @@ async def get_issue_comments( "Whether to include the ADF (Atlassian Document Format) content of the comment in the " "response. Defaults to False (return only the HTML rendered content).", ] = False, + atlassian_cloud_id: Annotated[ + str | None, + "The ID of the Atlassian Cloud to use (defaults to None). If not provided and the user has " + "a single cloud authorized, the tool will use that. Otherwise, an error will be raised.", + ] = None, ) -> Annotated[dict[str, Any], "Information about the issue comments"]: """Get the comments of a Jira issue by its ID.""" limit = max(min(limit, 100), 1) - client = JiraClient(context.get_auth_token_or_empty()) + atlassian_cloud_id = await resolve_cloud_id(context, atlassian_cloud_id) + client = JiraClient(context=context, cloud_id=atlassian_cloud_id) api_response = await client.get( f"issue/{issue}/comment", params=remove_none_values({ @@ -114,18 +127,29 @@ async def add_comment_to_issue( "The users to mention in the comment. Provide the user display name, email address, or ID. " "Ex: 'John Doe' or 'john.doe@example.com'. Defaults to None (no user mentions).", ] = None, + atlassian_cloud_id: Annotated[ + str | None, + "The ID of the Atlassian Cloud to use (defaults to None). If not provided and the user has " + "a single cloud authorized, the tool will use that. Otherwise, an error will be raised.", + ] = None, ) -> Annotated[dict[str, Any], "Information about the comment created"]: """Add a comment to a Jira issue.""" if not body: raise ToolExecutionError(message="Comment body cannot be empty.") - client = JiraClient(context.get_auth_token_or_empty()) + atlassian_cloud_id = await resolve_cloud_id(context, atlassian_cloud_id) + client = JiraClient(context=context, cloud_id=atlassian_cloud_id) adf_body = build_adf_doc(body) if mention_users: try: - users = await find_multiple_unique_users(context, mention_users, exact_match=True) + users = await find_multiple_unique_users( + context=context, + user_identifiers=mention_users, + exact_match=True, + atlassian_cloud_id=atlassian_cloud_id, + ) except (NotFoundError, MultipleItemsFoundError) as exc: return {"error": f"Failed to mention user: {exc.message}"} mentions = [ @@ -138,7 +162,13 @@ async def add_comment_to_issue( adf_body["content"][0]["content"] = mentions + adf_body["content"][0]["content"] if reply_to_comment: - quote_comment = await get_comment_by_id(context, issue, reply_to_comment, True) + quote_comment = await get_comment_by_id( + context=context, + issue_id=issue, + comment_id=reply_to_comment, + include_adf_content=True, + atlassian_cloud_id=atlassian_cloud_id, + ) if not quote_comment["comment"]: raise ToolExecutionError( message=f"Cannot quote comment. No comment found with ID '{reply_to_comment}'." diff --git a/toolkits/jira/arcade_jira/tools/issues.py b/toolkits/jira/arcade_jira/tools/issues.py index 6cbe4785..89670873 100644 --- a/toolkits/jira/arcade_jira/tools/issues.py +++ b/toolkits/jira/arcade_jira/tools/issues.py @@ -3,7 +3,6 @@ from typing import Annotated, Any, cast from arcade_tdk import ToolContext, tool from arcade_tdk.auth import Atlassian -import arcade_jira.cache as cache from arcade_jira.client import JiraClient from arcade_jira.exceptions import JiraToolExecutionError, MultipleItemsFoundError, NotFoundError from arcade_jira.utils import ( @@ -19,6 +18,7 @@ from arcade_jira.utils import ( find_unique_project, get_single_project, remove_none_values, + resolve_cloud_id, resolve_issue_users, validate_issue_args, ) @@ -41,15 +41,25 @@ async def list_issue_types_by_project( int, "The number of issue types to skip. Defaults to 0 (start from the first issue type).", ] = 0, + atlassian_cloud_id: Annotated[ + str | None, + "The ID of the Atlassian Cloud to use (defaults to None). If not provided and the user has " + "a single cloud authorized, the tool will use that. Otherwise, an error will be raised.", + ] = None, ) -> Annotated[ dict[str, Any], "Information about the issue types available for the specified project." ]: """Get the list of issue types (e.g. 'Task', 'Epic', etc.) available to a given project.""" limit = max(1, min(limit, 200)) - client = JiraClient(context.get_auth_token_or_empty()) + atlassian_cloud_id = await resolve_cloud_id(context, atlassian_cloud_id) + client = JiraClient(context=context, cloud_id=atlassian_cloud_id) try: - project_data = await find_unique_project(context, project) + project_data = await find_unique_project( + context=context, + project_identifier=project, + atlassian_cloud_id=atlassian_cloud_id, + ) except JiraToolExecutionError as error: return {"error": error.message} @@ -79,9 +89,15 @@ async def list_issue_types_by_project( async def get_issue_type_by_id( context: ToolContext, issue_type_id: Annotated[str, "The ID of the issue type to retrieve"], + atlassian_cloud_id: Annotated[ + str | None, + "The ID of the Atlassian Cloud to use (defaults to None). If not provided and the user has " + "a single cloud authorized, the tool will use that. Otherwise, an error will be raised.", + ] = None, ) -> Annotated[dict, "Information about the issue type"]: """Get the details of a Jira issue type by its ID.""" - client = JiraClient(context.get_auth_token_or_empty()) + atlassian_cloud_id = await resolve_cloud_id(context, atlassian_cloud_id) + client = JiraClient(context=context, cloud_id=atlassian_cloud_id) try: response = await client.get(f"issuetype/{issue_type_id}") except NotFoundError: @@ -93,9 +109,15 @@ async def get_issue_type_by_id( async def get_issue_by_id( context: ToolContext, issue: Annotated[str, "The ID or key of the issue to retrieve"], + atlassian_cloud_id: Annotated[ + str | None, + "The ID of the Atlassian Cloud to use (defaults to None). If not provided and the user has " + "a single cloud authorized, the tool will use that. Otherwise, an error will be raised.", + ] = None, ) -> Annotated[dict[str, Any], "Information about the issue"]: """Get the details of a Jira issue by its ID.""" - client = JiraClient(context.get_auth_token_or_empty()) + atlassian_cloud_id = await resolve_cloud_id(context, atlassian_cloud_id) + client = JiraClient(context=context, cloud_id=atlassian_cloud_id) try: response = await client.get( f"issue/{issue}", @@ -104,8 +126,7 @@ async def get_issue_by_id( except NotFoundError: return {"error": f"Issue not found with ID/key '{issue}'."} - cloud_name = cache.get_cloud_name(context.get_auth_token_or_empty()) - return {"issue": clean_issue_dict(response, cloud_name)} + return {"issue": clean_issue_dict(response)} # NOTE: This is not named `search_issues` because sometimes LLM's won't realize they can @@ -183,6 +204,11 @@ async def get_issues_without_id( str | None, "The token to use to get the next page of issues. Defaults to None (first page).", ] = None, + atlassian_cloud_id: Annotated[ + str | None, + "The ID of the Atlassian Cloud to use (defaults to None). If not provided and the user has " + "a single cloud authorized, the tool will use that. Otherwise, an error will be raised.", + ] = None, ) -> Annotated[dict[str, Any], "Information about the issues matching the search criteria"]: """Search for Jira issues when you don't have the issue ID(s). @@ -193,7 +219,8 @@ async def get_issues_without_id( """ limit = max(1, min(limit, 100)) - client = JiraClient(context.get_auth_token_or_empty()) + atlassian_cloud_id = await resolve_cloud_id(context, atlassian_cloud_id) + client = JiraClient(context=context, cloud_id=atlassian_cloud_id) due_from_date = convert_date_string_to_date(due_from) if due_from else None due_until_date = convert_date_string_to_date(due_until) if due_until else None @@ -233,10 +260,8 @@ async def get_issues_without_id( if response.get("nextPageToken"): pagination["next_page_token"] = response["nextPageToken"] - cloud_name = cache.get_cloud_name(context.get_auth_token_or_empty()) - return { - "issues": [clean_issue_dict(issue, cloud_name) for issue in response["issues"]], + "issues": [clean_issue_dict(issue) for issue in response["issues"]], "pagination": pagination, } @@ -267,10 +292,18 @@ async def list_issues( str | None, "The token to use to get the next page of issues. Defaults to None (first page).", ] = None, + atlassian_cloud_id: Annotated[ + str | None, + "The ID of the Atlassian Cloud to use (defaults to None). If not provided and the user has " + "a single cloud authorized, the tool will use that. Otherwise, an error will be raised.", + ] = None, ) -> Annotated[dict[str, Any], "Information about the issues matching the search criteria"]: """Get the issues for a given project.""" if not project: - project_data = await get_single_project(context) + project_data = await get_single_project( + context=context, + atlassian_cloud_id=atlassian_cloud_id, + ) project = project_data["id"] return cast( @@ -280,6 +313,7 @@ async def list_issues( project=project, limit=limit, next_page_token=next_page_token, + atlassian_cloud_id=atlassian_cloud_id, ), ) @@ -358,6 +392,11 @@ async def search_issues_without_jql( str | None, "The token to use to get the next page of issues. Defaults to None (first page).", ] = None, + atlassian_cloud_id: Annotated[ + str | None, + "The ID of the Atlassian Cloud to use (defaults to None). If not provided and the user has " + "a single cloud authorized, the tool will use that. Otherwise, an error will be raised.", + ] = None, ) -> Annotated[dict[str, Any], "Information about the issues matching the search criteria"]: """Parameterized search for Jira issues (without having to provide a JQL query). @@ -381,6 +420,7 @@ async def search_issues_without_jql( parent_issue=parent_issue, limit=limit, next_page_token=next_page_token, + atlassian_cloud_id=atlassian_cloud_id, ), ) @@ -397,6 +437,11 @@ async def search_issues_with_jql( str | None, "The token to use to get the next page of issues. Defaults to None (first page).", ] = None, + atlassian_cloud_id: Annotated[ + str | None, + "The ID of the Atlassian Cloud to use (defaults to None). If not provided and the user has " + "a single cloud authorized, the tool will use that. Otherwise, an error will be raised.", + ] = None, ) -> Annotated[dict[str, Any], "Information about the issues matching the search criteria"]: """Search for Jira issues using a JQL (Jira Query Language) query. @@ -406,7 +451,8 @@ async def search_issues_with_jql( `Jira_SearchIssuesWithoutJql` TOOL OR IF THE USER PROVIDES A JQL QUERY THEMSELVES. """ limit = max(1, min(limit, 100)) - client = JiraClient(context.get_auth_token_or_empty()) + atlassian_cloud_id = await resolve_cloud_id(context, atlassian_cloud_id) + client = JiraClient(context=context, cloud_id=atlassian_cloud_id) api_response = await client.post( "search/jql", json_data={ @@ -417,9 +463,9 @@ async def search_issues_with_jql( "expand": "renderedFields", }, ) - cloud_name = cache.get_cloud_name(context.get_auth_token_or_empty()) + response: dict[str, Any] = { - "issues": [clean_issue_dict(issue, cloud_name) for issue in api_response["issues"]] + "issues": [clean_issue_dict(issue) for issue in api_response["issues"]] } if api_response.get("isLast") is not False and api_response.get("nextPageToken"): @@ -504,6 +550,11 @@ async def create_issue( "provided, the tool will try to find a unique exact match among the available users. " "Defaults to None (no reporter).", ] = None, + atlassian_cloud_id: Annotated[ + str | None, + "The ID of the Atlassian Cloud to use (defaults to None). If not provided and the user has " + "a single cloud authorized, the tool will use that. Otherwise, an error will be raised.", + ] = None, ) -> Annotated[dict, "The created issue"]: """Create a new Jira issue. @@ -522,9 +573,14 @@ async def create_issue( """ project_data: dict[str, Any] | None = None + atlassian_cloud_id = await resolve_cloud_id(context, atlassian_cloud_id) + if project is None and parent_issue is None: try: - project_data = await get_single_project(context) + project_data = await get_single_project( + context=context, + atlassian_cloud_id=atlassian_cloud_id, + ) except (NotFoundError, MultipleItemsFoundError) as exc: return {"error": str(exc)} else: @@ -536,15 +592,28 @@ async def create_issue( issue_type_data, priority_data, parent_data, - ) = await validate_issue_args(context, due_date, project, issue_type, priority, parent_issue) + ) = await validate_issue_args( + context=context, + due_date=due_date, + project=project, + issue_type=issue_type, + priority=priority, + parent_issue=parent_issue, + atlassian_cloud_id=atlassian_cloud_id, + ) if error: return error - error, assignee_data, reporter_data = await resolve_issue_users(context, assignee, reporter) + error, assignee_data, reporter_data = await resolve_issue_users( + context=context, + assignee=assignee, + reporter=reporter, + atlassian_cloud_id=atlassian_cloud_id, + ) if error: return error - client = JiraClient(context.get_auth_token_or_empty()) + client = JiraClient(context=context, cloud_id=atlassian_cloud_id) request_body = { "fields": remove_none_values({ @@ -576,7 +645,6 @@ async def create_issue( "issue": { "id": response["id"], "key": response["key"], - "url": response["self"], }, } @@ -603,9 +671,19 @@ async def add_labels_to_issue( bool, "Whether to notify the issue's watchers. Defaults to True (notifies watchers).", ] = True, + atlassian_cloud_id: Annotated[ + str | None, + "The ID of the Atlassian Cloud to use (defaults to None). If not provided and the user has " + "a single cloud authorized, the tool will use that. Otherwise, an error will be raised.", + ] = None, ) -> Annotated[dict, "The updated issue"]: """Add labels to an existing Jira issue.""" - issue_data = await get_issue_by_id(context, issue) + atlassian_cloud_id = await resolve_cloud_id(context, atlassian_cloud_id) + issue_data = await get_issue_by_id( + context=context, + issue=issue, + atlassian_cloud_id=atlassian_cloud_id, + ) if issue_data.get("error"): return cast(dict, issue_data) @@ -616,6 +694,7 @@ async def add_labels_to_issue( issue=issue_data["issue"]["id"], labels=current_labels + labels, notify_watchers=notify_watchers, + atlassian_cloud_id=atlassian_cloud_id, ) return cast(dict, response) @@ -638,9 +717,15 @@ async def remove_labels_from_issue( bool, "Whether to notify the issue's watchers. Defaults to True (notifies watchers).", ] = True, + atlassian_cloud_id: Annotated[ + str | None, + "The ID of the Atlassian Cloud to use (defaults to None). If not provided and the user has " + "a single cloud authorized, the tool will use that. Otherwise, an error will be raised.", + ] = None, ) -> Annotated[dict[str, Any], "The updated issue"]: """Remove labels from an existing Jira issue.""" - issue_data = await get_issue_by_id(context, issue) + atlassian_cloud_id = await resolve_cloud_id(context, atlassian_cloud_id) + issue_data = await get_issue_by_id(context, issue, atlassian_cloud_id=atlassian_cloud_id) if issue_data.get("error"): return cast(dict, issue_data) @@ -652,6 +737,7 @@ async def remove_labels_from_issue( issue=issue_data["issue"]["id"], labels=new_labels, notify_watchers=notify_watchers, + atlassian_cloud_id=atlassian_cloud_id, ) return cast(dict, response) @@ -732,6 +818,11 @@ async def update_issue( bool, "Whether to notify the issue's watchers. Defaults to True (notifies watchers).", ] = True, + atlassian_cloud_id: Annotated[ + str | None, + "The ID of the Atlassian Cloud to use (defaults to None). If not provided and the user has " + "a single cloud authorized, the tool will use that. Otherwise, an error will be raised.", + ] = None, ) -> Annotated[dict[str, Any], "The updated issue"]: """Update an existing Jira issue. @@ -743,23 +834,39 @@ async def update_issue( DO NOT CALL OTHER TOOLS only to list available priorities, issue types, or users. Provide the name, key, or email and the tool will figure out the ID. """ - issue_data = await get_issue_by_id(context, issue) + atlassian_cloud_id = await resolve_cloud_id(context, atlassian_cloud_id) + issue_data = await get_issue_by_id( + context=context, + issue=issue, + atlassian_cloud_id=atlassian_cloud_id, + ) if issue_data.get("error"): return cast(dict, issue_data) project = issue_data["issue"]["project"]["id"] error, _, issue_type_data, priority_data, parent_issue_data = await validate_issue_args( - context, due_date, project, issue_type, priority, parent_issue + context=context, + due_date=due_date, + project=project, + issue_type=issue_type, + priority=priority, + parent_issue=parent_issue, + atlassian_cloud_id=atlassian_cloud_id, ) if error: return cast(dict, error) - error, assignee_data, reporter_data = await resolve_issue_users(context, assignee, reporter) + error, assignee_data, reporter_data = await resolve_issue_users( + context=context, + assignee=assignee, + reporter=reporter, + atlassian_cloud_id=atlassian_cloud_id, + ) if error: return cast(dict, error) - client = JiraClient(context.get_auth_token_or_empty()) + client = JiraClient(context=context, cloud_id=atlassian_cloud_id) params = {"notifyWatchers": notify_watchers, "expand": "renderedFields"} request_body = build_issue_update_request_body( title=title, diff --git a/toolkits/jira/arcade_jira/tools/labels.py b/toolkits/jira/arcade_jira/tools/labels.py index d83e0cf3..2d315a38 100644 --- a/toolkits/jira/arcade_jira/tools/labels.py +++ b/toolkits/jira/arcade_jira/tools/labels.py @@ -4,7 +4,7 @@ from arcade_tdk import ToolContext, tool from arcade_tdk.auth import Atlassian from arcade_jira.client import JiraClient -from arcade_jira.utils import add_pagination_to_response +from arcade_jira.utils import add_pagination_to_response, resolve_cloud_id @tool(requires_auth=Atlassian(scopes=["read:jira-work"])) @@ -16,10 +16,16 @@ async def list_labels( offset: Annotated[ int, "The number of labels to skip. Defaults to 0 (starts from the first label)" ] = 0, + atlassian_cloud_id: Annotated[ + str | None, + "The ID of the Atlassian Cloud to use (defaults to None). If not provided and the user has " + "a single cloud authorized, the tool will use that. Otherwise, an error will be raised.", + ] = None, ) -> Annotated[dict[str, Any], "The existing labels (tags) in the user's Jira instance"]: """Get the existing labels (tags) in the user's Jira instance.""" limit = max(min(limit, 200), 1) - client = JiraClient(context.get_auth_token_or_empty()) + atlassian_cloud_id = await resolve_cloud_id(context, atlassian_cloud_id) + client = JiraClient(context=context, cloud_id=atlassian_cloud_id) api_response = await client.get( "/label", params={ diff --git a/toolkits/jira/arcade_jira/tools/priorities.py b/toolkits/jira/arcade_jira/tools/priorities.py index 83f86b3f..37f2cde5 100644 --- a/toolkits/jira/arcade_jira/tools/priorities.py +++ b/toolkits/jira/arcade_jira/tools/priorities.py @@ -4,7 +4,6 @@ from typing import Annotated, Any, cast from arcade_tdk import ToolContext, tool from arcade_tdk.auth import Atlassian -import arcade_jira.cache as cache from arcade_jira.client import JiraClient from arcade_jira.constants import JIRA_API_REQUEST_TIMEOUT, PrioritySchemeOrderBy from arcade_jira.exceptions import JiraToolExecutionError, MultipleItemsFoundError, NotFoundError @@ -16,6 +15,7 @@ from arcade_jira.utils import ( find_priorities_by_project, find_unique_project, remove_none_values, + resolve_cloud_id, ) @@ -23,9 +23,15 @@ from arcade_jira.utils import ( async def get_priority_by_id( context: ToolContext, priority_id: Annotated[str, "The ID of the priority to retrieve."], + atlassian_cloud_id: Annotated[ + str | None, + "The ID of the Atlassian Cloud to use (defaults to None). If not provided and the user has " + "a single cloud authorized, the tool will use that. Otherwise, an error will be raised.", + ] = None, ) -> Annotated[dict[str, Any], "The priority"]: """Get the details of a priority by its ID.""" - client = JiraClient(context.get_auth_token_or_empty()) + atlassian_cloud_id = await resolve_cloud_id(context, atlassian_cloud_id) + client = JiraClient(context=context, cloud_id=atlassian_cloud_id) try: response = await client.get(f"/priority/{priority_id}") except NotFoundError: @@ -50,10 +56,16 @@ async def list_priority_schemes( PrioritySchemeOrderBy, "The order in which to return the priority schemes. Defaults to name ascending.", ] = PrioritySchemeOrderBy.NAME_ASCENDING, + atlassian_cloud_id: Annotated[ + str | None, + "The ID of the Atlassian Cloud to use (defaults to None). If not provided and the user has " + "a single cloud authorized, the tool will use that. Otherwise, an error will be raised.", + ] = None, ) -> Annotated[dict[str, Any], "The priority schemes available"]: """Browse the priority schemes available in Jira.""" limit = max(min(limit, 50), 1) - client = JiraClient(context.get_auth_token_or_empty()) + atlassian_cloud_id = await resolve_cloud_id(context, atlassian_cloud_id) + client = JiraClient(context=context, cloud_id=atlassian_cloud_id) api_response = await client.get( "/priorityscheme", params=remove_none_values({ @@ -63,8 +75,8 @@ async def list_priority_schemes( "orderBy": order_by.to_api_value(), }), ) - cloud_name = cache.get_cloud_name(context.get_auth_token_or_empty()) - schemes = [clean_priority_scheme_dict(scheme, cloud_name) for scheme in api_response["values"]] + + schemes = [clean_priority_scheme_dict(scheme) for scheme in api_response["values"]] response = { "priority_schemes": schemes, "isLast": api_response.get("isLast"), @@ -83,9 +95,15 @@ async def list_priorities_associated_with_a_priority_scheme( offset: Annotated[ int, "The number of priority schemes to skip. Defaults to 0 (start from the first scheme)." ] = 0, + atlassian_cloud_id: Annotated[ + str | None, + "The ID of the Atlassian Cloud to use (defaults to None). If not provided and the user has " + "a single cloud authorized, the tool will use that. Otherwise, an error will be raised.", + ] = None, ) -> Annotated[dict[str, Any], "The priorities associated with the priority scheme"]: """Browse the priorities associated with a priority scheme.""" - client = JiraClient(context.get_auth_token_or_empty()) + atlassian_cloud_id = await resolve_cloud_id(context, atlassian_cloud_id) + client = JiraClient(context=context, cloud_id=atlassian_cloud_id) api_response = await client.get( f"/priorityscheme/{scheme_id}/priorities", params={ @@ -115,17 +133,28 @@ async def list_projects_associated_with_a_priority_scheme( offset: Annotated[ int, "The number of projects to skip. Defaults to 0 (start from the first project)." ] = 0, + atlassian_cloud_id: Annotated[ + str | None, + "The ID of the Atlassian Cloud to use (defaults to None). If not provided and the user has " + "a single cloud authorized, the tool will use that. Otherwise, an error will be raised.", + ] = None, ) -> Annotated[dict[str, Any], "The projects associated with the priority scheme"]: """Browse the projects associated with a priority scheme.""" + atlassian_cloud_id = await resolve_cloud_id(context, atlassian_cloud_id) + if project: try: - project_data = await find_unique_project(context, project) + project_data = await find_unique_project( + context=context, + project_identifier=project, + atlassian_cloud_id=atlassian_cloud_id, + ) except (NotFoundError, MultipleItemsFoundError) as exc: return {"error": exc.message} else: project = project_data["id"] - client = JiraClient(context.get_auth_token_or_empty()) + client = JiraClient(context=context, cloud_id=atlassian_cloud_id) api_response = await client.get( f"/priorityscheme/{scheme_id}/projects", params=remove_none_values({ @@ -134,8 +163,8 @@ async def list_projects_associated_with_a_priority_scheme( "projectId": project, }), ) - cloud_name = cache.get_cloud_name(context.get_auth_token_or_empty()) - projects = [clean_project_dict(project, cloud_name) for project in api_response["values"]] + + projects = [clean_project_dict(project) for project in api_response["values"]] response = { "projects": projects, "isLast": api_response.get("isLast"), @@ -147,6 +176,11 @@ async def list_projects_associated_with_a_priority_scheme( async def list_priorities_available_to_a_project( context: ToolContext, project: Annotated[str, "The ID, key or name of the project to retrieve priorities for."], + atlassian_cloud_id: Annotated[ + str | None, + "The ID of the Atlassian Cloud to use (defaults to None). If not provided and the user has " + "a single cloud authorized, the tool will use that. Otherwise, an error will be raised.", + ] = None, ) -> Annotated[ dict[str, Any], "The priorities available to be used in issues in the specified Jira project", @@ -157,14 +191,24 @@ async def list_priorities_available_to_a_project( a specific project. In Jira environments with too many Projects or Priority Schemes, the search may take too long, and the tool call will timeout. """ + atlassian_cloud_id = await resolve_cloud_id(context, atlassian_cloud_id) + try: - project_data = await find_unique_project(context, project) + project_data = await find_unique_project( + context=context, + project_identifier=project, + atlassian_cloud_id=atlassian_cloud_id, + ) except (NotFoundError, MultipleItemsFoundError) as exc: return {"error": exc.message} try: return await asyncio.wait_for( - find_priorities_by_project(context, project_data), + find_priorities_by_project( + context=context, + project=project_data, + atlassian_cloud_id=atlassian_cloud_id, + ), timeout=JIRA_API_REQUEST_TIMEOUT, ) except asyncio.TimeoutError: @@ -177,18 +221,32 @@ async def list_priorities_available_to_a_project( async def list_priorities_available_to_an_issue( context: ToolContext, issue: Annotated[str, "The ID or key of the issue to retrieve priorities for."], + atlassian_cloud_id: Annotated[ + str | None, + "The ID of the Atlassian Cloud to use (defaults to None). If not provided and the user has " + "a single cloud authorized, the tool will use that. Otherwise, an error will be raised.", + ] = None, ) -> Annotated[dict[str, Any], "The priorities available to be used in the specified Jira issue"]: """Browse the priorities available to be used in the specified Jira issue.""" from arcade_jira.tools.issues import get_issue_by_id - issue_response = await get_issue_by_id(context, issue) + atlassian_cloud_id = await resolve_cloud_id(context, atlassian_cloud_id) + issue_response = await get_issue_by_id( + context=context, + issue=issue, + atlassian_cloud_id=atlassian_cloud_id, + ) if issue_response.get("error"): return cast(dict[str, Any], issue_response) issue_data = issue_response["issue"] project = issue_data["project"]["id"] - response = await list_priorities_available_to_a_project(context, project) + response = await list_priorities_available_to_a_project( + context=context, + project=project, + atlassian_cloud_id=atlassian_cloud_id, + ) return { "issue": { diff --git a/toolkits/jira/arcade_jira/tools/projects.py b/toolkits/jira/arcade_jira/tools/projects.py index c602f92e..52456775 100644 --- a/toolkits/jira/arcade_jira/tools/projects.py +++ b/toolkits/jira/arcade_jira/tools/projects.py @@ -3,13 +3,13 @@ from typing import Annotated, Any, cast from arcade_tdk import ToolContext, tool from arcade_tdk.auth import Atlassian -import arcade_jira.cache as cache from arcade_jira.client import JiraClient from arcade_jira.exceptions import NotFoundError from arcade_jira.utils import ( add_pagination_to_response, clean_project_dict, remove_none_values, + resolve_cloud_id, ) @@ -22,10 +22,23 @@ async def list_projects( offset: Annotated[ int, "The number of projects to skip. Defaults to 0 (starts from the first project)" ] = 0, + atlassian_cloud_id: Annotated[ + str | None, + "The ID of the Atlassian Cloud to use (defaults to None). If not provided and the user has " + "a single cloud authorized, the tool will use that. Otherwise, an error will be raised.", + ] = None, ) -> Annotated[dict[str, Any], "Information about the projects"]: """Browse projects available in Jira.""" + atlassian_cloud_id = await resolve_cloud_id(context, atlassian_cloud_id) return cast( - dict[str, Any], await search_projects(context, keywords=None, limit=limit, offset=offset) + dict[str, Any], + await search_projects( + context=context, + keywords=None, + limit=limit, + offset=offset, + atlassian_cloud_id=atlassian_cloud_id, + ), ) @@ -43,10 +56,16 @@ async def search_projects( offset: Annotated[ int, "The number of projects to skip. Defaults to 0 (starts from the first project)" ] = 0, + atlassian_cloud_id: Annotated[ + str | None, + "The ID of the Atlassian Cloud to use (defaults to None). If not provided and the user has " + "a single cloud authorized, the tool will use that. Otherwise, an error will be raised.", + ] = None, ) -> Annotated[dict[str, Any], "Information about the projects"]: """Get the details of all Jira projects.""" limit = max(min(limit, 50), 1) - client = JiraClient(context.get_auth_token_or_empty()) + atlassian_cloud_id = await resolve_cloud_id(context, atlassian_cloud_id) + client = JiraClient(context=context, cloud_id=atlassian_cloud_id) api_response = await client.get( "/project/search", params=remove_none_values({ @@ -59,8 +78,8 @@ async def search_projects( "query": keywords, }), ) - cloud_name = cache.get_cloud_name(context.get_auth_token_or_empty()) - projects = [clean_project_dict(project, cloud_name) for project in api_response["values"]] + + projects = [clean_project_dict(project) for project in api_response["values"]] response = { "projects": projects, "isLast": api_response.get("isLast"), @@ -72,14 +91,19 @@ async def search_projects( async def get_project_by_id( context: ToolContext, project: Annotated[str, "The ID or key of the project to retrieve"], + atlassian_cloud_id: Annotated[ + str | None, + "The ID of the Atlassian Cloud to use (defaults to None). If not provided and the user has " + "a single cloud authorized, the tool will use that. Otherwise, an error will be raised.", + ] = None, ) -> Annotated[dict[str, Any], "Information about the project"]: """Get the details of a Jira project by its ID or key.""" - client = JiraClient(context.get_auth_token_or_empty()) + atlassian_cloud_id = await resolve_cloud_id(context, atlassian_cloud_id) + client = JiraClient(context=context, cloud_id=atlassian_cloud_id) try: response = await client.get(f"project/{project}") except NotFoundError: return {"error": f"Project not found: {project}"} - cloud_name = cache.get_cloud_name(context.get_auth_token_or_empty()) - return {"project": clean_project_dict(response, cloud_name)} + return {"project": clean_project_dict(response)} diff --git a/toolkits/jira/arcade_jira/tools/transitions.py b/toolkits/jira/arcade_jira/tools/transitions.py index ae4f3151..4cafeeea 100644 --- a/toolkits/jira/arcade_jira/tools/transitions.py +++ b/toolkits/jira/arcade_jira/tools/transitions.py @@ -4,6 +4,7 @@ from arcade_tdk import ToolContext, tool from arcade_tdk.auth import Atlassian from arcade_jira.client import JiraClient +from arcade_jira.utils import resolve_cloud_id @tool(requires_auth=Atlassian(scopes=["read:jira-work"])) @@ -11,6 +12,11 @@ async def get_transition_by_id( context: ToolContext, issue: Annotated[str, "The ID or key of the issue"], transition_id: Annotated[str, "The ID of the transition"], + atlassian_cloud_id: Annotated[ + str | None, + "The ID of the Atlassian Cloud to use (defaults to None). If not provided and the user has " + "a single cloud authorized, the tool will use that. Otherwise, an error will be raised.", + ] = None, ) -> Annotated[dict, "The transition data"]: """Get a transition by its ID.""" if not transition_id: @@ -18,7 +24,8 @@ async def get_transition_by_id( if not transition_id.isdigit(): return {"error": "The transition ID must be a numeric string."} - client = JiraClient(context.get_auth_token_or_empty()) + atlassian_cloud_id = await resolve_cloud_id(context, atlassian_cloud_id) + client = JiraClient(context=context, cloud_id=atlassian_cloud_id) response = await client.get( f"/issue/{issue}/transitions", params={ @@ -49,12 +56,22 @@ async def get_transition_by_id( async def get_transitions_available_for_issue( context: ToolContext, issue: Annotated[str, "The ID or key of the issue"], + atlassian_cloud_id: Annotated[ + str | None, + "The ID of the Atlassian Cloud to use (defaults to None). If not provided and the user has " + "a single cloud authorized, the tool will use that. Otherwise, an error will be raised.", + ] = None, ) -> Annotated[dict, "The transitions available and the issue's current status"]: """Get the transitions available for an existing Jira issue.""" from arcade_jira.tools.issues import get_issue_by_id # Avoid circular import - client = JiraClient(context.get_auth_token_or_empty()) - issue_data = await get_issue_by_id(context, issue) + atlassian_cloud_id = await resolve_cloud_id(context, atlassian_cloud_id) + client = JiraClient(context=context, cloud_id=atlassian_cloud_id) + issue_data = await get_issue_by_id( + context=context, + issue=issue, + atlassian_cloud_id=atlassian_cloud_id, + ) if issue_data.get("error"): return cast(dict, issue_data) response = await client.get( @@ -85,12 +102,21 @@ async def get_transition_by_status_name( context: ToolContext, issue: Annotated[str, "The ID or key of the issue"], transition: Annotated[str, "The name of the transition status"], + atlassian_cloud_id: Annotated[ + str | None, + "The ID of the Atlassian Cloud to use (defaults to None). If not provided and the user has " + "a single cloud authorized, the tool will use that. Otherwise, an error will be raised.", + ] = None, ) -> Annotated[dict, "The transition data, including screen fields available"]: """Get a transition available for an issue by the transition name. The response will contain screen fields available for the transition, if any. """ - transitions = await get_transitions_available_for_issue(context, issue) + transitions = await get_transitions_available_for_issue( + context=context, + issue=issue, + atlassian_cloud_id=atlassian_cloud_id, + ) for available_transition in transitions["transitions_available"]: if available_transition["name"].casefold() == transition.casefold(): return {"issue": issue, "transition": available_transition} @@ -115,16 +141,32 @@ async def transition_issue_to_new_status( str, "The transition to perform. Provide the transition ID or its name (case insensitive).", ], + atlassian_cloud_id: Annotated[ + str | None, + "The ID of the Atlassian Cloud to use (defaults to None). If not provided and the user has " + "a single cloud authorized, the tool will use that. Otherwise, an error will be raised.", + ] = None, ) -> Annotated[dict, "The updated issue"]: """Transition a Jira issue to a new status.""" - client = JiraClient(context.get_auth_token_or_empty()) + atlassian_cloud_id = await resolve_cloud_id(context, atlassian_cloud_id) + client = JiraClient(context=context, cloud_id=atlassian_cloud_id) # Try to get the transition by ID first - response = await get_transition_by_id(context, issue, transition) + response = await get_transition_by_id( + context=context, + issue=issue, + transition_id=transition, + atlassian_cloud_id=atlassian_cloud_id, + ) # If the transition is not found by ID, try to get it by name if response.get("error"): - response = await get_transition_by_status_name(context, issue, transition) + response = await get_transition_by_status_name( + context=context, + issue=issue, + transition=transition, + atlassian_cloud_id=atlassian_cloud_id, + ) if response.get("error"): return cast(dict, response) diff --git a/toolkits/jira/arcade_jira/tools/users.py b/toolkits/jira/arcade_jira/tools/users.py index bce49f17..1867a931 100644 --- a/toolkits/jira/arcade_jira/tools/users.py +++ b/toolkits/jira/arcade_jira/tools/users.py @@ -4,10 +4,14 @@ from arcade_tdk import ToolContext, tool from arcade_tdk.auth import Atlassian from arcade_tdk.errors import ToolExecutionError -import arcade_jira.cache as cache from arcade_jira.client import JiraClient from arcade_jira.exceptions import NotFoundError -from arcade_jira.utils import add_pagination_to_response, clean_user_dict, remove_none_values +from arcade_jira.utils import ( + add_pagination_to_response, + clean_user_dict, + remove_none_values, + resolve_cloud_id, +) @tool(requires_auth=Atlassian(scopes=["read:jira-user"])) @@ -29,10 +33,16 @@ async def list_users( "The number of users to skip before starting to return users. " "Defaults to 0 (start from the first user).", ] = 0, + atlassian_cloud_id: Annotated[ + str | None, + "The ID of the Atlassian Cloud to use (defaults to None). If not provided and the user has " + "a single cloud authorized, the tool will use that. Otherwise, an error will be raised.", + ] = None, ) -> Annotated[dict[str, Any], "The information about all users."]: """Browse users in Jira.""" limit = max(min(limit, 50), 1) - client = JiraClient(context.get_auth_token_or_empty()) + atlassian_cloud_id = await resolve_cloud_id(context, atlassian_cloud_id) + client = JiraClient(context=context, cloud_id=atlassian_cloud_id) api_response = await client.get( "/users/search", params={ @@ -41,9 +51,9 @@ async def list_users( }, ) items = cast(list[dict[str, Any]], api_response) - cloud_name = cache.get_cloud_name(context.get_auth_token_or_empty()) + users = [ - clean_user_dict(user, cloud_name) + clean_user_dict(user) for user in api_response if not account_type or user["accountType"].casefold() == account_type.casefold() ] @@ -56,9 +66,15 @@ async def list_users( async def get_user_by_id( context: ToolContext, user_id: Annotated[str, "The the user's ID."], + atlassian_cloud_id: Annotated[ + str | None, + "The ID of the Atlassian Cloud to use (defaults to None). If not provided and the user has " + "a single cloud authorized, the tool will use that. Otherwise, an error will be raised.", + ] = None, ) -> Annotated[dict[str, Any], "The user information."]: """Get user information by their ID.""" - client = JiraClient(context.get_auth_token_or_empty()) + atlassian_cloud_id = await resolve_cloud_id(context, atlassian_cloud_id) + client = JiraClient(context=context, cloud_id=atlassian_cloud_id) not_found = {"error": "User not found"} @@ -70,8 +86,7 @@ async def get_user_by_id( if not response: return not_found - cloud_name = cache.get_cloud_name(context.get_auth_token_or_empty()) - return {"user": clean_user_dict(response, cloud_name)} + return {"user": clean_user_dict(response)} @tool(requires_auth=Atlassian(scopes=["read:jira-user"])) @@ -100,6 +115,11 @@ async def get_users_without_id( "The number of users to skip before starting to return users. " "Defaults to 0 (start from the first user).", ] = 0, + atlassian_cloud_id: Annotated[ + str | None, + "The ID of the Atlassian Cloud to use (defaults to None). If not provided and the user has " + "a single cloud authorized, the tool will use that. Otherwise, an error will be raised.", + ] = None, ) -> Annotated[dict[str, Any], "The information about users that match the search criteria."]: """Get users without their account ID, searching by display name and email address. @@ -119,7 +139,8 @@ async def get_users_without_id( message="The `user_name_or_email` argument is required to search for users." ) - client = JiraClient(context.get_auth_token_or_empty()) + atlassian_cloud_id = await resolve_cloud_id(context, atlassian_cloud_id) + client = JiraClient(context=context, cloud_id=atlassian_cloud_id) api_response = await client.get( "/user/search", params=remove_none_values({ @@ -128,8 +149,8 @@ async def get_users_without_id( "maxResults": limit, }), ) - cloud_name = cache.get_cloud_name(context.get_auth_token_or_empty()) - users = [clean_user_dict(user, cloud_name) for user in api_response] + + users = [clean_user_dict(user) for user in api_response] if enforce_exact_match: users = [ diff --git a/toolkits/jira/arcade_jira/utils.py b/toolkits/jira/arcade_jira/utils.py index 26175fbb..50833d37 100644 --- a/toolkits/jira/arcade_jira/utils.py +++ b/toolkits/jira/arcade_jira/utils.py @@ -2,15 +2,17 @@ import asyncio import base64 import json import mimetypes +import uuid from collections.abc import Callable from contextlib import suppress from datetime import date, datetime from typing import Any, cast +import httpx from arcade_tdk import ToolContext -from arcade_tdk.errors import ToolExecutionError +from arcade_tdk.errors import RetryableToolError, ToolExecutionError -from arcade_jira.constants import STOP_WORDS +from arcade_jira.constants import JIRA_BASE_URL, STOP_WORDS from arcade_jira.exceptions import JiraToolExecutionError, MultipleItemsFoundError, NotFoundError @@ -91,7 +93,7 @@ def build_search_issues_jql( return " AND ".join(clauses) if clauses else "" -def clean_issue_dict(issue: dict, cloud_name: str | None = None) -> dict: +def clean_issue_dict(issue: dict) -> dict: fields = cast(dict, issue["fields"]) rendered_fields = issue.get("renderedFields", {}) @@ -103,13 +105,13 @@ def clean_issue_dict(issue: dict, cloud_name: str | None = None) -> dict: fields["parent"] = get_summarized_issue_dict(fields["parent"]) if fields["assignee"]: - fields["assignee"] = clean_user_dict(fields["assignee"], cloud_name) + fields["assignee"] = clean_user_dict(fields["assignee"]) if fields["creator"]: - fields["creator"] = clean_user_dict(fields["creator"], cloud_name) + fields["creator"] = clean_user_dict(fields["creator"]) if fields["reporter"]: - fields["reporter"] = clean_user_dict(fields["reporter"], cloud_name) + fields["reporter"] = clean_user_dict(fields["reporter"]) if fields.get("description"): fields["description"] = rendered_fields.get("description") @@ -125,8 +127,7 @@ def clean_issue_dict(issue: dict, cloud_name: str | None = None) -> dict: if fields.get("attachment"): fields["attachments"] = [ - clean_attachment_dict(attachment, cloud_name) - for attachment in fields.get("attachment", []) + clean_attachment_dict(attachment) for attachment in fields.get("attachment", []) ] add_identified_fields_to_issue(fields, ["status", "issuetype", "priority", "project"]) @@ -151,8 +152,6 @@ def clean_issue_dict(issue: dict, cloud_name: str | None = None) -> dict: ], ) - fields["url"] = build_issue_url(cloud_name, fields["project"]["key"], fields["key"]) - return fields @@ -190,15 +189,13 @@ def clean_comment_dict(comment: dict, include_adf_content: bool = False) -> dict return data -def clean_project_dict(project: dict, cloud_name: str | None = None) -> dict: +def clean_project_dict(project: dict) -> dict: data = { "id": project["id"], "key": project["key"], "name": project["name"], } - data["url"] = build_project_url(cloud_name, project["key"]) - if "description" in project: data["description"] = project["description"] @@ -227,15 +224,13 @@ def clean_issue_type_dict(issue_type: dict) -> dict: return data -def clean_user_dict(user: dict, cloud_name: str | None = None) -> dict: +def clean_user_dict(user: dict) -> dict: data = { "id": user["accountId"], "name": user["displayName"], "active": user["active"], } - data["url"] = build_user_url(cloud_name, user["accountId"]) - if user.get("emailAddress"): data["email"] = user["emailAddress"] @@ -251,17 +246,17 @@ def clean_user_dict(user: dict, cloud_name: str | None = None) -> dict: return data -def clean_attachment_dict(attachment: dict, cloud_name: str | None = None) -> dict: +def clean_attachment_dict(attachment: dict) -> dict: return { "id": attachment["id"], "filename": attachment["filename"], "mime_type": attachment["mimeType"], "size": {"bytes": attachment["size"]}, - "author": clean_user_dict(attachment["author"], cloud_name), + "author": clean_user_dict(attachment["author"]), } -def clean_priority_scheme_dict(scheme: dict, cloud_name: str | None = None) -> dict: +def clean_priority_scheme_dict(scheme: dict) -> dict: data = { "id": scheme["id"], "name": scheme["name"], @@ -290,9 +285,7 @@ def clean_priority_scheme_dict(scheme: dict, cloud_name: str | None = None) -> d if isinstance(scheme.get("projects"), dict): all_projects = scheme["projects"].get("isLast", True) - data["projects"] = [ - clean_project_dict(project, cloud_name) for project in scheme["projects"]["values"] - ] + data["projects"] = [clean_project_dict(project) for project in scheme["projects"]["values"]] if not all_projects: # Avoid circular import from arcade_jira.tools.priorities import list_projects_associated_with_a_priority_scheme @@ -378,6 +371,7 @@ async def find_multiple_unique_users( context: ToolContext, user_identifiers: list[str], exact_match: bool = False, + atlassian_cloud_id: str | None = None, ) -> list[dict[str, Any]]: """ Find users matching either their display name, email address, or account ID. @@ -400,6 +394,7 @@ async def find_multiple_unique_users( context=context, name_or_email=user_identifier, enforce_exact_match=exact_match, + atlassian_cloud_id=atlassian_cloud_id, ) for user_identifier in user_identifiers ]) @@ -424,7 +419,12 @@ async def find_multiple_unique_users( if search_by_id: responses = await asyncio.gather(*[ - get_user_by_id(context, user_id=user_id) for user_id in search_by_id + get_user_by_id( + context=context, + user_id=user_id, + atlassian_cloud_id=atlassian_cloud_id, + ) + for user_id in search_by_id ]) for response in responses: if response["user"]: @@ -440,6 +440,7 @@ async def find_multiple_unique_users( async def find_unique_project( context: ToolContext, project_identifier: str, + atlassian_cloud_id: str | None = None, ) -> dict[str, Any]: """Find a unique project by its ID, key, or name @@ -453,12 +454,20 @@ async def find_unique_project( from arcade_jira.tools.projects import get_project_by_id, search_projects # Try to find project by ID or key - response = await get_project_by_id(context, project=project_identifier) + response = await get_project_by_id( + context=context, + project=project_identifier, + atlassian_cloud_id=atlassian_cloud_id, + ) if response.get("project"): return cast(dict, response["project"]) # If not found, search by name - response = await search_projects(context, keywords=project_identifier) + response = await search_projects( + context=context, + keywords=project_identifier, + atlassian_cloud_id=atlassian_cloud_id, + ) projects = response["projects"] if len(projects) == 1: return cast(dict, projects[0]) @@ -482,6 +491,7 @@ async def find_unique_priority( context: ToolContext, priority_identifier: str, project_id: str, + atlassian_cloud_id: str | None = None, ) -> dict[str, Any]: """Find a unique priority by ID or name that is associated with a project @@ -499,12 +509,20 @@ async def find_unique_priority( ) # Try to get the priority by ID first - response = await get_priority_by_id(context, priority_identifier) + response = await get_priority_by_id( + context=context, + priority_id=priority_identifier, + atlassian_cloud_id=atlassian_cloud_id, + ) if response.get("priority"): return cast(dict, response["priority"]) # If not found, search by name - response = await list_priorities_available_to_a_project(context, project_id) + response = await list_priorities_available_to_a_project( + context=context, + project=project_id, + atlassian_cloud_id=atlassian_cloud_id, + ) if response.get("error"): raise JiraToolExecutionError(response["error"]) @@ -538,6 +556,7 @@ async def find_unique_issue_type( context: ToolContext, issue_type_identifier: str, project_id: str, + atlassian_cloud_id: str | None = None, ) -> dict[str, Any]: """Find a unique issue type by its ID or name that is associated with a project @@ -552,12 +571,20 @@ async def find_unique_issue_type( from arcade_jira.tools.issues import get_issue_type_by_id, list_issue_types_by_project # Try to get the issue type by ID first - response = await get_issue_type_by_id(context, issue_type_identifier) + response = await get_issue_type_by_id( + context=context, + issue_type_id=issue_type_identifier, + atlassian_cloud_id=atlassian_cloud_id, + ) if response.get("issue_type"): return cast(dict, response["issue_type"]) # If not found, search by name - response = await list_issue_types_by_project(context, project_id) + response = await list_issue_types_by_project( + context=context, + project=project_id, + atlassian_cloud_id=atlassian_cloud_id, + ) if response.get("error"): raise JiraToolExecutionError(response["error"]) @@ -601,19 +628,27 @@ async def find_unique_issue_type( async def find_unique_user( context: ToolContext, user_identifier: str, + atlassian_cloud_id: str | None = None, ) -> dict[str, Any]: """Find a unique user by their ID, key, email address, or display name.""" # Avoid circular import from arcade_jira.tools.users import get_user_by_id, get_users_without_id # Try to get the user by ID - response = await get_user_by_id(context, user_identifier) + response = await get_user_by_id( + context=context, + user_id=user_identifier, + atlassian_cloud_id=atlassian_cloud_id, + ) if response.get("user"): return cast(dict, response["user"]) # Search for the user name or email, if not found by ID response = await get_users_without_id( - context, name_or_email=user_identifier, enforce_exact_match=True + context=context, + name_or_email=user_identifier, + enforce_exact_match=True, + atlassian_cloud_id=atlassian_cloud_id, ) users = response["users"] @@ -636,13 +671,17 @@ async def find_unique_user( raise NotFoundError(message=f"User not found with ID, name or email '{user_identifier}'") -async def get_single_project(context: ToolContext) -> dict[str, Any]: +async def get_single_project( + context: ToolContext, + atlassian_cloud_id: str | None = None, +) -> dict[str, Any]: from arcade_jira.tools.projects import list_projects projects = await paginate_all_items( context=context, tool=list_projects, response_items_key="projects", + atlassian_cloud_id=atlassian_cloud_id, ) if len(projects) == 0: @@ -743,17 +782,26 @@ async def paginate_all_items( return items -async def paginate_all_priority_schemes(context: ToolContext) -> list[dict]: +async def paginate_all_priority_schemes( + context: ToolContext, + atlassian_cloud_id: str | None = None, +) -> list[dict]: """Get all priority schemes.""" # Avoid circular import from arcade_jira.tools.priorities import list_priority_schemes - return await paginate_all_items(context, list_priority_schemes, "priority_schemes") + return await paginate_all_items( + context=context, + tool=list_priority_schemes, + response_items_key="priority_schemes", + atlassian_cloud_id=atlassian_cloud_id, + ) async def paginate_all_priorities_by_priority_scheme( context: ToolContext, scheme_id: str, + atlassian_cloud_id: str | None = None, ) -> list[dict]: """Get all priorities associated with a priority scheme.""" # Avoid circular import @@ -764,19 +812,7 @@ async def paginate_all_priorities_by_priority_scheme( list_priorities_associated_with_a_priority_scheme, "priorities", scheme_id=scheme_id, - ) - - -async def paginate_all_issue_types(context: ToolContext, project_identifier: str) -> list[dict]: - """Get all issue types associated with a project.""" - # Avoid circular import - from arcade_jira.tools.issues import list_issue_types_by_project - - return await paginate_all_items( - context, - list_issue_types_by_project, - "issue_types", - project=project_identifier, + atlassian_cloud_id=atlassian_cloud_id, ) @@ -787,6 +823,7 @@ async def validate_issue_args( issue_type: str | None, priority: str | None, parent_issue: str | None, + atlassian_cloud_id: str | None = None, ) -> tuple[dict | None, dict | None, str | dict | None, str | dict | None, dict | None]: if due_date and not is_valid_date_string(due_date): return ( @@ -808,7 +845,10 @@ async def validate_issue_args( error: dict[str, Any] | None = None project_data = await get_project_by_project_identifier_or_by_parent_issue( - context, project, parent_issue + context=context, + project=project, + parent_issue_id=parent_issue, + atlassian_cloud_id=atlassian_cloud_id, ) issue_type_data: str | dict[str, Any] | None = None priority_data: str | dict[str, Any] | None = None @@ -818,15 +858,29 @@ async def validate_issue_args( error = project_data return error, None, issue_type_data, priority_data, parent_issue_data - error, issue_type_data = await resolve_issue_type(context, issue_type, project_data) + error, issue_type_data = await resolve_issue_type( + context=context, + issue_type=issue_type, + project_data=project_data, + atlassian_cloud_id=atlassian_cloud_id, + ) if error: return error, project_data, issue_type_data, priority_data, parent_issue_data - error, priority_data = await resolve_issue_priority(context, priority, project_data) + error, priority_data = await resolve_issue_priority( + context=context, + priority=priority, + project_data=project_data, + atlassian_cloud_id=atlassian_cloud_id, + ) if error: return error, project_data, issue_type_data, priority_data, parent_issue_data - error, parent_issue_data = await resolve_parent_issue(context, parent_issue) + error, parent_issue_data = await resolve_parent_issue( + context=context, + parent_issue=parent_issue, + atlassian_cloud_id=atlassian_cloud_id, + ) if error: return error, project_data, issue_type_data, priority_data, parent_issue_data @@ -837,12 +891,18 @@ async def resolve_issue_type( context: ToolContext, issue_type: str | None, project_data: dict, + atlassian_cloud_id: str | None = None, ) -> tuple[dict[str, Any] | None, str | dict[str, Any] | None]: if issue_type == "": return None, "" elif issue_type: try: - response = await find_unique_issue_type(context, issue_type, project_data["id"]) + response = await find_unique_issue_type( + context=context, + issue_type_identifier=issue_type, + project_id=project_data["id"], + atlassian_cloud_id=atlassian_cloud_id, + ) except JiraToolExecutionError as exc: return {"error": exc.message}, None else: @@ -855,12 +915,18 @@ async def resolve_issue_priority( context: ToolContext, priority: str | None, project_data: dict, + atlassian_cloud_id: str | None = None, ) -> tuple[dict[str, Any] | None, str | dict[str, Any] | None]: if priority == "": return None, "" elif priority: try: - priority_data = await find_unique_priority(context, priority, project_data["id"]) + priority_data = await find_unique_priority( + context=context, + priority_identifier=priority, + project_id=project_data["id"], + atlassian_cloud_id=atlassian_cloud_id, + ) except JiraToolExecutionError as exc: return {"error": exc.message}, None else: @@ -872,6 +938,7 @@ async def resolve_issue_priority( async def resolve_parent_issue( context: ToolContext, parent_issue: str | None, + atlassian_cloud_id: str | None = None, ) -> tuple[dict[str, Any] | None, dict[str, Any] | None]: if parent_issue == "": return {"error": "Parent issue cannot be empty"}, None @@ -879,7 +946,11 @@ async def resolve_parent_issue( from arcade_jira.tools.issues import get_issue_by_id # Avoid circular import try: - parent_issue_data = await get_issue_by_id(context, parent_issue) + parent_issue_data = await get_issue_by_id( + context=context, + issue=parent_issue, + atlassian_cloud_id=atlassian_cloud_id, + ) except JiraToolExecutionError as exc: return {"error": exc.message}, None else: @@ -892,6 +963,7 @@ async def get_project_by_project_identifier_or_by_parent_issue( context: ToolContext, project: str | None, parent_issue_id: str | None, + atlassian_cloud_id: str | None = None, ) -> dict[str, Any]: from arcade_jira.tools.issues import get_issue_by_id # Avoid circular import @@ -899,13 +971,21 @@ async def get_project_by_project_identifier_or_by_parent_issue( return {"error": "Must provide either `project` or `parent_issue_id` argument."} if not project: - parent_issue_data = await get_issue_by_id(context, parent_issue_id) + parent_issue_data = await get_issue_by_id( + context=context, + issue=parent_issue_id, + atlassian_cloud_id=atlassian_cloud_id, + ) if parent_issue_data.get("error"): return {"error": f"Parent issue not found with ID {parent_issue_id}."} project = cast(str, parent_issue_data["project"]["id"]) try: - project_data = await find_unique_project(context, project) + project_data = await find_unique_project( + context=context, + project_identifier=project, + atlassian_cloud_id=atlassian_cloud_id, + ) except JiraToolExecutionError as exc: return {"error": exc.message} @@ -916,6 +996,7 @@ async def resolve_issue_users( context: ToolContext, assignee: str | None, reporter: str | None, + atlassian_cloud_id: str | None = None, ) -> tuple[dict | None, str | dict | None, str | dict | None]: assignee_data: str | dict | None = None reporter_data: str | dict | None = None @@ -927,7 +1008,11 @@ async def resolve_issue_users( assignee_data = "" elif assignee: try: - assignee_data = await find_unique_user(context, assignee) + assignee_data = await find_unique_user( + context=context, + user_identifier=assignee, + atlassian_cloud_id=atlassian_cloud_id, + ) except JiraToolExecutionError as exc: return {"error": exc.message}, assignee_data, reporter_data @@ -935,7 +1020,11 @@ async def resolve_issue_users( reporter_data = "" elif reporter: try: - reporter_data = await find_unique_user(context, reporter) + reporter_data = await find_unique_user( + context=context, + user_identifier=reporter, + atlassian_cloud_id=atlassian_cloud_id, + ) except JiraToolExecutionError as exc: return {"error": exc.message}, assignee_data, reporter_data @@ -945,6 +1034,7 @@ async def resolve_issue_users( async def find_priorities_by_project( context: ToolContext, project: dict[str, Any], + atlassian_cloud_id: str | None = None, ) -> dict[str, Any]: # Avoid circular import from arcade_jira.tools.priorities import list_projects_associated_with_a_priority_scheme @@ -953,7 +1043,10 @@ async def find_priorities_by_project( priority_ids: set[str] = set() priorities: list[dict[str, Any]] = [] - priority_schemes = await paginate_all_priority_schemes(context) + priority_schemes = await paginate_all_priority_schemes( + context=context, + atlassian_cloud_id=atlassian_cloud_id, + ) if not priority_schemes: raise NotFoundError("No priority schemes found") # noqa: TRY003 @@ -963,6 +1056,7 @@ async def find_priorities_by_project( context=context, scheme_id=scheme["id"], project=project["id"], + atlassian_cloud_id=atlassian_cloud_id, ) for scheme in priority_schemes ]) @@ -981,7 +1075,12 @@ async def find_priorities_by_project( return {"error": f"No priority schemes found for the project {project['id']}"} priorities_by_scheme = await asyncio.gather(*[ - paginate_all_priorities_by_priority_scheme(context, scheme_id) for scheme_id in scheme_ids + paginate_all_priorities_by_priority_scheme( + context=context, + scheme_id=scheme_id, + atlassian_cloud_id=atlassian_cloud_id, + ) + for scheme_id in scheme_ids ]) for priorities_available in priorities_by_scheme: @@ -1123,22 +1222,129 @@ def extract_id(field: Any) -> dict[str, str] | None: return {"id": field["id"]} if isinstance(field, dict) else None -def build_issue_url(cloud_name: str | None, issue_id: str, issue_key: str) -> str | None: - if not cloud_name: - return None +async def resolve_cloud_id(context: ToolContext, cloud_id: str | None) -> str: + try: + uuid.UUID(cloud_id) + except (AttributeError, TypeError, ValueError): + is_valid_uuid = False + else: + is_valid_uuid = True - return f"https://{cloud_name}.atlassian.net/jira/software/projects/{issue_id}/list?selectedIssue={issue_key}" + # If this is already a valid Cloud ID, return it + if is_valid_uuid: + return cast(str, cloud_id) + + # If not, it's possibly a Cloud name, so we try to match that. + if isinstance(cloud_id, str) and cloud_id != "": + return await get_cloud_id_by_cloud_name(context, cloud_name=cloud_id) + + # As a last resort, try to get a unique Cloud ID from the available Atlassian Clouds + return await get_unique_cloud_id(context) -def build_project_url(cloud_name: str | None, project_key: str) -> str | None: - if not cloud_name: - return None +async def get_cloud_id_by_cloud_name(context: ToolContext, cloud_name: str) -> str: + from arcade_jira.tools.cloud import get_available_atlassian_clouds # Avoid circular import - return f"https://{cloud_name}.atlassian.net/jira/software/projects/{project_key}/summary" + response = await get_available_atlassian_clouds(context) + clouds = response["clouds_available"] + + for cloud in clouds: + if ( + # Case-insensitive match in case of cloud names. + cloud["atlassian_cloud_name"].casefold() == cloud_name.casefold() + # Match the ID as well just in case. Who knows, Atlassian may start + # using some weird values as cloud IDs. If the value provided matches + # an ID in the list of clouds, then it's a match. + or cloud["atlassian_cloud_id"] == cloud_name + ): + return cast(str, cloud["atlassian_cloud_id"]) + + message = f"No Atlassian Cloud found matching '{cloud_name}'" + available_clouds_str = f"Available Atlassian Clouds:\n\n```json\n{json.dumps(clouds)}\n```" + + raise RetryableToolError( + message=message, + developer_message=message, + additional_prompt_content=available_clouds_str, + ) -def build_user_url(cloud_name: str | None, user_id: str) -> str | None: - if not cloud_name: - return None +async def get_unique_cloud_id(context: ToolContext) -> str: + from arcade_jira.tools.cloud import get_available_atlassian_clouds # Avoid circular import - return f"https://{cloud_name}.atlassian.net/jira/people/{user_id}" + response = await get_available_atlassian_clouds(context) + clouds = response["clouds_available"] + + if len(clouds) == 0: + message = "No Atlassian Cloud is available. Please authorize an Atlassian Cloud." + raise ToolExecutionError( + message=message, + developer_message=message, + ) + + if len(clouds) > 1: + message = ( + "Multiple Atlassian Clouds are available. One Cloud ID has to be selected and provided " + "in the tool call using the `atlassian_cloud_id` argument." + ) + raise RetryableToolError( + message=message, + developer_message=message, + additional_prompt_content=( + f"Available Atlassian Clouds:\n\n```json\n{json.dumps(clouds)}\n```" + ), + ) + + return cast(str, clouds[0]["atlassian_cloud_id"]) + + +async def check_if_cloud_is_authorized( + context: ToolContext, + cloud: dict[str, Any], + semaphore: asyncio.Semaphore, +) -> dict[str, Any] | bool: + """Confirm whether an Atlassian Cloud is authorized for the current auth token. + + The Atlassian available-resources endpoint may return Clouds that have not been + authorized by the current user. This is a known Atlassian OAuth2 API bug [1]. + + We run this check against the '/myself' endpoint to confirm whether the Cloud + was actually authorized for the current auth token. + + [1] Reference about the Atlassian API bug: + https://community.developer.atlassian.com/t/urgent-api-accessible-resources-endpoint-returns-sites-resources-that-are-not-permitted-by-the-user/66899 + Archived (2025-07-22): https://archive.is/0noNX + """ + cloud_id = cloud["atlassian_cloud_id"] + + try: + async with semaphore, httpx.AsyncClient() as client: + response = await client.get( + f"{JIRA_BASE_URL}/{cloud_id}/rest/api/3/myself", + headers={"Authorization": f"Bearer {context.get_auth_token_or_empty()}"}, + ) + + if response.status_code == 200: + return cloud + + elif response.status_code == 429 or response.status_code >= 500: + response.raise_for_status() + + else: + return False + + except Exception as exc: + message = ( + f"An error occurred while checking if the Atlassian Cloud with ID '{cloud_id}' " + "is authorized." + ) + developer_message = f"{message} Error info: {type(exc).__name__}: {exc!s}" + + raise ToolExecutionError( + message=message, + developer_message=developer_message, + ) from exc + + # This is necessary otherwise mypy will complain + else: + return False diff --git a/toolkits/jira/conftest.py b/toolkits/jira/conftest.py index 4f859f59..8a955b65 100644 --- a/toolkits/jira/conftest.py +++ b/toolkits/jira/conftest.py @@ -1,6 +1,7 @@ import random import string -from collections.abc import Callable +import uuid +from collections.abc import Callable, Generator from typing import Any from unittest.mock import MagicMock, patch @@ -8,8 +9,6 @@ import httpx import pytest from arcade_tdk import ToolAuthorizationContext, ToolContext -from arcade_jira.cache import set_cloud_id, set_cloud_name - @pytest.fixture def fake_auth_token(generate_random_str: Callable) -> str: @@ -17,8 +16,8 @@ def fake_auth_token(generate_random_str: Callable) -> str: @pytest.fixture -def fake_cloud_id(generate_random_str: Callable) -> str: - return generate_random_str() +def fake_cloud_id() -> str: + return str(uuid.uuid4()) @pytest.fixture @@ -26,13 +25,6 @@ def fake_cloud_name(generate_random_str: Callable) -> str: return generate_random_str() -@pytest.fixture(autouse=True) -def set_cloud_id_cache(fake_auth_token: str, fake_cloud_id: str, fake_cloud_name: str) -> None: - """This fixture auto-sets cloud ID in the cache to skip the HTTP call to get it""" - set_cloud_id(fake_auth_token, fake_cloud_id) - set_cloud_name(fake_auth_token, fake_cloud_name) - - @pytest.fixture def generate_random_str() -> Callable[[int], str]: def random_str_builder(length: int = 10) -> str: @@ -83,6 +75,31 @@ def mock_httpx_response() -> Callable[[int, dict], httpx.Response]: return generate_mock_httpx_response +@pytest.fixture(autouse=True) +def mock_get_available_atlassian_clouds_globally( + fake_cloud_id: str, + fake_cloud_name: str, +) -> Generator[None, None, None]: + """Mock get_available_atlassian_clouds for all tests.""" + + def mock_func(context: ToolContext) -> list[dict]: + return { + "clouds_available": [ + { + "atlassian_cloud_id": fake_cloud_id, + "atlassian_cloud_name": fake_cloud_name, + "atlassian_cloud_url": f"https://{fake_cloud_name}.atlassian.net", + } + ] + } + + with patch( + "arcade_jira.tools.cloud.get_available_atlassian_clouds", + side_effect=mock_func, + ): + yield + + @pytest.fixture def build_user_dict( generate_random_str: Callable[[int], str], diff --git a/toolkits/jira/evals/eval_multi_cloud.py b/toolkits/jira/evals/eval_multi_cloud.py new file mode 100644 index 00000000..9a440c75 --- /dev/null +++ b/toolkits/jira/evals/eval_multi_cloud.py @@ -0,0 +1,291 @@ +import json +import uuid + +from arcade_evals import ( + EvalRubric, + EvalSuite, + ExpectedToolCall, + tool_eval, +) +from arcade_evals.critic import BinaryCritic +from arcade_tdk import ToolCatalog + +import arcade_jira +from arcade_jira.tools.comments import get_issue_comments +from arcade_jira.tools.issues import get_issue_by_id + +# Evaluation rubric +rubric = EvalRubric( + fail_threshold=0.85, + warn_threshold=0.95, +) + + +catalog = ToolCatalog() +catalog.add_module(arcade_jira) + + +@tool_eval() +def multi_cloud_eval_suite() -> EvalSuite: + suite = EvalSuite( + name="Atlassian multi-cloud evaluation suite", + system_message=( + "You are an AI assistant with access to Jira tools. " + "Use them to help the user with their tasks." + ), + catalog=catalog, + rubric=rubric, + ) + + suite.add_case( + name="Test calling tool without specifying a cloud id", + user_message="Get the issue with ID '10000'.", + expected_tool_calls=[ + ExpectedToolCall( + func=get_issue_by_id, + args={ + "issue_id": "10000", + "atlassian_cloud_id": None, + }, + ), + ], + rubric=rubric, + critics=[ + BinaryCritic(critic_field="issue_id", weight=0.5), + BinaryCritic(critic_field="atlassian_cloud_id", weight=0.5), + ], + ) + + cloud_id = str(uuid.uuid4()) + + suite.add_case( + name="Test calling tool specifying a cloud id directly with the request", + user_message=f"Get the issue with ID '10000' in the Cloud with ID '{cloud_id}'.", + expected_tool_calls=[ + ExpectedToolCall( + func=get_issue_by_id, + args={ + "issue_id": "10000", + "atlassian_cloud_id": cloud_id, + }, + ), + ], + rubric=rubric, + critics=[ + BinaryCritic(critic_field="issue_id", weight=0.5), + BinaryCritic(critic_field="atlassian_cloud_id", weight=0.5), + ], + ) + + cloud_1_id = str(uuid.uuid4()) + cloud_2_id = str(uuid.uuid4()) + available_clouds = [ + { + "atlassian_cloud_id": cloud_1_id, + "atlassian_cloud_name": "Foobar", + "atlassian_cloud_url": "https://foobar.atlassian.com", + }, + { + "atlassian_cloud_id": cloud_2_id, + "atlassian_cloud_name": "Quick Brown Fox", + "atlassian_cloud_url": "https://quickbrownfox.atlassian.com", + }, + ] + available_clouds_str = json.dumps(available_clouds) + + suite.add_case( + name="Test calling tool with multiple clouds error and specifying which cloud to use", + user_message="Let's use the Foobar Cloud", + expected_tool_calls=[ + ExpectedToolCall( + func=get_issue_by_id, + args={ + "issue_id": "10000", + "atlassian_cloud_id": cloud_1_id, + }, + ), + ], + rubric=rubric, + critics=[ + BinaryCritic(critic_field="issue_id", weight=0.5), + BinaryCritic(critic_field="atlassian_cloud_id", weight=0.5), + ], + additional_messages=[ + {"role": "user", "content": "Get the issue with id '10000' in Jira"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": { + "name": "Jira_GetIssueById", + "arguments": json.dumps({ + "issue": "10000", + }), + }, + } + ], + }, + { + "role": "tool", + "content": json.dumps({ + "name": "retryable_tool_call_error", + "message": ( + "Multiple Atlassian Clouds are available. One Cloud ID has to be selected " + "and provided in the tool call using the `atlassian_cloud_id` argument.", + ), + "developer_message": ( + "Multiple Atlassian Clouds are available. One Cloud ID has to be selected " + "and provided in the tool call using the `atlassian_cloud_id` argument.", + ), + "additional_prompt_content": ( + f"Available Atlassian Clouds:\n\n```json\n{available_clouds_str}\n```" + ), + }), + "tool_call_id": "call_1", + "name": "Jira_GetIssueById", + }, + { + "role": "assistant", + "content": ( + "Here is the list of available Atlassian clouds:\n\n" + "1. **Name:** Foobar\n" + " - **URL:** https://foobar.atlassian.com\n" + "2. **Name:** Quick Brown Fox\n" + " - **URL:** https://quickbrownfox.atlassian.com\n" + "Please select one of the above Clouds to get the Jira issue." + ), + }, + ], + ) + + suite.add_case( + name="Test calling tool one interaction after specifying a cloud id", + user_message="Get the comments on this issue", + expected_tool_calls=[ + ExpectedToolCall( + func=get_issue_comments, + args={ + "issue": "10000", + "atlassian_cloud_id": cloud_1_id, + }, + ), + ], + rubric=rubric, + critics=[ + BinaryCritic(critic_field="issue", weight=0.5), + BinaryCritic(critic_field="atlassian_cloud_id", weight=0.5), + ], + additional_messages=[ + {"role": "user", "content": "Get the issue with id '10000' in Jira"}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_1", + "type": "function", + "function": { + "name": "Jira_GetIssueById", + "arguments": json.dumps({ + "issue": "10000", + }), + }, + } + ], + }, + { + "role": "tool", + "content": json.dumps({ + "name": "retryable_tool_call_error", + "message": ( + "Multiple Atlassian Clouds are available. One Cloud ID has to be selected " + "and provided in the tool call using the `atlassian_cloud_id` argument.", + ), + "developer_message": ( + "Multiple Atlassian Clouds are available. One Cloud ID has to be selected " + "and provided in the tool call using the `atlassian_cloud_id` argument.", + ), + "additional_prompt_content": ( + f"Available Atlassian Clouds:\n\n```json\n{available_clouds_str}\n```" + ), + }), + "tool_call_id": "call_1", + "name": "Jira_GetIssueById", + }, + { + "role": "assistant", + "content": ( + "Here is the list of available Atlassian clouds:\n\n" + "1. **Name:** Foobar\n" + " - **URL:** https://foobar.atlassian.com\n" + "2. **Name:** Quick Brown Fox\n" + " - **URL:** https://quickbrownfox.atlassian.com\n" + "Please select one of the above Clouds to get the Jira issue." + ), + }, + {"role": "user", "content": "Let's use the Foobar Cloud from now on."}, + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_2", + "type": "function", + "function": { + "name": "Jira_GetIssueById", + "arguments": json.dumps({ + "issue": "10000", + "atlassian_cloud_id": cloud_1_id, + }), + }, + } + ], + }, + { + "role": "tool", + "content": json.dumps({ + "id": "10000", + "key": "ENG-101", + "assignee": { + "id": "10010", + "name": "John Doe", + "email": "john.doe@example.com", + }, + "description": "Implement the message queue", + "status": { + "id": "10020", + "name": "In Progress", + }, + "issuetype": { + "id": "10030", + "name": "Task", + }, + "project": { + "id": "10040", + "key": "ENG", + "name": "Engineering", + }, + }), + "tool_call_id": "call_2", + "name": "Jira_GetIssueById", + }, + { + "role": "assistant", + "content": ( + "Here is the issue:\n\n" + "1. **ID:** 10000\n" + " - **Key:** ENG-101\n" + " - **Assignee:** John Doe\n" + " - **Description:** Implement the message queue\n" + " - **Status:** In Progress\n" + " - **Issue Type:** Task\n" + " - **Project:** Engineering" + ), + }, + ], + ) + + return suite diff --git a/toolkits/jira/pyproject.toml b/toolkits/jira/pyproject.toml index 3e7a4e30..da7c96dc 100644 --- a/toolkits/jira/pyproject.toml +++ b/toolkits/jira/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "hatchling.build" [project] name = "arcade_jira" -version = "0.1.3" +version = "1.0.0" description = "Arcade.dev LLM tools for interacting with Atlassian Jira" requires-python = ">=3.10" dependencies = [ diff --git a/toolkits/jira/tests/test_client.py b/toolkits/jira/tests/test_client.py deleted file mode 100644 index 19cb2e9d..00000000 --- a/toolkits/jira/tests/test_client.py +++ /dev/null @@ -1,63 +0,0 @@ -import json - -import httpx -import pytest - -from arcade_jira.client import JiraClient -from arcade_jira.exceptions import JiraToolExecutionError - - -@pytest.mark.asyncio -async def test_get_cloud_data_from_available_resources_single_cloud( - mock_httpx_client, fake_auth_token -): - cloud = {"id": "123", "name": "Test Cloud", "url": "https://test.atlassian.net"} - - client = JiraClient(auth_token=fake_auth_token) - - mock_httpx_client.get.return_value = httpx.Response( - status_code=200, - json=[cloud], - ) - - response = await client._get_cloud_data_from_available_resources() - assert response == cloud - - -@pytest.mark.asyncio -async def test_get_cloud_data_from_available_resources_multiple_clouds( - mock_httpx_client, fake_auth_token -): - cloud1 = {"id": "123", "name": "Test Cloud", "url": "https://test.atlassian.net"} - cloud2 = {"id": "456", "name": "Test Cloud 2", "url": "https://test2.atlassian.net"} - - client = JiraClient(auth_token=fake_auth_token) - - mock_httpx_client.get.return_value = httpx.Response( - status_code=200, - json=[cloud1, cloud2], - ) - - with pytest.raises(JiraToolExecutionError) as error: - await client._get_cloud_data_from_available_resources() - - assert "Multiple cloud IDs returned by Atlassian" in error.value.message - assert json.dumps(cloud1) in error.value.message - assert json.dumps(cloud2) in error.value.message - - -@pytest.mark.asyncio -async def test_get_cloud_data_from_available_resources_duplicate_cloud( - mock_httpx_client, fake_auth_token -): - cloud = {"id": "123", "name": "Test Cloud", "url": "https://test.atlassian.net"} - - client = JiraClient(auth_token=fake_auth_token) - - mock_httpx_client.get.return_value = httpx.Response( - status_code=200, - json=[cloud, cloud], - ) - - response = await client._get_cloud_data_from_available_resources() - assert response == cloud diff --git a/toolkits/jira/tests/test_find_unique_project.py b/toolkits/jira/tests/test_find_unique_project.py index 97762edc..13fc661e 100644 --- a/toolkits/jira/tests/test_find_unique_project.py +++ b/toolkits/jira/tests/test_find_unique_project.py @@ -13,14 +13,13 @@ async def test_find_unique_project_by_id_success( mock_httpx_client, mock_httpx_response: Callable, build_project_dict: Callable, - fake_cloud_name: str, ): sample_project = build_project_dict() project_response = mock_httpx_response(200, sample_project) mock_httpx_client.get.return_value = project_response response = await find_unique_project(mock_context, sample_project["id"]) - assert response == clean_project_dict(sample_project, fake_cloud_name) + assert response == clean_project_dict(sample_project) @pytest.mark.asyncio @@ -30,7 +29,6 @@ async def test_find_unique_project_by_name_with_a_single_match( mock_httpx_response: Callable, build_project_dict: Callable, build_project_search_response_dict: Callable, - fake_cloud_name: str, ): sample_project = build_project_dict() get_project_by_id_response = mock_httpx_response(404, {}) @@ -43,7 +41,7 @@ async def test_find_unique_project_by_name_with_a_single_match( ] response = await find_unique_project(mock_context, sample_project["name"].lower()) - assert response == clean_project_dict(sample_project, fake_cloud_name) + assert response == clean_project_dict(sample_project) @pytest.mark.asyncio diff --git a/toolkits/jira/tests/test_find_unique_user.py b/toolkits/jira/tests/test_find_unique_user.py index eb30e0cf..dcd0ac46 100644 --- a/toolkits/jira/tests/test_find_unique_user.py +++ b/toolkits/jira/tests/test_find_unique_user.py @@ -17,14 +17,13 @@ async def test_find_unique_user_by_id_success( mock_httpx_client, mock_httpx_response: Callable, build_user_dict: Callable, - fake_cloud_name: str, ): sample_user = build_user_dict() user_response = mock_httpx_response(200, sample_user) mock_httpx_client.get.return_value = user_response response = await find_unique_user(mock_context, sample_user["accountId"]) - assert response == clean_user_dict(sample_user, fake_cloud_name) + assert response == clean_user_dict(sample_user) @pytest.mark.asyncio @@ -33,7 +32,6 @@ async def test_find_unique_user_by_name_with_a_single_match( mock_httpx_client, mock_httpx_response: Callable, build_user_dict: Callable, - fake_cloud_name: str, ): sample_user = build_user_dict() get_user_by_id_response = mock_httpx_response(404, {}) @@ -41,7 +39,7 @@ async def test_find_unique_user_by_name_with_a_single_match( mock_httpx_client.get.side_effect = [get_user_by_id_response, get_users_without_id_response] response = await find_unique_user(mock_context, sample_user["displayName"].lower()) - assert response == clean_user_dict(sample_user, fake_cloud_name) + assert response == clean_user_dict(sample_user) @pytest.mark.asyncio @@ -89,7 +87,6 @@ async def test_find_multiple_users_when_all_names_match_one_result( mock_httpx_client, mock_httpx_response: Callable, build_user_dict: Callable, - fake_cloud_name: str, ): user1 = build_user_dict() user2 = build_user_dict() @@ -104,8 +101,8 @@ async def test_find_multiple_users_when_all_names_match_one_result( ) assert response == [ - clean_user_dict(user1, fake_cloud_name), - clean_user_dict(user2, fake_cloud_name), + clean_user_dict(user1), + clean_user_dict(user2), ] @@ -138,7 +135,6 @@ async def test_find_multiple_users_when_user_is_not_found_by_name_but_found_by_i mock_httpx_client, mock_httpx_response: Callable, build_user_dict: Callable, - fake_cloud_name: str, ): user1 = build_user_dict() user2 = build_user_dict() @@ -154,8 +150,8 @@ async def test_find_multiple_users_when_user_is_not_found_by_name_but_found_by_i ) assert response == [ - clean_user_dict(user1, fake_cloud_name), - clean_user_dict(user2, fake_cloud_name), + clean_user_dict(user1), + clean_user_dict(user2), ] @@ -165,7 +161,6 @@ async def test_find_multiple_users_when_various_users_are_not_found_by_name_but_ mock_httpx_client, mock_httpx_response: Callable, build_user_dict: Callable, - fake_cloud_name: str, ): user1 = build_user_dict() user2 = build_user_dict() @@ -184,7 +179,7 @@ async def test_find_multiple_users_when_various_users_are_not_found_by_name_but_ ) assert response == [ - clean_user_dict(user1, fake_cloud_name), - clean_user_dict(user2, fake_cloud_name), - clean_user_dict(user3, fake_cloud_name), + clean_user_dict(user1), + clean_user_dict(user2), + clean_user_dict(user3), ] diff --git a/toolkits/jira/tests/test_multi_cloud.py b/toolkits/jira/tests/test_multi_cloud.py new file mode 100644 index 00000000..2aa3fb50 --- /dev/null +++ b/toolkits/jira/tests/test_multi_cloud.py @@ -0,0 +1,275 @@ +import asyncio +import json +import uuid +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest +from arcade_tdk import ToolContext +from arcade_tdk.errors import RetryableToolError, ToolExecutionError + +from arcade_jira.utils import check_if_cloud_is_authorized, resolve_cloud_id + + +@pytest.fixture +def mock_httpx_client(): + with patch("arcade_jira.utils.httpx") as mock_httpx: + yield mock_httpx.AsyncClient().__aenter__.return_value + + +@patch("arcade_jira.tools.cloud.get_available_atlassian_clouds") +@pytest.mark.asyncio +async def test_resolve_cloud_id_with_value_already_provided( + mock_get_available_atlassian_clouds: MagicMock, + mock_context: ToolContext, + fake_cloud_id: str, + fake_cloud_name: str, +): + another_cloud_id = str(uuid.uuid4()) + mock_get_available_atlassian_clouds.return_value = { + "clouds_available": [ + { + "atlassian_cloud_id": fake_cloud_id, + "atlassian_cloud_name": fake_cloud_name, + "atlassian_cloud_url": f"https://{fake_cloud_name}.atlassian.net", + } + ] + } + + cloud_id = await resolve_cloud_id(mock_context, another_cloud_id) + assert cloud_id == another_cloud_id + + +@patch("arcade_jira.tools.cloud.get_available_atlassian_clouds") +@pytest.mark.asyncio +async def test_resolve_cloud_id_providing_cloud_name( + mock_get_available_atlassian_clouds: MagicMock, + mock_context: ToolContext, + fake_cloud_id: str, + fake_cloud_name: str, +): + mock_get_available_atlassian_clouds.return_value = { + "clouds_available": [ + { + "atlassian_cloud_id": fake_cloud_id, + "atlassian_cloud_name": fake_cloud_name, + "atlassian_cloud_url": f"https://{fake_cloud_name}.atlassian.net", + } + ] + } + + cloud_id = await resolve_cloud_id(mock_context, fake_cloud_name) + assert cloud_id == fake_cloud_id + + +@patch("arcade_jira.tools.cloud.get_available_atlassian_clouds") +@pytest.mark.asyncio +async def test_resolve_cloud_id_with_single_cloud_available( + mock_get_available_atlassian_clouds: MagicMock, + mock_context: ToolContext, + fake_cloud_id: str, + fake_cloud_name: str, +): + mock_get_available_atlassian_clouds.return_value = { + "clouds_available": [ + { + "atlassian_cloud_id": fake_cloud_id, + "atlassian_cloud_name": fake_cloud_name, + "atlassian_cloud_url": f"https://{fake_cloud_name}.atlassian.net", + } + ] + } + + cloud_id = await resolve_cloud_id(mock_context, None) + assert cloud_id == fake_cloud_id + + +@patch("arcade_jira.tools.cloud.get_available_atlassian_clouds") +@pytest.mark.asyncio +async def test_resolve_cloud_id_with_multiple_distinct_clouds_available( + mock_get_available_atlassian_clouds: MagicMock, + mock_context: ToolContext, + fake_cloud_id: str, + fake_cloud_name: str, +): + cloud_id_2 = str(uuid.uuid4()) + mock_get_available_atlassian_clouds.return_value = { + "clouds_available": [ + { + "atlassian_cloud_id": fake_cloud_id, + "atlassian_cloud_name": fake_cloud_name, + "atlassian_cloud_url": f"https://{fake_cloud_name}.atlassian.net", + }, + { + "atlassian_cloud_id": cloud_id_2, + "atlassian_cloud_name": "Cloud 2", + "atlassian_cloud_url": "https://cloud2.atlassian.net", + }, + ] + } + + with pytest.raises(RetryableToolError) as exc: + await resolve_cloud_id(mock_context, None) + + assert "Multiple Atlassian Clouds are available" in exc.value.message + assert fake_cloud_id in exc.value.additional_prompt_content + assert fake_cloud_name in exc.value.additional_prompt_content + assert cloud_id_2 in exc.value.additional_prompt_content + assert "Cloud 2" in exc.value.additional_prompt_content + + +@patch("arcade_jira.tools.cloud.get_available_atlassian_clouds") +@pytest.mark.asyncio +async def test_resolve_cloud_id_with_no_clouds_available( + mock_get_available_atlassian_clouds: MagicMock, + mock_context: ToolContext, + fake_cloud_id: str, + fake_cloud_name: str, +): + mock_get_available_atlassian_clouds.return_value = {"clouds_available": []} + + with pytest.raises(ToolExecutionError) as exc: + await resolve_cloud_id(mock_context, None) + + assert "No Atlassian Cloud is available" in exc.value.message + + +@pytest.mark.asyncio +async def test_check_if_cloud_is_authorized_success( + mock_httpx_client: MagicMock, + mock_context: ToolContext, + fake_cloud_id: str, + fake_cloud_name: str, +): + cloud = { + "atlassian_cloud_id": fake_cloud_id, + "atlassian_cloud_name": fake_cloud_name, + "atlassian_cloud_url": f"https://{fake_cloud_name}.atlassian.net", + } + fake_user_id = uuid.uuid4() + mock_httpx_client.get.return_value.status_code = 200 + mock_httpx_client.get.return_value.json.return_value = { + "self": f"https://api.atlassian.com/ex/jira/{fake_cloud_id}/rest/api/3/user?accountId={fake_user_id!s}", + "accountId": fake_user_id, + "accountType": "atlassian", + "emailAddress": f"john.doe@{fake_cloud_name}.com", + "displayName": "John Doe", + } + + semaphore = asyncio.Semaphore(1) + + response = await check_if_cloud_is_authorized(mock_context, cloud, semaphore) + + assert response == cloud + + +@pytest.mark.asyncio +async def test_check_if_cloud_is_authorized_returning_401_error( + mock_httpx_client: MagicMock, + mock_context: ToolContext, + fake_cloud_id: str, + fake_cloud_name: str, +): + cloud = { + "atlassian_cloud_id": fake_cloud_id, + "atlassian_cloud_name": fake_cloud_name, + "atlassian_cloud_url": f"https://{fake_cloud_name}.atlassian.net", + } + + mock_httpx_client.get.return_value.status_code = 401 + mock_httpx_client.get.return_value.json.return_value = { + "code": 401, + "message": "Unauthorized", + } + + semaphore = asyncio.Semaphore(1) + + response = await check_if_cloud_is_authorized(mock_context, cloud, semaphore) + + assert response is False + + +@pytest.mark.asyncio +async def test_check_if_cloud_is_authorized_returning_404_no_message_available_error( + mock_httpx_client: MagicMock, + mock_context: ToolContext, + fake_cloud_id: str, + fake_cloud_name: str, +): + cloud = { + "atlassian_cloud_id": fake_cloud_id, + "atlassian_cloud_name": fake_cloud_name, + "atlassian_cloud_url": f"https://{fake_cloud_name}.atlassian.net", + } + + def mock_response_json() -> dict[str, Any]: + return { + "code": 404, + "message": "No message available", + } + + mock_httpx_client.get.return_value.status_code = 404 + mock_httpx_client.get.return_value.json = mock_response_json + + semaphore = asyncio.Semaphore(1) + + response = await check_if_cloud_is_authorized(mock_context, cloud, semaphore) + + assert response is False + + +@pytest.mark.asyncio +async def test_check_if_cloud_is_authorized_returning_404_unrecognized_error( + mock_httpx_client: MagicMock, + mock_context: ToolContext, + fake_cloud_id: str, + fake_cloud_name: str, +): + cloud = { + "atlassian_cloud_id": fake_cloud_id, + "atlassian_cloud_name": fake_cloud_name, + "atlassian_cloud_url": f"https://{fake_cloud_name}.atlassian.net", + } + + response_data = { + "code": 404, + "message": "Something else was not found", + } + + def mock_response_json() -> dict[str, Any]: + return response_data + + mock_httpx_client.get.return_value.status_code = 404 + mock_httpx_client.get.return_value.text = json.dumps(response_data) + mock_httpx_client.get.return_value.json = mock_response_json + + semaphore = asyncio.Semaphore(1) + + response = await check_if_cloud_is_authorized(mock_context, cloud, semaphore) + + assert response is False + + +@pytest.mark.asyncio +async def test_check_if_cloud_is_authorized_raising_unexpected_exception( + mock_httpx_client: MagicMock, + mock_context: ToolContext, + fake_cloud_id: str, + fake_cloud_name: str, +): + cloud = { + "atlassian_cloud_id": fake_cloud_id, + "atlassian_cloud_name": fake_cloud_name, + "atlassian_cloud_url": f"https://{fake_cloud_name}.atlassian.net", + } + + mock_httpx_client.get.side_effect = Exception("Something went wrong") + + semaphore = asyncio.Semaphore(1) + + with pytest.raises(ToolExecutionError) as exc: + await check_if_cloud_is_authorized(mock_context, cloud, semaphore) + + assert fake_cloud_id in exc.value.message + assert fake_cloud_id in exc.value.developer_message + assert "Something went wrong" in exc.value.developer_message diff --git a/toolkits/jira/tests/test_pagination_helpers.py b/toolkits/jira/tests/test_pagination_helpers.py index f33c47ff..0ddaf310 100644 --- a/toolkits/jira/tests/test_pagination_helpers.py +++ b/toolkits/jira/tests/test_pagination_helpers.py @@ -34,7 +34,6 @@ async def test_paginate_all_items_with_one_page( mock_httpx_response: Callable, build_project_dict: Callable, build_project_search_response_dict: Callable, - fake_cloud_name: str, ): projects = [build_project_dict(), build_project_dict()] response = mock_httpx_response(200, build_project_search_response_dict(projects, is_last=True)) @@ -46,7 +45,7 @@ async def test_paginate_all_items_with_one_page( response_items_key="projects", scheme_id="123", ) - assert response == [clean_project_dict(project, fake_cloud_name) for project in projects] + assert response == [clean_project_dict(project) for project in projects] @pytest.mark.asyncio @@ -56,7 +55,6 @@ async def test_paginate_all_items_with_multiple_pages( mock_httpx_response: Callable, build_project_dict: Callable, build_project_search_response_dict: Callable, - fake_cloud_name: str, ): page1 = [build_project_dict(), build_project_dict()] page2 = [build_project_dict(), build_project_dict()] @@ -75,9 +73,7 @@ async def test_paginate_all_items_with_multiple_pages( scheme_id="123", limit=2, ) - assert response == [ - clean_project_dict(project, fake_cloud_name) for project in page1 + page2 + page3 - ] + assert response == [clean_project_dict(project) for project in page1 + page2 + page3] @pytest.mark.asyncio