From f4fe8c78922082fc72973967798f8a710c129709 Mon Sep 17 00:00:00 2001 From: Nate Barbettini Date: Tue, 17 Sep 2024 16:38:51 -0700 Subject: [PATCH] Clean up provider properties (scopes) (#42) In this PR: - Rename `scope` to `scopes` so it is more understandable by humans - DRY up provider structs, it was starting to get silly with so many providers that just have 1 property called `scopes` Must go along with this Engine PR: https://github.com/ArcadeAI/Engine/pull/79 --- arcade/arcade/client/client.py | 4 +- arcade/arcade/client/schema.py | 2 +- arcade/arcade/core/catalog.py | 12 +--- arcade/arcade/core/schema.py | 26 +-------- arcade/arcade/sdk/auth.py | 28 +++++----- arcade/tests/sdk/test_tool_decorator.py | 4 +- .../tests/tool/test_create_tool_definition.py | 56 ++++++++++++++++--- schemas/preview/tool_definition.schema.jsonc | 3 +- toolkits/gmail/arcade_gmail/tools/gmail.py | 6 +- toolkits/slack/arcade_slack/tools/chat.py | 4 +- 10 files changed, 77 insertions(+), 68 deletions(-) diff --git a/arcade/arcade/client/client.py b/arcade/arcade/client/client.py index 48d22d0c..110e574f 100644 --- a/arcade/arcade/client/client.py +++ b/arcade/arcade/client/client.py @@ -50,7 +50,7 @@ class AuthResource(BaseResource[ClientT]): body = { "auth_requirement": { "provider": auth_provider, - auth_provider: AuthRequest(scope=scopes, authority=authority).model_dump( + auth_provider: AuthRequest(scopes=scopes, authority=authority).model_dump( exclude_none=True ), }, @@ -200,7 +200,7 @@ class AsyncAuthResource(BaseResource[AsyncArcadeClient]): body = { "auth_requirement": { "provider": auth_provider, - auth_provider: AuthRequest(scope=scopes, authority=authority).model_dump( + auth_provider: AuthRequest(scopes=scopes, authority=authority).model_dump( exclude_none=True ), }, diff --git a/arcade/arcade/client/schema.py b/arcade/arcade/client/schema.py index 5ef4664c..0954542f 100644 --- a/arcade/arcade/client/schema.py +++ b/arcade/arcade/client/schema.py @@ -30,7 +30,7 @@ class AuthRequest(BaseModel): authority: AnyUrl | str | None = None """The URL of the OAuth 2.0 authorization server.""" - scope: list[str] + scopes: list[str] """The scope(s) needed for authorization.""" diff --git a/arcade/arcade/core/catalog.py b/arcade/arcade/core/catalog.py index 02d39ca5..7bc2e271 100644 --- a/arcade/arcade/core/catalog.py +++ b/arcade/arcade/core/catalog.py @@ -24,10 +24,8 @@ from pydantic_core import PydanticUndefined from arcade.core.errors import ToolDefinitionError from arcade.core.schema import ( - GoogleRequirement, InputParameter, OAuth2Requirement, - SlackUserRequirement, ToolAuthRequirement, ToolContext, ToolDefinition, @@ -44,7 +42,7 @@ from arcade.core.utils import ( snake_to_pascal_case, ) from arcade.sdk.annotations import Inferrable -from arcade.sdk.auth import Google, OAuth2, SlackUser, ToolAuthorization +from arcade.sdk.auth import BaseOAuth2, ToolAuthorization InnerWireType = Literal["string", "integer", "number", "boolean", "json"] WireType = Union[InnerWireType, Literal["array"]] @@ -204,14 +202,8 @@ class ToolCatalog(BaseModel): new_auth_requirement = ToolAuthRequirement( provider=auth_requirement.get_provider(), ) - if isinstance(auth_requirement, OAuth2): + if isinstance(auth_requirement, BaseOAuth2): new_auth_requirement.oauth2 = OAuth2Requirement(**auth_requirement.model_dump()) - elif isinstance(auth_requirement, Google): - new_auth_requirement.google = GoogleRequirement(**auth_requirement.model_dump()) - elif isinstance(auth_requirement, SlackUser): - new_auth_requirement.slack_user = SlackUserRequirement( - **auth_requirement.model_dump() - ) auth_requirement = new_auth_requirement return ToolDefinition( diff --git a/arcade/arcade/core/schema.py b/arcade/arcade/core/schema.py index 12eda81b..c6707907 100644 --- a/arcade/arcade/core/schema.py +++ b/arcade/arcade/core/schema.py @@ -68,25 +68,11 @@ class ToolOutput(BaseModel): class OAuth2Requirement(BaseModel): """Indicates that the tool requires OAuth 2.0 authorization.""" - authority: AnyUrl + authority: Optional[AnyUrl] = None """The URL of the OAuth 2.0 authorization server.""" - scope: Optional[list[str]] = None - """The scope(s) needed for authorization.""" - - -class GoogleRequirement(BaseModel): - """Indicates that the tool requires Google authorization.""" - - scope: Optional[list[str]] = None - """The scope(s) needed for authorization.""" - - -class SlackUserRequirement(BaseModel): - """Indicates that the tool requires Slack (user token) authorization.""" - - scope: Optional[list[str]] = None - """The scope(s) needed for authorization.""" + scopes: Optional[list[str]] = None + """The scope(s) needed for authorization, if any.""" class ToolAuthRequirement(BaseModel): @@ -98,12 +84,6 @@ class ToolAuthRequirement(BaseModel): oauth2: Optional[OAuth2Requirement] = None """The OAuth 2.0 requirement, if any.""" - google: Optional[GoogleRequirement] = None - """The Google requirement, if any.""" - - slack_user: Optional[SlackUserRequirement] = None - """The Slack (user token) requirement, if any.""" - class ToolRequirements(BaseModel): """The requirements for a tool to run.""" diff --git a/arcade/arcade/sdk/auth.py b/arcade/arcade/sdk/auth.py index 23b88fa4..84401a68 100644 --- a/arcade/arcade/sdk/auth.py +++ b/arcade/arcade/sdk/auth.py @@ -15,38 +15,36 @@ class ToolAuthorization(BaseModel, ABC): pass -class OAuth2(ToolAuthorization): +class BaseOAuth2(ToolAuthorization): + """Base class for any provider supporting OAuth 2.0-like authorization.""" + + authority: Optional[AnyUrl] = None + """The URL of the OAuth 2.0 authorization server.""" + + scopes: Optional[list[str]] = None + """The scope(s) needed for the authorized action.""" + + +class OAuth2(BaseOAuth2): """Marks a tool as requiring OAuth 2.0 authorization.""" def get_provider(self) -> str: return "oauth2" - authority: AnyUrl - """The URL of the OAuth 2.0 authorization server.""" - scope: Optional[list[str]] = None - """The scope(s) needed for the authorized action.""" - - -class Google(ToolAuthorization): +class Google(BaseOAuth2): """Marks a tool as requiring Google authorization.""" def get_provider(self) -> str: return "google" - scope: Optional[list[str]] = None - """The scope(s) needed for the authorized action.""" - -class SlackUser(ToolAuthorization): +class SlackUser(BaseOAuth2): """Marks a tool as requiring Slack (user token) authorization.""" def get_provider(self) -> str: return "slack_user" - scope: Optional[list[str]] = None - """The scope(s) needed for the authorized action.""" - class GitHubApp(ToolAuthorization): """Marks a tool as requiring GitHub App authorization.""" diff --git a/arcade/tests/sdk/test_tool_decorator.py b/arcade/tests/sdk/test_tool_decorator.py index 10ffb2a2..d2cb642c 100644 --- a/arcade/tests/sdk/test_tool_decorator.py +++ b/arcade/tests/sdk/test_tool_decorator.py @@ -40,7 +40,7 @@ def test_tool_decorator_with_all_options(): desc="Test description", requires_auth=OAuth2( authority="https://example.com/oauth2/auth", - scope=["test_scope", "another.scope"], + scopes=["test_scope", "another.scope"], ), ) def test_tool(x, y): @@ -49,4 +49,4 @@ def test_tool_decorator_with_all_options(): assert test_tool.__tool_name__ == "TestTool" assert test_tool.__tool_description__ == "Test description" assert str(test_tool.__tool_requires_auth__.authority) == "https://example.com/oauth2/auth" - assert test_tool.__tool_requires_auth__.scope == ["test_scope", "another.scope"] + assert test_tool.__tool_requires_auth__.scopes == ["test_scope", "another.scope"] diff --git a/arcade/tests/tool/test_create_tool_definition.py b/arcade/tests/tool/test_create_tool_definition.py index 424303fc..63798a5c 100644 --- a/arcade/tests/tool/test_create_tool_definition.py +++ b/arcade/tests/tool/test_create_tool_definition.py @@ -5,7 +5,6 @@ import pytest from arcade.core.catalog import ToolCatalog from arcade.core.schema import ( - GoogleRequirement, InputParameter, OAuth2Requirement, ToolAuthRequirement, @@ -17,7 +16,7 @@ from arcade.core.schema import ( ) from arcade.sdk import tool from arcade.sdk.annotations import Inferrable -from arcade.sdk.auth import Google, OAuth2 +from arcade.sdk.auth import GitHubApp, Google, OAuth2, SlackUser ### Tests on @tool decorator @@ -48,20 +47,36 @@ def func_with_name_and_description(): @tool( desc="A function that requires authentication", - requires_auth=OAuth2(authority="https://example.com/oauth2/auth", scope=["scope1", "scope2"]), + requires_auth=OAuth2(authority="https://example.com/oauth2/auth", scopes=["scope1", "scope2"]), ) def func_with_auth_requirement(): pass @tool( - desc="A function that requires authentication", - requires_auth=Google(scope=["https://www.googleapis.com/auth/gmail.readonly"]), + desc="A function that requires Google authorization", + requires_auth=Google(scopes=["https://www.googleapis.com/auth/gmail.readonly"]), ) def func_with_google_auth_requirement(): pass +@tool( + desc="A function that requires GitHub authorization", + requires_auth=GitHubApp(), +) +def func_with_github_auth_requirement(): + pass + + +@tool( + desc="A function that requires Slack user authorization", + requires_auth=SlackUser(scopes=["chat:write", "channels:history"]), +) +def func_with_slack_user_auth_requirement(): + pass + + ### Tests on input params @tool(desc="A function with a non-inferrable input parameter") def func_with_non_inferrable_param(param1: Annotated[str, "First param", Inferrable(False)]): @@ -229,7 +244,7 @@ def func_with_complex_return() -> dict[str, str]: provider="oauth2", oauth2=OAuth2Requirement( authority="https://example.com/oauth2/auth", - scope=["scope1", "scope2"], + scopes=["scope1", "scope2"], ), ) ) @@ -242,14 +257,39 @@ def func_with_complex_return() -> dict[str, str]: "requirements": ToolRequirements( authorization=ToolAuthRequirement( provider="google", - google=GoogleRequirement( - scope=["https://www.googleapis.com/auth/gmail.readonly"], + oauth2=OAuth2Requirement( + scopes=["https://www.googleapis.com/auth/gmail.readonly"], ), ) ) }, id="func_with_google_auth_requirement", ), + pytest.param( + func_with_github_auth_requirement, + { + "requirements": ToolRequirements( + authorization=ToolAuthRequirement( + provider="github_app", + ) + ) + }, + id="func_with_github_auth_requirement", + ), + pytest.param( + func_with_slack_user_auth_requirement, + { + "requirements": ToolRequirements( + authorization=ToolAuthRequirement( + provider="slack_user", + oauth2=OAuth2Requirement( + scopes=["chat:write", "channels:history"], + ), + ) + ) + }, + id="func_with_slack_user_auth_requirement", + ), # Tests on input params pytest.param( func_with_non_inferrable_param, diff --git a/schemas/preview/tool_definition.schema.jsonc b/schemas/preview/tool_definition.schema.jsonc index 757a04d0..a9679fb5 100644 --- a/schemas/preview/tool_definition.schema.jsonc +++ b/schemas/preview/tool_definition.schema.jsonc @@ -148,11 +148,10 @@ } } }, - "required": ["authority"], "additionalProperties": false } }, - "required": [], + "required": ["provider"], "additionalProperties": false } ] diff --git a/toolkits/gmail/arcade_gmail/tools/gmail.py b/toolkits/gmail/arcade_gmail/tools/gmail.py index 5e476a75..fd6f3054 100644 --- a/toolkits/gmail/arcade_gmail/tools/gmail.py +++ b/toolkits/gmail/arcade_gmail/tools/gmail.py @@ -18,7 +18,7 @@ from arcade.sdk.auth import Google @tool( requires_auth=Google( - scope=["https://www.googleapis.com/auth/gmail.compose"], + scopes=["https://www.googleapis.com/auth/gmail.compose"], ) ) async def write_draft( @@ -88,7 +88,7 @@ class DateRange(Enum): @tool( requires_auth=Google( - scope=["https://www.googleapis.com/auth/gmail.readonly"], + scopes=["https://www.googleapis.com/auth/gmail.readonly"], ) ) async def search_emails_by_header( @@ -168,7 +168,7 @@ async def search_emails_by_header( @tool( requires_auth=Google( - scope=["https://www.googleapis.com/auth/gmail.readonly"], + scopes=["https://www.googleapis.com/auth/gmail.readonly"], ) ) async def get_emails( diff --git a/toolkits/slack/arcade_slack/tools/chat.py b/toolkits/slack/arcade_slack/tools/chat.py index 217071ff..cdc76489 100644 --- a/toolkits/slack/arcade_slack/tools/chat.py +++ b/toolkits/slack/arcade_slack/tools/chat.py @@ -11,7 +11,7 @@ from arcade.sdk.auth import SlackUser @tool( requires_auth=SlackUser( - scope=[ + scopes=[ "chat:write", "im:write", "users.profile:read", @@ -72,7 +72,7 @@ def format_users(userListResponse: dict) -> str: @tool( requires_auth=SlackUser( - scope=[ + scopes=[ "chat:write", "channels:read", "groups:read",