From 6d854b01101bd1e7c71080f12f361a2c8bb515d9 Mon Sep 17 00:00:00 2001 From: Nate Barbettini Date: Thu, 19 Sep 2024 09:51:28 -0700 Subject: [PATCH] Update requires_authorization and other naming (#39) Corresponds to Engine PR: https://github.com/ArcadeAI/Engine/pull/73 --- arcade/arcade/cli/main.py | 2 ++ arcade/arcade/client/client.py | 34 +++++++++++++++---- arcade/arcade/client/schema.py | 4 +-- arcade/arcade/core/schema.py | 15 ++++++++ arcade/tests/client/test_client.py | 8 ++--- .../preview/invoke_tool_response.schema.jsonc | 32 +++++++++-------- 6 files changed, 67 insertions(+), 28 deletions(-) diff --git a/arcade/arcade/cli/main.py b/arcade/arcade/cli/main.py index 317d0177..f9358aa1 100644 --- a/arcade/arcade/cli/main.py +++ b/arcade/arcade/cli/main.py @@ -201,6 +201,8 @@ def chat( "bold blue", ), ) + if stream: + chat_header.append(" (streaming)") console.print(chat_header) # Try to hit /health endpoint on engine and warn if it is down diff --git a/arcade/arcade/client/client.py b/arcade/arcade/client/client.py index 110e574f..346a3253 100644 --- a/arcade/arcade/client/client.py +++ b/arcade/arcade/client/client.py @@ -67,7 +67,8 @@ class AuthResource(BaseResource[ClientT]): def status( self, auth_id_or_response: Union[str, AuthResponse], scopes: list[str] | None = None ) -> AuthResponse: - """Poll for the status of an authorization + """ + Poll for the status of an authorization Polls using either the authorization ID or the data returned from the authorize method. @@ -85,7 +86,7 @@ class AuthResource(BaseResource[ClientT]): data = self._client._execute_request( # type: ignore[attr-defined] "GET", f"{self._base_path}/status", - params={"authorizationID": auth_id, "scopes": " ".join(scopes) if scopes else None}, + params={"authorizationId": auth_id, "scopes": " ".join(scopes) if scopes else None}, ) return AuthResponse(**data) @@ -129,7 +130,7 @@ class ToolResource(BaseResource[ClientT]): data = self._client._execute_request( # type: ignore[attr-defined] "GET", f"{self._base_path}/definition", - params={"director_id": director_id, "tool_id": tool_id}, + params={"directorId": director_id, "toolId": tool_id}, ) return ToolDefinition(**data) @@ -214,10 +215,29 @@ class AsyncAuthResource(BaseResource[AsyncArcadeClient]): ) return AuthResponse(**data) - async def status(self, auth_id: str) -> AuthResponse: - """Poll for the status of an authorization asynchronously""" + async def status( + self, auth_id_or_response: Union[str, AuthResponse], scopes: list[str] | None = None + ) -> AuthResponse: + """ + Poll for the status of an authorization asynchronously + + Polls using either the authorization ID or the data returned from the authorize method. + + Example: + auth_response = await client.auth.authorize(...) + auth_status = await client.auth.poll_authorization(auth_response) + auth_status = await client.auth.poll_authorization("auth_123", ["scope1", "scope2"]) + """ + if isinstance(auth_id_or_response, AuthResponse): + auth_id = auth_id_or_response.auth_id + scopes = auth_id_or_response.scopes + else: + auth_id = auth_id_or_response + data = await self._client._execute_request( # type: ignore[attr-defined] - "GET", f"{self._base_path}/status", params={"authorizationID": auth_id} + "GET", + f"{self._base_path}/status", + params={"authorizationId": auth_id, "scopes": " ".join(scopes) if scopes else None}, ) return AuthResponse(**data) @@ -255,7 +275,7 @@ class AsyncToolResource(BaseResource[AsyncArcadeClient]): data = await self._client._execute_request( # type: ignore[attr-defined] "GET", f"{self._base_path}/definition", - params={"director_id": director_id, "tool_id": tool_id}, + params={"directorId": director_id, "toolId": tool_id}, ) return ToolDefinition(**data) diff --git a/arcade/arcade/client/schema.py b/arcade/arcade/client/schema.py index 0954542f..b8a48091 100644 --- a/arcade/arcade/client/schema.py +++ b/arcade/arcade/client/schema.py @@ -52,14 +52,14 @@ class HealthCheckResponse(BaseModel): class AuthResponse(BaseModel): """Response from an authorization request.""" - auth_id: str = Field(alias="authorizationID") + auth_id: str = Field(alias="authorization_id") """The ID of the authorization request""" scopes: list[str] """The scope(s) requested in the authorization request""" # TODO: Use AnyUrl? - auth_url: str | None = Field(None, alias="authorizationURL") + auth_url: str | None = Field(None, alias="authorization_url") """The URL for the authorization""" status: AuthStatus diff --git a/arcade/arcade/core/schema.py b/arcade/arcade/core/schema.py index c6707907..1847319b 100644 --- a/arcade/arcade/core/schema.py +++ b/arcade/arcade/core/schema.py @@ -159,6 +159,19 @@ class ToolCallError(BaseModel): """The number of milliseconds (if any) to wait before retrying the tool call.""" +class ToolCallRequiresAuthorization(BaseModel): + """The authorization requirements for the tool invocation.""" + + authorization_url: str | None = None + """The URL to redirect the user to for authorization.""" + authorization_id: str | None = None + """The ID for checking the status of the authorization.""" + scopes: list[str] | None = None + """The scopes that are required for authorization.""" + status: str | None = None + """The status of the authorization.""" + + class ToolCallOutput(BaseModel): """The output of a tool invocation.""" @@ -166,6 +179,8 @@ class ToolCallOutput(BaseModel): """The value returned by the tool.""" error: ToolCallError | None = None """The error that occurred during the tool invocation.""" + requires_authorization: ToolCallRequiresAuthorization | None = None + """The authorization requirements for the tool invocation.""" model_config = { "json_schema_extra": { diff --git a/arcade/tests/client/test_client.py b/arcade/tests/client/test_client.py index 4db7d258..0d34835f 100644 --- a/arcade/tests/client/test_client.py +++ b/arcade/tests/client/test_client.py @@ -17,9 +17,9 @@ from arcade.core.schema import ToolDefinition AUTH_RESPONSE_DATA = { "auth_id": "auth_123", - "auth_url": "https://example.com/auth", + "authorization_url": "https://example.com/auth", "status": "pending", - "authorizationID": "auth_123", + "authorization_id": "auth_123", "scopes": ["https://www.googleapis.com/auth/gmail.readonly"], } @@ -46,8 +46,8 @@ TOOL_DEFINITION_DATA = { } TOOL_AUTHORIZE_RESPONSE_DATA = { - "authorizationID": "auth_456", - "authorizationURL": "https://example.com/auth", + "authorization_id": "auth_456", + "authorization_url": "https://example.com/auth", "scopes": ["scope1", "scope2"], "status": "pending", } diff --git a/schemas/preview/invoke_tool_response.schema.jsonc b/schemas/preview/invoke_tool_response.schema.jsonc index 46c0bba5..e5098a8a 100644 --- a/schemas/preview/invoke_tool_response.schema.jsonc +++ b/schemas/preview/invoke_tool_response.schema.jsonc @@ -81,26 +81,28 @@ "requires_authorization": { "type": "object", "properties": { - "message": { + "authorization_url": { "type": "string", - "description": "A message that can be shown to the user or AI model that explains the authorization requirement" + "format": "uri", + "description": "The URL to redirect the user to for authorization" }, - "oauth2": { - "type": "object", - "properties": { - "url": { - "type": "string", - "format": "uri" - }, - "scope": { - "type": "string" - } + "authorization_id": { + "type": "string", + "description": "The ID for checking the status of the authorization" + }, + "scopes": { + "type": "array", + "items": { + "type": "string" }, - "required": ["url"], - "additionalProperties": false + "description": "The scopes that are required for authorization" + }, + "status": { + "type": "string", + "description": "The status of the authorization" } }, - "required": ["message"], + "required": ["status"], "additionalProperties": false } },