From 14998a43e34fe8bab35adf0a465e1c3accce5016 Mon Sep 17 00:00:00 2001 From: Nate Barbettini Date: Fri, 2 Aug 2024 11:25:08 -0700 Subject: [PATCH] Add ToolContext and OAuth tool support (#10) - Adds initial `ToolContext` to tool invocations - This unlocks the ability to call authenticated tools (e.g. Gmail), which works in this branch against Nate's dev engine --- arcade/arcade/actor/core/base.py | 21 +- arcade/arcade/actor/schema.py | 45 ----- arcade/arcade/cli/main.py | 1 + arcade/arcade/cli/new.py | 2 +- arcade/arcade/core/catalog.py | 33 +++- arcade/arcade/core/executor.py | 15 +- arcade/arcade/core/response.py | 6 +- arcade/arcade/core/schema.py | 181 ++++++++++++++++++ arcade/arcade/core/tool.py | 89 --------- arcade/arcade/sdk/__init__.py | 5 + arcade/arcade/sdk/auth.py | 20 ++ arcade/arcade/sdk/tool.py | 16 +- arcade/tests/sdk/test_tool_decorator.py | 10 +- .../tests/tool/test_create_tool_definition.py | 44 +++-- .../test_create_tool_definition_errors.py | 13 +- .../test_create_tool_definition_pydantic.py | 4 +- ..._create_tool_definition_pydantic_errors.py | 2 +- examples/gmail/arcade_gmail/tools/gdrive.py | 65 +++++++ examples/gmail/arcade_gmail/tools/gmail.py | 108 ++--------- .../arcade_arithmetic/tools/arithmetic.py | 2 +- .../arcade_websearch/tools/google.py | 14 +- .../preview/invoke_tool_request.schema.jsonc | 2 +- schemas/preview/tool_definition.schema.jsonc | 5 +- 23 files changed, 421 insertions(+), 282 deletions(-) delete mode 100644 arcade/arcade/actor/schema.py create mode 100644 arcade/arcade/core/schema.py delete mode 100644 arcade/arcade/core/tool.py create mode 100644 arcade/arcade/sdk/auth.py create mode 100644 examples/gmail/arcade_gmail/tools/gdrive.py diff --git a/arcade/arcade/actor/core/base.py b/arcade/arcade/actor/core/base.py index 7750e821..0f18226d 100644 --- a/arcade/arcade/actor/core/base.py +++ b/arcade/arcade/actor/core/base.py @@ -3,15 +3,16 @@ from abc import ABC, abstractmethod from datetime import datetime from typing import Any, Callable -from arcade.actor.schema import ( - InvokeToolRequest, - InvokeToolResponse, - ToolOutput, - ToolOutputError, -) from arcade.core.catalog import ToolCatalog, Toolkit from arcade.core.executor import ToolExecutor -from arcade.core.tool import ToolDefinition +from arcade.core.schema import ( + InvokeToolError, + InvokeToolOutput, + InvokeToolRequest, + InvokeToolResponse, + ToolContext, + ToolDefinition, +) class ActorComponent(ABC): @@ -72,15 +73,17 @@ class BaseActor: response = await ToolExecutor.run( func=materialized_tool.tool, + definition=materialized_tool.definition, input_model=materialized_tool.input_model, output_model=materialized_tool.output_model, + context=tool_request.context or ToolContext(), **tool_request.inputs or {}, ) if response.code == 200: # TODO remove ignore - output = ToolOutput(value=response.data.result) # type: ignore[union-attr] + output = InvokeToolOutput(value=response.data.result) # type: ignore[union-attr] else: - output = ToolOutput(error=ToolOutputError(message=response.msg)) + output = InvokeToolOutput(error=InvokeToolError(message=response.msg)) end_time = time.time() # End time in seconds duration_ms = (end_time - start_time) * 1000 # Convert to milliseconds diff --git a/arcade/arcade/actor/schema.py b/arcade/arcade/actor/schema.py deleted file mode 100644 index cd1bff17..00000000 --- a/arcade/arcade/actor/schema.py +++ /dev/null @@ -1,45 +0,0 @@ -from typing import ClassVar, Union - -from pydantic import BaseModel - - -class ToolVersion(BaseModel): - name: str - version: str - - -class InvokeToolRequest(BaseModel): - run_id: str - invocation_id: str - created_at: str - tool: ToolVersion - inputs: dict | None - context: dict | None - - -class ToolOutputError(BaseModel): - message: str - developer_message: str | None = None - - -class ToolOutput(BaseModel): - value: Union[str, int, float, bool, dict] | None = None - error: ToolOutputError | None = None - - class Config: - json_schema_extra: ClassVar[dict] = { - "oneOf": [ - {"required": ["value"]}, - {"required": ["error"]}, - {"required": ["requires_authorization"]}, - {"required": ["artifact"]}, - ] - } - - -class InvokeToolResponse(BaseModel): - invocation_id: str - finished_at: str - duration: float - success: bool - output: ToolOutput | None = None diff --git a/arcade/arcade/cli/main.py b/arcade/arcade/cli/main.py index a6b7075b..5db1a2b1 100644 --- a/arcade/arcade/cli/main.py +++ b/arcade/arcade/cli/main.py @@ -153,6 +153,7 @@ def run( # noqa: C901 output = asyncio.run( ToolExecutor.run( called_tool.tool, + called_tool.definition, called_tool.input_model, called_tool.output_model, **parameters, diff --git a/arcade/arcade/cli/new.py b/arcade/arcade/cli/new.py index c1faa73f..9b9a7163 100644 --- a/arcade/arcade/cli/new.py +++ b/arcade/arcade/cli/new.py @@ -119,7 +119,7 @@ def create_new_toolkit(directory: str) -> None: os.path.join(toolkit_dir, "tools", "hello.py"), dedent( f""" - from arcade.sdk.tool import tool + from arcade.sdk import tool @tool def hello() -> str: diff --git a/arcade/arcade/core/catalog.py b/arcade/arcade/core/catalog.py index 3343bf8c..9431da11 100644 --- a/arcade/arcade/core/catalog.py +++ b/arcade/arcade/core/catalog.py @@ -22,8 +22,11 @@ from pydantic.fields import FieldInfo from pydantic_core import PydanticUndefined from arcade.core.errors import ToolDefinitionError -from arcade.core.tool import ( +from arcade.core.schema import ( InputParameter, + OAuth2Requirement, + ToolAuthRequirement, + ToolContext, ToolDefinition, ToolInputs, ToolOutput, @@ -38,6 +41,7 @@ from arcade.core.utils import ( snake_to_pascal_case, ) from arcade.sdk.annotations import Inferrable +from arcade.sdk.auth import OAuth2 WireType = Literal["string", "integer", "float", "boolean", "json"] @@ -176,6 +180,12 @@ class ToolCatalog(BaseModel): if does_function_return_value(tool) and tool.__annotations__.get("return") is None: raise ToolDefinitionError(f"Tool {tool_name} must have a return type annotation") + auth_requirement = getattr(tool, "__tool_requires_auth__", None) + if isinstance(auth_requirement, OAuth2): + auth_requirement = ToolAuthRequirement( + oauth2=OAuth2Requirement(**auth_requirement.model_dump()) + ) + return ToolDefinition( name=snake_to_pascal_case(tool_name), description=tool_description, @@ -183,7 +193,7 @@ class ToolCatalog(BaseModel): inputs=create_input_definition(tool), output=create_output_definition(tool), requirements=ToolRequirements( - authorization=getattr(tool, "__tool_requires_auth__", None), + auth=auth_requirement, ), ) @@ -193,7 +203,18 @@ def create_input_definition(func: Callable) -> ToolInputs: Create an input model for a function based on its parameters. """ input_parameters = [] + tool_context_param_name: str | None = None + for _, param in inspect.signature(func, follow_wrapped=True).parameters.items(): + if param.annotation is ToolContext: + if tool_context_param_name is not None: + raise ToolDefinitionError( + f"Only one ToolContext parameter is supported, but tool {func.__name__} has multiple." + ) + + tool_context_param_name = param.name + continue # No further processing of this param (don't add it to the list of inputs) + tool_field_info = extract_field_info(param) is_enum = False @@ -222,7 +243,9 @@ def create_input_definition(func: Callable) -> ToolInputs: ) ) - return ToolInputs(parameters=input_parameters) + return ToolInputs( + parameters=input_parameters, tool_context_parameter_name=tool_context_param_name + ) def create_output_definition(func: Callable) -> ToolOutput: @@ -455,6 +478,10 @@ def create_func_models(func: Callable) -> tuple[type[BaseModel], type[BaseModel] if asyncio.iscoroutinefunction(func) and hasattr(func, "__wrapped__"): func = func.__wrapped__ for name, param in inspect.signature(func, follow_wrapped=True).parameters.items(): + # Skip ToolContext parameters + if param.annotation is ToolContext: + continue + # TODO make this cleaner tool_field_info = extract_field_info(param) param_fields = { diff --git a/arcade/arcade/core/executor.py b/arcade/arcade/core/executor.py index 1ab52a81..6f561c8c 100644 --- a/arcade/arcade/core/executor.py +++ b/arcade/arcade/core/executor.py @@ -10,14 +10,17 @@ from arcade.core.errors import ( ToolSerializationError, ) from arcade.core.response import ToolResponse, tool_response +from arcade.core.schema import ToolContext, ToolDefinition class ToolExecutor: @staticmethod async def run( func: Callable, + definition: ToolDefinition, input_model: type[BaseModel], output_model: type[BaseModel], + context: ToolContext, *args: Any, **kwargs: Any, ) -> ToolResponse: @@ -28,11 +31,18 @@ class ToolExecutor: # serialize the input model inputs = await ToolExecutor._serialize_input(input_model, **kwargs) + # prepare the arguments for the function call + func_args = inputs.model_dump() + + # inject ToolContext, if the target function supports it + if definition.inputs.tool_context_parameter_name is not None: + func_args[definition.inputs.tool_context_parameter_name] = context + # execute the tool function if asyncio.iscoroutinefunction(func): - results = await func(**inputs.model_dump()) + results = await func(**func_args) else: - results = func(**inputs.model_dump()) + results = func(**func_args) # serialize the output model output = await ToolExecutor._serialize_output(output_model, results) @@ -73,6 +83,7 @@ class ToolExecutor: Serialize the output of a tool function. """ # TODO how to type this the results object? + # TODO how to ensure `results` contains only safe (serializable) stuff? try: # TODO Logging and telemetry diff --git a/arcade/arcade/core/response.py b/arcade/arcade/core/response.py index b415b446..df59cdb7 100644 --- a/arcade/arcade/core/response.py +++ b/arcade/arcade/core/response.py @@ -61,7 +61,11 @@ class ToolResponseFactory: msg: str = CustomResponseCode.HTTP_400.msg, data: Any = None, ) -> ToolResponse: - return await self.__response(res=res, data=data) + return await self.__response( + res=res, + msg=msg, # TODO this needs to map to developer_message in output.error + data=data, + ) tool_response = ToolResponseFactory() diff --git a/arcade/arcade/core/schema.py b/arcade/arcade/core/schema.py new file mode 100644 index 00000000..a597a291 --- /dev/null +++ b/arcade/arcade/core/schema.py @@ -0,0 +1,181 @@ +from typing import Any, Literal, Optional, Union + +from pydantic import AnyUrl, BaseModel, Field + + +class ValueSchema(BaseModel): + """Value schema for input parameters and outputs.""" + + val_type: Literal["string", "integer", "float", "boolean", "json"] + """The type of the value.""" + + enum: Optional[list[str]] = None + """The list of possible values for the value, if it is a closed list.""" + + +class InputParameter(BaseModel): + """A parameter that can be passed to a tool.""" + + name: str = Field(..., description="The human-readable name of this parameter.") + required: bool = Field( + ..., + description="Whether this parameter is required (true) or optional (false).", + ) + description: Optional[str] = Field( + None, description="A descriptive, human-readable explanation of the parameter." + ) + value_schema: ValueSchema = Field( + ..., + description="The schema of the value of this parameter.", + ) + inferrable: bool = Field( + True, + description="Whether a value for this parameter can be inferred by a model. Defaults to `true`.", + ) + + +class ToolInputs(BaseModel): + """The inputs that a tool accepts.""" + + parameters: list[InputParameter] + """The list of parameters that the tool accepts.""" + + tool_context_parameter_name: str | None = Field(default=None, exclude=True) + """ + The name of the target parameter that will contain the tool context (if any). + This field will not be included in serialization. + """ + + +class ToolOutput(BaseModel): + """The output of a tool.""" + + description: Optional[str] = Field( + None, description="A descriptive, human-readable explanation of the output." + ) + available_modes: list[str] = Field( + default_factory=lambda: ["value", "error", "null"], + description="The available modes for the output.", + ) + value_schema: Optional[ValueSchema] = Field( + None, description="The schema of the value of the output." + ) + + +class OAuth2Requirement(BaseModel): + """Indicates that the tool requires OAuth 2.0 authorization.""" + + authority: AnyUrl + """The URL of the OAuth 2.0 authorization server.""" + + scope: Optional[list[str]] = None + """The scope(s) needed for authorization.""" + + +class ToolAuthRequirement(BaseModel): + """A requirement for authorization to use a tool.""" + + oauth2: Optional[OAuth2Requirement] = None + + +class ToolRequirements(BaseModel): + """The requirements for a tool to run.""" + + auth: Union[ToolAuthRequirement, None] = None + """The authorization requirements for the tool, if any.""" + + +class ToolDefinition(BaseModel): + """The specification of a tool.""" + + name: str + description: str + version: str + inputs: ToolInputs + output: ToolOutput + requirements: ToolRequirements + + +class ToolVersion(BaseModel): + """The name and version of a tool.""" + + name: str + """The name of the tool.""" + + version: str + """The version of the tool.""" + + +class ToolAuthorizationContext(BaseModel): + """The context for a tool invocation that requires authorization.""" + + token: str | None = None + """The token for the tool invocation.""" + + +class ToolContext(BaseModel): + """The context for a tool invocation.""" + + authorization: ToolAuthorizationContext | None = None + """The authorization context for the tool invocation that requires authorization.""" + + +class InvokeToolRequest(BaseModel): + """The request to invoke a tool.""" + + run_id: str + """The globally-unique run ID provided by the Engine.""" + invocation_id: str + """The globally-unique ID for this tool invocation in the run.""" + created_at: str + """The timestamp when the tool invocation was created.""" + tool: ToolVersion + """The name and version of the tool.""" + inputs: dict[str, Any] | None = None + """The inputs for the tool.""" + context: ToolContext + """The context for the tool invocation.""" + + +class InvokeToolError(BaseModel): + """The error that occurred during the tool invocation.""" + + message: str + """The user-facing error message.""" + developer_message: str | None = None + """The developer-facing error details.""" + + +class InvokeToolOutput(BaseModel): + """The output of a tool invocation.""" + + value: Union[str, int, float, bool, dict] | None = None + """The value returned by the tool.""" + error: InvokeToolError | None = None + """The error that occurred during the tool invocation.""" + + model_config = { + "json_schema_extra": { + "oneOf": [ + {"required": ["value"]}, + {"required": ["error"]}, + {"required": ["requires_authorization"]}, + {"required": ["artifact"]}, + ] + } + } + + +class InvokeToolResponse(BaseModel): + """The response to a tool invocation.""" + + invocation_id: str + """The globally-unique ID for this tool invocation.""" + finished_at: str + """The timestamp when the tool invocation finished.""" + duration: float + """The duration of the tool invocation in milliseconds (ms).""" + success: bool + """Whether the tool invocation was successful.""" + output: InvokeToolOutput | None = None + """The output of the tool invocation.""" diff --git a/arcade/arcade/core/tool.py b/arcade/arcade/core/tool.py deleted file mode 100644 index ac327520..00000000 --- a/arcade/arcade/core/tool.py +++ /dev/null @@ -1,89 +0,0 @@ -from abc import ABC -from typing import Literal, Optional, Union - -from pydantic import AnyUrl, BaseModel, Field - - -class ValueSchema(BaseModel): - """Value schema for input parameters and outputs.""" - - val_type: Literal["string", "integer", "float", "boolean", "json"] - """The type of the value.""" - - enum: Optional[list[str]] = None - - -class InputParameter(BaseModel): - """A parameter that can be passed to a tool.""" - - name: str = Field(..., description="The human-readable name of this parameter.") - required: bool = Field( - ..., - description="Whether this parameter is required (true) or optional (false).", - ) - description: Optional[str] = Field( - None, description="A descriptive, human-readable explanation of the parameter." - ) - value_schema: ValueSchema = Field( - ..., - description="The schema of the value of this parameter.", - ) - inferrable: bool = Field( - True, - description="Whether a value for this parameter can be inferred by a model. Defaults to `true`.", - ) - - -class ToolInputs(BaseModel): - """The inputs that a tool accepts.""" - - parameters: list[InputParameter] - """The list of parameters that the tool accepts.""" - - -class ToolOutput(BaseModel): - """The output of a tool.""" - - description: Optional[str] = Field( - None, description="A descriptive, human-readable explanation of the output." - ) - available_modes: list[str] = Field( - default_factory=lambda: ["value", "error", "null"], - description="The available modes for the output.", - ) - value_schema: Optional[ValueSchema] = Field( - None, description="The schema of the value of the output." - ) - - -class ToolAuthorizationRequirement(BaseModel, ABC): - """A requirement for authorization to use a tool.""" - - pass - - -class OAuth2AuthorizationRequirement(ToolAuthorizationRequirement): - """Specifies OAuth2 requirement for tool execution.""" - - url: AnyUrl - """The URL to which the user should be redirected to authorize the tool.""" - - scope: Optional[list[str]] = None - """The scope of the authorization.""" - - -class ToolRequirements(BaseModel): - """The requirements for a tool to run.""" - - authorization: Union[ToolAuthorizationRequirement, None] = None - - -class ToolDefinition(BaseModel): - """The specification of a tool.""" - - name: str - description: str - version: str - inputs: ToolInputs - output: ToolOutput - requirements: ToolRequirements diff --git a/arcade/arcade/sdk/__init__.py b/arcade/arcade/sdk/__init__.py index e69de29b..26c0fc27 100644 --- a/arcade/arcade/sdk/__init__.py +++ b/arcade/arcade/sdk/__init__.py @@ -0,0 +1,5 @@ +from .tool import tool + +__all__ = [ + "tool", +] diff --git a/arcade/arcade/sdk/auth.py b/arcade/arcade/sdk/auth.py new file mode 100644 index 00000000..51899057 --- /dev/null +++ b/arcade/arcade/sdk/auth.py @@ -0,0 +1,20 @@ +from abc import ABC +from typing import Optional + +from pydantic import AnyUrl, BaseModel + + +class ToolAuthorization(BaseModel, ABC): + """Marks a tool as requiring authorization.""" + + pass + + +class OAuth2(ToolAuthorization): + """Marks a tool as requiring OAuth 2.0 authorization.""" + + authority: AnyUrl + """The URL of the OAuth 2.0 authorization server.""" + + scope: Optional[list[str]] = None + """The scope(s) needed for the authorized action.""" diff --git a/arcade/arcade/sdk/tool.py b/arcade/arcade/sdk/tool.py index a80d865c..574a7b67 100644 --- a/arcade/arcade/sdk/tool.py +++ b/arcade/arcade/sdk/tool.py @@ -1,9 +1,8 @@ import inspect -import os -from typing import Any, Callable, Optional, TypeVar, Union +from typing import Callable, TypeVar, Union -from arcade.core.tool import ToolAuthorizationRequirement from arcade.core.utils import snake_to_pascal_case +from arcade.sdk.auth import ToolAuthorization T = TypeVar("T") @@ -13,7 +12,7 @@ def tool( func: Callable | None = None, desc: str | None = None, name: str | None = None, - requires_auth: Union[ToolAuthorizationRequirement, None] = None, + requires_auth: Union[ToolAuthorization, None] = None, ) -> Callable: def decorator(func: Callable) -> Callable: func_name = str(getattr(func, "__name__", None)) @@ -28,12 +27,3 @@ def tool( if func: # This means the decorator is used without parameters return decorator(func) return decorator - - -def get_secret(name: str, default: Optional[Any] = None) -> Any: - secret = os.getenv(name) - if secret is None: - if default is not None: - return default - raise ValueError(f"Secret {name} is not set.") - return secret diff --git a/arcade/tests/sdk/test_tool_decorator.py b/arcade/tests/sdk/test_tool_decorator.py index e0e6f676..10ffb2a2 100644 --- a/arcade/tests/sdk/test_tool_decorator.py +++ b/arcade/tests/sdk/test_tool_decorator.py @@ -2,8 +2,8 @@ import asyncio import pytest -from arcade.core.tool import OAuth2AuthorizationRequirement -from arcade.sdk.tool import tool +from arcade.sdk import tool +from arcade.sdk.auth import OAuth2 def test_sync_function(): @@ -38,8 +38,8 @@ def test_tool_decorator_with_all_options(): @tool( name="TestTool", desc="Test description", - requires_auth=OAuth2AuthorizationRequirement( - url="https://example.com/oauth2/auth", + requires_auth=OAuth2( + authority="https://example.com/oauth2/auth", scope=["test_scope", "another.scope"], ), ) @@ -48,5 +48,5 @@ def test_tool_decorator_with_all_options(): assert test_tool.__tool_name__ == "TestTool" assert test_tool.__tool_description__ == "Test description" - assert str(test_tool.__tool_requires_auth__.url) == "https://example.com/oauth2/auth" + assert str(test_tool.__tool_requires_auth__.authority) == "https://example.com/oauth2/auth" assert test_tool.__tool_requires_auth__.scope == ["test_scope", "another.scope"] diff --git a/arcade/tests/tool/test_create_tool_definition.py b/arcade/tests/tool/test_create_tool_definition.py index c15111fc..18441225 100644 --- a/arcade/tests/tool/test_create_tool_definition.py +++ b/arcade/tests/tool/test_create_tool_definition.py @@ -3,16 +3,19 @@ from typing import Annotated, Literal, Optional import pytest from arcade.core.catalog import ToolCatalog -from arcade.core.tool import ( +from arcade.core.schema import ( InputParameter, - OAuth2AuthorizationRequirement, + OAuth2Requirement, + ToolAuthRequirement, + ToolContext, ToolInputs, ToolOutput, ToolRequirements, ValueSchema, ) +from arcade.sdk import tool from arcade.sdk.annotations import Inferrable -from arcade.sdk.tool import tool +from arcade.sdk.auth import OAuth2 ### Tests on @tool decorator @@ -43,9 +46,7 @@ def func_with_name_and_description(): @tool( desc="A function that requires authentication", - requires_auth=OAuth2AuthorizationRequirement( - url="https://example.com/oauth2/auth", scope=["scope1", "scope2"] - ), + requires_auth=OAuth2(authority="https://example.com/oauth2/auth", scope=["scope1", "scope2"]), ) def func_with_auth_requirement(): pass @@ -53,7 +54,7 @@ def func_with_auth_requirement(): ### Tests on input params @tool(desc="A function with an input parameter") -def func_with_param(param1: Annotated[str, "First param"]): +def func_with_param(context: Annotated[str, "First param"]): pass @@ -114,6 +115,7 @@ def func_with_optional_param_with_default_value( @tool(desc="A function with multiple parameters, some with default values") def func_with_mixed_params( + context: ToolContext, param1: Annotated[str, "First param"], param2: Annotated[int, "Second param"] = 42, ): @@ -125,6 +127,11 @@ def func_with_complex_param(param1: Annotated[list[str], "A list of strings"]): pass +@tool(desc="A function that takes a context") +def func_with_context(my_context: ToolContext): + pass + + ### Tests on output/return values @tool(desc="A function that performs an action without returning anything") def func_with_no_return(): @@ -184,15 +191,18 @@ def func_with_complex_return() -> list[dict[str, str]]: ), pytest.param( func_with_name_and_description, - {"name": "MyCustomTool", "requirements": ToolRequirements(authorization=None)}, + {"name": "MyCustomTool", "requirements": ToolRequirements(auth=None)}, id="func_with_no_auth_requirement", ), pytest.param( func_with_auth_requirement, { "requirements": ToolRequirements( - authorization=OAuth2AuthorizationRequirement( - url="https://example.com/oauth2/auth", scope=["scope1", "scope2"] + auth=ToolAuthRequirement( + oauth2=OAuth2Requirement( + authority="https://example.com/oauth2/auth", + scope=["scope1", "scope2"], + ) ) ) }, @@ -212,7 +222,7 @@ def func_with_complex_return() -> list[dict[str, str]]: "inputs": ToolInputs( parameters=[ InputParameter( - name="param1", + name="context", # Nothing special about this name, parameters can be named anything description="First param", inferrable=True, # Defaults to true required=True, @@ -428,7 +438,8 @@ def func_with_complex_return() -> list[dict[str, str]]: required=False, # Because a default value is provided value_schema=ValueSchema(val_type="integer", enum=None), ), - ] + ], + tool_context_parameter_name="context", ), }, id="func_with_mixed_params", @@ -450,6 +461,15 @@ def func_with_complex_return() -> list[dict[str, str]]: }, id="func_with_complex_param", ), + pytest.param( + func_with_context, + { + "inputs": ToolInputs( + parameters=[], tool_context_parameter_name="my_context" + ), # ToolContext type is not an input param, but it's stored in the inputs field + }, + id="func_with_context", + ), # Tests on output values pytest.param( func_with_no_return, diff --git a/arcade/tests/tool/test_create_tool_definition_errors.py b/arcade/tests/tool/test_create_tool_definition_errors.py index 4102ef22..fe1b8ef2 100644 --- a/arcade/tests/tool/test_create_tool_definition_errors.py +++ b/arcade/tests/tool/test_create_tool_definition_errors.py @@ -2,7 +2,8 @@ import pytest from arcade.core.catalog import ToolCatalog from arcade.core.errors import ToolDefinitionError -from arcade.sdk.tool import tool +from arcade.core.schema import ToolContext +from arcade.sdk import tool @tool @@ -30,6 +31,11 @@ def func_with_unsupported_param(param1: complex): pass +@tool(desc="A function with multiple context parameters (illegal)") +def func_with_multiple_context_params(context: ToolContext, context2: ToolContext): + pass + + @pytest.mark.parametrize( "func_under_test, exception_type", [ @@ -58,6 +64,11 @@ def func_with_unsupported_param(param1: complex): ToolDefinitionError, id=func_with_unsupported_param.__name__, ), + pytest.param( + func_with_multiple_context_params, + ToolDefinitionError, + id=func_with_multiple_context_params.__name__, + ), ], ) def test_missing_info_raises_error(func_under_test, exception_type): diff --git a/arcade/tests/tool/test_create_tool_definition_pydantic.py b/arcade/tests/tool/test_create_tool_definition_pydantic.py index 7f0f5f24..d3a1272b 100644 --- a/arcade/tests/tool/test_create_tool_definition_pydantic.py +++ b/arcade/tests/tool/test_create_tool_definition_pydantic.py @@ -4,13 +4,13 @@ import pytest from pydantic import BaseModel, Field from arcade.core.catalog import ToolCatalog -from arcade.core.tool import ( +from arcade.core.schema import ( InputParameter, ToolInputs, ToolOutput, ValueSchema, ) -from arcade.sdk.tool import tool +from arcade.sdk import tool class ProductOutput(BaseModel): diff --git a/arcade/tests/tool/test_create_tool_definition_pydantic_errors.py b/arcade/tests/tool/test_create_tool_definition_pydantic_errors.py index a875d407..e05fb92a 100644 --- a/arcade/tests/tool/test_create_tool_definition_pydantic_errors.py +++ b/arcade/tests/tool/test_create_tool_definition_pydantic_errors.py @@ -5,7 +5,7 @@ from pydantic import Field from arcade.core.catalog import ToolCatalog from arcade.core.errors import ToolDefinitionError -from arcade.sdk.tool import tool +from arcade.sdk import tool @tool diff --git a/examples/gmail/arcade_gmail/tools/gdrive.py b/examples/gmail/arcade_gmail/tools/gdrive.py new file mode 100644 index 00000000..4e7f8490 --- /dev/null +++ b/examples/gmail/arcade_gmail/tools/gdrive.py @@ -0,0 +1,65 @@ +import os + +from google.auth.transport.requests import Request +from google.oauth2.credentials import Credentials +from google_auth_oauthlib.flow import InstalledAppFlow +from google.auth.exceptions import RefreshError +from googleapiclient.discovery import build +from typing import Annotated +from arcade.sdk import tool + +SECRET_FILE = "/Users/spartee/Dropbox/Arcade/gcp/credentials.json" +DRIVE_SCOPES = ["https://www.googleapis.com/auth/drive.metadata.readonly"] + + +@tool +async def list_drive_files( + n_files: Annotated[int, "Number of files to search"] = 5, +) -> list[str]: + """List files from a Google Drive account and return their details.""" + + creds = None + # The file token.json stores the user's access and refresh tokens, and is + # created automatically when the authorization flow completes for the first time. + # TODO: use context.authorization.token like gmail.py + if os.path.exists("token.json"): + creds = Credentials.from_authorized_user_file("token.json") + # If there are no (valid) credentials available, let the user log in. + if not creds or not creds.valid: + if creds and creds.expired and creds.refresh_token: + try: + creds.refresh(Request()) + except RefreshError: + flow = InstalledAppFlow.from_client_secrets_file( + SECRET_FILE, DRIVE_SCOPES + ) + creds = flow.run_local_server(port=0) + # Save the credentials for the next run + with open("token.json", "w") as token: + token.write(creds.to_json()) + else: + flow = InstalledAppFlow.from_client_secrets_file(SECRET_FILE, DRIVE_SCOPES) + creds = flow.run_local_server(port=0) + # Save the credentials for the next run + with open("token.json", "w") as token: + token.write(creds.to_json()) + + # Call the Drive v3 API + service = build("drive", "v3", credentials=creds) + + # Request a list of all the files + results = ( + service.files() + .list(pageSize=n_files, fields="nextPageToken, files(id, name)") + .execute() + ) + items = results.get("files", []) + + if not items: + print("No files found.") + else: + print("Files:") + for item in items: + print("{0} ({1})".format(item["name"], item["id"])) + + return items diff --git a/examples/gmail/arcade_gmail/tools/gmail.py b/examples/gmail/arcade_gmail/tools/gmail.py index da0e583f..78f14659 100644 --- a/examples/gmail/arcade_gmail/tools/gmail.py +++ b/examples/gmail/arcade_gmail/tools/gmail.py @@ -1,52 +1,29 @@ -import os import re from base64 import urlsafe_b64decode +from arcade.core.schema import ToolContext +from arcade.sdk.auth import OAuth2 from bs4 import BeautifulSoup -from google.auth.transport.requests import Request from google.oauth2.credentials import Credentials -from google_auth_oauthlib.flow import InstalledAppFlow -from google.auth.exceptions import RefreshError from googleapiclient.discovery import build -from typing import Dict, List, Annotated -from arcade.sdk.tool import tool +from typing import Annotated +from arcade.sdk import tool -SCOPES = ["https://www.googleapis.com/auth/gmail.readonly"] -SECRET_FILE = "/Users/spartee/Dropbox/Arcade/gcp/credentials.json" - - -@tool -async def oauth_read_email( +@tool( + requires_auth=OAuth2( + authority="https://accounts.google.com", + scope=["https://www.googleapis.com/auth/gmail.readonly"], + ) +) +async def get_emails( + context: ToolContext, n_emails: Annotated[int, "Number of emails to read"] = 5, -) -> List[Dict[str, str]]: +) -> dict: """Read emails from a Gmail account and extract plain text content, removing any HTML.""" - creds = None - # The file token.json stores the user's access and refresh tokens, and is - # created automatically when the authorization flow completes for the first time. - if os.path.exists("token.json"): - creds = Credentials.from_authorized_user_file("token.json") - # If there are no (valid) credentials available, let the user log in. - if not creds or not creds.valid: - if creds and creds.expired and creds.refresh_token: - try: - creds.refresh(Request()) - except RefreshError: - flow = InstalledAppFlow.from_client_secrets_file(SECRET_FILE, SCOPES) - creds = flow.run_local_server(port=0) - # Save the credentials for the next run - with open("token.json", "w") as token: - token.write(creds.to_json()) - else: - flow = InstalledAppFlow.from_client_secrets_file(SECRET_FILE, SCOPES) - creds = flow.run_local_server(port=0) - # Save the credentials for the next run - with open("token.json", "w") as token: - token.write(creds.to_json()) - # Call the Gmail API - service = build("gmail", "v1", credentials=creds) + service = build("gmail", "v1", credentials=Credentials(context.authorization.token)) # Request a list of all the messages result = service.users().messages().list(userId="me").execute() @@ -97,7 +74,7 @@ async def oauth_read_email( print(f"Error reading email {msg['id']}: {e}", "ERROR") continue - return emails + return {"emails": emails} def clean_email_body(body: str) -> str: @@ -114,58 +91,3 @@ def clean_email_body(body: str) -> str: text = " ".join(text.split()) # Remove extra whitespace return text - - -DRIVE_SCOPES = ["https://www.googleapis.com/auth/drive.metadata.readonly"] - - -@tool -async def list_drive_files( - n_files: Annotated[int, "Number of files to search"] = 5, -) -> list[str]: - """List files from a Google Drive account and return their details.""" - - creds = None - # The file token.json stores the user's access and refresh tokens, and is - # created automatically when the authorization flow completes for the first time. - if os.path.exists("token.json"): - creds = Credentials.from_authorized_user_file("token.json") - # If there are no (valid) credentials available, let the user log in. - if not creds or not creds.valid: - if creds and creds.expired and creds.refresh_token: - try: - creds.refresh(Request()) - except RefreshError: - flow = InstalledAppFlow.from_client_secrets_file( - SECRET_FILE, DRIVE_SCOPES - ) - creds = flow.run_local_server(port=0) - # Save the credentials for the next run - with open("token.json", "w") as token: - token.write(creds.to_json()) - else: - flow = InstalledAppFlow.from_client_secrets_file(SECRET_FILE, DRIVE_SCOPES) - creds = flow.run_local_server(port=0) - # Save the credentials for the next run - with open("token.json", "w") as token: - token.write(creds.to_json()) - - # Call the Drive v3 API - service = build("drive", "v3", credentials=creds) - - # Request a list of all the files - results = ( - service.files() - .list(pageSize=n_files, fields="nextPageToken, files(id, name)") - .execute() - ) - items = results.get("files", []) - - if not items: - print("No files found.") - else: - print("Files:") - for item in items: - print("{0} ({1})".format(item["name"], item["id"])) - - return items diff --git a/examples/math/arcade_arithmetic/tools/arithmetic.py b/examples/math/arcade_arithmetic/tools/arithmetic.py index 7f8433b3..dde95fad 100644 --- a/examples/math/arcade_arithmetic/tools/arithmetic.py +++ b/examples/math/arcade_arithmetic/tools/arithmetic.py @@ -1,7 +1,7 @@ import math from typing import Annotated -from arcade.sdk.tool import tool +from arcade.sdk import tool @tool diff --git a/examples/websearch/arcade_websearch/tools/google.py b/examples/websearch/arcade_websearch/tools/google.py index 82c195ca..302fccd3 100644 --- a/examples/websearch/arcade_websearch/tools/google.py +++ b/examples/websearch/arcade_websearch/tools/google.py @@ -1,7 +1,8 @@ import json +import os import serpapi -from typing import Annotated -from arcade.sdk.tool import tool, get_secret +from typing import Annotated, Any, Optional +from arcade.sdk import tool @tool @@ -23,3 +24,12 @@ async def search_google( organic_results = results.get("organic_results", []) return json.dumps(organic_results[:n_results]) + + +def get_secret(name: str, default: Optional[Any] = None) -> Any: + secret = os.getenv(name) + if secret is None: + if default is not None: + return default + raise ValueError(f"Secret {name} is not set.") + return secret diff --git a/schemas/preview/invoke_tool_request.schema.jsonc b/schemas/preview/invoke_tool_request.schema.jsonc index 0ad76197..03f0e227 100644 --- a/schemas/preview/invoke_tool_request.schema.jsonc +++ b/schemas/preview/invoke_tool_request.schema.jsonc @@ -68,6 +68,6 @@ } } }, - "required": ["run_id", "invocation_id", "created_at", "tool", "input", "context"], + "required": ["run_id", "invocation_id", "created_at", "tool", "inputs", "context"], "additionalProperties": false } diff --git a/schemas/preview/tool_definition.schema.jsonc b/schemas/preview/tool_definition.schema.jsonc index 57c4b197..79d51ace 100644 --- a/schemas/preview/tool_definition.schema.jsonc +++ b/schemas/preview/tool_definition.schema.jsonc @@ -129,7 +129,10 @@ "format": "uri" }, "scope": { - "type": "string" + "type": "array", + "items": { + "type": "string" + } } }, "required": ["url"],