Clean up provider properties (scopes) (#42)

In this PR:
- Rename `scope` to `scopes` so it is more understandable by humans
- DRY up provider structs, it was starting to get silly with so many
providers that just have 1 property called `scopes`

Must go along with this Engine PR:
https://github.com/ArcadeAI/Engine/pull/79
This commit is contained in:
Nate Barbettini 2024-09-17 16:38:51 -07:00 committed by GitHub
parent ce4a9b28a9
commit f4fe8c7892
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 77 additions and 68 deletions

View file

@ -50,7 +50,7 @@ class AuthResource(BaseResource[ClientT]):
body = {
"auth_requirement": {
"provider": auth_provider,
auth_provider: AuthRequest(scope=scopes, authority=authority).model_dump(
auth_provider: AuthRequest(scopes=scopes, authority=authority).model_dump(
exclude_none=True
),
},
@ -200,7 +200,7 @@ class AsyncAuthResource(BaseResource[AsyncArcadeClient]):
body = {
"auth_requirement": {
"provider": auth_provider,
auth_provider: AuthRequest(scope=scopes, authority=authority).model_dump(
auth_provider: AuthRequest(scopes=scopes, authority=authority).model_dump(
exclude_none=True
),
},

View file

@ -30,7 +30,7 @@ class AuthRequest(BaseModel):
authority: AnyUrl | str | None = None
"""The URL of the OAuth 2.0 authorization server."""
scope: list[str]
scopes: list[str]
"""The scope(s) needed for authorization."""

View file

@ -24,10 +24,8 @@ from pydantic_core import PydanticUndefined
from arcade.core.errors import ToolDefinitionError
from arcade.core.schema import (
GoogleRequirement,
InputParameter,
OAuth2Requirement,
SlackUserRequirement,
ToolAuthRequirement,
ToolContext,
ToolDefinition,
@ -44,7 +42,7 @@ from arcade.core.utils import (
snake_to_pascal_case,
)
from arcade.sdk.annotations import Inferrable
from arcade.sdk.auth import Google, OAuth2, SlackUser, ToolAuthorization
from arcade.sdk.auth import BaseOAuth2, ToolAuthorization
InnerWireType = Literal["string", "integer", "number", "boolean", "json"]
WireType = Union[InnerWireType, Literal["array"]]
@ -204,14 +202,8 @@ class ToolCatalog(BaseModel):
new_auth_requirement = ToolAuthRequirement(
provider=auth_requirement.get_provider(),
)
if isinstance(auth_requirement, OAuth2):
if isinstance(auth_requirement, BaseOAuth2):
new_auth_requirement.oauth2 = OAuth2Requirement(**auth_requirement.model_dump())
elif isinstance(auth_requirement, Google):
new_auth_requirement.google = GoogleRequirement(**auth_requirement.model_dump())
elif isinstance(auth_requirement, SlackUser):
new_auth_requirement.slack_user = SlackUserRequirement(
**auth_requirement.model_dump()
)
auth_requirement = new_auth_requirement
return ToolDefinition(

View file

@ -68,25 +68,11 @@ class ToolOutput(BaseModel):
class OAuth2Requirement(BaseModel):
"""Indicates that the tool requires OAuth 2.0 authorization."""
authority: AnyUrl
authority: Optional[AnyUrl] = None
"""The URL of the OAuth 2.0 authorization server."""
scope: Optional[list[str]] = None
"""The scope(s) needed for authorization."""
class GoogleRequirement(BaseModel):
"""Indicates that the tool requires Google authorization."""
scope: Optional[list[str]] = None
"""The scope(s) needed for authorization."""
class SlackUserRequirement(BaseModel):
"""Indicates that the tool requires Slack (user token) authorization."""
scope: Optional[list[str]] = None
"""The scope(s) needed for authorization."""
scopes: Optional[list[str]] = None
"""The scope(s) needed for authorization, if any."""
class ToolAuthRequirement(BaseModel):
@ -98,12 +84,6 @@ class ToolAuthRequirement(BaseModel):
oauth2: Optional[OAuth2Requirement] = None
"""The OAuth 2.0 requirement, if any."""
google: Optional[GoogleRequirement] = None
"""The Google requirement, if any."""
slack_user: Optional[SlackUserRequirement] = None
"""The Slack (user token) requirement, if any."""
class ToolRequirements(BaseModel):
"""The requirements for a tool to run."""

View file

@ -15,38 +15,36 @@ class ToolAuthorization(BaseModel, ABC):
pass
class OAuth2(ToolAuthorization):
class BaseOAuth2(ToolAuthorization):
"""Base class for any provider supporting OAuth 2.0-like authorization."""
authority: Optional[AnyUrl] = None
"""The URL of the OAuth 2.0 authorization server."""
scopes: Optional[list[str]] = None
"""The scope(s) needed for the authorized action."""
class OAuth2(BaseOAuth2):
"""Marks a tool as requiring OAuth 2.0 authorization."""
def get_provider(self) -> str:
return "oauth2"
authority: AnyUrl
"""The URL of the OAuth 2.0 authorization server."""
scope: Optional[list[str]] = None
"""The scope(s) needed for the authorized action."""
class Google(ToolAuthorization):
class Google(BaseOAuth2):
"""Marks a tool as requiring Google authorization."""
def get_provider(self) -> str:
return "google"
scope: Optional[list[str]] = None
"""The scope(s) needed for the authorized action."""
class SlackUser(ToolAuthorization):
class SlackUser(BaseOAuth2):
"""Marks a tool as requiring Slack (user token) authorization."""
def get_provider(self) -> str:
return "slack_user"
scope: Optional[list[str]] = None
"""The scope(s) needed for the authorized action."""
class GitHubApp(ToolAuthorization):
"""Marks a tool as requiring GitHub App authorization."""

View file

@ -40,7 +40,7 @@ def test_tool_decorator_with_all_options():
desc="Test description",
requires_auth=OAuth2(
authority="https://example.com/oauth2/auth",
scope=["test_scope", "another.scope"],
scopes=["test_scope", "another.scope"],
),
)
def test_tool(x, y):
@ -49,4 +49,4 @@ def test_tool_decorator_with_all_options():
assert test_tool.__tool_name__ == "TestTool"
assert test_tool.__tool_description__ == "Test description"
assert str(test_tool.__tool_requires_auth__.authority) == "https://example.com/oauth2/auth"
assert test_tool.__tool_requires_auth__.scope == ["test_scope", "another.scope"]
assert test_tool.__tool_requires_auth__.scopes == ["test_scope", "another.scope"]

View file

@ -5,7 +5,6 @@ import pytest
from arcade.core.catalog import ToolCatalog
from arcade.core.schema import (
GoogleRequirement,
InputParameter,
OAuth2Requirement,
ToolAuthRequirement,
@ -17,7 +16,7 @@ from arcade.core.schema import (
)
from arcade.sdk import tool
from arcade.sdk.annotations import Inferrable
from arcade.sdk.auth import Google, OAuth2
from arcade.sdk.auth import GitHubApp, Google, OAuth2, SlackUser
### Tests on @tool decorator
@ -48,20 +47,36 @@ def func_with_name_and_description():
@tool(
desc="A function that requires authentication",
requires_auth=OAuth2(authority="https://example.com/oauth2/auth", scope=["scope1", "scope2"]),
requires_auth=OAuth2(authority="https://example.com/oauth2/auth", scopes=["scope1", "scope2"]),
)
def func_with_auth_requirement():
pass
@tool(
desc="A function that requires authentication",
requires_auth=Google(scope=["https://www.googleapis.com/auth/gmail.readonly"]),
desc="A function that requires Google authorization",
requires_auth=Google(scopes=["https://www.googleapis.com/auth/gmail.readonly"]),
)
def func_with_google_auth_requirement():
pass
@tool(
desc="A function that requires GitHub authorization",
requires_auth=GitHubApp(),
)
def func_with_github_auth_requirement():
pass
@tool(
desc="A function that requires Slack user authorization",
requires_auth=SlackUser(scopes=["chat:write", "channels:history"]),
)
def func_with_slack_user_auth_requirement():
pass
### Tests on input params
@tool(desc="A function with a non-inferrable input parameter")
def func_with_non_inferrable_param(param1: Annotated[str, "First param", Inferrable(False)]):
@ -229,7 +244,7 @@ def func_with_complex_return() -> dict[str, str]:
provider="oauth2",
oauth2=OAuth2Requirement(
authority="https://example.com/oauth2/auth",
scope=["scope1", "scope2"],
scopes=["scope1", "scope2"],
),
)
)
@ -242,14 +257,39 @@ def func_with_complex_return() -> dict[str, str]:
"requirements": ToolRequirements(
authorization=ToolAuthRequirement(
provider="google",
google=GoogleRequirement(
scope=["https://www.googleapis.com/auth/gmail.readonly"],
oauth2=OAuth2Requirement(
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
),
)
)
},
id="func_with_google_auth_requirement",
),
pytest.param(
func_with_github_auth_requirement,
{
"requirements": ToolRequirements(
authorization=ToolAuthRequirement(
provider="github_app",
)
)
},
id="func_with_github_auth_requirement",
),
pytest.param(
func_with_slack_user_auth_requirement,
{
"requirements": ToolRequirements(
authorization=ToolAuthRequirement(
provider="slack_user",
oauth2=OAuth2Requirement(
scopes=["chat:write", "channels:history"],
),
)
)
},
id="func_with_slack_user_auth_requirement",
),
# Tests on input params
pytest.param(
func_with_non_inferrable_param,

View file

@ -148,11 +148,10 @@
}
}
},
"required": ["authority"],
"additionalProperties": false
}
},
"required": [],
"required": ["provider"],
"additionalProperties": false
}
]

View file

@ -18,7 +18,7 @@ from arcade.sdk.auth import Google
@tool(
requires_auth=Google(
scope=["https://www.googleapis.com/auth/gmail.compose"],
scopes=["https://www.googleapis.com/auth/gmail.compose"],
)
)
async def write_draft(
@ -88,7 +88,7 @@ class DateRange(Enum):
@tool(
requires_auth=Google(
scope=["https://www.googleapis.com/auth/gmail.readonly"],
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
)
)
async def search_emails_by_header(
@ -168,7 +168,7 @@ async def search_emails_by_header(
@tool(
requires_auth=Google(
scope=["https://www.googleapis.com/auth/gmail.readonly"],
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
)
)
async def get_emails(

View file

@ -11,7 +11,7 @@ from arcade.sdk.auth import SlackUser
@tool(
requires_auth=SlackUser(
scope=[
scopes=[
"chat:write",
"im:write",
"users.profile:read",
@ -72,7 +72,7 @@ def format_users(userListResponse: dict) -> str:
@tool(
requires_auth=SlackUser(
scope=[
scopes=[
"chat:write",
"channels:read",
"groups:read",