SDK: Generic OAuth 2.0 connector (#81)

- Implements https://app.clickup.com/t/86b1whxb3 on the SDK side
- - Corresponding Engine PR:
https://github.com/ArcadeAI/Engine/pull/113/files?w=1
- Updates existing toolkits with new syntax.
This commit is contained in:
Nate Barbettini 2024-10-03 16:40:02 -07:00 committed by GitHub
parent 28ce4d0dfc
commit 799d376ae5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
37 changed files with 686 additions and 137 deletions

View file

@ -11,6 +11,7 @@ from arcade.client.base import (
from arcade.client.errors import APIStatusError, EngineNotHealthyError, EngineOfflineError
from arcade.client.schema import (
AuthProvider,
AuthProviderType,
AuthRequest,
AuthResponse,
ExecuteToolResponse,
@ -29,10 +30,10 @@ class AuthResource(BaseResource[ClientT]):
def authorize(
self,
provider: AuthProvider,
scopes: list[str],
user_id: str,
authority: str | None = None,
provider: AuthProvider | str,
provider_type: AuthProviderType = AuthProviderType.oauth2,
scopes: list[str] | None = None,
) -> AuthResponse:
"""
Initiate an authorization request.
@ -41,16 +42,14 @@ class AuthResource(BaseResource[ClientT]):
provider: The authorization provider.
scopes: The scopes required for the authorization.
user_id: The user ID initiating the authorization.
authority: The authority initiating the authorization.
"""
auth_provider = provider.value
auth_provider_type = provider_type.value
body = {
"auth_requirement": {
"provider": auth_provider,
auth_provider: AuthRequest(scopes=scopes, authority=authority).model_dump(
exclude_none=True
),
"provider_id": provider.value if isinstance(provider, AuthProvider) else provider,
"provider_type": auth_provider_type,
auth_provider_type: AuthRequest(scopes=scopes or []).model_dump(exclude_none=True),
},
"user_id": user_id,
}
@ -190,22 +189,21 @@ class AsyncAuthResource(BaseResource[AsyncArcadeClient]):
async def authorize(
self,
provider: AuthProvider,
scopes: list[str],
user_id: str,
authority: str | None = None,
provider: AuthProvider | str,
provider_type: AuthProviderType = AuthProviderType.oauth2,
scopes: list[str] | None = None,
) -> AuthResponse:
"""
Initiate an asynchronous authorization request.
"""
auth_provider = provider.value
auth_provider_type = provider_type.value
body = {
"auth_requirement": {
"provider": auth_provider,
auth_provider: AuthRequest(scopes=scopes, authority=authority).model_dump(
exclude_none=True
),
"provider_id": provider.value if isinstance(provider, AuthProvider) else provider,
"provider_type": auth_provider_type,
auth_provider_type: AuthRequest(scopes=scopes or []).model_dump(exclude_none=True),
},
"user_id": user_id,
}

View file

@ -1,7 +1,7 @@
import os
from enum import Enum
from pydantic import AnyUrl, BaseModel, Field
from pydantic import BaseModel, Field
from arcade.core.schema import ToolAuthorizationContext, ToolCallOutput
@ -9,19 +9,21 @@ OPENAI_API_VERSION = os.getenv("OPENAI_API_VERSION", "v1")
class AuthProvider(str, Enum):
"""The supported authorization providers."""
oauth2 = "oauth2"
"""OAuth 2.0 authorization"""
google = "google"
"""Google authorization"""
slack_user = "slack_user"
slack = "slack_user"
"""Slack (user token) authorization"""
github_app = "github_app"
"""GitHub App authorization"""
github = "github"
"""GitHub authorization"""
class AuthProviderType(str, Enum):
"""The supported authorization provider types."""
oauth2 = "oauth2"
"""OAuth 2.0 authorization"""
class AuthRequest(BaseModel):
@ -30,9 +32,6 @@ class AuthRequest(BaseModel):
# TODO (Nate): Make a validator here
"""
authority: AnyUrl | str | None = None
"""The URL of the OAuth 2.0 authorization server."""
scopes: list[str]
"""The scope(s) needed for authorization."""

View file

@ -46,7 +46,7 @@ from arcade.core.utils import (
snake_to_pascal_case,
)
from arcade.sdk.annotations import Inferrable
from arcade.sdk.auth import BaseOAuth2, ToolAuthorization
from arcade.sdk.auth import OAuth2, ToolAuthorization
InnerWireType = Literal["string", "integer", "number", "boolean", "json"]
WireType = Union[InnerWireType, Literal["array"]]
@ -257,9 +257,10 @@ class ToolCatalog(BaseModel):
auth_requirement = getattr(tool, "__tool_requires_auth__", None)
if isinstance(auth_requirement, ToolAuthorization):
new_auth_requirement = ToolAuthRequirement(
provider=auth_requirement.get_provider(),
provider_id=auth_requirement.provider_id,
provider_type=auth_requirement.provider_type,
)
if isinstance(auth_requirement, BaseOAuth2):
if isinstance(auth_requirement, OAuth2):
new_auth_requirement.oauth2 = OAuth2Requirement(**auth_requirement.model_dump())
auth_requirement = new_auth_requirement
@ -274,7 +275,7 @@ class ToolCatalog(BaseModel):
return ToolDefinition(
name=tool_name,
full_name=str(fully_qualified_name),
fully_qualified_name=str(fully_qualified_name),
description=tool_description,
toolkit=toolkit_definition,
inputs=create_input_definition(tool),

View file

@ -2,7 +2,7 @@ import os
from dataclasses import dataclass
from typing import Any, Literal, Optional, Union
from pydantic import AnyUrl, BaseModel, Field
from pydantic import BaseModel, Field
# allow for custom tool name separator
TOOL_NAME_SEPARATOR = os.getenv("ARCADE_TOOL_NAME_SEPARATOR", ".")
@ -73,9 +73,6 @@ class ToolOutput(BaseModel):
class OAuth2Requirement(BaseModel):
"""Indicates that the tool requires OAuth 2.0 authorization."""
authority: Optional[AnyUrl] = None
"""The URL of the OAuth 2.0 authorization server."""
scopes: Optional[list[str]] = None
"""The scope(s) needed for authorization, if any."""
@ -83,7 +80,19 @@ class OAuth2Requirement(BaseModel):
class ToolAuthRequirement(BaseModel):
"""A requirement for authorization to use a tool."""
provider: str
# Provider ID and Type needed for the Arcade Engine to look up the auth provider.
# However, the developer generally does not need to set these directly.
# Instead, they will use:
# @tool(requires_auth=Google(scopes=["profile", "email"]))
# or
# client.auth.authorize(provider=AuthProvider.google, scopes=["profile", "email"])
#
# The Arcade SDK translates these into the appropriate provider ID and type.
# The only time the developer will set these is if they are using a custom auth provider.
provider_id: Optional[str] = None
"""A unique provider ID."""
provider_type: str
"""The provider type."""
oauth2: Optional[OAuth2Requirement] = None
@ -161,7 +170,7 @@ class ToolDefinition(BaseModel):
name: str
"""The name of the tool."""
full_name: str
fully_qualified_name: str
"""The fully-qualified name of the tool."""
description: str
@ -205,6 +214,16 @@ class ToolAuthorizationContext(BaseModel):
token: str | None = None
"""The token for the tool invocation."""
user_info: dict = Field(default={})
"""
The user information provided by the authorization server (if any).
Some providers can provide structured user info,
for example an internal provider-specific user ID.
For those providers that support retrieving user info,
the Engine can automatically pass that to tool invocations.
"""
class ToolContext(BaseModel):
"""The context for a tool invocation."""
@ -212,6 +231,9 @@ class ToolContext(BaseModel):
authorization: ToolAuthorizationContext | None = None
"""The authorization context for the tool invocation that requires authorization."""
user_id: str | None = None
"""The user ID for the tool invocation (if any)."""
class ToolCallRequest(BaseModel):
"""The request to call (invoke) a tool."""

View file

@ -1,60 +1,71 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import Optional
from pydantic import AnyUrl, BaseModel
from pydantic import BaseModel, ConfigDict
class ToolAuthorization(BaseModel, ABC):
class AuthProviderType(str, Enum):
oauth2 = "oauth2"
class ToolAuthorization(BaseModel):
"""Marks a tool as requiring authorization."""
@abstractmethod
def get_provider(self) -> str:
"""Return the name of the authorization method."""
pass
model_config = ConfigDict(frozen=True)
pass
provider_id: str
"""The unique provider ID configured in Arcade."""
provider_type: AuthProviderType
"""The type of the authorization provider."""
class BaseOAuth2(ToolAuthorization):
"""Base class for any provider supporting OAuth 2.0-like authorization."""
class OAuth2(ToolAuthorization):
"""Marks a tool as requiring OAuth 2.0 authorization."""
authority: Optional[AnyUrl] = None
"""The URL of the OAuth 2.0 authorization server."""
provider_type: AuthProviderType = AuthProviderType.oauth2
scopes: Optional[list[str]] = None
"""The scope(s) needed for the authorized action."""
class OAuth2(BaseOAuth2):
"""Marks a tool as requiring OAuth 2.0 authorization."""
def get_provider(self) -> str:
return "oauth2"
class Google(BaseOAuth2):
class Google(OAuth2):
"""Marks a tool as requiring Google authorization."""
def get_provider(self) -> str:
return "google"
provider_id: str = "google"
class SlackUser(BaseOAuth2):
class Slack(OAuth2):
"""Marks a tool as requiring Slack (user token) authorization."""
def get_provider(self) -> str:
return "slack_user"
provider_id: str = "slack"
class GitHubApp(ToolAuthorization):
class GitHub(OAuth2):
"""Marks a tool as requiring GitHub App authorization."""
def get_provider(self) -> str:
return "github_app"
provider_id: str = "github"
class X(BaseOAuth2):
class X(OAuth2):
"""Marks a tool as requiring X (Twitter) authorization."""
def get_provider(self) -> str:
return "x"
provider_id: str = "x"
class LinkedIn(OAuth2):
"""Marks a tool as requiring LinkedIn authorization."""
provider_id: str = "linkedin"
class Spotify(OAuth2):
"""Marks a tool as requiring Spotify authorization."""
provider_id: str = "spotify"
class Zoom(OAuth2):
"""Marks a tool as requiring Zoom authorization."""
provider_id: str = "zoom"

View file

@ -5,7 +5,7 @@ description = ""
packages = [
{include="arcade", from="."}
]
authors = ["Arcade AI <sam@arcade-ai.com>"]
authors = ["Arcade AI <dev@arcade-ai.com>"]
[build-system]
requires = ["poetry-core>=1.0.0"]

View file

@ -12,7 +12,7 @@ from arcade.client.errors import (
PermissionDeniedError,
UnauthorizedError,
)
from arcade.client.schema import AuthResponse, ExecuteToolResponse
from arcade.client.schema import AuthProviderType, AuthResponse, ExecuteToolResponse
from arcade.core.schema import ToolDefinition
AUTH_RESPONSE_DATA = {
@ -23,6 +23,14 @@ AUTH_RESPONSE_DATA = {
"scopes": ["https://www.googleapis.com/auth/gmail.readonly"],
}
AUTH_RESPONSE_DATA_NO_SCOPES = {
"auth_id": "auth_123",
"authorization_url": "https://example.com/auth",
"status": "pending",
"authorization_id": "auth_123",
"scopes": [],
}
TOOL_RESPONSE_DATA = {
"tool_name": "GetEmails",
"tool_version": "0.1.0",
@ -36,7 +44,7 @@ TOOL_RESPONSE_DATA = {
TOOL_DEFINITION_DATA = {
"name": "GetEmails",
"full_name": "TestToolkit.GetEmails",
"fully_qualified_name": "TestToolkit.GetEmails",
"description": "Retrieve emails from a user's inbox",
"toolkit": {
"name": "TestToolkit",
@ -130,6 +138,30 @@ def test_arcade_auth_authorize(test_sync_client, mock_response, monkeypatch):
assert auth_response == AuthResponse(**AUTH_RESPONSE_DATA)
def test_arcade_auth_authorize_with_provider_type(test_sync_client, mock_response, monkeypatch):
"""Test Arcade.auth.authorize method."""
monkeypatch.setattr(Arcade, "_execute_request", lambda *args, **kwargs: AUTH_RESPONSE_DATA)
auth_response = test_sync_client.auth.authorize(
provider="hooli",
provider_type=AuthProviderType.oauth2,
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
user_id="sam@arcade-ai.com",
)
assert auth_response == AuthResponse(**AUTH_RESPONSE_DATA)
def test_arcade_auth_authorize_with_no_scopes(test_sync_client, mock_response, monkeypatch):
"""Test Arcade.auth.authorize method."""
monkeypatch.setattr(
Arcade, "_execute_request", lambda *args, **kwargs: AUTH_RESPONSE_DATA_NO_SCOPES
)
auth_response = test_sync_client.auth.authorize(
provider=AuthProvider.google,
user_id="sam@arcade-ai.com",
)
assert auth_response == AuthResponse(**AUTH_RESPONSE_DATA_NO_SCOPES)
def test_arcade_auth_poll_authorization(test_sync_client, mock_response, monkeypatch):
"""Test Arcade.auth.poll_authorization method."""
monkeypatch.setattr(Arcade, "_execute_request", lambda *args, **kwargs: AUTH_RESPONSE_DATA)
@ -201,6 +233,42 @@ async def test_async_arcade_auth_authorize(test_async_client, mock_async_respons
assert auth_response == AuthResponse(**AUTH_RESPONSE_DATA)
@pytest.mark.asyncio
async def test_async_arcade_auth_authorize_with_provider_type(
test_async_client, mock_async_response, monkeypatch
):
"""Test AsyncArcade.auth.authorize method."""
async def mock_execute_request(*args, **kwargs):
return AUTH_RESPONSE_DATA
monkeypatch.setattr(AsyncArcade, "_execute_request", mock_execute_request)
auth_response = await test_async_client.auth.authorize(
provider="hooli",
provider_type=AuthProviderType.oauth2,
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
user_id="sam@arcade-ai.com",
)
assert auth_response == AuthResponse(**AUTH_RESPONSE_DATA)
@pytest.mark.asyncio
async def test_async_arcade_auth_authorize_with_no_scopes(
test_async_client, mock_async_response, monkeypatch
):
"""Test AsyncArcade.auth.authorize method."""
async def mock_execute_request(*args, **kwargs):
return AUTH_RESPONSE_DATA_NO_SCOPES
monkeypatch.setattr(AsyncArcade, "_execute_request", mock_execute_request)
auth_response = await test_async_client.auth.authorize(
provider=AuthProvider.google,
user_id="sam@arcade-ai.com",
)
assert auth_response == AuthResponse(**AUTH_RESPONSE_DATA_NO_SCOPES)
@pytest.mark.asyncio
async def test_async_arcade_auth_poll_authorization(
test_async_client, mock_async_response, monkeypatch

View file

@ -39,7 +39,7 @@ def test_tool_decorator_with_all_options():
name="TestTool",
desc="Test description",
requires_auth=OAuth2(
authority="https://example.com/oauth2/auth",
provider_id="example",
scopes=["test_scope", "another.scope"],
),
)
@ -48,5 +48,4 @@ def test_tool_decorator_with_all_options():
assert test_tool.__tool_name__ == "TestTool"
assert test_tool.__tool_description__ == "Test description"
assert str(test_tool.__tool_requires_auth__.authority) == "https://example.com/oauth2/auth"
assert test_tool.__tool_requires_auth__.scopes == ["test_scope", "another.scope"]

View file

@ -17,7 +17,7 @@ from arcade.core.schema import (
from arcade.core.utils import snake_to_pascal_case
from arcade.sdk import tool
from arcade.sdk.annotations import Inferrable
from arcade.sdk.auth import GitHubApp, Google, OAuth2, SlackUser, X
from arcade.sdk.auth import GitHub, Google, OAuth2, Slack, X
### Tests on @tool decorator
@ -48,7 +48,10 @@ def func_with_name_and_description():
@tool(
desc="A function that requires authentication",
requires_auth=OAuth2(authority="https://example.com/oauth2/auth", scopes=["scope1", "scope2"]),
requires_auth=OAuth2(
provider_id="example",
scopes=["scope1", "scope2"],
),
)
def func_with_auth_requirement():
pass
@ -64,7 +67,7 @@ def func_with_google_auth_requirement():
@tool(
desc="A function that requires GitHub authorization",
requires_auth=GitHubApp(),
requires_auth=GitHub(),
)
def func_with_github_auth_requirement():
pass
@ -72,7 +75,7 @@ def func_with_github_auth_requirement():
@tool(
desc="A function that requires Slack user authorization",
requires_auth=SlackUser(scopes=["chat:write", "channels:history"]),
requires_auth=Slack(scopes=["chat:write", "channels:history"]),
)
def func_with_slack_user_auth_requirement():
pass
@ -239,7 +242,7 @@ def func_with_complex_return() -> dict[str, str]:
func_with_name_and_description,
{
"name": "MyCustomTool",
"full_name": "TestToolkit.MyCustomTool",
"fully_qualified_name": "TestToolkit.MyCustomTool",
"description": "A function with a very cool description",
},
id="func_with_description_and_name",
@ -254,7 +257,8 @@ def func_with_complex_return() -> dict[str, str]:
{
"requirements": ToolRequirements(
authorization=ToolAuthRequirement(
provider="oauth2",
provider_id="example",
provider_type="oauth2",
oauth2=OAuth2Requirement(
authority="https://example.com/oauth2/auth",
scopes=["scope1", "scope2"],
@ -269,7 +273,8 @@ def func_with_complex_return() -> dict[str, str]:
{
"requirements": ToolRequirements(
authorization=ToolAuthRequirement(
provider="google",
provider_id="google",
provider_type="oauth2",
oauth2=OAuth2Requirement(
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
),
@ -283,7 +288,7 @@ def func_with_complex_return() -> dict[str, str]:
{
"requirements": ToolRequirements(
authorization=ToolAuthRequirement(
provider="github_app",
provider_id="github", provider_type="oauth2", oauth2=OAuth2Requirement()
)
)
},
@ -294,7 +299,8 @@ def func_with_complex_return() -> dict[str, str]:
{
"requirements": ToolRequirements(
authorization=ToolAuthRequirement(
provider="slack_user",
provider_id="slack",
provider_type="oauth2",
oauth2=OAuth2Requirement(
scopes=["chat:write", "channels:history"],
),
@ -683,7 +689,7 @@ def test_tool_name_is_set_correctly():
tool_def = ToolCatalog.create_tool_definition(func_with_description, "test_toolkit", "1.0.0")
assert tool_def.name == snake_to_pascal_case(func_with_description.__name__)
assert tool_def.full_name == "TestToolkit.FuncWithDescription"
assert tool_def.fully_qualified_name == "TestToolkit.FuncWithDescription"
@pytest.mark.parametrize(

View file

@ -6,6 +6,7 @@ dictionaries: []
words:
- conlist
- fastapi
- httpx
- openai
- pydantic
- pyproject

View file

@ -2,7 +2,7 @@
name = "arcade_example_fastapi"
version = "0.1.0"
description = "FastAPI example app with Arcade"
authors = ["Nate Barbettini <nate@arcade-ai.com>"]
authors = ["Arcade AI <dev@arcade-ai.com"]
[tool.poetry.dependencies]
python = "^3.10"

View file

@ -54,20 +54,13 @@
"additionalProperties": false
}
},
"user": {
"user_id": {
"type": "string",
"description": "A unique ID that identifies the user (if any)"
},
"user_info": {
"type": "object",
"properties": {
"id": {
"type": "string",
"description": "A unique ID that identifies the user"
},
"name": {
"type": "string",
"description": "The name of the user"
}
},
"required": ["id"],
"additionalProperties": false
"description": "The user information provided by the authorization server (if any)"
}
}
}

View file

@ -50,6 +50,10 @@
"type": "string",
"description": "The tool name"
},
"fully_qualified_name": {
"type": "string",
"description": "The tool's fully-qualified name"
},
"description": {
"type": "string",
"description": "A human-readable description of the tool and when to use it"
@ -147,17 +151,18 @@
{
"type": "object",
"properties": {
"provider": {
"type": "string"
"provider_id": {
"type": "string",
"description": "A unique provider ID."
},
"provider_type": {
"type": "string",
"description": "The provider type."
},
"oauth2": {
"type": "object",
"properties": {
"authority": {
"type": "string",
"format": "uri"
},
"scope": {
"scopes": {
"type": "array",
"items": {
"type": "string"
@ -167,7 +172,7 @@
"additionalProperties": false
}
},
"required": ["provider"],
"required": ["provider_id", "provider_type"],
"additionalProperties": false
}
]
@ -176,6 +181,6 @@
"additionalProperties": false
}
},
"required": ["name", "toolkit", "inputs", "output"],
"required": ["name", "fully_qualified_name", "toolkit", "inputs", "output"],
"additionalProperties": false
}

View file

@ -4,13 +4,13 @@ import httpx
from arcade.core.schema import ToolContext
from arcade.sdk import tool
from arcade.sdk.auth import GitHubApp
from arcade.sdk.auth import GitHub
from arcade_github.tools.utils import get_github_json_headers, get_url, handle_github_response
# Implements https://docs.github.com/en/rest/activity/starring?apiVersion=2022-11-28#star-a-repository-for-the-authenticated-user and https://docs.github.com/en/rest/activity/starring?apiVersion=2022-11-28#unstar-a-repository-for-the-authenticated-user
# Example `arcade chat` usage: "star the vscode repo owned by microsoft"
@tool(requires_auth=GitHubApp())
@tool(requires_auth=GitHub())
async def set_starred(
context: ToolContext,
owner: Annotated[str, "The owner of the repository"],

View file

@ -5,7 +5,7 @@ import httpx
from arcade.core.schema import ToolContext
from arcade.sdk import tool
from arcade.sdk.auth import GitHubApp
from arcade.sdk.auth import GitHub
from arcade_github.tools.utils import (
get_github_json_headers,
get_url,
@ -16,7 +16,7 @@ from arcade_github.tools.utils import (
# Implements https://docs.github.com/en/rest/issues/issues?apiVersion=2022-11-28#create-an-issue
# Example `arcade chat` usage: "create an issue in the <REPO> repo owned by <OWNER> titled 'Found a bug' with the body 'I'm having a problem with this.' Assign it to <USER> and label it 'bug'"
@tool(requires_auth=GitHubApp())
@tool(requires_auth=GitHub())
async def create_issue(
context: ToolContext,
owner: Annotated[str, "The account owner of the repository. The name is not case sensitive."],
@ -85,7 +85,7 @@ async def create_issue(
# Implements https://docs.github.com/en/rest/issues/comments?apiVersion=2022-11-28#create-an-issue-comment
# Example `arcade chat` usage: "create a comment in the vscode repo owned by microsoft for issue 1347 that says 'Me too'"
@tool(requires_auth=GitHubApp())
@tool(requires_auth=GitHub())
async def create_issue_comment(
context: ToolContext,
owner: Annotated[str, "The account owner of the repository. The name is not case sensitive."],

View file

@ -6,7 +6,7 @@ import httpx
from arcade.core.errors import RetryableToolError
from arcade.core.schema import ToolContext
from arcade.sdk import tool
from arcade.sdk.auth import GitHubApp
from arcade.sdk.auth import GitHub
from arcade_github.tools.models import (
DiffSide,
PRSortProperty,
@ -28,7 +28,7 @@ from arcade_github.tools.utils import (
# Example `arcade chat` usage: "get all open PRs that <USER> has that are in the <OWNER>/<REPO> repo"
# TODO: Validate owner/repo combination is valid for the authenticated user. If not, return RetryableToolError with available repos.
# TODO: list repo's branches and validate base is in the list (or default to main). If not, return RetryableToolError with available branches.
@tool(requires_auth=GitHubApp())
@tool(requires_auth=GitHub())
async def list_pull_requests(
context: ToolContext,
owner: Annotated[str, "The account owner of the repository. The name is not case sensitive."],
@ -103,7 +103,7 @@ async def list_pull_requests(
# Implements https://docs.github.com/en/rest/pulls/pulls?apiVersion=2022-11-28#get-a-pull-request
# Example `arcade chat` usage: "get the PR #72 in the <OWNER>/<REPO> repo. Include diff content in your response."
@tool(requires_auth=GitHubApp())
@tool(requires_auth=GitHub())
async def get_pull_request(
context: ToolContext,
owner: Annotated[str, "The account owner of the repository. The name is not case sensitive."],
@ -177,7 +177,7 @@ async def get_pull_request(
# Implements https://docs.github.com/en/rest/pulls/pulls?apiVersion=2022-11-28#update-a-pull-request
# Example `arcade chat` usage: "update PR #72 in the <OWNER>/<REPO> repo by changing the title to 'New Title' and setting the body to 'This PR description was added via arcade chat!'."
# TODO: Enable this tool to append to the PR contents instead of only replacing content.
@tool(requires_auth=GitHubApp())
@tool(requires_auth=GitHub())
async def update_pull_request(
context: ToolContext,
owner: Annotated[str, "The account owner of the repository. The name is not case sensitive."],
@ -242,7 +242,7 @@ async def update_pull_request(
# Implements https://docs.github.com/en/rest/pulls/commits?apiVersion=2022-11-28#list-commits-on-a-pull-request
# Example `arcade chat` usage: "list all of the commits for the PR 72 in the <OWNER>/<REPO> repo"
@tool(requires_auth=GitHubApp())
@tool(requires_auth=GitHub())
async def list_pull_request_commits(
context: ToolContext,
owner: Annotated[str, "The account owner of the repository. The name is not case sensitive."],
@ -308,7 +308,7 @@ async def list_pull_request_commits(
# Example `arcade chat` usage: "create a reply to the review comment 1778019974 in arcadeai/arcade-ai for the PR 72 that says 'Thanks for the suggestion.'"
# Note: This tool requires the ID of the review comment to reply to. To obtain this ID, you should first call the `list_review_comments_on_pull_request` function.
# The returned JSON will contain the `id` field for each comment, which can be used as the `comment_id` parameter in this function.
@tool(requires_auth=GitHubApp())
@tool(requires_auth=GitHub())
async def create_reply_for_review_comment(
context: ToolContext,
owner: Annotated[str, "The account owner of the repository. The name is not case sensitive."],
@ -350,7 +350,7 @@ async def create_reply_for_review_comment(
# Implements https://docs.github.com/en/rest/pulls/comments?apiVersion=2022-11-28#list-review-comments-on-a-pull-request
# Example `arcade chat` usage: "list all of the review comments for PR 72 in <OWNER>/<REPO>"
@tool(requires_auth=GitHubApp())
@tool(requires_auth=GitHub())
async def list_review_comments_on_pull_request(
context: ToolContext,
owner: Annotated[str, "The account owner of the repository. The name is not case sensitive."],
@ -437,7 +437,7 @@ async def list_review_comments_on_pull_request(
# Implements https://docs.github.com/en/rest/pulls/comments?apiVersion=2022-11-28#create-a-review-comment-for-a-pull-request
# Example `arcade chat` usage: "create a review comment for PR 72 in <OWNER>/<REPO> that says 'Great stuff! This looks good to merge. Add the comment to README.md file.'"
# TODO: Verify that path parameter exists in the PR's files that have changed (Or should we allow for any file in the repo?). If not, then throw RetryableToolError with all valid file paths.
@tool(requires_auth=GitHubApp())
@tool(requires_auth=GitHub())
async def create_review_comment(
context: ToolContext,
owner: Annotated[str, "The account owner of the repository. The name is not case sensitive."],

View file

@ -5,7 +5,7 @@ import httpx
from arcade.core.schema import ToolContext
from arcade.sdk import tool
from arcade.sdk.auth import GitHubApp
from arcade.sdk.auth import GitHub
from arcade_github.tools.models import (
ActivityType,
RepoSortProperty,
@ -24,7 +24,7 @@ from arcade_github.tools.utils import (
# Implements https://docs.github.com/en/rest/repos/repos?apiVersion=2022-11-28#get-a-repository and returns only the stargazers_count field.
# Example arcade chat usage: "How many stargazers does the <OWNER>/<REPO> repo have?"
@tool(requires_auth=GitHubApp())
@tool(requires_auth=GitHub())
async def count_stargazers(
owner: Annotated[str, "The owner of the repository"],
name: Annotated[str, "The name of the repository"],
@ -49,7 +49,7 @@ async def count_stargazers(
# Implements https://docs.github.com/en/rest/repos/repos?apiVersion=2022-11-28#list-organization-repositories
# Example arcade chat usage: "List all repositories for the <ORG> organization. Sort by creation date in descending order."
@tool(requires_auth=GitHubApp())
@tool(requires_auth=GitHub())
async def list_org_repositories(
context: ToolContext,
org: Annotated[str, "The organization name. The name is not case sensitive"],
@ -111,7 +111,7 @@ async def list_org_repositories(
# Implements https://docs.github.com/en/rest/repos/repos?apiVersion=2022-11-28#get-a-repository
# Example arcade chat usage: "Tell me about the <OWNER>/<REPO> repo."
@tool(requires_auth=GitHubApp())
@tool(requires_auth=GitHub())
async def get_repository(
context: ToolContext,
owner: Annotated[str, "The account owner of the repository. The name is not case sensitive."],
@ -166,7 +166,7 @@ async def get_repository(
# Implements https://docs.github.com/en/rest/repos/repos?apiVersion=2022-11-28#list-repository-activities
# Example arcade chat usage: "List all merges into main for the <OWNER>/<REPO> repo in the last week by <USER>"
@tool(requires_auth=GitHubApp())
@tool(requires_auth=GitHub())
async def list_repository_activities(
context: ToolContext,
owner: Annotated[str, "The account owner of the repository. The name is not case sensitive."],
@ -260,7 +260,7 @@ async def list_repository_activities(
# Implements https://docs.github.com/en/rest/pulls/comments?apiVersion=2022-11-28#list-review-comments-in-a-repository
# Example arcade chat usage: "List all review comments for the <OWNER>/<REPO> repo. Sort by update date in descending order."
# TODO: Improve the 'since' input parameter such that language model can more easily specify a valid date/time.
@tool(requires_auth=GitHubApp())
@tool(requires_auth=GitHub())
async def list_review_comments_in_a_repository(
context: ToolContext,
owner: Annotated[str, "The account owner of the repository. The name is not case sensitive."],

View file

@ -2,7 +2,7 @@
name = "arcade_github"
version = "0.1.0"
description = "LLM tools for interacting with Github"
authors = ["Eric Gustin <eric@arcade-ai.com>"]
authors = ["Arcade AI <dev@arcade-ai.com"]
[tool.poetry.dependencies]
python = "^3.10"

View file

@ -2,7 +2,7 @@
name = "arcade_google"
version = "0.1.0"
description = "Arcade tools for the entire google suite"
authors = ["Sam Partee <sam@arcade-ai.com>", "Eric Gustin <eric@arcade-ai.com>"]
authors = ["Arcade AI <dev@arcade-ai.com"]
[tool.poetry.dependencies]
python = "^3.10"

View file

@ -0,0 +1,114 @@
from typing import Annotated
import httpx
from arcade.core.errors import ToolExecutionError
from arcade.core.schema import ToolContext
from arcade.sdk import tool
from arcade.sdk.auth import LinkedIn
LINKEDIN_BASE_URL = "https://api.linkedin.com/v2"
async def _send_linkedin_request(
context: ToolContext,
method: str,
endpoint: str,
params: dict | None = None,
json_data: dict | None = None,
) -> httpx.Response:
"""
Send an asynchronous request to the LinkedIn API.
Args:
context: The tool context containing the authorization token.
method: The HTTP method (GET, POST, PUT, DELETE, etc.).
endpoint: The API endpoint path (e.g., "/ugcPosts").
params: Query parameters to include in the request.
json_data: JSON data to include in the request body.
Returns:
The response object from the API request.
Raises:
ToolExecutionError: If the request fails for any reason.
"""
url = f"{LINKEDIN_BASE_URL}{endpoint}"
headers = {"Authorization": f"Bearer {context.authorization.token}"}
async with httpx.AsyncClient() as client:
try:
response = await client.request(
method, url, headers=headers, params=params, json=json_data
)
response.raise_for_status()
except httpx.RequestError as e:
raise ToolExecutionError(f"Failed to send request to LinkedIn API: {e}")
return response
def _handle_linkedin_api_error(response: httpx.Response):
"""
Handle errors from the LinkedIn API by mapping common status codes to ToolExecutionErrors.
Args:
response: The response object from the API request.
Raises:
ToolExecutionError: If the response contains an error status code.
"""
status_code_map = {
401: ToolExecutionError("Unauthorized: Invalid or expired token"),
403: ToolExecutionError("Forbidden: User does not have Spotify Premium"),
429: ToolExecutionError("Too Many Requests: Rate limit exceeded"),
}
if response.status_code in status_code_map:
raise status_code_map[response.status_code]
elif response.status_code >= 400:
raise ToolExecutionError(f"Error: {response.status_code} - {response.text}")
@tool(
requires_auth=LinkedIn(
scopes=["w_member_social"],
)
)
async def create_text_post(
context: ToolContext,
text: Annotated[str, "The text content of the post"],
) -> Annotated[str, "URL of the shared post"]:
"""Share a new text post to LinkedIn."""
endpoint = "/ugcPosts"
# The LinkedIn user ID is required to create a post, even though we're using the user's access token.
# Arcade Engine gets the current user's info from LinkedIn and automatically populates context.authorization.user_info.
# LinkedIn calls the user ID "sub" in their user_info data payload. See:
# https://learn.microsoft.com/en-us/linkedin/consumer/integrations/self-serve/sign-in-with-linkedin-v2#api-request-to-retreive-member-details
user_id = context.authorization.user_info.get("sub")
if not user_id:
raise ToolExecutionError(
"User ID not found.",
developer_message="User ID not found in `context.authorization.user_info.sub`",
)
author_id = f"urn:li:person:{user_id}"
payload = {
"author": author_id,
"lifecycleState": "PUBLISHED",
"specificContent": {
"com.linkedin.ugc.ShareContent": {
"shareCommentary": {"text": text},
"shareMediaCategory": "NONE",
}
},
"visibility": {"com.linkedin.ugc.MemberNetworkVisibility": "PUBLIC"},
}
response = await _send_linkedin_request(context, "POST", endpoint, json=payload)
if response.status_code >= 200 and response.status_code < 300:
share_id = response.json().get("id")
return f"https://www.linkedin.com/feed/update/{share_id}/"
else:
_handle_linkedin_api_error(response)

View file

@ -0,0 +1,17 @@
[tool.poetry]
name = "arcade_linkedin"
version = "0.1.0"
description = "Arcade tools for LinkedIn"
authors = ["Arcade AI <dev@arcade-ai.com"]
[tool.poetry.dependencies]
python = "^3.10"
arcade-ai = "0.1.*"
httpx = "^0.27.2"
[tool.poetry.dev-dependencies]
pytest = "^8.3.0"
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"

View file

@ -2,7 +2,7 @@
name = "arcade_math"
version = "0.1.0"
description = "Math toolkit for Arcade"
authors = ["Nate <nate@arcade-ai.com>"]
authors = ["Arcade AI <dev@arcade-ai.com"]
[tool.poetry.dependencies]

View file

@ -2,7 +2,7 @@
name = "arcade_search"
version = "0.1.0"
description = "Tools for searching the web"
authors = ["Sam Partee <sam@arcade-ai.com>"]
authors = ["Arcade AI <dev@arcade-ai.com"]
[tool.poetry.dependencies]
python = "^3.10"

View file

@ -6,11 +6,11 @@ from slack_sdk.errors import SlackApiError
from arcade.core.errors import RetryableToolError, ToolExecutionError
from arcade.core.schema import ToolContext
from arcade.sdk import tool
from arcade.sdk.auth import SlackUser
from arcade.sdk.auth import Slack
@tool(
requires_auth=SlackUser(
requires_auth=Slack(
scopes=[
"chat:write",
"im:write",
@ -74,7 +74,7 @@ def format_users(userListResponse: dict) -> str:
@tool(
requires_auth=SlackUser(
requires_auth=Slack(
scopes=[
"chat:write",
"channels:read",

View file

@ -2,7 +2,7 @@
name = "arcade_slack"
version = "0.1.0"
description = "Slack tools for LLMs"
authors = ["Nate Barbettini <nate@arcade-ai.com>"]
authors = ["Arcade AI <dev@arcade-ai.com"]
[tool.poetry.dependencies]
python = "^3.10"

View file

@ -0,0 +1,163 @@
from typing import Annotated, Optional
import httpx
from arcade.core.errors import ToolExecutionError
from arcade.core.schema import ToolContext
from arcade.sdk import tool
from arcade.sdk.auth import Spotify
SPOTIFY_BASE_URL = "https://api.spotify.com/v1"
async def _send_spotify_request(
context: ToolContext,
method: str,
endpoint: str,
params: dict | None = None,
json_data: dict | None = None,
) -> httpx.Response:
"""
Send an asynchronous request to the Spotify API.
Args:
context: The tool context containing the authorization token.
method: The HTTP method (GET, POST, PUT, DELETE, etc.).
endpoint: The API endpoint path (e.g., "/me/player/play").
params: Query parameters to include in the request.
json_data: JSON data to include in the request body.
Returns:
The response object from the API request.
Raises:
ToolExecutionError: If the request fails for any reason.
"""
url = f"{SPOTIFY_BASE_URL}{endpoint}"
headers = {"Authorization": f"Bearer {context.authorization.token}"}
async with httpx.AsyncClient() as client:
try:
response = await client.request(
method, url, headers=headers, params=params, json=json_data
)
response.raise_for_status()
except httpx.RequestError as e:
raise ToolExecutionError(f"Failed to send request to Spotify API: {e}")
return response
def _handle_spotify_api_error(response: httpx.Response):
"""
Handle errors from the Spotify API by mapping common status codes to ToolExecutionErrors.
Args:
response: The response object from the API request.
Raises:
ToolExecutionError: If the response contains an error status code.
"""
status_code_map = {
401: ToolExecutionError("Unauthorized: Invalid or expired token"),
403: ToolExecutionError("Forbidden: User does not have Spotify Premium"),
429: ToolExecutionError("Too Many Requests: Rate limit exceeded"),
}
if response.status_code in status_code_map:
raise status_code_map[response.status_code]
elif response.status_code >= 400:
raise ToolExecutionError(f"Error: {response.status_code} - {response.text}")
@tool(
requires_auth=Spotify(
scopes=["user-modify-playback-state"],
)
)
async def pause(
context: ToolContext,
device_id: Annotated[
Optional[str],
"The id of the device this command is targeting. If omitted, the active device is targeted.",
] = None,
) -> Annotated[str, "Success string confirming the pause"]:
"""Pause the current track"""
endpoint = "/me/player/pause"
params = {"device_id": device_id} if device_id else {}
response = await _send_spotify_request(context, "PUT", endpoint, params=params)
if response.status_code >= 200 and response.status_code < 300:
return "Playback paused"
else:
_handle_spotify_api_error(response)
@tool(
requires_auth=Spotify(
scopes=["user-modify-playback-state"],
)
)
async def resume(
context: ToolContext,
device_id: Annotated[
Optional[str],
"The id of the device this command is targeting. If omitted, the active device is targeted.",
] = None,
) -> Annotated[str, "Success string confirming the playback resume"]:
"""Resume the current track, if any"""
endpoint = "/me/player/play"
params = {"device_id": device_id} if device_id else {}
response = await _send_spotify_request(context, "PUT", endpoint, params=params)
if response.status_code >= 200 and response.status_code < 300:
return "Playback resumed"
else:
_handle_spotify_api_error(response)
@tool(
requires_auth=Spotify(
scopes=["user-read-playback-state"],
)
)
async def get_playback_state(
context: ToolContext,
) -> Annotated[dict, "Information about the user's current playback state"]:
"""Get information about the user's current playback state, including track or episode, progress, and active device."""
endpoint = "/me/player"
response = await _send_spotify_request(context, "GET", endpoint)
if response.status_code == 204:
return {"status": "Playback not available or active"}
elif response.status_code == 200:
data = response.json()
# TODO: Return a more structured model
result = {
"device_name": data.get("device", {}).get("name"),
"currently_playing_type": data.get("currently_playing_type"),
}
if data.get("currently_playing_type") == "track":
item = data.get("item", {})
album = item.get("album", {})
result.update({
"album_name": album.get("name"),
"album_artists": [artist.get("name") for artist in album.get("artists", [])],
"album_spotify_url": album.get("external_urls", {}).get("spotify"),
"track_name": item.get("name"),
"track_artists": [artist.get("name") for artist in item.get("artists", [])],
})
elif data.get("currently_playing_type") == "episode":
item = data.get("item", {})
show = item.get("show", {})
result.update({
"show_name": show.get("name"),
"show_spotify_url": show.get("external_urls", {}).get("spotify"),
"episode_name": item.get("name"),
"episode_spotify_url": item.get("external_urls", {}).get("spotify"),
})
return result
else:
_handle_spotify_api_error(response)

View file

@ -0,0 +1,17 @@
[tool.poetry]
name = "arcade_spotify"
version = "0.1.0"
description = "Arcade tools for Spotify"
authors = ["Arcade AI <dev@arcade-ai.com"]
[tool.poetry.dependencies]
python = "^3.10"
arcade-ai = "0.1.*"
httpx = "^0.27.2"
[tool.poetry.dev-dependencies]
pytest = "^8.3.0"
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"

View file

@ -12,7 +12,11 @@ TWEETS_URL = "https://api.x.com/2/tweets"
# Manage Tweets Tools. See developer docs for additional available parameters: https://developer.x.com/en/docs/x-api/tweets/manage-tweets/api-reference
@tool(requires_auth=X(scopes=["tweet.read", "tweet.write", "users.read"]))
@tool(
requires_auth=X(
scopes=["tweet.read", "tweet.write", "users.read"],
)
)
async def post_tweet(
context: ToolContext,
tweet_text: Annotated[str, "The text content of the tweet you want to post"],

View file

@ -2,7 +2,7 @@
name = "arcade_x"
version = "0.1.0"
description = "LLM tools for interacting with X (Twitter)"
authors = ["Eric Gustin <eric@arcade-ai.com>"]
authors = ["Arcade AI <dev@arcade-ai.com"]
[tool.poetry.dependencies]
python = "^3.10"

View file

View file

@ -0,0 +1,114 @@
from typing import Annotated, Optional
import httpx
from arcade.core.errors import ToolExecutionError
from arcade.core.schema import ToolContext
from arcade.sdk import tool
from arcade.sdk.auth import Zoom
ZOOM_BASE_URL = "https://api.zoom.us/v2"
async def _send_zoom_request(
context: ToolContext,
method: str,
endpoint: str,
params: dict | None = None,
json_data: dict | None = None,
) -> httpx.Response:
"""
Send an asynchronous request to the Zoom API.
Args:
context: The tool context containing the authorization token.
method: The HTTP method (GET, POST, PUT, DELETE, etc.).
endpoint: The API endpoint path (e.g., "/users/me/upcoming_meetings").
params: Query parameters to include in the request.
json_data: JSON data to include in the request body.
Returns:
The response object from the API request.
Raises:
ToolExecutionError: If the request fails for any reason.
"""
url = f"{ZOOM_BASE_URL}{endpoint}"
headers = {"Authorization": f"Bearer {context.authorization.token}"}
async with httpx.AsyncClient() as client:
try:
response = await client.request(
method, url, headers=headers, params=params, json=json_data
)
response.raise_for_status()
except httpx.RequestError as e:
raise ToolExecutionError(f"Failed to send request to Zoom API: {e}")
return response
def _handle_zoom_api_error(response: httpx.Response):
"""
Handle errors from the Zoom API by mapping common status codes to ToolExecutionErrors.
Args:
response: The response object from the API request.
Raises:
ToolExecutionError: If the response contains an error status code.
"""
status_code_map = {
401: ToolExecutionError("Unauthorized: Invalid or expired token"),
403: ToolExecutionError("Forbidden: Access denied"),
429: ToolExecutionError("Too Many Requests: Rate limit exceeded"),
}
if response.status_code in status_code_map:
raise status_code_map[response.status_code]
elif response.status_code >= 400:
raise ToolExecutionError(f"Error: {response.status_code} - {response.text}")
@tool(
requires_auth=Zoom(
scopes=["meeting:read:list_upcoming_meetings"],
)
)
async def list_upcoming_meetings(
context: ToolContext,
user_id: Annotated[
Optional[str],
"The user's user ID or email address. Defaults to 'me' for the current user.",
] = "me",
) -> Annotated[dict, "List of upcoming meetings within the next 24 hours"]:
"""List a Zoom user's upcoming meetings within the next 24 hours."""
endpoint = f"/users/{user_id}/upcoming_meetings"
response = await _send_zoom_request(context, "GET", endpoint)
if response.status_code >= 200 and response.status_code < 300:
return response.json()
else:
_handle_zoom_api_error(response)
@tool(
requires_auth=Zoom(
scopes=["meeting:read:invitation"],
)
)
async def get_meeting_invitation(
context: ToolContext,
meeting_id: Annotated[
str,
"The meeting's numeric ID (as a string).",
],
) -> Annotated[dict, "Meeting invitation string"]:
"""Retrieve the invitation note for a specific Zoom meeting."""
endpoint = f"/meetings/{meeting_id}/invitation"
response = await _send_zoom_request(context, "GET", endpoint)
if response.status_code >= 200 and response.status_code < 300:
return response.json()
else:
_handle_zoom_api_error(response)

View file

@ -0,0 +1,17 @@
[tool.poetry]
name = "arcade_zoom"
version = "0.1.0"
description = "Arcade tools for Zoom"
authors = ["Arcade AI <dev@arcade-ai.com"]
[tool.poetry.dependencies]
python = "^3.10"
arcade-ai = "0.1.*"
httpx = "^0.27.2"
[tool.poetry.dev-dependencies]
pytest = "^8.3.0"
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"