diff --git a/arcade/arcade/client/client.py b/arcade/arcade/client/client.py index 499992b3..ba0b3296 100644 --- a/arcade/arcade/client/client.py +++ b/arcade/arcade/client/client.py @@ -11,6 +11,7 @@ from arcade.client.base import ( from arcade.client.errors import APIStatusError, EngineNotHealthyError, EngineOfflineError from arcade.client.schema import ( AuthProvider, + AuthProviderType, AuthRequest, AuthResponse, ExecuteToolResponse, @@ -29,10 +30,10 @@ class AuthResource(BaseResource[ClientT]): def authorize( self, - provider: AuthProvider, - scopes: list[str], user_id: str, - authority: str | None = None, + provider: AuthProvider | str, + provider_type: AuthProviderType = AuthProviderType.oauth2, + scopes: list[str] | None = None, ) -> AuthResponse: """ Initiate an authorization request. @@ -41,16 +42,14 @@ class AuthResource(BaseResource[ClientT]): provider: The authorization provider. scopes: The scopes required for the authorization. user_id: The user ID initiating the authorization. - authority: The authority initiating the authorization. """ - auth_provider = provider.value + auth_provider_type = provider_type.value body = { "auth_requirement": { - "provider": auth_provider, - auth_provider: AuthRequest(scopes=scopes, authority=authority).model_dump( - exclude_none=True - ), + "provider_id": provider.value if isinstance(provider, AuthProvider) else provider, + "provider_type": auth_provider_type, + auth_provider_type: AuthRequest(scopes=scopes or []).model_dump(exclude_none=True), }, "user_id": user_id, } @@ -190,22 +189,21 @@ class AsyncAuthResource(BaseResource[AsyncArcadeClient]): async def authorize( self, - provider: AuthProvider, - scopes: list[str], user_id: str, - authority: str | None = None, + provider: AuthProvider | str, + provider_type: AuthProviderType = AuthProviderType.oauth2, + scopes: list[str] | None = None, ) -> AuthResponse: """ Initiate an asynchronous authorization request. """ - auth_provider = provider.value + auth_provider_type = provider_type.value body = { "auth_requirement": { - "provider": auth_provider, - auth_provider: AuthRequest(scopes=scopes, authority=authority).model_dump( - exclude_none=True - ), + "provider_id": provider.value if isinstance(provider, AuthProvider) else provider, + "provider_type": auth_provider_type, + auth_provider_type: AuthRequest(scopes=scopes or []).model_dump(exclude_none=True), }, "user_id": user_id, } diff --git a/arcade/arcade/client/schema.py b/arcade/arcade/client/schema.py index 627c44ef..f255f22d 100644 --- a/arcade/arcade/client/schema.py +++ b/arcade/arcade/client/schema.py @@ -1,7 +1,7 @@ import os from enum import Enum -from pydantic import AnyUrl, BaseModel, Field +from pydantic import BaseModel, Field from arcade.core.schema import ToolAuthorizationContext, ToolCallOutput @@ -9,19 +9,21 @@ OPENAI_API_VERSION = os.getenv("OPENAI_API_VERSION", "v1") class AuthProvider(str, Enum): - """The supported authorization providers.""" - - oauth2 = "oauth2" - """OAuth 2.0 authorization""" - google = "google" """Google authorization""" - slack_user = "slack_user" + slack = "slack_user" """Slack (user token) authorization""" - github_app = "github_app" - """GitHub App authorization""" + github = "github" + """GitHub authorization""" + + +class AuthProviderType(str, Enum): + """The supported authorization provider types.""" + + oauth2 = "oauth2" + """OAuth 2.0 authorization""" class AuthRequest(BaseModel): @@ -30,9 +32,6 @@ class AuthRequest(BaseModel): # TODO (Nate): Make a validator here """ - authority: AnyUrl | str | None = None - """The URL of the OAuth 2.0 authorization server.""" - scopes: list[str] """The scope(s) needed for authorization.""" diff --git a/arcade/arcade/core/catalog.py b/arcade/arcade/core/catalog.py index ee64252f..a4a92397 100644 --- a/arcade/arcade/core/catalog.py +++ b/arcade/arcade/core/catalog.py @@ -46,7 +46,7 @@ from arcade.core.utils import ( snake_to_pascal_case, ) from arcade.sdk.annotations import Inferrable -from arcade.sdk.auth import BaseOAuth2, ToolAuthorization +from arcade.sdk.auth import OAuth2, ToolAuthorization InnerWireType = Literal["string", "integer", "number", "boolean", "json"] WireType = Union[InnerWireType, Literal["array"]] @@ -257,9 +257,10 @@ class ToolCatalog(BaseModel): auth_requirement = getattr(tool, "__tool_requires_auth__", None) if isinstance(auth_requirement, ToolAuthorization): new_auth_requirement = ToolAuthRequirement( - provider=auth_requirement.get_provider(), + provider_id=auth_requirement.provider_id, + provider_type=auth_requirement.provider_type, ) - if isinstance(auth_requirement, BaseOAuth2): + if isinstance(auth_requirement, OAuth2): new_auth_requirement.oauth2 = OAuth2Requirement(**auth_requirement.model_dump()) auth_requirement = new_auth_requirement @@ -274,7 +275,7 @@ class ToolCatalog(BaseModel): return ToolDefinition( name=tool_name, - full_name=str(fully_qualified_name), + fully_qualified_name=str(fully_qualified_name), description=tool_description, toolkit=toolkit_definition, inputs=create_input_definition(tool), diff --git a/arcade/arcade/core/schema.py b/arcade/arcade/core/schema.py index 0ff81a29..5a873752 100644 --- a/arcade/arcade/core/schema.py +++ b/arcade/arcade/core/schema.py @@ -2,7 +2,7 @@ import os from dataclasses import dataclass from typing import Any, Literal, Optional, Union -from pydantic import AnyUrl, BaseModel, Field +from pydantic import BaseModel, Field # allow for custom tool name separator TOOL_NAME_SEPARATOR = os.getenv("ARCADE_TOOL_NAME_SEPARATOR", ".") @@ -73,9 +73,6 @@ class ToolOutput(BaseModel): class OAuth2Requirement(BaseModel): """Indicates that the tool requires OAuth 2.0 authorization.""" - authority: Optional[AnyUrl] = None - """The URL of the OAuth 2.0 authorization server.""" - scopes: Optional[list[str]] = None """The scope(s) needed for authorization, if any.""" @@ -83,7 +80,19 @@ class OAuth2Requirement(BaseModel): class ToolAuthRequirement(BaseModel): """A requirement for authorization to use a tool.""" - provider: str + # Provider ID and Type needed for the Arcade Engine to look up the auth provider. + # However, the developer generally does not need to set these directly. + # Instead, they will use: + # @tool(requires_auth=Google(scopes=["profile", "email"])) + # or + # client.auth.authorize(provider=AuthProvider.google, scopes=["profile", "email"]) + # + # The Arcade SDK translates these into the appropriate provider ID and type. + # The only time the developer will set these is if they are using a custom auth provider. + provider_id: Optional[str] = None + """A unique provider ID.""" + + provider_type: str """The provider type.""" oauth2: Optional[OAuth2Requirement] = None @@ -161,7 +170,7 @@ class ToolDefinition(BaseModel): name: str """The name of the tool.""" - full_name: str + fully_qualified_name: str """The fully-qualified name of the tool.""" description: str @@ -205,6 +214,16 @@ class ToolAuthorizationContext(BaseModel): token: str | None = None """The token for the tool invocation.""" + user_info: dict = Field(default={}) + """ + The user information provided by the authorization server (if any). + + Some providers can provide structured user info, + for example an internal provider-specific user ID. + For those providers that support retrieving user info, + the Engine can automatically pass that to tool invocations. + """ + class ToolContext(BaseModel): """The context for a tool invocation.""" @@ -212,6 +231,9 @@ class ToolContext(BaseModel): authorization: ToolAuthorizationContext | None = None """The authorization context for the tool invocation that requires authorization.""" + user_id: str | None = None + """The user ID for the tool invocation (if any).""" + class ToolCallRequest(BaseModel): """The request to call (invoke) a tool.""" diff --git a/arcade/arcade/sdk/auth.py b/arcade/arcade/sdk/auth.py index 20901a78..ad024932 100644 --- a/arcade/arcade/sdk/auth.py +++ b/arcade/arcade/sdk/auth.py @@ -1,60 +1,71 @@ -from abc import ABC, abstractmethod +from enum import Enum from typing import Optional -from pydantic import AnyUrl, BaseModel +from pydantic import BaseModel, ConfigDict -class ToolAuthorization(BaseModel, ABC): +class AuthProviderType(str, Enum): + oauth2 = "oauth2" + + +class ToolAuthorization(BaseModel): """Marks a tool as requiring authorization.""" - @abstractmethod - def get_provider(self) -> str: - """Return the name of the authorization method.""" - pass + model_config = ConfigDict(frozen=True) - pass + provider_id: str + """The unique provider ID configured in Arcade.""" + + provider_type: AuthProviderType + """The type of the authorization provider.""" -class BaseOAuth2(ToolAuthorization): - """Base class for any provider supporting OAuth 2.0-like authorization.""" +class OAuth2(ToolAuthorization): + """Marks a tool as requiring OAuth 2.0 authorization.""" - authority: Optional[AnyUrl] = None - """The URL of the OAuth 2.0 authorization server.""" + provider_type: AuthProviderType = AuthProviderType.oauth2 scopes: Optional[list[str]] = None """The scope(s) needed for the authorized action.""" -class OAuth2(BaseOAuth2): - """Marks a tool as requiring OAuth 2.0 authorization.""" - - def get_provider(self) -> str: - return "oauth2" - - -class Google(BaseOAuth2): +class Google(OAuth2): """Marks a tool as requiring Google authorization.""" - def get_provider(self) -> str: - return "google" + provider_id: str = "google" -class SlackUser(BaseOAuth2): +class Slack(OAuth2): """Marks a tool as requiring Slack (user token) authorization.""" - def get_provider(self) -> str: - return "slack_user" + provider_id: str = "slack" -class GitHubApp(ToolAuthorization): +class GitHub(OAuth2): """Marks a tool as requiring GitHub App authorization.""" - def get_provider(self) -> str: - return "github_app" + provider_id: str = "github" -class X(BaseOAuth2): +class X(OAuth2): """Marks a tool as requiring X (Twitter) authorization.""" - def get_provider(self) -> str: - return "x" + provider_id: str = "x" + + +class LinkedIn(OAuth2): + """Marks a tool as requiring LinkedIn authorization.""" + + provider_id: str = "linkedin" + + +class Spotify(OAuth2): + """Marks a tool as requiring Spotify authorization.""" + + provider_id: str = "spotify" + + +class Zoom(OAuth2): + """Marks a tool as requiring Zoom authorization.""" + + provider_id: str = "zoom" diff --git a/arcade/pyproject.toml b/arcade/pyproject.toml index 519a4411..2ffb19b0 100644 --- a/arcade/pyproject.toml +++ b/arcade/pyproject.toml @@ -5,7 +5,7 @@ description = "" packages = [ {include="arcade", from="."} ] -authors = ["Arcade AI "] +authors = ["Arcade AI "] [build-system] requires = ["poetry-core>=1.0.0"] diff --git a/arcade/tests/client/test_client.py b/arcade/tests/client/test_client.py index faa7e6af..fe5a259c 100644 --- a/arcade/tests/client/test_client.py +++ b/arcade/tests/client/test_client.py @@ -12,7 +12,7 @@ from arcade.client.errors import ( PermissionDeniedError, UnauthorizedError, ) -from arcade.client.schema import AuthResponse, ExecuteToolResponse +from arcade.client.schema import AuthProviderType, AuthResponse, ExecuteToolResponse from arcade.core.schema import ToolDefinition AUTH_RESPONSE_DATA = { @@ -23,6 +23,14 @@ AUTH_RESPONSE_DATA = { "scopes": ["https://www.googleapis.com/auth/gmail.readonly"], } +AUTH_RESPONSE_DATA_NO_SCOPES = { + "auth_id": "auth_123", + "authorization_url": "https://example.com/auth", + "status": "pending", + "authorization_id": "auth_123", + "scopes": [], +} + TOOL_RESPONSE_DATA = { "tool_name": "GetEmails", "tool_version": "0.1.0", @@ -36,7 +44,7 @@ TOOL_RESPONSE_DATA = { TOOL_DEFINITION_DATA = { "name": "GetEmails", - "full_name": "TestToolkit.GetEmails", + "fully_qualified_name": "TestToolkit.GetEmails", "description": "Retrieve emails from a user's inbox", "toolkit": { "name": "TestToolkit", @@ -130,6 +138,30 @@ def test_arcade_auth_authorize(test_sync_client, mock_response, monkeypatch): assert auth_response == AuthResponse(**AUTH_RESPONSE_DATA) +def test_arcade_auth_authorize_with_provider_type(test_sync_client, mock_response, monkeypatch): + """Test Arcade.auth.authorize method.""" + monkeypatch.setattr(Arcade, "_execute_request", lambda *args, **kwargs: AUTH_RESPONSE_DATA) + auth_response = test_sync_client.auth.authorize( + provider="hooli", + provider_type=AuthProviderType.oauth2, + scopes=["https://www.googleapis.com/auth/gmail.readonly"], + user_id="sam@arcade-ai.com", + ) + assert auth_response == AuthResponse(**AUTH_RESPONSE_DATA) + + +def test_arcade_auth_authorize_with_no_scopes(test_sync_client, mock_response, monkeypatch): + """Test Arcade.auth.authorize method.""" + monkeypatch.setattr( + Arcade, "_execute_request", lambda *args, **kwargs: AUTH_RESPONSE_DATA_NO_SCOPES + ) + auth_response = test_sync_client.auth.authorize( + provider=AuthProvider.google, + user_id="sam@arcade-ai.com", + ) + assert auth_response == AuthResponse(**AUTH_RESPONSE_DATA_NO_SCOPES) + + def test_arcade_auth_poll_authorization(test_sync_client, mock_response, monkeypatch): """Test Arcade.auth.poll_authorization method.""" monkeypatch.setattr(Arcade, "_execute_request", lambda *args, **kwargs: AUTH_RESPONSE_DATA) @@ -201,6 +233,42 @@ async def test_async_arcade_auth_authorize(test_async_client, mock_async_respons assert auth_response == AuthResponse(**AUTH_RESPONSE_DATA) +@pytest.mark.asyncio +async def test_async_arcade_auth_authorize_with_provider_type( + test_async_client, mock_async_response, monkeypatch +): + """Test AsyncArcade.auth.authorize method.""" + + async def mock_execute_request(*args, **kwargs): + return AUTH_RESPONSE_DATA + + monkeypatch.setattr(AsyncArcade, "_execute_request", mock_execute_request) + auth_response = await test_async_client.auth.authorize( + provider="hooli", + provider_type=AuthProviderType.oauth2, + scopes=["https://www.googleapis.com/auth/gmail.readonly"], + user_id="sam@arcade-ai.com", + ) + assert auth_response == AuthResponse(**AUTH_RESPONSE_DATA) + + +@pytest.mark.asyncio +async def test_async_arcade_auth_authorize_with_no_scopes( + test_async_client, mock_async_response, monkeypatch +): + """Test AsyncArcade.auth.authorize method.""" + + async def mock_execute_request(*args, **kwargs): + return AUTH_RESPONSE_DATA_NO_SCOPES + + monkeypatch.setattr(AsyncArcade, "_execute_request", mock_execute_request) + auth_response = await test_async_client.auth.authorize( + provider=AuthProvider.google, + user_id="sam@arcade-ai.com", + ) + assert auth_response == AuthResponse(**AUTH_RESPONSE_DATA_NO_SCOPES) + + @pytest.mark.asyncio async def test_async_arcade_auth_poll_authorization( test_async_client, mock_async_response, monkeypatch diff --git a/arcade/tests/sdk/test_tool_decorator.py b/arcade/tests/sdk/test_tool_decorator.py index d2cb642c..d7e7bcb7 100644 --- a/arcade/tests/sdk/test_tool_decorator.py +++ b/arcade/tests/sdk/test_tool_decorator.py @@ -39,7 +39,7 @@ def test_tool_decorator_with_all_options(): name="TestTool", desc="Test description", requires_auth=OAuth2( - authority="https://example.com/oauth2/auth", + provider_id="example", scopes=["test_scope", "another.scope"], ), ) @@ -48,5 +48,4 @@ def test_tool_decorator_with_all_options(): assert test_tool.__tool_name__ == "TestTool" assert test_tool.__tool_description__ == "Test description" - assert str(test_tool.__tool_requires_auth__.authority) == "https://example.com/oauth2/auth" assert test_tool.__tool_requires_auth__.scopes == ["test_scope", "another.scope"] diff --git a/arcade/tests/tool/test_create_tool_definition.py b/arcade/tests/tool/test_create_tool_definition.py index 82ba78a9..6e991fb5 100644 --- a/arcade/tests/tool/test_create_tool_definition.py +++ b/arcade/tests/tool/test_create_tool_definition.py @@ -17,7 +17,7 @@ from arcade.core.schema import ( from arcade.core.utils import snake_to_pascal_case from arcade.sdk import tool from arcade.sdk.annotations import Inferrable -from arcade.sdk.auth import GitHubApp, Google, OAuth2, SlackUser, X +from arcade.sdk.auth import GitHub, Google, OAuth2, Slack, X ### Tests on @tool decorator @@ -48,7 +48,10 @@ def func_with_name_and_description(): @tool( desc="A function that requires authentication", - requires_auth=OAuth2(authority="https://example.com/oauth2/auth", scopes=["scope1", "scope2"]), + requires_auth=OAuth2( + provider_id="example", + scopes=["scope1", "scope2"], + ), ) def func_with_auth_requirement(): pass @@ -64,7 +67,7 @@ def func_with_google_auth_requirement(): @tool( desc="A function that requires GitHub authorization", - requires_auth=GitHubApp(), + requires_auth=GitHub(), ) def func_with_github_auth_requirement(): pass @@ -72,7 +75,7 @@ def func_with_github_auth_requirement(): @tool( desc="A function that requires Slack user authorization", - requires_auth=SlackUser(scopes=["chat:write", "channels:history"]), + requires_auth=Slack(scopes=["chat:write", "channels:history"]), ) def func_with_slack_user_auth_requirement(): pass @@ -239,7 +242,7 @@ def func_with_complex_return() -> dict[str, str]: func_with_name_and_description, { "name": "MyCustomTool", - "full_name": "TestToolkit.MyCustomTool", + "fully_qualified_name": "TestToolkit.MyCustomTool", "description": "A function with a very cool description", }, id="func_with_description_and_name", @@ -254,7 +257,8 @@ def func_with_complex_return() -> dict[str, str]: { "requirements": ToolRequirements( authorization=ToolAuthRequirement( - provider="oauth2", + provider_id="example", + provider_type="oauth2", oauth2=OAuth2Requirement( authority="https://example.com/oauth2/auth", scopes=["scope1", "scope2"], @@ -269,7 +273,8 @@ def func_with_complex_return() -> dict[str, str]: { "requirements": ToolRequirements( authorization=ToolAuthRequirement( - provider="google", + provider_id="google", + provider_type="oauth2", oauth2=OAuth2Requirement( scopes=["https://www.googleapis.com/auth/gmail.readonly"], ), @@ -283,7 +288,7 @@ def func_with_complex_return() -> dict[str, str]: { "requirements": ToolRequirements( authorization=ToolAuthRequirement( - provider="github_app", + provider_id="github", provider_type="oauth2", oauth2=OAuth2Requirement() ) ) }, @@ -294,7 +299,8 @@ def func_with_complex_return() -> dict[str, str]: { "requirements": ToolRequirements( authorization=ToolAuthRequirement( - provider="slack_user", + provider_id="slack", + provider_type="oauth2", oauth2=OAuth2Requirement( scopes=["chat:write", "channels:history"], ), @@ -683,7 +689,7 @@ def test_tool_name_is_set_correctly(): tool_def = ToolCatalog.create_tool_definition(func_with_description, "test_toolkit", "1.0.0") assert tool_def.name == snake_to_pascal_case(func_with_description.__name__) - assert tool_def.full_name == "TestToolkit.FuncWithDescription" + assert tool_def.fully_qualified_name == "TestToolkit.FuncWithDescription" @pytest.mark.parametrize( diff --git a/cspell.config.yaml b/cspell.config.yaml index cc24d7ce..add8d12a 100644 --- a/cspell.config.yaml +++ b/cspell.config.yaml @@ -6,6 +6,7 @@ dictionaries: [] words: - conlist - fastapi + - httpx - openai - pydantic - pyproject diff --git a/examples/fastapi/pyproject.toml b/examples/fastapi/pyproject.toml index a2de251a..4e285d0d 100644 --- a/examples/fastapi/pyproject.toml +++ b/examples/fastapi/pyproject.toml @@ -2,7 +2,7 @@ name = "arcade_example_fastapi" version = "0.1.0" description = "FastAPI example app with Arcade" -authors = ["Nate Barbettini "] +authors = ["Arcade AI repo owned by titled 'Found a bug' with the body 'I'm having a problem with this.' Assign it to and label it 'bug'" -@tool(requires_auth=GitHubApp()) +@tool(requires_auth=GitHub()) async def create_issue( context: ToolContext, owner: Annotated[str, "The account owner of the repository. The name is not case sensitive."], @@ -85,7 +85,7 @@ async def create_issue( # Implements https://docs.github.com/en/rest/issues/comments?apiVersion=2022-11-28#create-an-issue-comment # Example `arcade chat` usage: "create a comment in the vscode repo owned by microsoft for issue 1347 that says 'Me too'" -@tool(requires_auth=GitHubApp()) +@tool(requires_auth=GitHub()) async def create_issue_comment( context: ToolContext, owner: Annotated[str, "The account owner of the repository. The name is not case sensitive."], diff --git a/toolkits/github/arcade_github/tools/pull_requests.py b/toolkits/github/arcade_github/tools/pull_requests.py index 81a2b2c8..54a63e3b 100644 --- a/toolkits/github/arcade_github/tools/pull_requests.py +++ b/toolkits/github/arcade_github/tools/pull_requests.py @@ -6,7 +6,7 @@ import httpx from arcade.core.errors import RetryableToolError from arcade.core.schema import ToolContext from arcade.sdk import tool -from arcade.sdk.auth import GitHubApp +from arcade.sdk.auth import GitHub from arcade_github.tools.models import ( DiffSide, PRSortProperty, @@ -28,7 +28,7 @@ from arcade_github.tools.utils import ( # Example `arcade chat` usage: "get all open PRs that has that are in the / repo" # TODO: Validate owner/repo combination is valid for the authenticated user. If not, return RetryableToolError with available repos. # TODO: list repo's branches and validate base is in the list (or default to main). If not, return RetryableToolError with available branches. -@tool(requires_auth=GitHubApp()) +@tool(requires_auth=GitHub()) async def list_pull_requests( context: ToolContext, owner: Annotated[str, "The account owner of the repository. The name is not case sensitive."], @@ -103,7 +103,7 @@ async def list_pull_requests( # Implements https://docs.github.com/en/rest/pulls/pulls?apiVersion=2022-11-28#get-a-pull-request # Example `arcade chat` usage: "get the PR #72 in the / repo. Include diff content in your response." -@tool(requires_auth=GitHubApp()) +@tool(requires_auth=GitHub()) async def get_pull_request( context: ToolContext, owner: Annotated[str, "The account owner of the repository. The name is not case sensitive."], @@ -177,7 +177,7 @@ async def get_pull_request( # Implements https://docs.github.com/en/rest/pulls/pulls?apiVersion=2022-11-28#update-a-pull-request # Example `arcade chat` usage: "update PR #72 in the / repo by changing the title to 'New Title' and setting the body to 'This PR description was added via arcade chat!'." # TODO: Enable this tool to append to the PR contents instead of only replacing content. -@tool(requires_auth=GitHubApp()) +@tool(requires_auth=GitHub()) async def update_pull_request( context: ToolContext, owner: Annotated[str, "The account owner of the repository. The name is not case sensitive."], @@ -242,7 +242,7 @@ async def update_pull_request( # Implements https://docs.github.com/en/rest/pulls/commits?apiVersion=2022-11-28#list-commits-on-a-pull-request # Example `arcade chat` usage: "list all of the commits for the PR 72 in the / repo" -@tool(requires_auth=GitHubApp()) +@tool(requires_auth=GitHub()) async def list_pull_request_commits( context: ToolContext, owner: Annotated[str, "The account owner of the repository. The name is not case sensitive."], @@ -308,7 +308,7 @@ async def list_pull_request_commits( # Example `arcade chat` usage: "create a reply to the review comment 1778019974 in arcadeai/arcade-ai for the PR 72 that says 'Thanks for the suggestion.'" # Note: This tool requires the ID of the review comment to reply to. To obtain this ID, you should first call the `list_review_comments_on_pull_request` function. # The returned JSON will contain the `id` field for each comment, which can be used as the `comment_id` parameter in this function. -@tool(requires_auth=GitHubApp()) +@tool(requires_auth=GitHub()) async def create_reply_for_review_comment( context: ToolContext, owner: Annotated[str, "The account owner of the repository. The name is not case sensitive."], @@ -350,7 +350,7 @@ async def create_reply_for_review_comment( # Implements https://docs.github.com/en/rest/pulls/comments?apiVersion=2022-11-28#list-review-comments-on-a-pull-request # Example `arcade chat` usage: "list all of the review comments for PR 72 in /" -@tool(requires_auth=GitHubApp()) +@tool(requires_auth=GitHub()) async def list_review_comments_on_pull_request( context: ToolContext, owner: Annotated[str, "The account owner of the repository. The name is not case sensitive."], @@ -437,7 +437,7 @@ async def list_review_comments_on_pull_request( # Implements https://docs.github.com/en/rest/pulls/comments?apiVersion=2022-11-28#create-a-review-comment-for-a-pull-request # Example `arcade chat` usage: "create a review comment for PR 72 in / that says 'Great stuff! This looks good to merge. Add the comment to README.md file.'" # TODO: Verify that path parameter exists in the PR's files that have changed (Or should we allow for any file in the repo?). If not, then throw RetryableToolError with all valid file paths. -@tool(requires_auth=GitHubApp()) +@tool(requires_auth=GitHub()) async def create_review_comment( context: ToolContext, owner: Annotated[str, "The account owner of the repository. The name is not case sensitive."], diff --git a/toolkits/github/arcade_github/tools/repositories.py b/toolkits/github/arcade_github/tools/repositories.py index df8ecda3..3d7b18a8 100644 --- a/toolkits/github/arcade_github/tools/repositories.py +++ b/toolkits/github/arcade_github/tools/repositories.py @@ -5,7 +5,7 @@ import httpx from arcade.core.schema import ToolContext from arcade.sdk import tool -from arcade.sdk.auth import GitHubApp +from arcade.sdk.auth import GitHub from arcade_github.tools.models import ( ActivityType, RepoSortProperty, @@ -24,7 +24,7 @@ from arcade_github.tools.utils import ( # Implements https://docs.github.com/en/rest/repos/repos?apiVersion=2022-11-28#get-a-repository and returns only the stargazers_count field. # Example arcade chat usage: "How many stargazers does the / repo have?" -@tool(requires_auth=GitHubApp()) +@tool(requires_auth=GitHub()) async def count_stargazers( owner: Annotated[str, "The owner of the repository"], name: Annotated[str, "The name of the repository"], @@ -49,7 +49,7 @@ async def count_stargazers( # Implements https://docs.github.com/en/rest/repos/repos?apiVersion=2022-11-28#list-organization-repositories # Example arcade chat usage: "List all repositories for the organization. Sort by creation date in descending order." -@tool(requires_auth=GitHubApp()) +@tool(requires_auth=GitHub()) async def list_org_repositories( context: ToolContext, org: Annotated[str, "The organization name. The name is not case sensitive"], @@ -111,7 +111,7 @@ async def list_org_repositories( # Implements https://docs.github.com/en/rest/repos/repos?apiVersion=2022-11-28#get-a-repository # Example arcade chat usage: "Tell me about the / repo." -@tool(requires_auth=GitHubApp()) +@tool(requires_auth=GitHub()) async def get_repository( context: ToolContext, owner: Annotated[str, "The account owner of the repository. The name is not case sensitive."], @@ -166,7 +166,7 @@ async def get_repository( # Implements https://docs.github.com/en/rest/repos/repos?apiVersion=2022-11-28#list-repository-activities # Example arcade chat usage: "List all merges into main for the / repo in the last week by " -@tool(requires_auth=GitHubApp()) +@tool(requires_auth=GitHub()) async def list_repository_activities( context: ToolContext, owner: Annotated[str, "The account owner of the repository. The name is not case sensitive."], @@ -260,7 +260,7 @@ async def list_repository_activities( # Implements https://docs.github.com/en/rest/pulls/comments?apiVersion=2022-11-28#list-review-comments-in-a-repository # Example arcade chat usage: "List all review comments for the / repo. Sort by update date in descending order." # TODO: Improve the 'since' input parameter such that language model can more easily specify a valid date/time. -@tool(requires_auth=GitHubApp()) +@tool(requires_auth=GitHub()) async def list_review_comments_in_a_repository( context: ToolContext, owner: Annotated[str, "The account owner of the repository. The name is not case sensitive."], diff --git a/toolkits/github/pyproject.toml b/toolkits/github/pyproject.toml index a7fe4f7d..b355cd3e 100644 --- a/toolkits/github/pyproject.toml +++ b/toolkits/github/pyproject.toml @@ -2,7 +2,7 @@ name = "arcade_github" version = "0.1.0" description = "LLM tools for interacting with Github" -authors = ["Eric Gustin "] +authors = ["Arcade AI ", "Eric Gustin "] +authors = ["Arcade AI httpx.Response: + """ + Send an asynchronous request to the LinkedIn API. + + Args: + context: The tool context containing the authorization token. + method: The HTTP method (GET, POST, PUT, DELETE, etc.). + endpoint: The API endpoint path (e.g., "/ugcPosts"). + params: Query parameters to include in the request. + json_data: JSON data to include in the request body. + + Returns: + The response object from the API request. + + Raises: + ToolExecutionError: If the request fails for any reason. + """ + url = f"{LINKEDIN_BASE_URL}{endpoint}" + headers = {"Authorization": f"Bearer {context.authorization.token}"} + + async with httpx.AsyncClient() as client: + try: + response = await client.request( + method, url, headers=headers, params=params, json=json_data + ) + response.raise_for_status() + except httpx.RequestError as e: + raise ToolExecutionError(f"Failed to send request to LinkedIn API: {e}") + + return response + + +def _handle_linkedin_api_error(response: httpx.Response): + """ + Handle errors from the LinkedIn API by mapping common status codes to ToolExecutionErrors. + + Args: + response: The response object from the API request. + + Raises: + ToolExecutionError: If the response contains an error status code. + """ + status_code_map = { + 401: ToolExecutionError("Unauthorized: Invalid or expired token"), + 403: ToolExecutionError("Forbidden: User does not have Spotify Premium"), + 429: ToolExecutionError("Too Many Requests: Rate limit exceeded"), + } + + if response.status_code in status_code_map: + raise status_code_map[response.status_code] + elif response.status_code >= 400: + raise ToolExecutionError(f"Error: {response.status_code} - {response.text}") + + +@tool( + requires_auth=LinkedIn( + scopes=["w_member_social"], + ) +) +async def create_text_post( + context: ToolContext, + text: Annotated[str, "The text content of the post"], +) -> Annotated[str, "URL of the shared post"]: + """Share a new text post to LinkedIn.""" + endpoint = "/ugcPosts" + + # The LinkedIn user ID is required to create a post, even though we're using the user's access token. + # Arcade Engine gets the current user's info from LinkedIn and automatically populates context.authorization.user_info. + # LinkedIn calls the user ID "sub" in their user_info data payload. See: + # https://learn.microsoft.com/en-us/linkedin/consumer/integrations/self-serve/sign-in-with-linkedin-v2#api-request-to-retreive-member-details + user_id = context.authorization.user_info.get("sub") + if not user_id: + raise ToolExecutionError( + "User ID not found.", + developer_message="User ID not found in `context.authorization.user_info.sub`", + ) + + author_id = f"urn:li:person:{user_id}" + payload = { + "author": author_id, + "lifecycleState": "PUBLISHED", + "specificContent": { + "com.linkedin.ugc.ShareContent": { + "shareCommentary": {"text": text}, + "shareMediaCategory": "NONE", + } + }, + "visibility": {"com.linkedin.ugc.MemberNetworkVisibility": "PUBLIC"}, + } + + response = await _send_linkedin_request(context, "POST", endpoint, json=payload) + if response.status_code >= 200 and response.status_code < 300: + share_id = response.json().get("id") + return f"https://www.linkedin.com/feed/update/{share_id}/" + else: + _handle_linkedin_api_error(response) diff --git a/toolkits/linkedin/pyproject.toml b/toolkits/linkedin/pyproject.toml new file mode 100644 index 00000000..8c6177f9 --- /dev/null +++ b/toolkits/linkedin/pyproject.toml @@ -0,0 +1,17 @@ +[tool.poetry] +name = "arcade_linkedin" +version = "0.1.0" +description = "Arcade tools for LinkedIn" +authors = ["Arcade AI =1.0.0"] +build-backend = "poetry.core.masonry.api" diff --git a/toolkits/math/pyproject.toml b/toolkits/math/pyproject.toml index 4012d6dd..19a39089 100644 --- a/toolkits/math/pyproject.toml +++ b/toolkits/math/pyproject.toml @@ -2,7 +2,7 @@ name = "arcade_math" version = "0.1.0" description = "Math toolkit for Arcade" -authors = ["Nate "] +authors = ["Arcade AI "] +authors = ["Arcade AI str: @tool( - requires_auth=SlackUser( + requires_auth=Slack( scopes=[ "chat:write", "channels:read", diff --git a/toolkits/slack/pyproject.toml b/toolkits/slack/pyproject.toml index 2fdf7f22..6558aee8 100644 --- a/toolkits/slack/pyproject.toml +++ b/toolkits/slack/pyproject.toml @@ -2,7 +2,7 @@ name = "arcade_slack" version = "0.1.0" description = "Slack tools for LLMs" -authors = ["Nate Barbettini "] +authors = ["Arcade AI httpx.Response: + """ + Send an asynchronous request to the Spotify API. + + Args: + context: The tool context containing the authorization token. + method: The HTTP method (GET, POST, PUT, DELETE, etc.). + endpoint: The API endpoint path (e.g., "/me/player/play"). + params: Query parameters to include in the request. + json_data: JSON data to include in the request body. + + Returns: + The response object from the API request. + + Raises: + ToolExecutionError: If the request fails for any reason. + """ + url = f"{SPOTIFY_BASE_URL}{endpoint}" + headers = {"Authorization": f"Bearer {context.authorization.token}"} + + async with httpx.AsyncClient() as client: + try: + response = await client.request( + method, url, headers=headers, params=params, json=json_data + ) + response.raise_for_status() + except httpx.RequestError as e: + raise ToolExecutionError(f"Failed to send request to Spotify API: {e}") + + return response + + +def _handle_spotify_api_error(response: httpx.Response): + """ + Handle errors from the Spotify API by mapping common status codes to ToolExecutionErrors. + + Args: + response: The response object from the API request. + + Raises: + ToolExecutionError: If the response contains an error status code. + """ + status_code_map = { + 401: ToolExecutionError("Unauthorized: Invalid or expired token"), + 403: ToolExecutionError("Forbidden: User does not have Spotify Premium"), + 429: ToolExecutionError("Too Many Requests: Rate limit exceeded"), + } + + if response.status_code in status_code_map: + raise status_code_map[response.status_code] + elif response.status_code >= 400: + raise ToolExecutionError(f"Error: {response.status_code} - {response.text}") + + +@tool( + requires_auth=Spotify( + scopes=["user-modify-playback-state"], + ) +) +async def pause( + context: ToolContext, + device_id: Annotated[ + Optional[str], + "The id of the device this command is targeting. If omitted, the active device is targeted.", + ] = None, +) -> Annotated[str, "Success string confirming the pause"]: + """Pause the current track""" + endpoint = "/me/player/pause" + params = {"device_id": device_id} if device_id else {} + + response = await _send_spotify_request(context, "PUT", endpoint, params=params) + if response.status_code >= 200 and response.status_code < 300: + return "Playback paused" + else: + _handle_spotify_api_error(response) + + +@tool( + requires_auth=Spotify( + scopes=["user-modify-playback-state"], + ) +) +async def resume( + context: ToolContext, + device_id: Annotated[ + Optional[str], + "The id of the device this command is targeting. If omitted, the active device is targeted.", + ] = None, +) -> Annotated[str, "Success string confirming the playback resume"]: + """Resume the current track, if any""" + endpoint = "/me/player/play" + params = {"device_id": device_id} if device_id else {} + + response = await _send_spotify_request(context, "PUT", endpoint, params=params) + if response.status_code >= 200 and response.status_code < 300: + return "Playback resumed" + else: + _handle_spotify_api_error(response) + + +@tool( + requires_auth=Spotify( + scopes=["user-read-playback-state"], + ) +) +async def get_playback_state( + context: ToolContext, +) -> Annotated[dict, "Information about the user's current playback state"]: + """Get information about the user's current playback state, including track or episode, progress, and active device.""" + endpoint = "/me/player" + + response = await _send_spotify_request(context, "GET", endpoint) + if response.status_code == 204: + return {"status": "Playback not available or active"} + elif response.status_code == 200: + data = response.json() + + # TODO: Return a more structured model + result = { + "device_name": data.get("device", {}).get("name"), + "currently_playing_type": data.get("currently_playing_type"), + } + + if data.get("currently_playing_type") == "track": + item = data.get("item", {}) + album = item.get("album", {}) + result.update({ + "album_name": album.get("name"), + "album_artists": [artist.get("name") for artist in album.get("artists", [])], + "album_spotify_url": album.get("external_urls", {}).get("spotify"), + "track_name": item.get("name"), + "track_artists": [artist.get("name") for artist in item.get("artists", [])], + }) + elif data.get("currently_playing_type") == "episode": + item = data.get("item", {}) + show = item.get("show", {}) + result.update({ + "show_name": show.get("name"), + "show_spotify_url": show.get("external_urls", {}).get("spotify"), + "episode_name": item.get("name"), + "episode_spotify_url": item.get("external_urls", {}).get("spotify"), + }) + return result + else: + _handle_spotify_api_error(response) diff --git a/toolkits/spotify/pyproject.toml b/toolkits/spotify/pyproject.toml new file mode 100644 index 00000000..9ec56366 --- /dev/null +++ b/toolkits/spotify/pyproject.toml @@ -0,0 +1,17 @@ +[tool.poetry] +name = "arcade_spotify" +version = "0.1.0" +description = "Arcade tools for Spotify" +authors = ["Arcade AI =1.0.0"] +build-backend = "poetry.core.masonry.api" diff --git a/toolkits/x/arcade_x/tools/tweets.py b/toolkits/x/arcade_x/tools/tweets.py index 7e113eff..ee80f378 100644 --- a/toolkits/x/arcade_x/tools/tweets.py +++ b/toolkits/x/arcade_x/tools/tweets.py @@ -12,7 +12,11 @@ TWEETS_URL = "https://api.x.com/2/tweets" # Manage Tweets Tools. See developer docs for additional available parameters: https://developer.x.com/en/docs/x-api/tweets/manage-tweets/api-reference -@tool(requires_auth=X(scopes=["tweet.read", "tweet.write", "users.read"])) +@tool( + requires_auth=X( + scopes=["tweet.read", "tweet.write", "users.read"], + ) +) async def post_tweet( context: ToolContext, tweet_text: Annotated[str, "The text content of the tweet you want to post"], diff --git a/toolkits/x/pyproject.toml b/toolkits/x/pyproject.toml index 0451f9e4..5a0a49aa 100644 --- a/toolkits/x/pyproject.toml +++ b/toolkits/x/pyproject.toml @@ -2,7 +2,7 @@ name = "arcade_x" version = "0.1.0" description = "LLM tools for interacting with X (Twitter)" -authors = ["Eric Gustin "] +authors = ["Arcade AI httpx.Response: + """ + Send an asynchronous request to the Zoom API. + + Args: + context: The tool context containing the authorization token. + method: The HTTP method (GET, POST, PUT, DELETE, etc.). + endpoint: The API endpoint path (e.g., "/users/me/upcoming_meetings"). + params: Query parameters to include in the request. + json_data: JSON data to include in the request body. + + Returns: + The response object from the API request. + + Raises: + ToolExecutionError: If the request fails for any reason. + """ + url = f"{ZOOM_BASE_URL}{endpoint}" + headers = {"Authorization": f"Bearer {context.authorization.token}"} + + async with httpx.AsyncClient() as client: + try: + response = await client.request( + method, url, headers=headers, params=params, json=json_data + ) + response.raise_for_status() + except httpx.RequestError as e: + raise ToolExecutionError(f"Failed to send request to Zoom API: {e}") + + return response + + +def _handle_zoom_api_error(response: httpx.Response): + """ + Handle errors from the Zoom API by mapping common status codes to ToolExecutionErrors. + + Args: + response: The response object from the API request. + + Raises: + ToolExecutionError: If the response contains an error status code. + """ + status_code_map = { + 401: ToolExecutionError("Unauthorized: Invalid or expired token"), + 403: ToolExecutionError("Forbidden: Access denied"), + 429: ToolExecutionError("Too Many Requests: Rate limit exceeded"), + } + + if response.status_code in status_code_map: + raise status_code_map[response.status_code] + elif response.status_code >= 400: + raise ToolExecutionError(f"Error: {response.status_code} - {response.text}") + + +@tool( + requires_auth=Zoom( + scopes=["meeting:read:list_upcoming_meetings"], + ) +) +async def list_upcoming_meetings( + context: ToolContext, + user_id: Annotated[ + Optional[str], + "The user's user ID or email address. Defaults to 'me' for the current user.", + ] = "me", +) -> Annotated[dict, "List of upcoming meetings within the next 24 hours"]: + """List a Zoom user's upcoming meetings within the next 24 hours.""" + endpoint = f"/users/{user_id}/upcoming_meetings" + + response = await _send_zoom_request(context, "GET", endpoint) + if response.status_code >= 200 and response.status_code < 300: + return response.json() + else: + _handle_zoom_api_error(response) + + +@tool( + requires_auth=Zoom( + scopes=["meeting:read:invitation"], + ) +) +async def get_meeting_invitation( + context: ToolContext, + meeting_id: Annotated[ + str, + "The meeting's numeric ID (as a string).", + ], +) -> Annotated[dict, "Meeting invitation string"]: + """Retrieve the invitation note for a specific Zoom meeting.""" + endpoint = f"/meetings/{meeting_id}/invitation" + + response = await _send_zoom_request(context, "GET", endpoint) + if response.status_code >= 200 and response.status_code < 300: + return response.json() + else: + _handle_zoom_api_error(response) diff --git a/toolkits/zoom/pyproject.toml b/toolkits/zoom/pyproject.toml new file mode 100644 index 00000000..d9d7d206 --- /dev/null +++ b/toolkits/zoom/pyproject.toml @@ -0,0 +1,17 @@ +[tool.poetry] +name = "arcade_zoom" +version = "0.1.0" +description = "Arcade tools for Zoom" +authors = ["Arcade AI =1.0.0"] +build-backend = "poetry.core.masonry.api"