Update auth provider_id (#173)

# PR Description
Well known providers (Google, X, Dropbox, etc.) can optionally have an
`id` in addition to their hardcoded `provider_id`. For non well known
providers, they must provide an `id`, and the `provider_id` is hardcoded
as `None`.

```python
OAuth2() # INVALID
OAuth2(provider_id="abc") # INVALID
OAuth2(id="abc") # VALID
OAuth2(provider_id="abc", id="def") # INVALID
```
```python
Google() # VALID
Google(provider_id="abc") # INVALID
Google(id="abc") # VALID
Google(provider_id="abc", id="def") # INVALID
```

---------

Co-authored-by: Wils Dawson <wils@arcade-ai.com>
This commit is contained in:
Eric Gustin 2025-01-17 11:38:06 -08:00 committed by GitHub
parent cdd90b4844
commit 6035cde920
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 161 additions and 26 deletions

View file

@ -13,20 +13,24 @@ class ToolAuthorization(BaseModel):
model_config = ConfigDict(frozen=True)
provider_id: str
"""The unique provider ID configured in Arcade."""
provider_id: Optional[str] = None
"""The provider ID configured in Arcade that acts as an alias to well-known configuration."""
provider_type: AuthProviderType
"""The type of the authorization provider."""
id: Optional[str] = None
"""A provider's unique identifier, allowing the tool to specify a specific authorization provider. Recommended for private tools only."""
scopes: Optional[list[str]] = None
"""The scope(s) needed for the authorized action."""
class OAuth2(ToolAuthorization):
"""Marks a tool as requiring OAuth 2.0 authorization."""
provider_type: AuthProviderType = AuthProviderType.oauth2
scopes: Optional[list[str]] = None
"""The scope(s) needed for the authorized action."""
def __init__(self, *, id: str | None, scopes: Optional[list[str]] = None): # noqa: A002
super().__init__(id=id, scopes=scopes, provider_type=AuthProviderType.oauth2)
class Atlassian(OAuth2):
@ -34,56 +38,86 @@ class Atlassian(OAuth2):
provider_id: str = "atlassian"
def __init__(self, *, id: Optional[str] = None, scopes: Optional[list[str]] = None): # noqa: A002
super().__init__(id=id, scopes=scopes)
class Discord(OAuth2):
"""Marks a tool as requiring Discord authorization."""
provider_id: str = "discord"
def __init__(self, *, id: Optional[str] = None, scopes: Optional[list[str]] = None): # noqa: A002
super().__init__(id=id, scopes=scopes)
class Dropbox(OAuth2):
"""Marks a tool as requiring Dropbox authorization."""
provider_id: str = "dropbox"
def __init__(self, *, id: Optional[str] = None, scopes: Optional[list[str]] = None): # noqa: A002
super().__init__(id=id, scopes=scopes)
class Google(OAuth2):
"""Marks a tool as requiring Google authorization."""
provider_id: str = "google"
def __init__(self, *, id: Optional[str] = None, scopes: Optional[list[str]] = None): # noqa: A002
super().__init__(id=id, scopes=scopes)
class Slack(OAuth2):
"""Marks a tool as requiring Slack (user token) authorization."""
provider_id: str = "slack"
def __init__(self, *, id: Optional[str] = None, scopes: Optional[list[str]] = None): # noqa: A002
super().__init__(id=id, scopes=scopes)
class GitHub(OAuth2):
"""Marks a tool as requiring GitHub App authorization."""
provider_id: str = "github"
def __init__(self, *, id: Optional[str] = None, scopes: Optional[list[str]] = None): # noqa: A002
super().__init__(id=id, scopes=scopes)
class X(OAuth2):
"""Marks a tool as requiring X (Twitter) authorization."""
provider_id: str = "x"
def __init__(self, *, id: Optional[str] = None, scopes: Optional[list[str]] = None): # noqa: A002
super().__init__(id=id, scopes=scopes)
class LinkedIn(OAuth2):
"""Marks a tool as requiring LinkedIn authorization."""
provider_id: str = "linkedin"
def __init__(self, *, id: Optional[str] = None, scopes: Optional[list[str]] = None): # noqa: A002
super().__init__(id=id, scopes=scopes)
class Spotify(OAuth2):
"""Marks a tool as requiring Spotify authorization."""
provider_id: str = "spotify"
def __init__(self, *, id: Optional[str] = None, scopes: Optional[list[str]] = None): # noqa: A002
super().__init__(id=id, scopes=scopes)
class Zoom(OAuth2):
"""Marks a tool as requiring Zoom authorization."""
provider_id: str = "zoom"
def __init__(self, *, id: Optional[str] = None, scopes: Optional[list[str]] = None): # noqa: A002
super().__init__(id=id, scopes=scopes)

View file

@ -298,6 +298,7 @@ class ToolCatalog(BaseModel):
new_auth_requirement = ToolAuthRequirement(
provider_id=auth_requirement.provider_id,
provider_type=auth_requirement.provider_type,
id=auth_requirement.id,
)
if isinstance(auth_requirement, OAuth2):
new_auth_requirement.oauth2 = OAuth2Requirement(**auth_requirement.model_dump())

View file

@ -74,26 +74,29 @@ class OAuth2Requirement(BaseModel):
"""Indicates that the tool requires OAuth 2.0 authorization."""
scopes: Optional[list[str]] = None
"""The scope(s) needed for authorization, if any."""
"""The scope(s) needed for the authorized action."""
class ToolAuthRequirement(BaseModel):
"""A requirement for authorization to use a tool."""
# Provider ID and Type needed for the Arcade Engine to look up the auth provider.
# Provider ID, Type, and ID needed for the Arcade Engine to look up the auth provider.
# However, the developer generally does not need to set these directly.
# Instead, they will use:
# @tool(requires_auth=Google(scopes=["profile", "email"]))
# or
# client.auth.authorize(provider=AuthProvider.google, scopes=["profile", "email"])
#
# The Arcade SDK translates these into the appropriate provider ID and type.
# The Arcade SDK translates these into the appropriate provider ID (Google) and type (OAuth2).
# The only time the developer will set these is if they are using a custom auth provider.
provider_id: Optional[str] = None
"""A unique provider ID."""
"""The provider ID configured in Arcade that acts as an alias to well-known configuration."""
provider_type: str
"""The provider type."""
"""The type of the authorization provider."""
id: Optional[str] = None
"""A provider's unique identifier, allowing the tool to specify a specific authorization provider. Recommended for private tools only."""
oauth2: Optional[OAuth2Requirement] = None
"""The OAuth 2.0 requirement, if any."""

View file

@ -2,6 +2,7 @@ import asyncio
import pytest
from arcade.core.auth import AuthProviderType, Google
from arcade.sdk import tool
from arcade.sdk.auth import OAuth2
@ -34,18 +35,83 @@ async def test_async_function():
assert result == 3
def test_tool_decorator_with_all_options():
@pytest.mark.parametrize(
"auth_class, auth_kwargs, expected_provider_id, expected_id",
[
(
OAuth2,
{"id": "my_example_provider123", "scopes": ["test_scope", "another.scope"]},
None,
"my_example_provider123",
),
(Google, {"scopes": ["test_scope", "another.scope"]}, "google", None),
(
Google,
{"id": "my_google_provider123", "scopes": ["test_scope", "another.scope"]},
"google",
"my_google_provider123",
),
],
)
def test_tool_decorator_with_auth_success(
auth_class, auth_kwargs, expected_provider_id, expected_id
):
@tool(
name="TestTool",
desc="Test description",
requires_auth=OAuth2(
provider_id="example",
scopes=["test_scope", "another.scope"],
),
requires_auth=auth_class(**auth_kwargs),
)
def test_tool(x, y):
return x + y
assert test_tool.__tool_name__ == "TestTool"
assert test_tool.__tool_description__ == "Test description"
assert test_tool.__tool_requires_auth__.provider_id == expected_provider_id
assert test_tool.__tool_requires_auth__.provider_type == AuthProviderType.oauth2
assert test_tool.__tool_requires_auth__.id == expected_id
assert test_tool.__tool_requires_auth__.scopes == ["test_scope", "another.scope"]
@pytest.mark.parametrize(
"auth_class, auth_kwargs",
[
(OAuth2, {"scopes": ["test_scope", "another.scope"]}),
(
OAuth2,
{"provider_id": "my_example_provider123", "scopes": ["test_scope", "another.scope"]},
),
(
OAuth2,
{
"provider_id": "my_example_provider_id_123",
"id": "my_example_id_123",
"scopes": ["test_scope", "another.scope"],
},
),
(
Google,
{
"provider_id": "my_example_provider_id_123",
"scopes": ["test_scope", "another.scope"],
},
),
(
Google,
{
"provider_id": "my_example_provider_id_123",
"id": "my_example_id_123",
"scopes": ["test_scope", "another.scope"],
},
),
],
)
def test_tool_decorator_with_auth_failure(auth_class, auth_kwargs):
with pytest.raises(TypeError):
@tool(
name="TestTool",
desc="Test description",
requires_auth=auth_class(**auth_kwargs),
)
def test_tool(x, y):
return x + y

View file

@ -49,7 +49,7 @@ def func_with_name_and_description():
@tool(
desc="A function that requires authentication",
requires_auth=OAuth2(
provider_id="example",
id="my_example_provider123",
scopes=["scope1", "scope2"],
),
)
@ -59,7 +59,10 @@ def func_with_auth_requirement():
@tool(
desc="A function that requires Google authorization",
requires_auth=Google(scopes=["https://www.googleapis.com/auth/gmail.readonly"]),
requires_auth=Google(
id="my_google_provider123",
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
),
)
def func_with_google_auth_requirement():
pass
@ -67,7 +70,9 @@ def func_with_google_auth_requirement():
@tool(
desc="A function that requires GitHub authorization",
requires_auth=GitHub(),
requires_auth=GitHub(
id="my_github_provider123",
),
)
def func_with_github_auth_requirement():
pass
@ -75,7 +80,9 @@ def func_with_github_auth_requirement():
@tool(
desc="A function that requires Slack user authorization",
requires_auth=Slack(scopes=["chat:write", "channels:history"]),
requires_auth=Slack(
scopes=["chat:write", "channels:history"],
),
)
def func_with_slack_user_auth_requirement():
pass
@ -83,7 +90,9 @@ def func_with_slack_user_auth_requirement():
@tool(
desc="A function that requires X (Twitter) authorization",
requires_auth=X(scopes=["tweet.write"]),
requires_auth=X(
scopes=["tweet.write"],
),
)
def func_with_x_requirement():
pass
@ -257,8 +266,8 @@ def func_with_complex_return() -> dict[str, str]:
{
"requirements": ToolRequirements(
authorization=ToolAuthRequirement(
provider_id="example",
provider_type="oauth2",
id="my_example_provider123",
oauth2=OAuth2Requirement(
authority="https://example.com/oauth2/auth",
scopes=["scope1", "scope2"],
@ -275,6 +284,7 @@ def func_with_complex_return() -> dict[str, str]:
authorization=ToolAuthRequirement(
provider_id="google",
provider_type="oauth2",
id="my_google_provider123",
oauth2=OAuth2Requirement(
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
),
@ -288,7 +298,10 @@ def func_with_complex_return() -> dict[str, str]:
{
"requirements": ToolRequirements(
authorization=ToolAuthRequirement(
provider_id="github", provider_type="oauth2", oauth2=OAuth2Requirement()
provider_id="github",
provider_type="oauth2",
id="my_github_provider123",
oauth2=OAuth2Requirement(),
)
)
},
@ -309,6 +322,20 @@ def func_with_complex_return() -> dict[str, str]:
},
id="func_with_slack_user_auth_requirement",
),
pytest.param(
func_with_x_requirement,
{
"requirements": ToolRequirements(
authorization=ToolAuthRequirement(
provider_id="x",
provider_type="oauth2",
oauth2=OAuth2Requirement(
scopes=["tweet.write"],
),
)
)
},
),
# Tests on input params
pytest.param(
func_with_non_inferrable_param,

View file

@ -153,11 +153,15 @@
"properties": {
"provider_id": {
"type": "string",
"description": "A unique provider ID."
"description": "The provider ID configured in Arcade that acts as an alias to well-known configuration."
},
"provider_type": {
"type": "string",
"description": "The provider type."
"description": "The type of the authorization provider."
},
"id": {
"type": "string",
"description": "A provider's unique identifier, allowing the tool to specify a specific authorization provider. Recommended for private tools only."
},
"oauth2": {
"type": "object",
@ -172,7 +176,7 @@
"additionalProperties": false
}
},
"required": ["provider_id", "provider_type"],
"required": ["provider_type"],
"additionalProperties": false
}
]