Clean up provider properties (scopes) (#42)
In this PR: - Rename `scope` to `scopes` so it is more understandable by humans - DRY up provider structs, it was starting to get silly with so many providers that just have 1 property called `scopes` Must go along with this Engine PR: https://github.com/ArcadeAI/Engine/pull/79
This commit is contained in:
parent
ce4a9b28a9
commit
f4fe8c7892
10 changed files with 77 additions and 68 deletions
|
|
@ -50,7 +50,7 @@ class AuthResource(BaseResource[ClientT]):
|
|||
body = {
|
||||
"auth_requirement": {
|
||||
"provider": auth_provider,
|
||||
auth_provider: AuthRequest(scope=scopes, authority=authority).model_dump(
|
||||
auth_provider: AuthRequest(scopes=scopes, authority=authority).model_dump(
|
||||
exclude_none=True
|
||||
),
|
||||
},
|
||||
|
|
@ -200,7 +200,7 @@ class AsyncAuthResource(BaseResource[AsyncArcadeClient]):
|
|||
body = {
|
||||
"auth_requirement": {
|
||||
"provider": auth_provider,
|
||||
auth_provider: AuthRequest(scope=scopes, authority=authority).model_dump(
|
||||
auth_provider: AuthRequest(scopes=scopes, authority=authority).model_dump(
|
||||
exclude_none=True
|
||||
),
|
||||
},
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ class AuthRequest(BaseModel):
|
|||
authority: AnyUrl | str | None = None
|
||||
"""The URL of the OAuth 2.0 authorization server."""
|
||||
|
||||
scope: list[str]
|
||||
scopes: list[str]
|
||||
"""The scope(s) needed for authorization."""
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -24,10 +24,8 @@ from pydantic_core import PydanticUndefined
|
|||
|
||||
from arcade.core.errors import ToolDefinitionError
|
||||
from arcade.core.schema import (
|
||||
GoogleRequirement,
|
||||
InputParameter,
|
||||
OAuth2Requirement,
|
||||
SlackUserRequirement,
|
||||
ToolAuthRequirement,
|
||||
ToolContext,
|
||||
ToolDefinition,
|
||||
|
|
@ -44,7 +42,7 @@ from arcade.core.utils import (
|
|||
snake_to_pascal_case,
|
||||
)
|
||||
from arcade.sdk.annotations import Inferrable
|
||||
from arcade.sdk.auth import Google, OAuth2, SlackUser, ToolAuthorization
|
||||
from arcade.sdk.auth import BaseOAuth2, ToolAuthorization
|
||||
|
||||
InnerWireType = Literal["string", "integer", "number", "boolean", "json"]
|
||||
WireType = Union[InnerWireType, Literal["array"]]
|
||||
|
|
@ -204,14 +202,8 @@ class ToolCatalog(BaseModel):
|
|||
new_auth_requirement = ToolAuthRequirement(
|
||||
provider=auth_requirement.get_provider(),
|
||||
)
|
||||
if isinstance(auth_requirement, OAuth2):
|
||||
if isinstance(auth_requirement, BaseOAuth2):
|
||||
new_auth_requirement.oauth2 = OAuth2Requirement(**auth_requirement.model_dump())
|
||||
elif isinstance(auth_requirement, Google):
|
||||
new_auth_requirement.google = GoogleRequirement(**auth_requirement.model_dump())
|
||||
elif isinstance(auth_requirement, SlackUser):
|
||||
new_auth_requirement.slack_user = SlackUserRequirement(
|
||||
**auth_requirement.model_dump()
|
||||
)
|
||||
auth_requirement = new_auth_requirement
|
||||
|
||||
return ToolDefinition(
|
||||
|
|
|
|||
|
|
@ -68,25 +68,11 @@ class ToolOutput(BaseModel):
|
|||
class OAuth2Requirement(BaseModel):
|
||||
"""Indicates that the tool requires OAuth 2.0 authorization."""
|
||||
|
||||
authority: AnyUrl
|
||||
authority: Optional[AnyUrl] = None
|
||||
"""The URL of the OAuth 2.0 authorization server."""
|
||||
|
||||
scope: Optional[list[str]] = None
|
||||
"""The scope(s) needed for authorization."""
|
||||
|
||||
|
||||
class GoogleRequirement(BaseModel):
|
||||
"""Indicates that the tool requires Google authorization."""
|
||||
|
||||
scope: Optional[list[str]] = None
|
||||
"""The scope(s) needed for authorization."""
|
||||
|
||||
|
||||
class SlackUserRequirement(BaseModel):
|
||||
"""Indicates that the tool requires Slack (user token) authorization."""
|
||||
|
||||
scope: Optional[list[str]] = None
|
||||
"""The scope(s) needed for authorization."""
|
||||
scopes: Optional[list[str]] = None
|
||||
"""The scope(s) needed for authorization, if any."""
|
||||
|
||||
|
||||
class ToolAuthRequirement(BaseModel):
|
||||
|
|
@ -98,12 +84,6 @@ class ToolAuthRequirement(BaseModel):
|
|||
oauth2: Optional[OAuth2Requirement] = None
|
||||
"""The OAuth 2.0 requirement, if any."""
|
||||
|
||||
google: Optional[GoogleRequirement] = None
|
||||
"""The Google requirement, if any."""
|
||||
|
||||
slack_user: Optional[SlackUserRequirement] = None
|
||||
"""The Slack (user token) requirement, if any."""
|
||||
|
||||
|
||||
class ToolRequirements(BaseModel):
|
||||
"""The requirements for a tool to run."""
|
||||
|
|
|
|||
|
|
@ -15,38 +15,36 @@ class ToolAuthorization(BaseModel, ABC):
|
|||
pass
|
||||
|
||||
|
||||
class OAuth2(ToolAuthorization):
|
||||
class BaseOAuth2(ToolAuthorization):
|
||||
"""Base class for any provider supporting OAuth 2.0-like authorization."""
|
||||
|
||||
authority: Optional[AnyUrl] = None
|
||||
"""The URL of the OAuth 2.0 authorization server."""
|
||||
|
||||
scopes: Optional[list[str]] = None
|
||||
"""The scope(s) needed for the authorized action."""
|
||||
|
||||
|
||||
class OAuth2(BaseOAuth2):
|
||||
"""Marks a tool as requiring OAuth 2.0 authorization."""
|
||||
|
||||
def get_provider(self) -> str:
|
||||
return "oauth2"
|
||||
|
||||
authority: AnyUrl
|
||||
"""The URL of the OAuth 2.0 authorization server."""
|
||||
|
||||
scope: Optional[list[str]] = None
|
||||
"""The scope(s) needed for the authorized action."""
|
||||
|
||||
|
||||
class Google(ToolAuthorization):
|
||||
class Google(BaseOAuth2):
|
||||
"""Marks a tool as requiring Google authorization."""
|
||||
|
||||
def get_provider(self) -> str:
|
||||
return "google"
|
||||
|
||||
scope: Optional[list[str]] = None
|
||||
"""The scope(s) needed for the authorized action."""
|
||||
|
||||
|
||||
class SlackUser(ToolAuthorization):
|
||||
class SlackUser(BaseOAuth2):
|
||||
"""Marks a tool as requiring Slack (user token) authorization."""
|
||||
|
||||
def get_provider(self) -> str:
|
||||
return "slack_user"
|
||||
|
||||
scope: Optional[list[str]] = None
|
||||
"""The scope(s) needed for the authorized action."""
|
||||
|
||||
|
||||
class GitHubApp(ToolAuthorization):
|
||||
"""Marks a tool as requiring GitHub App authorization."""
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ def test_tool_decorator_with_all_options():
|
|||
desc="Test description",
|
||||
requires_auth=OAuth2(
|
||||
authority="https://example.com/oauth2/auth",
|
||||
scope=["test_scope", "another.scope"],
|
||||
scopes=["test_scope", "another.scope"],
|
||||
),
|
||||
)
|
||||
def test_tool(x, y):
|
||||
|
|
@ -49,4 +49,4 @@ def test_tool_decorator_with_all_options():
|
|||
assert test_tool.__tool_name__ == "TestTool"
|
||||
assert test_tool.__tool_description__ == "Test description"
|
||||
assert str(test_tool.__tool_requires_auth__.authority) == "https://example.com/oauth2/auth"
|
||||
assert test_tool.__tool_requires_auth__.scope == ["test_scope", "another.scope"]
|
||||
assert test_tool.__tool_requires_auth__.scopes == ["test_scope", "another.scope"]
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ import pytest
|
|||
|
||||
from arcade.core.catalog import ToolCatalog
|
||||
from arcade.core.schema import (
|
||||
GoogleRequirement,
|
||||
InputParameter,
|
||||
OAuth2Requirement,
|
||||
ToolAuthRequirement,
|
||||
|
|
@ -17,7 +16,7 @@ from arcade.core.schema import (
|
|||
)
|
||||
from arcade.sdk import tool
|
||||
from arcade.sdk.annotations import Inferrable
|
||||
from arcade.sdk.auth import Google, OAuth2
|
||||
from arcade.sdk.auth import GitHubApp, Google, OAuth2, SlackUser
|
||||
|
||||
|
||||
### Tests on @tool decorator
|
||||
|
|
@ -48,20 +47,36 @@ def func_with_name_and_description():
|
|||
|
||||
@tool(
|
||||
desc="A function that requires authentication",
|
||||
requires_auth=OAuth2(authority="https://example.com/oauth2/auth", scope=["scope1", "scope2"]),
|
||||
requires_auth=OAuth2(authority="https://example.com/oauth2/auth", scopes=["scope1", "scope2"]),
|
||||
)
|
||||
def func_with_auth_requirement():
|
||||
pass
|
||||
|
||||
|
||||
@tool(
|
||||
desc="A function that requires authentication",
|
||||
requires_auth=Google(scope=["https://www.googleapis.com/auth/gmail.readonly"]),
|
||||
desc="A function that requires Google authorization",
|
||||
requires_auth=Google(scopes=["https://www.googleapis.com/auth/gmail.readonly"]),
|
||||
)
|
||||
def func_with_google_auth_requirement():
|
||||
pass
|
||||
|
||||
|
||||
@tool(
|
||||
desc="A function that requires GitHub authorization",
|
||||
requires_auth=GitHubApp(),
|
||||
)
|
||||
def func_with_github_auth_requirement():
|
||||
pass
|
||||
|
||||
|
||||
@tool(
|
||||
desc="A function that requires Slack user authorization",
|
||||
requires_auth=SlackUser(scopes=["chat:write", "channels:history"]),
|
||||
)
|
||||
def func_with_slack_user_auth_requirement():
|
||||
pass
|
||||
|
||||
|
||||
### Tests on input params
|
||||
@tool(desc="A function with a non-inferrable input parameter")
|
||||
def func_with_non_inferrable_param(param1: Annotated[str, "First param", Inferrable(False)]):
|
||||
|
|
@ -229,7 +244,7 @@ def func_with_complex_return() -> dict[str, str]:
|
|||
provider="oauth2",
|
||||
oauth2=OAuth2Requirement(
|
||||
authority="https://example.com/oauth2/auth",
|
||||
scope=["scope1", "scope2"],
|
||||
scopes=["scope1", "scope2"],
|
||||
),
|
||||
)
|
||||
)
|
||||
|
|
@ -242,14 +257,39 @@ def func_with_complex_return() -> dict[str, str]:
|
|||
"requirements": ToolRequirements(
|
||||
authorization=ToolAuthRequirement(
|
||||
provider="google",
|
||||
google=GoogleRequirement(
|
||||
scope=["https://www.googleapis.com/auth/gmail.readonly"],
|
||||
oauth2=OAuth2Requirement(
|
||||
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
|
||||
),
|
||||
)
|
||||
)
|
||||
},
|
||||
id="func_with_google_auth_requirement",
|
||||
),
|
||||
pytest.param(
|
||||
func_with_github_auth_requirement,
|
||||
{
|
||||
"requirements": ToolRequirements(
|
||||
authorization=ToolAuthRequirement(
|
||||
provider="github_app",
|
||||
)
|
||||
)
|
||||
},
|
||||
id="func_with_github_auth_requirement",
|
||||
),
|
||||
pytest.param(
|
||||
func_with_slack_user_auth_requirement,
|
||||
{
|
||||
"requirements": ToolRequirements(
|
||||
authorization=ToolAuthRequirement(
|
||||
provider="slack_user",
|
||||
oauth2=OAuth2Requirement(
|
||||
scopes=["chat:write", "channels:history"],
|
||||
),
|
||||
)
|
||||
)
|
||||
},
|
||||
id="func_with_slack_user_auth_requirement",
|
||||
),
|
||||
# Tests on input params
|
||||
pytest.param(
|
||||
func_with_non_inferrable_param,
|
||||
|
|
|
|||
|
|
@ -148,11 +148,10 @@
|
|||
}
|
||||
}
|
||||
},
|
||||
"required": ["authority"],
|
||||
"additionalProperties": false
|
||||
}
|
||||
},
|
||||
"required": [],
|
||||
"required": ["provider"],
|
||||
"additionalProperties": false
|
||||
}
|
||||
]
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from arcade.sdk.auth import Google
|
|||
|
||||
@tool(
|
||||
requires_auth=Google(
|
||||
scope=["https://www.googleapis.com/auth/gmail.compose"],
|
||||
scopes=["https://www.googleapis.com/auth/gmail.compose"],
|
||||
)
|
||||
)
|
||||
async def write_draft(
|
||||
|
|
@ -88,7 +88,7 @@ class DateRange(Enum):
|
|||
|
||||
@tool(
|
||||
requires_auth=Google(
|
||||
scope=["https://www.googleapis.com/auth/gmail.readonly"],
|
||||
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
|
||||
)
|
||||
)
|
||||
async def search_emails_by_header(
|
||||
|
|
@ -168,7 +168,7 @@ async def search_emails_by_header(
|
|||
|
||||
@tool(
|
||||
requires_auth=Google(
|
||||
scope=["https://www.googleapis.com/auth/gmail.readonly"],
|
||||
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
|
||||
)
|
||||
)
|
||||
async def get_emails(
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from arcade.sdk.auth import SlackUser
|
|||
|
||||
@tool(
|
||||
requires_auth=SlackUser(
|
||||
scope=[
|
||||
scopes=[
|
||||
"chat:write",
|
||||
"im:write",
|
||||
"users.profile:read",
|
||||
|
|
@ -72,7 +72,7 @@ def format_users(userListResponse: dict) -> str:
|
|||
|
||||
@tool(
|
||||
requires_auth=SlackUser(
|
||||
scope=[
|
||||
scopes=[
|
||||
"chat:write",
|
||||
"channels:read",
|
||||
"groups:read",
|
||||
|
|
|
|||
Loading…
Reference in a new issue