Update requires_authorization and other naming (#39)
Corresponds to Engine PR: https://github.com/ArcadeAI/Engine/pull/73
This commit is contained in:
parent
db948125d5
commit
6d854b0110
6 changed files with 67 additions and 28 deletions
|
|
@ -201,6 +201,8 @@ def chat(
|
|||
"bold blue",
|
||||
),
|
||||
)
|
||||
if stream:
|
||||
chat_header.append(" (streaming)")
|
||||
console.print(chat_header)
|
||||
|
||||
# Try to hit /health endpoint on engine and warn if it is down
|
||||
|
|
|
|||
|
|
@ -67,7 +67,8 @@ class AuthResource(BaseResource[ClientT]):
|
|||
def status(
|
||||
self, auth_id_or_response: Union[str, AuthResponse], scopes: list[str] | None = None
|
||||
) -> AuthResponse:
|
||||
"""Poll for the status of an authorization
|
||||
"""
|
||||
Poll for the status of an authorization
|
||||
|
||||
Polls using either the authorization ID or the data returned from the authorize method.
|
||||
|
||||
|
|
@ -85,7 +86,7 @@ class AuthResource(BaseResource[ClientT]):
|
|||
data = self._client._execute_request( # type: ignore[attr-defined]
|
||||
"GET",
|
||||
f"{self._base_path}/status",
|
||||
params={"authorizationID": auth_id, "scopes": " ".join(scopes) if scopes else None},
|
||||
params={"authorizationId": auth_id, "scopes": " ".join(scopes) if scopes else None},
|
||||
)
|
||||
return AuthResponse(**data)
|
||||
|
||||
|
|
@ -129,7 +130,7 @@ class ToolResource(BaseResource[ClientT]):
|
|||
data = self._client._execute_request( # type: ignore[attr-defined]
|
||||
"GET",
|
||||
f"{self._base_path}/definition",
|
||||
params={"director_id": director_id, "tool_id": tool_id},
|
||||
params={"directorId": director_id, "toolId": tool_id},
|
||||
)
|
||||
return ToolDefinition(**data)
|
||||
|
||||
|
|
@ -214,10 +215,29 @@ class AsyncAuthResource(BaseResource[AsyncArcadeClient]):
|
|||
)
|
||||
return AuthResponse(**data)
|
||||
|
||||
async def status(self, auth_id: str) -> AuthResponse:
|
||||
"""Poll for the status of an authorization asynchronously"""
|
||||
async def status(
|
||||
self, auth_id_or_response: Union[str, AuthResponse], scopes: list[str] | None = None
|
||||
) -> AuthResponse:
|
||||
"""
|
||||
Poll for the status of an authorization asynchronously
|
||||
|
||||
Polls using either the authorization ID or the data returned from the authorize method.
|
||||
|
||||
Example:
|
||||
auth_response = await client.auth.authorize(...)
|
||||
auth_status = await client.auth.poll_authorization(auth_response)
|
||||
auth_status = await client.auth.poll_authorization("auth_123", ["scope1", "scope2"])
|
||||
"""
|
||||
if isinstance(auth_id_or_response, AuthResponse):
|
||||
auth_id = auth_id_or_response.auth_id
|
||||
scopes = auth_id_or_response.scopes
|
||||
else:
|
||||
auth_id = auth_id_or_response
|
||||
|
||||
data = await self._client._execute_request( # type: ignore[attr-defined]
|
||||
"GET", f"{self._base_path}/status", params={"authorizationID": auth_id}
|
||||
"GET",
|
||||
f"{self._base_path}/status",
|
||||
params={"authorizationId": auth_id, "scopes": " ".join(scopes) if scopes else None},
|
||||
)
|
||||
return AuthResponse(**data)
|
||||
|
||||
|
|
@ -255,7 +275,7 @@ class AsyncToolResource(BaseResource[AsyncArcadeClient]):
|
|||
data = await self._client._execute_request( # type: ignore[attr-defined]
|
||||
"GET",
|
||||
f"{self._base_path}/definition",
|
||||
params={"director_id": director_id, "tool_id": tool_id},
|
||||
params={"directorId": director_id, "toolId": tool_id},
|
||||
)
|
||||
return ToolDefinition(**data)
|
||||
|
||||
|
|
|
|||
|
|
@ -52,14 +52,14 @@ class HealthCheckResponse(BaseModel):
|
|||
class AuthResponse(BaseModel):
|
||||
"""Response from an authorization request."""
|
||||
|
||||
auth_id: str = Field(alias="authorizationID")
|
||||
auth_id: str = Field(alias="authorization_id")
|
||||
"""The ID of the authorization request"""
|
||||
|
||||
scopes: list[str]
|
||||
"""The scope(s) requested in the authorization request"""
|
||||
|
||||
# TODO: Use AnyUrl?
|
||||
auth_url: str | None = Field(None, alias="authorizationURL")
|
||||
auth_url: str | None = Field(None, alias="authorization_url")
|
||||
"""The URL for the authorization"""
|
||||
|
||||
status: AuthStatus
|
||||
|
|
|
|||
|
|
@ -159,6 +159,19 @@ class ToolCallError(BaseModel):
|
|||
"""The number of milliseconds (if any) to wait before retrying the tool call."""
|
||||
|
||||
|
||||
class ToolCallRequiresAuthorization(BaseModel):
|
||||
"""The authorization requirements for the tool invocation."""
|
||||
|
||||
authorization_url: str | None = None
|
||||
"""The URL to redirect the user to for authorization."""
|
||||
authorization_id: str | None = None
|
||||
"""The ID for checking the status of the authorization."""
|
||||
scopes: list[str] | None = None
|
||||
"""The scopes that are required for authorization."""
|
||||
status: str | None = None
|
||||
"""The status of the authorization."""
|
||||
|
||||
|
||||
class ToolCallOutput(BaseModel):
|
||||
"""The output of a tool invocation."""
|
||||
|
||||
|
|
@ -166,6 +179,8 @@ class ToolCallOutput(BaseModel):
|
|||
"""The value returned by the tool."""
|
||||
error: ToolCallError | None = None
|
||||
"""The error that occurred during the tool invocation."""
|
||||
requires_authorization: ToolCallRequiresAuthorization | None = None
|
||||
"""The authorization requirements for the tool invocation."""
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
|
|
|
|||
|
|
@ -17,9 +17,9 @@ from arcade.core.schema import ToolDefinition
|
|||
|
||||
AUTH_RESPONSE_DATA = {
|
||||
"auth_id": "auth_123",
|
||||
"auth_url": "https://example.com/auth",
|
||||
"authorization_url": "https://example.com/auth",
|
||||
"status": "pending",
|
||||
"authorizationID": "auth_123",
|
||||
"authorization_id": "auth_123",
|
||||
"scopes": ["https://www.googleapis.com/auth/gmail.readonly"],
|
||||
}
|
||||
|
||||
|
|
@ -46,8 +46,8 @@ TOOL_DEFINITION_DATA = {
|
|||
}
|
||||
|
||||
TOOL_AUTHORIZE_RESPONSE_DATA = {
|
||||
"authorizationID": "auth_456",
|
||||
"authorizationURL": "https://example.com/auth",
|
||||
"authorization_id": "auth_456",
|
||||
"authorization_url": "https://example.com/auth",
|
||||
"scopes": ["scope1", "scope2"],
|
||||
"status": "pending",
|
||||
}
|
||||
|
|
|
|||
|
|
@ -81,26 +81,28 @@
|
|||
"requires_authorization": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"message": {
|
||||
"authorization_url": {
|
||||
"type": "string",
|
||||
"description": "A message that can be shown to the user or AI model that explains the authorization requirement"
|
||||
"format": "uri",
|
||||
"description": "The URL to redirect the user to for authorization"
|
||||
},
|
||||
"oauth2": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"format": "uri"
|
||||
},
|
||||
"scope": {
|
||||
"type": "string"
|
||||
}
|
||||
"authorization_id": {
|
||||
"type": "string",
|
||||
"description": "The ID for checking the status of the authorization"
|
||||
},
|
||||
"scopes": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
},
|
||||
"required": ["url"],
|
||||
"additionalProperties": false
|
||||
"description": "The scopes that are required for authorization"
|
||||
},
|
||||
"status": {
|
||||
"type": "string",
|
||||
"description": "The status of the authorization"
|
||||
}
|
||||
},
|
||||
"required": ["message"],
|
||||
"required": ["status"],
|
||||
"additionalProperties": false
|
||||
}
|
||||
},
|
||||
|
|
|
|||
Loading…
Reference in a new issue