From 6035cde920af36c5dcc30c5c2fe7a9298b223716 Mon Sep 17 00:00:00 2001 From: Eric Gustin <34000337+EricGustin@users.noreply.github.com> Date: Fri, 17 Jan 2025 11:38:06 -0800 Subject: [PATCH] Update auth `provider_id` (#173) # PR Description Well known providers (Google, X, Dropbox, etc.) can optionally have an `id` in addition to their hardcoded `provider_id`. For non well known providers, they must provide an `id`, and the `provider_id` is hardcoded as `None`. ```python OAuth2() # INVALID OAuth2(provider_id="abc") # INVALID OAuth2(id="abc") # VALID OAuth2(provider_id="abc", id="def") # INVALID ``` ```python Google() # VALID Google(provider_id="abc") # INVALID Google(id="abc") # VALID Google(provider_id="abc", id="def") # INVALID ``` --------- Co-authored-by: Wils Dawson --- arcade/arcade/core/auth.py | 46 +++++++++-- arcade/arcade/core/catalog.py | 1 + arcade/arcade/core/schema.py | 13 ++-- arcade/tests/sdk/test_tool_decorator.py | 76 +++++++++++++++++-- .../tests/tool/test_create_tool_definition.py | 41 ++++++++-- schemas/preview/tool_definition.schema.jsonc | 10 ++- 6 files changed, 161 insertions(+), 26 deletions(-) diff --git a/arcade/arcade/core/auth.py b/arcade/arcade/core/auth.py index f9a6d487..7ff48bc2 100644 --- a/arcade/arcade/core/auth.py +++ b/arcade/arcade/core/auth.py @@ -13,20 +13,24 @@ class ToolAuthorization(BaseModel): model_config = ConfigDict(frozen=True) - provider_id: str - """The unique provider ID configured in Arcade.""" + provider_id: Optional[str] = None + """The provider ID configured in Arcade that acts as an alias to well-known configuration.""" provider_type: AuthProviderType """The type of the authorization provider.""" + id: Optional[str] = None + """A provider's unique identifier, allowing the tool to specify a specific authorization provider. Recommended for private tools only.""" + + scopes: Optional[list[str]] = None + """The scope(s) needed for the authorized action.""" + class OAuth2(ToolAuthorization): """Marks a tool as requiring OAuth 2.0 authorization.""" - provider_type: AuthProviderType = AuthProviderType.oauth2 - - scopes: Optional[list[str]] = None - """The scope(s) needed for the authorized action.""" + def __init__(self, *, id: str | None, scopes: Optional[list[str]] = None): # noqa: A002 + super().__init__(id=id, scopes=scopes, provider_type=AuthProviderType.oauth2) class Atlassian(OAuth2): @@ -34,56 +38,86 @@ class Atlassian(OAuth2): provider_id: str = "atlassian" + def __init__(self, *, id: Optional[str] = None, scopes: Optional[list[str]] = None): # noqa: A002 + super().__init__(id=id, scopes=scopes) + class Discord(OAuth2): """Marks a tool as requiring Discord authorization.""" provider_id: str = "discord" + def __init__(self, *, id: Optional[str] = None, scopes: Optional[list[str]] = None): # noqa: A002 + super().__init__(id=id, scopes=scopes) + class Dropbox(OAuth2): """Marks a tool as requiring Dropbox authorization.""" provider_id: str = "dropbox" + def __init__(self, *, id: Optional[str] = None, scopes: Optional[list[str]] = None): # noqa: A002 + super().__init__(id=id, scopes=scopes) + class Google(OAuth2): """Marks a tool as requiring Google authorization.""" provider_id: str = "google" + def __init__(self, *, id: Optional[str] = None, scopes: Optional[list[str]] = None): # noqa: A002 + super().__init__(id=id, scopes=scopes) + class Slack(OAuth2): """Marks a tool as requiring Slack (user token) authorization.""" provider_id: str = "slack" + def __init__(self, *, id: Optional[str] = None, scopes: Optional[list[str]] = None): # noqa: A002 + super().__init__(id=id, scopes=scopes) + class GitHub(OAuth2): """Marks a tool as requiring GitHub App authorization.""" provider_id: str = "github" + def __init__(self, *, id: Optional[str] = None, scopes: Optional[list[str]] = None): # noqa: A002 + super().__init__(id=id, scopes=scopes) + class X(OAuth2): """Marks a tool as requiring X (Twitter) authorization.""" provider_id: str = "x" + def __init__(self, *, id: Optional[str] = None, scopes: Optional[list[str]] = None): # noqa: A002 + super().__init__(id=id, scopes=scopes) + class LinkedIn(OAuth2): """Marks a tool as requiring LinkedIn authorization.""" provider_id: str = "linkedin" + def __init__(self, *, id: Optional[str] = None, scopes: Optional[list[str]] = None): # noqa: A002 + super().__init__(id=id, scopes=scopes) + class Spotify(OAuth2): """Marks a tool as requiring Spotify authorization.""" provider_id: str = "spotify" + def __init__(self, *, id: Optional[str] = None, scopes: Optional[list[str]] = None): # noqa: A002 + super().__init__(id=id, scopes=scopes) + class Zoom(OAuth2): """Marks a tool as requiring Zoom authorization.""" provider_id: str = "zoom" + + def __init__(self, *, id: Optional[str] = None, scopes: Optional[list[str]] = None): # noqa: A002 + super().__init__(id=id, scopes=scopes) diff --git a/arcade/arcade/core/catalog.py b/arcade/arcade/core/catalog.py index 50f6e0d5..d150f7bc 100644 --- a/arcade/arcade/core/catalog.py +++ b/arcade/arcade/core/catalog.py @@ -298,6 +298,7 @@ class ToolCatalog(BaseModel): new_auth_requirement = ToolAuthRequirement( provider_id=auth_requirement.provider_id, provider_type=auth_requirement.provider_type, + id=auth_requirement.id, ) if isinstance(auth_requirement, OAuth2): new_auth_requirement.oauth2 = OAuth2Requirement(**auth_requirement.model_dump()) diff --git a/arcade/arcade/core/schema.py b/arcade/arcade/core/schema.py index 73394783..cbea0319 100644 --- a/arcade/arcade/core/schema.py +++ b/arcade/arcade/core/schema.py @@ -74,26 +74,29 @@ class OAuth2Requirement(BaseModel): """Indicates that the tool requires OAuth 2.0 authorization.""" scopes: Optional[list[str]] = None - """The scope(s) needed for authorization, if any.""" + """The scope(s) needed for the authorized action.""" class ToolAuthRequirement(BaseModel): """A requirement for authorization to use a tool.""" - # Provider ID and Type needed for the Arcade Engine to look up the auth provider. + # Provider ID, Type, and ID needed for the Arcade Engine to look up the auth provider. # However, the developer generally does not need to set these directly. # Instead, they will use: # @tool(requires_auth=Google(scopes=["profile", "email"])) # or # client.auth.authorize(provider=AuthProvider.google, scopes=["profile", "email"]) # - # The Arcade SDK translates these into the appropriate provider ID and type. + # The Arcade SDK translates these into the appropriate provider ID (Google) and type (OAuth2). # The only time the developer will set these is if they are using a custom auth provider. provider_id: Optional[str] = None - """A unique provider ID.""" + """The provider ID configured in Arcade that acts as an alias to well-known configuration.""" provider_type: str - """The provider type.""" + """The type of the authorization provider.""" + + id: Optional[str] = None + """A provider's unique identifier, allowing the tool to specify a specific authorization provider. Recommended for private tools only.""" oauth2: Optional[OAuth2Requirement] = None """The OAuth 2.0 requirement, if any.""" diff --git a/arcade/tests/sdk/test_tool_decorator.py b/arcade/tests/sdk/test_tool_decorator.py index d7e7bcb7..f86b0166 100644 --- a/arcade/tests/sdk/test_tool_decorator.py +++ b/arcade/tests/sdk/test_tool_decorator.py @@ -2,6 +2,7 @@ import asyncio import pytest +from arcade.core.auth import AuthProviderType, Google from arcade.sdk import tool from arcade.sdk.auth import OAuth2 @@ -34,18 +35,83 @@ async def test_async_function(): assert result == 3 -def test_tool_decorator_with_all_options(): +@pytest.mark.parametrize( + "auth_class, auth_kwargs, expected_provider_id, expected_id", + [ + ( + OAuth2, + {"id": "my_example_provider123", "scopes": ["test_scope", "another.scope"]}, + None, + "my_example_provider123", + ), + (Google, {"scopes": ["test_scope", "another.scope"]}, "google", None), + ( + Google, + {"id": "my_google_provider123", "scopes": ["test_scope", "another.scope"]}, + "google", + "my_google_provider123", + ), + ], +) +def test_tool_decorator_with_auth_success( + auth_class, auth_kwargs, expected_provider_id, expected_id +): @tool( name="TestTool", desc="Test description", - requires_auth=OAuth2( - provider_id="example", - scopes=["test_scope", "another.scope"], - ), + requires_auth=auth_class(**auth_kwargs), ) def test_tool(x, y): return x + y assert test_tool.__tool_name__ == "TestTool" assert test_tool.__tool_description__ == "Test description" + assert test_tool.__tool_requires_auth__.provider_id == expected_provider_id + assert test_tool.__tool_requires_auth__.provider_type == AuthProviderType.oauth2 + assert test_tool.__tool_requires_auth__.id == expected_id assert test_tool.__tool_requires_auth__.scopes == ["test_scope", "another.scope"] + + +@pytest.mark.parametrize( + "auth_class, auth_kwargs", + [ + (OAuth2, {"scopes": ["test_scope", "another.scope"]}), + ( + OAuth2, + {"provider_id": "my_example_provider123", "scopes": ["test_scope", "another.scope"]}, + ), + ( + OAuth2, + { + "provider_id": "my_example_provider_id_123", + "id": "my_example_id_123", + "scopes": ["test_scope", "another.scope"], + }, + ), + ( + Google, + { + "provider_id": "my_example_provider_id_123", + "scopes": ["test_scope", "another.scope"], + }, + ), + ( + Google, + { + "provider_id": "my_example_provider_id_123", + "id": "my_example_id_123", + "scopes": ["test_scope", "another.scope"], + }, + ), + ], +) +def test_tool_decorator_with_auth_failure(auth_class, auth_kwargs): + with pytest.raises(TypeError): + + @tool( + name="TestTool", + desc="Test description", + requires_auth=auth_class(**auth_kwargs), + ) + def test_tool(x, y): + return x + y diff --git a/arcade/tests/tool/test_create_tool_definition.py b/arcade/tests/tool/test_create_tool_definition.py index bbbfc0b6..27a2ee9f 100644 --- a/arcade/tests/tool/test_create_tool_definition.py +++ b/arcade/tests/tool/test_create_tool_definition.py @@ -49,7 +49,7 @@ def func_with_name_and_description(): @tool( desc="A function that requires authentication", requires_auth=OAuth2( - provider_id="example", + id="my_example_provider123", scopes=["scope1", "scope2"], ), ) @@ -59,7 +59,10 @@ def func_with_auth_requirement(): @tool( desc="A function that requires Google authorization", - requires_auth=Google(scopes=["https://www.googleapis.com/auth/gmail.readonly"]), + requires_auth=Google( + id="my_google_provider123", + scopes=["https://www.googleapis.com/auth/gmail.readonly"], + ), ) def func_with_google_auth_requirement(): pass @@ -67,7 +70,9 @@ def func_with_google_auth_requirement(): @tool( desc="A function that requires GitHub authorization", - requires_auth=GitHub(), + requires_auth=GitHub( + id="my_github_provider123", + ), ) def func_with_github_auth_requirement(): pass @@ -75,7 +80,9 @@ def func_with_github_auth_requirement(): @tool( desc="A function that requires Slack user authorization", - requires_auth=Slack(scopes=["chat:write", "channels:history"]), + requires_auth=Slack( + scopes=["chat:write", "channels:history"], + ), ) def func_with_slack_user_auth_requirement(): pass @@ -83,7 +90,9 @@ def func_with_slack_user_auth_requirement(): @tool( desc="A function that requires X (Twitter) authorization", - requires_auth=X(scopes=["tweet.write"]), + requires_auth=X( + scopes=["tweet.write"], + ), ) def func_with_x_requirement(): pass @@ -257,8 +266,8 @@ def func_with_complex_return() -> dict[str, str]: { "requirements": ToolRequirements( authorization=ToolAuthRequirement( - provider_id="example", provider_type="oauth2", + id="my_example_provider123", oauth2=OAuth2Requirement( authority="https://example.com/oauth2/auth", scopes=["scope1", "scope2"], @@ -275,6 +284,7 @@ def func_with_complex_return() -> dict[str, str]: authorization=ToolAuthRequirement( provider_id="google", provider_type="oauth2", + id="my_google_provider123", oauth2=OAuth2Requirement( scopes=["https://www.googleapis.com/auth/gmail.readonly"], ), @@ -288,7 +298,10 @@ def func_with_complex_return() -> dict[str, str]: { "requirements": ToolRequirements( authorization=ToolAuthRequirement( - provider_id="github", provider_type="oauth2", oauth2=OAuth2Requirement() + provider_id="github", + provider_type="oauth2", + id="my_github_provider123", + oauth2=OAuth2Requirement(), ) ) }, @@ -309,6 +322,20 @@ def func_with_complex_return() -> dict[str, str]: }, id="func_with_slack_user_auth_requirement", ), + pytest.param( + func_with_x_requirement, + { + "requirements": ToolRequirements( + authorization=ToolAuthRequirement( + provider_id="x", + provider_type="oauth2", + oauth2=OAuth2Requirement( + scopes=["tweet.write"], + ), + ) + ) + }, + ), # Tests on input params pytest.param( func_with_non_inferrable_param, diff --git a/schemas/preview/tool_definition.schema.jsonc b/schemas/preview/tool_definition.schema.jsonc index 3f64191a..de4da4d9 100644 --- a/schemas/preview/tool_definition.schema.jsonc +++ b/schemas/preview/tool_definition.schema.jsonc @@ -153,11 +153,15 @@ "properties": { "provider_id": { "type": "string", - "description": "A unique provider ID." + "description": "The provider ID configured in Arcade that acts as an alias to well-known configuration." }, "provider_type": { "type": "string", - "description": "The provider type." + "description": "The type of the authorization provider." + }, + "id": { + "type": "string", + "description": "A provider's unique identifier, allowing the tool to specify a specific authorization provider. Recommended for private tools only." }, "oauth2": { "type": "object", @@ -172,7 +176,7 @@ "additionalProperties": false } }, - "required": ["provider_id", "provider_type"], + "required": ["provider_type"], "additionalProperties": false } ]