Update auth provider_id (#173)
# PR Description Well known providers (Google, X, Dropbox, etc.) can optionally have an `id` in addition to their hardcoded `provider_id`. For non well known providers, they must provide an `id`, and the `provider_id` is hardcoded as `None`. ```python OAuth2() # INVALID OAuth2(provider_id="abc") # INVALID OAuth2(id="abc") # VALID OAuth2(provider_id="abc", id="def") # INVALID ``` ```python Google() # VALID Google(provider_id="abc") # INVALID Google(id="abc") # VALID Google(provider_id="abc", id="def") # INVALID ``` --------- Co-authored-by: Wils Dawson <wils@arcade-ai.com>
This commit is contained in:
parent
cdd90b4844
commit
6035cde920
6 changed files with 161 additions and 26 deletions
|
|
@ -13,20 +13,24 @@ class ToolAuthorization(BaseModel):
|
|||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
provider_id: str
|
||||
"""The unique provider ID configured in Arcade."""
|
||||
provider_id: Optional[str] = None
|
||||
"""The provider ID configured in Arcade that acts as an alias to well-known configuration."""
|
||||
|
||||
provider_type: AuthProviderType
|
||||
"""The type of the authorization provider."""
|
||||
|
||||
id: Optional[str] = None
|
||||
"""A provider's unique identifier, allowing the tool to specify a specific authorization provider. Recommended for private tools only."""
|
||||
|
||||
scopes: Optional[list[str]] = None
|
||||
"""The scope(s) needed for the authorized action."""
|
||||
|
||||
|
||||
class OAuth2(ToolAuthorization):
|
||||
"""Marks a tool as requiring OAuth 2.0 authorization."""
|
||||
|
||||
provider_type: AuthProviderType = AuthProviderType.oauth2
|
||||
|
||||
scopes: Optional[list[str]] = None
|
||||
"""The scope(s) needed for the authorized action."""
|
||||
def __init__(self, *, id: str | None, scopes: Optional[list[str]] = None): # noqa: A002
|
||||
super().__init__(id=id, scopes=scopes, provider_type=AuthProviderType.oauth2)
|
||||
|
||||
|
||||
class Atlassian(OAuth2):
|
||||
|
|
@ -34,56 +38,86 @@ class Atlassian(OAuth2):
|
|||
|
||||
provider_id: str = "atlassian"
|
||||
|
||||
def __init__(self, *, id: Optional[str] = None, scopes: Optional[list[str]] = None): # noqa: A002
|
||||
super().__init__(id=id, scopes=scopes)
|
||||
|
||||
|
||||
class Discord(OAuth2):
|
||||
"""Marks a tool as requiring Discord authorization."""
|
||||
|
||||
provider_id: str = "discord"
|
||||
|
||||
def __init__(self, *, id: Optional[str] = None, scopes: Optional[list[str]] = None): # noqa: A002
|
||||
super().__init__(id=id, scopes=scopes)
|
||||
|
||||
|
||||
class Dropbox(OAuth2):
|
||||
"""Marks a tool as requiring Dropbox authorization."""
|
||||
|
||||
provider_id: str = "dropbox"
|
||||
|
||||
def __init__(self, *, id: Optional[str] = None, scopes: Optional[list[str]] = None): # noqa: A002
|
||||
super().__init__(id=id, scopes=scopes)
|
||||
|
||||
|
||||
class Google(OAuth2):
|
||||
"""Marks a tool as requiring Google authorization."""
|
||||
|
||||
provider_id: str = "google"
|
||||
|
||||
def __init__(self, *, id: Optional[str] = None, scopes: Optional[list[str]] = None): # noqa: A002
|
||||
super().__init__(id=id, scopes=scopes)
|
||||
|
||||
|
||||
class Slack(OAuth2):
|
||||
"""Marks a tool as requiring Slack (user token) authorization."""
|
||||
|
||||
provider_id: str = "slack"
|
||||
|
||||
def __init__(self, *, id: Optional[str] = None, scopes: Optional[list[str]] = None): # noqa: A002
|
||||
super().__init__(id=id, scopes=scopes)
|
||||
|
||||
|
||||
class GitHub(OAuth2):
|
||||
"""Marks a tool as requiring GitHub App authorization."""
|
||||
|
||||
provider_id: str = "github"
|
||||
|
||||
def __init__(self, *, id: Optional[str] = None, scopes: Optional[list[str]] = None): # noqa: A002
|
||||
super().__init__(id=id, scopes=scopes)
|
||||
|
||||
|
||||
class X(OAuth2):
|
||||
"""Marks a tool as requiring X (Twitter) authorization."""
|
||||
|
||||
provider_id: str = "x"
|
||||
|
||||
def __init__(self, *, id: Optional[str] = None, scopes: Optional[list[str]] = None): # noqa: A002
|
||||
super().__init__(id=id, scopes=scopes)
|
||||
|
||||
|
||||
class LinkedIn(OAuth2):
|
||||
"""Marks a tool as requiring LinkedIn authorization."""
|
||||
|
||||
provider_id: str = "linkedin"
|
||||
|
||||
def __init__(self, *, id: Optional[str] = None, scopes: Optional[list[str]] = None): # noqa: A002
|
||||
super().__init__(id=id, scopes=scopes)
|
||||
|
||||
|
||||
class Spotify(OAuth2):
|
||||
"""Marks a tool as requiring Spotify authorization."""
|
||||
|
||||
provider_id: str = "spotify"
|
||||
|
||||
def __init__(self, *, id: Optional[str] = None, scopes: Optional[list[str]] = None): # noqa: A002
|
||||
super().__init__(id=id, scopes=scopes)
|
||||
|
||||
|
||||
class Zoom(OAuth2):
|
||||
"""Marks a tool as requiring Zoom authorization."""
|
||||
|
||||
provider_id: str = "zoom"
|
||||
|
||||
def __init__(self, *, id: Optional[str] = None, scopes: Optional[list[str]] = None): # noqa: A002
|
||||
super().__init__(id=id, scopes=scopes)
|
||||
|
|
|
|||
|
|
@ -298,6 +298,7 @@ class ToolCatalog(BaseModel):
|
|||
new_auth_requirement = ToolAuthRequirement(
|
||||
provider_id=auth_requirement.provider_id,
|
||||
provider_type=auth_requirement.provider_type,
|
||||
id=auth_requirement.id,
|
||||
)
|
||||
if isinstance(auth_requirement, OAuth2):
|
||||
new_auth_requirement.oauth2 = OAuth2Requirement(**auth_requirement.model_dump())
|
||||
|
|
|
|||
|
|
@ -74,26 +74,29 @@ class OAuth2Requirement(BaseModel):
|
|||
"""Indicates that the tool requires OAuth 2.0 authorization."""
|
||||
|
||||
scopes: Optional[list[str]] = None
|
||||
"""The scope(s) needed for authorization, if any."""
|
||||
"""The scope(s) needed for the authorized action."""
|
||||
|
||||
|
||||
class ToolAuthRequirement(BaseModel):
|
||||
"""A requirement for authorization to use a tool."""
|
||||
|
||||
# Provider ID and Type needed for the Arcade Engine to look up the auth provider.
|
||||
# Provider ID, Type, and ID needed for the Arcade Engine to look up the auth provider.
|
||||
# However, the developer generally does not need to set these directly.
|
||||
# Instead, they will use:
|
||||
# @tool(requires_auth=Google(scopes=["profile", "email"]))
|
||||
# or
|
||||
# client.auth.authorize(provider=AuthProvider.google, scopes=["profile", "email"])
|
||||
#
|
||||
# The Arcade SDK translates these into the appropriate provider ID and type.
|
||||
# The Arcade SDK translates these into the appropriate provider ID (Google) and type (OAuth2).
|
||||
# The only time the developer will set these is if they are using a custom auth provider.
|
||||
provider_id: Optional[str] = None
|
||||
"""A unique provider ID."""
|
||||
"""The provider ID configured in Arcade that acts as an alias to well-known configuration."""
|
||||
|
||||
provider_type: str
|
||||
"""The provider type."""
|
||||
"""The type of the authorization provider."""
|
||||
|
||||
id: Optional[str] = None
|
||||
"""A provider's unique identifier, allowing the tool to specify a specific authorization provider. Recommended for private tools only."""
|
||||
|
||||
oauth2: Optional[OAuth2Requirement] = None
|
||||
"""The OAuth 2.0 requirement, if any."""
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import asyncio
|
|||
|
||||
import pytest
|
||||
|
||||
from arcade.core.auth import AuthProviderType, Google
|
||||
from arcade.sdk import tool
|
||||
from arcade.sdk.auth import OAuth2
|
||||
|
||||
|
|
@ -34,18 +35,83 @@ async def test_async_function():
|
|||
assert result == 3
|
||||
|
||||
|
||||
def test_tool_decorator_with_all_options():
|
||||
@pytest.mark.parametrize(
|
||||
"auth_class, auth_kwargs, expected_provider_id, expected_id",
|
||||
[
|
||||
(
|
||||
OAuth2,
|
||||
{"id": "my_example_provider123", "scopes": ["test_scope", "another.scope"]},
|
||||
None,
|
||||
"my_example_provider123",
|
||||
),
|
||||
(Google, {"scopes": ["test_scope", "another.scope"]}, "google", None),
|
||||
(
|
||||
Google,
|
||||
{"id": "my_google_provider123", "scopes": ["test_scope", "another.scope"]},
|
||||
"google",
|
||||
"my_google_provider123",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_tool_decorator_with_auth_success(
|
||||
auth_class, auth_kwargs, expected_provider_id, expected_id
|
||||
):
|
||||
@tool(
|
||||
name="TestTool",
|
||||
desc="Test description",
|
||||
requires_auth=OAuth2(
|
||||
provider_id="example",
|
||||
scopes=["test_scope", "another.scope"],
|
||||
),
|
||||
requires_auth=auth_class(**auth_kwargs),
|
||||
)
|
||||
def test_tool(x, y):
|
||||
return x + y
|
||||
|
||||
assert test_tool.__tool_name__ == "TestTool"
|
||||
assert test_tool.__tool_description__ == "Test description"
|
||||
assert test_tool.__tool_requires_auth__.provider_id == expected_provider_id
|
||||
assert test_tool.__tool_requires_auth__.provider_type == AuthProviderType.oauth2
|
||||
assert test_tool.__tool_requires_auth__.id == expected_id
|
||||
assert test_tool.__tool_requires_auth__.scopes == ["test_scope", "another.scope"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"auth_class, auth_kwargs",
|
||||
[
|
||||
(OAuth2, {"scopes": ["test_scope", "another.scope"]}),
|
||||
(
|
||||
OAuth2,
|
||||
{"provider_id": "my_example_provider123", "scopes": ["test_scope", "another.scope"]},
|
||||
),
|
||||
(
|
||||
OAuth2,
|
||||
{
|
||||
"provider_id": "my_example_provider_id_123",
|
||||
"id": "my_example_id_123",
|
||||
"scopes": ["test_scope", "another.scope"],
|
||||
},
|
||||
),
|
||||
(
|
||||
Google,
|
||||
{
|
||||
"provider_id": "my_example_provider_id_123",
|
||||
"scopes": ["test_scope", "another.scope"],
|
||||
},
|
||||
),
|
||||
(
|
||||
Google,
|
||||
{
|
||||
"provider_id": "my_example_provider_id_123",
|
||||
"id": "my_example_id_123",
|
||||
"scopes": ["test_scope", "another.scope"],
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_tool_decorator_with_auth_failure(auth_class, auth_kwargs):
|
||||
with pytest.raises(TypeError):
|
||||
|
||||
@tool(
|
||||
name="TestTool",
|
||||
desc="Test description",
|
||||
requires_auth=auth_class(**auth_kwargs),
|
||||
)
|
||||
def test_tool(x, y):
|
||||
return x + y
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ def func_with_name_and_description():
|
|||
@tool(
|
||||
desc="A function that requires authentication",
|
||||
requires_auth=OAuth2(
|
||||
provider_id="example",
|
||||
id="my_example_provider123",
|
||||
scopes=["scope1", "scope2"],
|
||||
),
|
||||
)
|
||||
|
|
@ -59,7 +59,10 @@ def func_with_auth_requirement():
|
|||
|
||||
@tool(
|
||||
desc="A function that requires Google authorization",
|
||||
requires_auth=Google(scopes=["https://www.googleapis.com/auth/gmail.readonly"]),
|
||||
requires_auth=Google(
|
||||
id="my_google_provider123",
|
||||
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
|
||||
),
|
||||
)
|
||||
def func_with_google_auth_requirement():
|
||||
pass
|
||||
|
|
@ -67,7 +70,9 @@ def func_with_google_auth_requirement():
|
|||
|
||||
@tool(
|
||||
desc="A function that requires GitHub authorization",
|
||||
requires_auth=GitHub(),
|
||||
requires_auth=GitHub(
|
||||
id="my_github_provider123",
|
||||
),
|
||||
)
|
||||
def func_with_github_auth_requirement():
|
||||
pass
|
||||
|
|
@ -75,7 +80,9 @@ def func_with_github_auth_requirement():
|
|||
|
||||
@tool(
|
||||
desc="A function that requires Slack user authorization",
|
||||
requires_auth=Slack(scopes=["chat:write", "channels:history"]),
|
||||
requires_auth=Slack(
|
||||
scopes=["chat:write", "channels:history"],
|
||||
),
|
||||
)
|
||||
def func_with_slack_user_auth_requirement():
|
||||
pass
|
||||
|
|
@ -83,7 +90,9 @@ def func_with_slack_user_auth_requirement():
|
|||
|
||||
@tool(
|
||||
desc="A function that requires X (Twitter) authorization",
|
||||
requires_auth=X(scopes=["tweet.write"]),
|
||||
requires_auth=X(
|
||||
scopes=["tweet.write"],
|
||||
),
|
||||
)
|
||||
def func_with_x_requirement():
|
||||
pass
|
||||
|
|
@ -257,8 +266,8 @@ def func_with_complex_return() -> dict[str, str]:
|
|||
{
|
||||
"requirements": ToolRequirements(
|
||||
authorization=ToolAuthRequirement(
|
||||
provider_id="example",
|
||||
provider_type="oauth2",
|
||||
id="my_example_provider123",
|
||||
oauth2=OAuth2Requirement(
|
||||
authority="https://example.com/oauth2/auth",
|
||||
scopes=["scope1", "scope2"],
|
||||
|
|
@ -275,6 +284,7 @@ def func_with_complex_return() -> dict[str, str]:
|
|||
authorization=ToolAuthRequirement(
|
||||
provider_id="google",
|
||||
provider_type="oauth2",
|
||||
id="my_google_provider123",
|
||||
oauth2=OAuth2Requirement(
|
||||
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
|
||||
),
|
||||
|
|
@ -288,7 +298,10 @@ def func_with_complex_return() -> dict[str, str]:
|
|||
{
|
||||
"requirements": ToolRequirements(
|
||||
authorization=ToolAuthRequirement(
|
||||
provider_id="github", provider_type="oauth2", oauth2=OAuth2Requirement()
|
||||
provider_id="github",
|
||||
provider_type="oauth2",
|
||||
id="my_github_provider123",
|
||||
oauth2=OAuth2Requirement(),
|
||||
)
|
||||
)
|
||||
},
|
||||
|
|
@ -309,6 +322,20 @@ def func_with_complex_return() -> dict[str, str]:
|
|||
},
|
||||
id="func_with_slack_user_auth_requirement",
|
||||
),
|
||||
pytest.param(
|
||||
func_with_x_requirement,
|
||||
{
|
||||
"requirements": ToolRequirements(
|
||||
authorization=ToolAuthRequirement(
|
||||
provider_id="x",
|
||||
provider_type="oauth2",
|
||||
oauth2=OAuth2Requirement(
|
||||
scopes=["tweet.write"],
|
||||
),
|
||||
)
|
||||
)
|
||||
},
|
||||
),
|
||||
# Tests on input params
|
||||
pytest.param(
|
||||
func_with_non_inferrable_param,
|
||||
|
|
|
|||
|
|
@ -153,11 +153,15 @@
|
|||
"properties": {
|
||||
"provider_id": {
|
||||
"type": "string",
|
||||
"description": "A unique provider ID."
|
||||
"description": "The provider ID configured in Arcade that acts as an alias to well-known configuration."
|
||||
},
|
||||
"provider_type": {
|
||||
"type": "string",
|
||||
"description": "The provider type."
|
||||
"description": "The type of the authorization provider."
|
||||
},
|
||||
"id": {
|
||||
"type": "string",
|
||||
"description": "A provider's unique identifier, allowing the tool to specify a specific authorization provider. Recommended for private tools only."
|
||||
},
|
||||
"oauth2": {
|
||||
"type": "object",
|
||||
|
|
@ -172,7 +176,7 @@
|
|||
"additionalProperties": false
|
||||
}
|
||||
},
|
||||
"required": ["provider_id", "provider_type"],
|
||||
"required": ["provider_type"],
|
||||
"additionalProperties": false
|
||||
}
|
||||
]
|
||||
|
|
|
|||
Loading…
Reference in a new issue