Add ToolContext and OAuth tool support (#10)
- Adds initial `ToolContext` to tool invocations - This unlocks the ability to call authenticated tools (e.g. Gmail), which works in this branch against Nate's dev engine
This commit is contained in:
parent
1b67cee667
commit
14998a43e3
23 changed files with 421 additions and 282 deletions
|
|
@ -3,15 +3,16 @@ from abc import ABC, abstractmethod
|
|||
from datetime import datetime
|
||||
from typing import Any, Callable
|
||||
|
||||
from arcade.actor.schema import (
|
||||
InvokeToolRequest,
|
||||
InvokeToolResponse,
|
||||
ToolOutput,
|
||||
ToolOutputError,
|
||||
)
|
||||
from arcade.core.catalog import ToolCatalog, Toolkit
|
||||
from arcade.core.executor import ToolExecutor
|
||||
from arcade.core.tool import ToolDefinition
|
||||
from arcade.core.schema import (
|
||||
InvokeToolError,
|
||||
InvokeToolOutput,
|
||||
InvokeToolRequest,
|
||||
InvokeToolResponse,
|
||||
ToolContext,
|
||||
ToolDefinition,
|
||||
)
|
||||
|
||||
|
||||
class ActorComponent(ABC):
|
||||
|
|
@ -72,15 +73,17 @@ class BaseActor:
|
|||
|
||||
response = await ToolExecutor.run(
|
||||
func=materialized_tool.tool,
|
||||
definition=materialized_tool.definition,
|
||||
input_model=materialized_tool.input_model,
|
||||
output_model=materialized_tool.output_model,
|
||||
context=tool_request.context or ToolContext(),
|
||||
**tool_request.inputs or {},
|
||||
)
|
||||
if response.code == 200:
|
||||
# TODO remove ignore
|
||||
output = ToolOutput(value=response.data.result) # type: ignore[union-attr]
|
||||
output = InvokeToolOutput(value=response.data.result) # type: ignore[union-attr]
|
||||
else:
|
||||
output = ToolOutput(error=ToolOutputError(message=response.msg))
|
||||
output = InvokeToolOutput(error=InvokeToolError(message=response.msg))
|
||||
|
||||
end_time = time.time() # End time in seconds
|
||||
duration_ms = (end_time - start_time) * 1000 # Convert to milliseconds
|
||||
|
|
|
|||
|
|
@ -1,45 +0,0 @@
|
|||
from typing import ClassVar, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ToolVersion(BaseModel):
|
||||
name: str
|
||||
version: str
|
||||
|
||||
|
||||
class InvokeToolRequest(BaseModel):
|
||||
run_id: str
|
||||
invocation_id: str
|
||||
created_at: str
|
||||
tool: ToolVersion
|
||||
inputs: dict | None
|
||||
context: dict | None
|
||||
|
||||
|
||||
class ToolOutputError(BaseModel):
|
||||
message: str
|
||||
developer_message: str | None = None
|
||||
|
||||
|
||||
class ToolOutput(BaseModel):
|
||||
value: Union[str, int, float, bool, dict] | None = None
|
||||
error: ToolOutputError | None = None
|
||||
|
||||
class Config:
|
||||
json_schema_extra: ClassVar[dict] = {
|
||||
"oneOf": [
|
||||
{"required": ["value"]},
|
||||
{"required": ["error"]},
|
||||
{"required": ["requires_authorization"]},
|
||||
{"required": ["artifact"]},
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
class InvokeToolResponse(BaseModel):
|
||||
invocation_id: str
|
||||
finished_at: str
|
||||
duration: float
|
||||
success: bool
|
||||
output: ToolOutput | None = None
|
||||
|
|
@ -153,6 +153,7 @@ def run( # noqa: C901
|
|||
output = asyncio.run(
|
||||
ToolExecutor.run(
|
||||
called_tool.tool,
|
||||
called_tool.definition,
|
||||
called_tool.input_model,
|
||||
called_tool.output_model,
|
||||
**parameters,
|
||||
|
|
|
|||
|
|
@ -119,7 +119,7 @@ def create_new_toolkit(directory: str) -> None:
|
|||
os.path.join(toolkit_dir, "tools", "hello.py"),
|
||||
dedent(
|
||||
f"""
|
||||
from arcade.sdk.tool import tool
|
||||
from arcade.sdk import tool
|
||||
|
||||
@tool
|
||||
def hello() -> str:
|
||||
|
|
|
|||
|
|
@ -22,8 +22,11 @@ from pydantic.fields import FieldInfo
|
|||
from pydantic_core import PydanticUndefined
|
||||
|
||||
from arcade.core.errors import ToolDefinitionError
|
||||
from arcade.core.tool import (
|
||||
from arcade.core.schema import (
|
||||
InputParameter,
|
||||
OAuth2Requirement,
|
||||
ToolAuthRequirement,
|
||||
ToolContext,
|
||||
ToolDefinition,
|
||||
ToolInputs,
|
||||
ToolOutput,
|
||||
|
|
@ -38,6 +41,7 @@ from arcade.core.utils import (
|
|||
snake_to_pascal_case,
|
||||
)
|
||||
from arcade.sdk.annotations import Inferrable
|
||||
from arcade.sdk.auth import OAuth2
|
||||
|
||||
WireType = Literal["string", "integer", "float", "boolean", "json"]
|
||||
|
||||
|
|
@ -176,6 +180,12 @@ class ToolCatalog(BaseModel):
|
|||
if does_function_return_value(tool) and tool.__annotations__.get("return") is None:
|
||||
raise ToolDefinitionError(f"Tool {tool_name} must have a return type annotation")
|
||||
|
||||
auth_requirement = getattr(tool, "__tool_requires_auth__", None)
|
||||
if isinstance(auth_requirement, OAuth2):
|
||||
auth_requirement = ToolAuthRequirement(
|
||||
oauth2=OAuth2Requirement(**auth_requirement.model_dump())
|
||||
)
|
||||
|
||||
return ToolDefinition(
|
||||
name=snake_to_pascal_case(tool_name),
|
||||
description=tool_description,
|
||||
|
|
@ -183,7 +193,7 @@ class ToolCatalog(BaseModel):
|
|||
inputs=create_input_definition(tool),
|
||||
output=create_output_definition(tool),
|
||||
requirements=ToolRequirements(
|
||||
authorization=getattr(tool, "__tool_requires_auth__", None),
|
||||
auth=auth_requirement,
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -193,7 +203,18 @@ def create_input_definition(func: Callable) -> ToolInputs:
|
|||
Create an input model for a function based on its parameters.
|
||||
"""
|
||||
input_parameters = []
|
||||
tool_context_param_name: str | None = None
|
||||
|
||||
for _, param in inspect.signature(func, follow_wrapped=True).parameters.items():
|
||||
if param.annotation is ToolContext:
|
||||
if tool_context_param_name is not None:
|
||||
raise ToolDefinitionError(
|
||||
f"Only one ToolContext parameter is supported, but tool {func.__name__} has multiple."
|
||||
)
|
||||
|
||||
tool_context_param_name = param.name
|
||||
continue # No further processing of this param (don't add it to the list of inputs)
|
||||
|
||||
tool_field_info = extract_field_info(param)
|
||||
|
||||
is_enum = False
|
||||
|
|
@ -222,7 +243,9 @@ def create_input_definition(func: Callable) -> ToolInputs:
|
|||
)
|
||||
)
|
||||
|
||||
return ToolInputs(parameters=input_parameters)
|
||||
return ToolInputs(
|
||||
parameters=input_parameters, tool_context_parameter_name=tool_context_param_name
|
||||
)
|
||||
|
||||
|
||||
def create_output_definition(func: Callable) -> ToolOutput:
|
||||
|
|
@ -455,6 +478,10 @@ def create_func_models(func: Callable) -> tuple[type[BaseModel], type[BaseModel]
|
|||
if asyncio.iscoroutinefunction(func) and hasattr(func, "__wrapped__"):
|
||||
func = func.__wrapped__
|
||||
for name, param in inspect.signature(func, follow_wrapped=True).parameters.items():
|
||||
# Skip ToolContext parameters
|
||||
if param.annotation is ToolContext:
|
||||
continue
|
||||
|
||||
# TODO make this cleaner
|
||||
tool_field_info = extract_field_info(param)
|
||||
param_fields = {
|
||||
|
|
|
|||
|
|
@ -10,14 +10,17 @@ from arcade.core.errors import (
|
|||
ToolSerializationError,
|
||||
)
|
||||
from arcade.core.response import ToolResponse, tool_response
|
||||
from arcade.core.schema import ToolContext, ToolDefinition
|
||||
|
||||
|
||||
class ToolExecutor:
|
||||
@staticmethod
|
||||
async def run(
|
||||
func: Callable,
|
||||
definition: ToolDefinition,
|
||||
input_model: type[BaseModel],
|
||||
output_model: type[BaseModel],
|
||||
context: ToolContext,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> ToolResponse:
|
||||
|
|
@ -28,11 +31,18 @@ class ToolExecutor:
|
|||
# serialize the input model
|
||||
inputs = await ToolExecutor._serialize_input(input_model, **kwargs)
|
||||
|
||||
# prepare the arguments for the function call
|
||||
func_args = inputs.model_dump()
|
||||
|
||||
# inject ToolContext, if the target function supports it
|
||||
if definition.inputs.tool_context_parameter_name is not None:
|
||||
func_args[definition.inputs.tool_context_parameter_name] = context
|
||||
|
||||
# execute the tool function
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
results = await func(**inputs.model_dump())
|
||||
results = await func(**func_args)
|
||||
else:
|
||||
results = func(**inputs.model_dump())
|
||||
results = func(**func_args)
|
||||
|
||||
# serialize the output model
|
||||
output = await ToolExecutor._serialize_output(output_model, results)
|
||||
|
|
@ -73,6 +83,7 @@ class ToolExecutor:
|
|||
Serialize the output of a tool function.
|
||||
"""
|
||||
# TODO how to type this the results object?
|
||||
# TODO how to ensure `results` contains only safe (serializable) stuff?
|
||||
try:
|
||||
# TODO Logging and telemetry
|
||||
|
||||
|
|
|
|||
|
|
@ -61,7 +61,11 @@ class ToolResponseFactory:
|
|||
msg: str = CustomResponseCode.HTTP_400.msg,
|
||||
data: Any = None,
|
||||
) -> ToolResponse:
|
||||
return await self.__response(res=res, data=data)
|
||||
return await self.__response(
|
||||
res=res,
|
||||
msg=msg, # TODO this needs to map to developer_message in output.error
|
||||
data=data,
|
||||
)
|
||||
|
||||
|
||||
tool_response = ToolResponseFactory()
|
||||
|
|
|
|||
181
arcade/arcade/core/schema.py
Normal file
181
arcade/arcade/core/schema.py
Normal file
|
|
@ -0,0 +1,181 @@
|
|||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
from pydantic import AnyUrl, BaseModel, Field
|
||||
|
||||
|
||||
class ValueSchema(BaseModel):
|
||||
"""Value schema for input parameters and outputs."""
|
||||
|
||||
val_type: Literal["string", "integer", "float", "boolean", "json"]
|
||||
"""The type of the value."""
|
||||
|
||||
enum: Optional[list[str]] = None
|
||||
"""The list of possible values for the value, if it is a closed list."""
|
||||
|
||||
|
||||
class InputParameter(BaseModel):
|
||||
"""A parameter that can be passed to a tool."""
|
||||
|
||||
name: str = Field(..., description="The human-readable name of this parameter.")
|
||||
required: bool = Field(
|
||||
...,
|
||||
description="Whether this parameter is required (true) or optional (false).",
|
||||
)
|
||||
description: Optional[str] = Field(
|
||||
None, description="A descriptive, human-readable explanation of the parameter."
|
||||
)
|
||||
value_schema: ValueSchema = Field(
|
||||
...,
|
||||
description="The schema of the value of this parameter.",
|
||||
)
|
||||
inferrable: bool = Field(
|
||||
True,
|
||||
description="Whether a value for this parameter can be inferred by a model. Defaults to `true`.",
|
||||
)
|
||||
|
||||
|
||||
class ToolInputs(BaseModel):
|
||||
"""The inputs that a tool accepts."""
|
||||
|
||||
parameters: list[InputParameter]
|
||||
"""The list of parameters that the tool accepts."""
|
||||
|
||||
tool_context_parameter_name: str | None = Field(default=None, exclude=True)
|
||||
"""
|
||||
The name of the target parameter that will contain the tool context (if any).
|
||||
This field will not be included in serialization.
|
||||
"""
|
||||
|
||||
|
||||
class ToolOutput(BaseModel):
|
||||
"""The output of a tool."""
|
||||
|
||||
description: Optional[str] = Field(
|
||||
None, description="A descriptive, human-readable explanation of the output."
|
||||
)
|
||||
available_modes: list[str] = Field(
|
||||
default_factory=lambda: ["value", "error", "null"],
|
||||
description="The available modes for the output.",
|
||||
)
|
||||
value_schema: Optional[ValueSchema] = Field(
|
||||
None, description="The schema of the value of the output."
|
||||
)
|
||||
|
||||
|
||||
class OAuth2Requirement(BaseModel):
|
||||
"""Indicates that the tool requires OAuth 2.0 authorization."""
|
||||
|
||||
authority: AnyUrl
|
||||
"""The URL of the OAuth 2.0 authorization server."""
|
||||
|
||||
scope: Optional[list[str]] = None
|
||||
"""The scope(s) needed for authorization."""
|
||||
|
||||
|
||||
class ToolAuthRequirement(BaseModel):
|
||||
"""A requirement for authorization to use a tool."""
|
||||
|
||||
oauth2: Optional[OAuth2Requirement] = None
|
||||
|
||||
|
||||
class ToolRequirements(BaseModel):
|
||||
"""The requirements for a tool to run."""
|
||||
|
||||
auth: Union[ToolAuthRequirement, None] = None
|
||||
"""The authorization requirements for the tool, if any."""
|
||||
|
||||
|
||||
class ToolDefinition(BaseModel):
|
||||
"""The specification of a tool."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
version: str
|
||||
inputs: ToolInputs
|
||||
output: ToolOutput
|
||||
requirements: ToolRequirements
|
||||
|
||||
|
||||
class ToolVersion(BaseModel):
|
||||
"""The name and version of a tool."""
|
||||
|
||||
name: str
|
||||
"""The name of the tool."""
|
||||
|
||||
version: str
|
||||
"""The version of the tool."""
|
||||
|
||||
|
||||
class ToolAuthorizationContext(BaseModel):
|
||||
"""The context for a tool invocation that requires authorization."""
|
||||
|
||||
token: str | None = None
|
||||
"""The token for the tool invocation."""
|
||||
|
||||
|
||||
class ToolContext(BaseModel):
|
||||
"""The context for a tool invocation."""
|
||||
|
||||
authorization: ToolAuthorizationContext | None = None
|
||||
"""The authorization context for the tool invocation that requires authorization."""
|
||||
|
||||
|
||||
class InvokeToolRequest(BaseModel):
|
||||
"""The request to invoke a tool."""
|
||||
|
||||
run_id: str
|
||||
"""The globally-unique run ID provided by the Engine."""
|
||||
invocation_id: str
|
||||
"""The globally-unique ID for this tool invocation in the run."""
|
||||
created_at: str
|
||||
"""The timestamp when the tool invocation was created."""
|
||||
tool: ToolVersion
|
||||
"""The name and version of the tool."""
|
||||
inputs: dict[str, Any] | None = None
|
||||
"""The inputs for the tool."""
|
||||
context: ToolContext
|
||||
"""The context for the tool invocation."""
|
||||
|
||||
|
||||
class InvokeToolError(BaseModel):
|
||||
"""The error that occurred during the tool invocation."""
|
||||
|
||||
message: str
|
||||
"""The user-facing error message."""
|
||||
developer_message: str | None = None
|
||||
"""The developer-facing error details."""
|
||||
|
||||
|
||||
class InvokeToolOutput(BaseModel):
|
||||
"""The output of a tool invocation."""
|
||||
|
||||
value: Union[str, int, float, bool, dict] | None = None
|
||||
"""The value returned by the tool."""
|
||||
error: InvokeToolError | None = None
|
||||
"""The error that occurred during the tool invocation."""
|
||||
|
||||
model_config = {
|
||||
"json_schema_extra": {
|
||||
"oneOf": [
|
||||
{"required": ["value"]},
|
||||
{"required": ["error"]},
|
||||
{"required": ["requires_authorization"]},
|
||||
{"required": ["artifact"]},
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class InvokeToolResponse(BaseModel):
|
||||
"""The response to a tool invocation."""
|
||||
|
||||
invocation_id: str
|
||||
"""The globally-unique ID for this tool invocation."""
|
||||
finished_at: str
|
||||
"""The timestamp when the tool invocation finished."""
|
||||
duration: float
|
||||
"""The duration of the tool invocation in milliseconds (ms)."""
|
||||
success: bool
|
||||
"""Whether the tool invocation was successful."""
|
||||
output: InvokeToolOutput | None = None
|
||||
"""The output of the tool invocation."""
|
||||
|
|
@ -1,89 +0,0 @@
|
|||
from abc import ABC
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
from pydantic import AnyUrl, BaseModel, Field
|
||||
|
||||
|
||||
class ValueSchema(BaseModel):
|
||||
"""Value schema for input parameters and outputs."""
|
||||
|
||||
val_type: Literal["string", "integer", "float", "boolean", "json"]
|
||||
"""The type of the value."""
|
||||
|
||||
enum: Optional[list[str]] = None
|
||||
|
||||
|
||||
class InputParameter(BaseModel):
|
||||
"""A parameter that can be passed to a tool."""
|
||||
|
||||
name: str = Field(..., description="The human-readable name of this parameter.")
|
||||
required: bool = Field(
|
||||
...,
|
||||
description="Whether this parameter is required (true) or optional (false).",
|
||||
)
|
||||
description: Optional[str] = Field(
|
||||
None, description="A descriptive, human-readable explanation of the parameter."
|
||||
)
|
||||
value_schema: ValueSchema = Field(
|
||||
...,
|
||||
description="The schema of the value of this parameter.",
|
||||
)
|
||||
inferrable: bool = Field(
|
||||
True,
|
||||
description="Whether a value for this parameter can be inferred by a model. Defaults to `true`.",
|
||||
)
|
||||
|
||||
|
||||
class ToolInputs(BaseModel):
|
||||
"""The inputs that a tool accepts."""
|
||||
|
||||
parameters: list[InputParameter]
|
||||
"""The list of parameters that the tool accepts."""
|
||||
|
||||
|
||||
class ToolOutput(BaseModel):
|
||||
"""The output of a tool."""
|
||||
|
||||
description: Optional[str] = Field(
|
||||
None, description="A descriptive, human-readable explanation of the output."
|
||||
)
|
||||
available_modes: list[str] = Field(
|
||||
default_factory=lambda: ["value", "error", "null"],
|
||||
description="The available modes for the output.",
|
||||
)
|
||||
value_schema: Optional[ValueSchema] = Field(
|
||||
None, description="The schema of the value of the output."
|
||||
)
|
||||
|
||||
|
||||
class ToolAuthorizationRequirement(BaseModel, ABC):
|
||||
"""A requirement for authorization to use a tool."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class OAuth2AuthorizationRequirement(ToolAuthorizationRequirement):
|
||||
"""Specifies OAuth2 requirement for tool execution."""
|
||||
|
||||
url: AnyUrl
|
||||
"""The URL to which the user should be redirected to authorize the tool."""
|
||||
|
||||
scope: Optional[list[str]] = None
|
||||
"""The scope of the authorization."""
|
||||
|
||||
|
||||
class ToolRequirements(BaseModel):
|
||||
"""The requirements for a tool to run."""
|
||||
|
||||
authorization: Union[ToolAuthorizationRequirement, None] = None
|
||||
|
||||
|
||||
class ToolDefinition(BaseModel):
|
||||
"""The specification of a tool."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
version: str
|
||||
inputs: ToolInputs
|
||||
output: ToolOutput
|
||||
requirements: ToolRequirements
|
||||
|
|
@ -0,0 +1,5 @@
|
|||
from .tool import tool
|
||||
|
||||
__all__ = [
|
||||
"tool",
|
||||
]
|
||||
20
arcade/arcade/sdk/auth.py
Normal file
20
arcade/arcade/sdk/auth.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
from abc import ABC
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import AnyUrl, BaseModel
|
||||
|
||||
|
||||
class ToolAuthorization(BaseModel, ABC):
|
||||
"""Marks a tool as requiring authorization."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class OAuth2(ToolAuthorization):
|
||||
"""Marks a tool as requiring OAuth 2.0 authorization."""
|
||||
|
||||
authority: AnyUrl
|
||||
"""The URL of the OAuth 2.0 authorization server."""
|
||||
|
||||
scope: Optional[list[str]] = None
|
||||
"""The scope(s) needed for the authorized action."""
|
||||
|
|
@ -1,9 +1,8 @@
|
|||
import inspect
|
||||
import os
|
||||
from typing import Any, Callable, Optional, TypeVar, Union
|
||||
from typing import Callable, TypeVar, Union
|
||||
|
||||
from arcade.core.tool import ToolAuthorizationRequirement
|
||||
from arcade.core.utils import snake_to_pascal_case
|
||||
from arcade.sdk.auth import ToolAuthorization
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
|
@ -13,7 +12,7 @@ def tool(
|
|||
func: Callable | None = None,
|
||||
desc: str | None = None,
|
||||
name: str | None = None,
|
||||
requires_auth: Union[ToolAuthorizationRequirement, None] = None,
|
||||
requires_auth: Union[ToolAuthorization, None] = None,
|
||||
) -> Callable:
|
||||
def decorator(func: Callable) -> Callable:
|
||||
func_name = str(getattr(func, "__name__", None))
|
||||
|
|
@ -28,12 +27,3 @@ def tool(
|
|||
if func: # This means the decorator is used without parameters
|
||||
return decorator(func)
|
||||
return decorator
|
||||
|
||||
|
||||
def get_secret(name: str, default: Optional[Any] = None) -> Any:
|
||||
secret = os.getenv(name)
|
||||
if secret is None:
|
||||
if default is not None:
|
||||
return default
|
||||
raise ValueError(f"Secret {name} is not set.")
|
||||
return secret
|
||||
|
|
|
|||
|
|
@ -2,8 +2,8 @@ import asyncio
|
|||
|
||||
import pytest
|
||||
|
||||
from arcade.core.tool import OAuth2AuthorizationRequirement
|
||||
from arcade.sdk.tool import tool
|
||||
from arcade.sdk import tool
|
||||
from arcade.sdk.auth import OAuth2
|
||||
|
||||
|
||||
def test_sync_function():
|
||||
|
|
@ -38,8 +38,8 @@ def test_tool_decorator_with_all_options():
|
|||
@tool(
|
||||
name="TestTool",
|
||||
desc="Test description",
|
||||
requires_auth=OAuth2AuthorizationRequirement(
|
||||
url="https://example.com/oauth2/auth",
|
||||
requires_auth=OAuth2(
|
||||
authority="https://example.com/oauth2/auth",
|
||||
scope=["test_scope", "another.scope"],
|
||||
),
|
||||
)
|
||||
|
|
@ -48,5 +48,5 @@ def test_tool_decorator_with_all_options():
|
|||
|
||||
assert test_tool.__tool_name__ == "TestTool"
|
||||
assert test_tool.__tool_description__ == "Test description"
|
||||
assert str(test_tool.__tool_requires_auth__.url) == "https://example.com/oauth2/auth"
|
||||
assert str(test_tool.__tool_requires_auth__.authority) == "https://example.com/oauth2/auth"
|
||||
assert test_tool.__tool_requires_auth__.scope == ["test_scope", "another.scope"]
|
||||
|
|
|
|||
|
|
@ -3,16 +3,19 @@ from typing import Annotated, Literal, Optional
|
|||
import pytest
|
||||
|
||||
from arcade.core.catalog import ToolCatalog
|
||||
from arcade.core.tool import (
|
||||
from arcade.core.schema import (
|
||||
InputParameter,
|
||||
OAuth2AuthorizationRequirement,
|
||||
OAuth2Requirement,
|
||||
ToolAuthRequirement,
|
||||
ToolContext,
|
||||
ToolInputs,
|
||||
ToolOutput,
|
||||
ToolRequirements,
|
||||
ValueSchema,
|
||||
)
|
||||
from arcade.sdk import tool
|
||||
from arcade.sdk.annotations import Inferrable
|
||||
from arcade.sdk.tool import tool
|
||||
from arcade.sdk.auth import OAuth2
|
||||
|
||||
|
||||
### Tests on @tool decorator
|
||||
|
|
@ -43,9 +46,7 @@ def func_with_name_and_description():
|
|||
|
||||
@tool(
|
||||
desc="A function that requires authentication",
|
||||
requires_auth=OAuth2AuthorizationRequirement(
|
||||
url="https://example.com/oauth2/auth", scope=["scope1", "scope2"]
|
||||
),
|
||||
requires_auth=OAuth2(authority="https://example.com/oauth2/auth", scope=["scope1", "scope2"]),
|
||||
)
|
||||
def func_with_auth_requirement():
|
||||
pass
|
||||
|
|
@ -53,7 +54,7 @@ def func_with_auth_requirement():
|
|||
|
||||
### Tests on input params
|
||||
@tool(desc="A function with an input parameter")
|
||||
def func_with_param(param1: Annotated[str, "First param"]):
|
||||
def func_with_param(context: Annotated[str, "First param"]):
|
||||
pass
|
||||
|
||||
|
||||
|
|
@ -114,6 +115,7 @@ def func_with_optional_param_with_default_value(
|
|||
|
||||
@tool(desc="A function with multiple parameters, some with default values")
|
||||
def func_with_mixed_params(
|
||||
context: ToolContext,
|
||||
param1: Annotated[str, "First param"],
|
||||
param2: Annotated[int, "Second param"] = 42,
|
||||
):
|
||||
|
|
@ -125,6 +127,11 @@ def func_with_complex_param(param1: Annotated[list[str], "A list of strings"]):
|
|||
pass
|
||||
|
||||
|
||||
@tool(desc="A function that takes a context")
|
||||
def func_with_context(my_context: ToolContext):
|
||||
pass
|
||||
|
||||
|
||||
### Tests on output/return values
|
||||
@tool(desc="A function that performs an action without returning anything")
|
||||
def func_with_no_return():
|
||||
|
|
@ -184,15 +191,18 @@ def func_with_complex_return() -> list[dict[str, str]]:
|
|||
),
|
||||
pytest.param(
|
||||
func_with_name_and_description,
|
||||
{"name": "MyCustomTool", "requirements": ToolRequirements(authorization=None)},
|
||||
{"name": "MyCustomTool", "requirements": ToolRequirements(auth=None)},
|
||||
id="func_with_no_auth_requirement",
|
||||
),
|
||||
pytest.param(
|
||||
func_with_auth_requirement,
|
||||
{
|
||||
"requirements": ToolRequirements(
|
||||
authorization=OAuth2AuthorizationRequirement(
|
||||
url="https://example.com/oauth2/auth", scope=["scope1", "scope2"]
|
||||
auth=ToolAuthRequirement(
|
||||
oauth2=OAuth2Requirement(
|
||||
authority="https://example.com/oauth2/auth",
|
||||
scope=["scope1", "scope2"],
|
||||
)
|
||||
)
|
||||
)
|
||||
},
|
||||
|
|
@ -212,7 +222,7 @@ def func_with_complex_return() -> list[dict[str, str]]:
|
|||
"inputs": ToolInputs(
|
||||
parameters=[
|
||||
InputParameter(
|
||||
name="param1",
|
||||
name="context", # Nothing special about this name, parameters can be named anything
|
||||
description="First param",
|
||||
inferrable=True, # Defaults to true
|
||||
required=True,
|
||||
|
|
@ -428,7 +438,8 @@ def func_with_complex_return() -> list[dict[str, str]]:
|
|||
required=False, # Because a default value is provided
|
||||
value_schema=ValueSchema(val_type="integer", enum=None),
|
||||
),
|
||||
]
|
||||
],
|
||||
tool_context_parameter_name="context",
|
||||
),
|
||||
},
|
||||
id="func_with_mixed_params",
|
||||
|
|
@ -450,6 +461,15 @@ def func_with_complex_return() -> list[dict[str, str]]:
|
|||
},
|
||||
id="func_with_complex_param",
|
||||
),
|
||||
pytest.param(
|
||||
func_with_context,
|
||||
{
|
||||
"inputs": ToolInputs(
|
||||
parameters=[], tool_context_parameter_name="my_context"
|
||||
), # ToolContext type is not an input param, but it's stored in the inputs field
|
||||
},
|
||||
id="func_with_context",
|
||||
),
|
||||
# Tests on output values
|
||||
pytest.param(
|
||||
func_with_no_return,
|
||||
|
|
|
|||
|
|
@ -2,7 +2,8 @@ import pytest
|
|||
|
||||
from arcade.core.catalog import ToolCatalog
|
||||
from arcade.core.errors import ToolDefinitionError
|
||||
from arcade.sdk.tool import tool
|
||||
from arcade.core.schema import ToolContext
|
||||
from arcade.sdk import tool
|
||||
|
||||
|
||||
@tool
|
||||
|
|
@ -30,6 +31,11 @@ def func_with_unsupported_param(param1: complex):
|
|||
pass
|
||||
|
||||
|
||||
@tool(desc="A function with multiple context parameters (illegal)")
|
||||
def func_with_multiple_context_params(context: ToolContext, context2: ToolContext):
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"func_under_test, exception_type",
|
||||
[
|
||||
|
|
@ -58,6 +64,11 @@ def func_with_unsupported_param(param1: complex):
|
|||
ToolDefinitionError,
|
||||
id=func_with_unsupported_param.__name__,
|
||||
),
|
||||
pytest.param(
|
||||
func_with_multiple_context_params,
|
||||
ToolDefinitionError,
|
||||
id=func_with_multiple_context_params.__name__,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_missing_info_raises_error(func_under_test, exception_type):
|
||||
|
|
|
|||
|
|
@ -4,13 +4,13 @@ import pytest
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
from arcade.core.catalog import ToolCatalog
|
||||
from arcade.core.tool import (
|
||||
from arcade.core.schema import (
|
||||
InputParameter,
|
||||
ToolInputs,
|
||||
ToolOutput,
|
||||
ValueSchema,
|
||||
)
|
||||
from arcade.sdk.tool import tool
|
||||
from arcade.sdk import tool
|
||||
|
||||
|
||||
class ProductOutput(BaseModel):
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from pydantic import Field
|
|||
|
||||
from arcade.core.catalog import ToolCatalog
|
||||
from arcade.core.errors import ToolDefinitionError
|
||||
from arcade.sdk.tool import tool
|
||||
from arcade.sdk import tool
|
||||
|
||||
|
||||
@tool
|
||||
|
|
|
|||
65
examples/gmail/arcade_gmail/tools/gdrive.py
Normal file
65
examples/gmail/arcade_gmail/tools/gdrive.py
Normal file
|
|
@ -0,0 +1,65 @@
|
|||
import os
|
||||
|
||||
from google.auth.transport.requests import Request
|
||||
from google.oauth2.credentials import Credentials
|
||||
from google_auth_oauthlib.flow import InstalledAppFlow
|
||||
from google.auth.exceptions import RefreshError
|
||||
from googleapiclient.discovery import build
|
||||
from typing import Annotated
|
||||
from arcade.sdk import tool
|
||||
|
||||
SECRET_FILE = "/Users/spartee/Dropbox/Arcade/gcp/credentials.json"
|
||||
DRIVE_SCOPES = ["https://www.googleapis.com/auth/drive.metadata.readonly"]
|
||||
|
||||
|
||||
@tool
|
||||
async def list_drive_files(
|
||||
n_files: Annotated[int, "Number of files to search"] = 5,
|
||||
) -> list[str]:
|
||||
"""List files from a Google Drive account and return their details."""
|
||||
|
||||
creds = None
|
||||
# The file token.json stores the user's access and refresh tokens, and is
|
||||
# created automatically when the authorization flow completes for the first time.
|
||||
# TODO: use context.authorization.token like gmail.py
|
||||
if os.path.exists("token.json"):
|
||||
creds = Credentials.from_authorized_user_file("token.json")
|
||||
# If there are no (valid) credentials available, let the user log in.
|
||||
if not creds or not creds.valid:
|
||||
if creds and creds.expired and creds.refresh_token:
|
||||
try:
|
||||
creds.refresh(Request())
|
||||
except RefreshError:
|
||||
flow = InstalledAppFlow.from_client_secrets_file(
|
||||
SECRET_FILE, DRIVE_SCOPES
|
||||
)
|
||||
creds = flow.run_local_server(port=0)
|
||||
# Save the credentials for the next run
|
||||
with open("token.json", "w") as token:
|
||||
token.write(creds.to_json())
|
||||
else:
|
||||
flow = InstalledAppFlow.from_client_secrets_file(SECRET_FILE, DRIVE_SCOPES)
|
||||
creds = flow.run_local_server(port=0)
|
||||
# Save the credentials for the next run
|
||||
with open("token.json", "w") as token:
|
||||
token.write(creds.to_json())
|
||||
|
||||
# Call the Drive v3 API
|
||||
service = build("drive", "v3", credentials=creds)
|
||||
|
||||
# Request a list of all the files
|
||||
results = (
|
||||
service.files()
|
||||
.list(pageSize=n_files, fields="nextPageToken, files(id, name)")
|
||||
.execute()
|
||||
)
|
||||
items = results.get("files", [])
|
||||
|
||||
if not items:
|
||||
print("No files found.")
|
||||
else:
|
||||
print("Files:")
|
||||
for item in items:
|
||||
print("{0} ({1})".format(item["name"], item["id"]))
|
||||
|
||||
return items
|
||||
|
|
@ -1,52 +1,29 @@
|
|||
import os
|
||||
import re
|
||||
|
||||
from base64 import urlsafe_b64decode
|
||||
from arcade.core.schema import ToolContext
|
||||
from arcade.sdk.auth import OAuth2
|
||||
from bs4 import BeautifulSoup
|
||||
from google.auth.transport.requests import Request
|
||||
from google.oauth2.credentials import Credentials
|
||||
from google_auth_oauthlib.flow import InstalledAppFlow
|
||||
from google.auth.exceptions import RefreshError
|
||||
from googleapiclient.discovery import build
|
||||
from typing import Dict, List, Annotated
|
||||
from arcade.sdk.tool import tool
|
||||
from typing import Annotated
|
||||
from arcade.sdk import tool
|
||||
|
||||
|
||||
SCOPES = ["https://www.googleapis.com/auth/gmail.readonly"]
|
||||
SECRET_FILE = "/Users/spartee/Dropbox/Arcade/gcp/credentials.json"
|
||||
|
||||
|
||||
@tool
|
||||
async def oauth_read_email(
|
||||
@tool(
|
||||
requires_auth=OAuth2(
|
||||
authority="https://accounts.google.com",
|
||||
scope=["https://www.googleapis.com/auth/gmail.readonly"],
|
||||
)
|
||||
)
|
||||
async def get_emails(
|
||||
context: ToolContext,
|
||||
n_emails: Annotated[int, "Number of emails to read"] = 5,
|
||||
) -> List[Dict[str, str]]:
|
||||
) -> dict:
|
||||
"""Read emails from a Gmail account and extract plain text content, removing any HTML."""
|
||||
|
||||
creds = None
|
||||
# The file token.json stores the user's access and refresh tokens, and is
|
||||
# created automatically when the authorization flow completes for the first time.
|
||||
if os.path.exists("token.json"):
|
||||
creds = Credentials.from_authorized_user_file("token.json")
|
||||
# If there are no (valid) credentials available, let the user log in.
|
||||
if not creds or not creds.valid:
|
||||
if creds and creds.expired and creds.refresh_token:
|
||||
try:
|
||||
creds.refresh(Request())
|
||||
except RefreshError:
|
||||
flow = InstalledAppFlow.from_client_secrets_file(SECRET_FILE, SCOPES)
|
||||
creds = flow.run_local_server(port=0)
|
||||
# Save the credentials for the next run
|
||||
with open("token.json", "w") as token:
|
||||
token.write(creds.to_json())
|
||||
else:
|
||||
flow = InstalledAppFlow.from_client_secrets_file(SECRET_FILE, SCOPES)
|
||||
creds = flow.run_local_server(port=0)
|
||||
# Save the credentials for the next run
|
||||
with open("token.json", "w") as token:
|
||||
token.write(creds.to_json())
|
||||
|
||||
# Call the Gmail API
|
||||
service = build("gmail", "v1", credentials=creds)
|
||||
service = build("gmail", "v1", credentials=Credentials(context.authorization.token))
|
||||
|
||||
# Request a list of all the messages
|
||||
result = service.users().messages().list(userId="me").execute()
|
||||
|
|
@ -97,7 +74,7 @@ async def oauth_read_email(
|
|||
print(f"Error reading email {msg['id']}: {e}", "ERROR")
|
||||
continue
|
||||
|
||||
return emails
|
||||
return {"emails": emails}
|
||||
|
||||
|
||||
def clean_email_body(body: str) -> str:
|
||||
|
|
@ -114,58 +91,3 @@ def clean_email_body(body: str) -> str:
|
|||
text = " ".join(text.split()) # Remove extra whitespace
|
||||
|
||||
return text
|
||||
|
||||
|
||||
DRIVE_SCOPES = ["https://www.googleapis.com/auth/drive.metadata.readonly"]
|
||||
|
||||
|
||||
@tool
|
||||
async def list_drive_files(
|
||||
n_files: Annotated[int, "Number of files to search"] = 5,
|
||||
) -> list[str]:
|
||||
"""List files from a Google Drive account and return their details."""
|
||||
|
||||
creds = None
|
||||
# The file token.json stores the user's access and refresh tokens, and is
|
||||
# created automatically when the authorization flow completes for the first time.
|
||||
if os.path.exists("token.json"):
|
||||
creds = Credentials.from_authorized_user_file("token.json")
|
||||
# If there are no (valid) credentials available, let the user log in.
|
||||
if not creds or not creds.valid:
|
||||
if creds and creds.expired and creds.refresh_token:
|
||||
try:
|
||||
creds.refresh(Request())
|
||||
except RefreshError:
|
||||
flow = InstalledAppFlow.from_client_secrets_file(
|
||||
SECRET_FILE, DRIVE_SCOPES
|
||||
)
|
||||
creds = flow.run_local_server(port=0)
|
||||
# Save the credentials for the next run
|
||||
with open("token.json", "w") as token:
|
||||
token.write(creds.to_json())
|
||||
else:
|
||||
flow = InstalledAppFlow.from_client_secrets_file(SECRET_FILE, DRIVE_SCOPES)
|
||||
creds = flow.run_local_server(port=0)
|
||||
# Save the credentials for the next run
|
||||
with open("token.json", "w") as token:
|
||||
token.write(creds.to_json())
|
||||
|
||||
# Call the Drive v3 API
|
||||
service = build("drive", "v3", credentials=creds)
|
||||
|
||||
# Request a list of all the files
|
||||
results = (
|
||||
service.files()
|
||||
.list(pageSize=n_files, fields="nextPageToken, files(id, name)")
|
||||
.execute()
|
||||
)
|
||||
items = results.get("files", [])
|
||||
|
||||
if not items:
|
||||
print("No files found.")
|
||||
else:
|
||||
print("Files:")
|
||||
for item in items:
|
||||
print("{0} ({1})".format(item["name"], item["id"]))
|
||||
|
||||
return items
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import math
|
||||
from typing import Annotated
|
||||
|
||||
from arcade.sdk.tool import tool
|
||||
from arcade.sdk import tool
|
||||
|
||||
|
||||
@tool
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
import json
|
||||
import os
|
||||
import serpapi
|
||||
from typing import Annotated
|
||||
from arcade.sdk.tool import tool, get_secret
|
||||
from typing import Annotated, Any, Optional
|
||||
from arcade.sdk import tool
|
||||
|
||||
|
||||
@tool
|
||||
|
|
@ -23,3 +24,12 @@ async def search_google(
|
|||
organic_results = results.get("organic_results", [])
|
||||
|
||||
return json.dumps(organic_results[:n_results])
|
||||
|
||||
|
||||
def get_secret(name: str, default: Optional[Any] = None) -> Any:
|
||||
secret = os.getenv(name)
|
||||
if secret is None:
|
||||
if default is not None:
|
||||
return default
|
||||
raise ValueError(f"Secret {name} is not set.")
|
||||
return secret
|
||||
|
|
|
|||
|
|
@ -68,6 +68,6 @@
|
|||
}
|
||||
}
|
||||
},
|
||||
"required": ["run_id", "invocation_id", "created_at", "tool", "input", "context"],
|
||||
"required": ["run_id", "invocation_id", "created_at", "tool", "inputs", "context"],
|
||||
"additionalProperties": false
|
||||
}
|
||||
|
|
|
|||
|
|
@ -129,7 +129,10 @@
|
|||
"format": "uri"
|
||||
},
|
||||
"scope": {
|
||||
"type": "string"
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
},
|
||||
"required": ["url"],
|
||||
|
|
|
|||
Loading…
Reference in a new issue