Add ToolContext and OAuth tool support (#10)

- Adds initial `ToolContext` to tool invocations
- This unlocks the ability to call authenticated tools (e.g. Gmail),
which works in this branch against Nate's dev engine
This commit is contained in:
Nate Barbettini 2024-08-02 11:25:08 -07:00 committed by GitHub
parent 1b67cee667
commit 14998a43e3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 421 additions and 282 deletions

View file

@ -3,15 +3,16 @@ from abc import ABC, abstractmethod
from datetime import datetime
from typing import Any, Callable
from arcade.actor.schema import (
InvokeToolRequest,
InvokeToolResponse,
ToolOutput,
ToolOutputError,
)
from arcade.core.catalog import ToolCatalog, Toolkit
from arcade.core.executor import ToolExecutor
from arcade.core.tool import ToolDefinition
from arcade.core.schema import (
InvokeToolError,
InvokeToolOutput,
InvokeToolRequest,
InvokeToolResponse,
ToolContext,
ToolDefinition,
)
class ActorComponent(ABC):
@ -72,15 +73,17 @@ class BaseActor:
response = await ToolExecutor.run(
func=materialized_tool.tool,
definition=materialized_tool.definition,
input_model=materialized_tool.input_model,
output_model=materialized_tool.output_model,
context=tool_request.context or ToolContext(),
**tool_request.inputs or {},
)
if response.code == 200:
# TODO remove ignore
output = ToolOutput(value=response.data.result) # type: ignore[union-attr]
output = InvokeToolOutput(value=response.data.result) # type: ignore[union-attr]
else:
output = ToolOutput(error=ToolOutputError(message=response.msg))
output = InvokeToolOutput(error=InvokeToolError(message=response.msg))
end_time = time.time() # End time in seconds
duration_ms = (end_time - start_time) * 1000 # Convert to milliseconds

View file

@ -1,45 +0,0 @@
from typing import ClassVar, Union
from pydantic import BaseModel
class ToolVersion(BaseModel):
name: str
version: str
class InvokeToolRequest(BaseModel):
run_id: str
invocation_id: str
created_at: str
tool: ToolVersion
inputs: dict | None
context: dict | None
class ToolOutputError(BaseModel):
message: str
developer_message: str | None = None
class ToolOutput(BaseModel):
value: Union[str, int, float, bool, dict] | None = None
error: ToolOutputError | None = None
class Config:
json_schema_extra: ClassVar[dict] = {
"oneOf": [
{"required": ["value"]},
{"required": ["error"]},
{"required": ["requires_authorization"]},
{"required": ["artifact"]},
]
}
class InvokeToolResponse(BaseModel):
invocation_id: str
finished_at: str
duration: float
success: bool
output: ToolOutput | None = None

View file

@ -153,6 +153,7 @@ def run( # noqa: C901
output = asyncio.run(
ToolExecutor.run(
called_tool.tool,
called_tool.definition,
called_tool.input_model,
called_tool.output_model,
**parameters,

View file

@ -119,7 +119,7 @@ def create_new_toolkit(directory: str) -> None:
os.path.join(toolkit_dir, "tools", "hello.py"),
dedent(
f"""
from arcade.sdk.tool import tool
from arcade.sdk import tool
@tool
def hello() -> str:

View file

@ -22,8 +22,11 @@ from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefined
from arcade.core.errors import ToolDefinitionError
from arcade.core.tool import (
from arcade.core.schema import (
InputParameter,
OAuth2Requirement,
ToolAuthRequirement,
ToolContext,
ToolDefinition,
ToolInputs,
ToolOutput,
@ -38,6 +41,7 @@ from arcade.core.utils import (
snake_to_pascal_case,
)
from arcade.sdk.annotations import Inferrable
from arcade.sdk.auth import OAuth2
WireType = Literal["string", "integer", "float", "boolean", "json"]
@ -176,6 +180,12 @@ class ToolCatalog(BaseModel):
if does_function_return_value(tool) and tool.__annotations__.get("return") is None:
raise ToolDefinitionError(f"Tool {tool_name} must have a return type annotation")
auth_requirement = getattr(tool, "__tool_requires_auth__", None)
if isinstance(auth_requirement, OAuth2):
auth_requirement = ToolAuthRequirement(
oauth2=OAuth2Requirement(**auth_requirement.model_dump())
)
return ToolDefinition(
name=snake_to_pascal_case(tool_name),
description=tool_description,
@ -183,7 +193,7 @@ class ToolCatalog(BaseModel):
inputs=create_input_definition(tool),
output=create_output_definition(tool),
requirements=ToolRequirements(
authorization=getattr(tool, "__tool_requires_auth__", None),
auth=auth_requirement,
),
)
@ -193,7 +203,18 @@ def create_input_definition(func: Callable) -> ToolInputs:
Create an input model for a function based on its parameters.
"""
input_parameters = []
tool_context_param_name: str | None = None
for _, param in inspect.signature(func, follow_wrapped=True).parameters.items():
if param.annotation is ToolContext:
if tool_context_param_name is not None:
raise ToolDefinitionError(
f"Only one ToolContext parameter is supported, but tool {func.__name__} has multiple."
)
tool_context_param_name = param.name
continue # No further processing of this param (don't add it to the list of inputs)
tool_field_info = extract_field_info(param)
is_enum = False
@ -222,7 +243,9 @@ def create_input_definition(func: Callable) -> ToolInputs:
)
)
return ToolInputs(parameters=input_parameters)
return ToolInputs(
parameters=input_parameters, tool_context_parameter_name=tool_context_param_name
)
def create_output_definition(func: Callable) -> ToolOutput:
@ -455,6 +478,10 @@ def create_func_models(func: Callable) -> tuple[type[BaseModel], type[BaseModel]
if asyncio.iscoroutinefunction(func) and hasattr(func, "__wrapped__"):
func = func.__wrapped__
for name, param in inspect.signature(func, follow_wrapped=True).parameters.items():
# Skip ToolContext parameters
if param.annotation is ToolContext:
continue
# TODO make this cleaner
tool_field_info = extract_field_info(param)
param_fields = {

View file

@ -10,14 +10,17 @@ from arcade.core.errors import (
ToolSerializationError,
)
from arcade.core.response import ToolResponse, tool_response
from arcade.core.schema import ToolContext, ToolDefinition
class ToolExecutor:
@staticmethod
async def run(
func: Callable,
definition: ToolDefinition,
input_model: type[BaseModel],
output_model: type[BaseModel],
context: ToolContext,
*args: Any,
**kwargs: Any,
) -> ToolResponse:
@ -28,11 +31,18 @@ class ToolExecutor:
# serialize the input model
inputs = await ToolExecutor._serialize_input(input_model, **kwargs)
# prepare the arguments for the function call
func_args = inputs.model_dump()
# inject ToolContext, if the target function supports it
if definition.inputs.tool_context_parameter_name is not None:
func_args[definition.inputs.tool_context_parameter_name] = context
# execute the tool function
if asyncio.iscoroutinefunction(func):
results = await func(**inputs.model_dump())
results = await func(**func_args)
else:
results = func(**inputs.model_dump())
results = func(**func_args)
# serialize the output model
output = await ToolExecutor._serialize_output(output_model, results)
@ -73,6 +83,7 @@ class ToolExecutor:
Serialize the output of a tool function.
"""
# TODO how to type this the results object?
# TODO how to ensure `results` contains only safe (serializable) stuff?
try:
# TODO Logging and telemetry

View file

@ -61,7 +61,11 @@ class ToolResponseFactory:
msg: str = CustomResponseCode.HTTP_400.msg,
data: Any = None,
) -> ToolResponse:
return await self.__response(res=res, data=data)
return await self.__response(
res=res,
msg=msg, # TODO this needs to map to developer_message in output.error
data=data,
)
tool_response = ToolResponseFactory()

View file

@ -0,0 +1,181 @@
from typing import Any, Literal, Optional, Union
from pydantic import AnyUrl, BaseModel, Field
class ValueSchema(BaseModel):
"""Value schema for input parameters and outputs."""
val_type: Literal["string", "integer", "float", "boolean", "json"]
"""The type of the value."""
enum: Optional[list[str]] = None
"""The list of possible values for the value, if it is a closed list."""
class InputParameter(BaseModel):
"""A parameter that can be passed to a tool."""
name: str = Field(..., description="The human-readable name of this parameter.")
required: bool = Field(
...,
description="Whether this parameter is required (true) or optional (false).",
)
description: Optional[str] = Field(
None, description="A descriptive, human-readable explanation of the parameter."
)
value_schema: ValueSchema = Field(
...,
description="The schema of the value of this parameter.",
)
inferrable: bool = Field(
True,
description="Whether a value for this parameter can be inferred by a model. Defaults to `true`.",
)
class ToolInputs(BaseModel):
"""The inputs that a tool accepts."""
parameters: list[InputParameter]
"""The list of parameters that the tool accepts."""
tool_context_parameter_name: str | None = Field(default=None, exclude=True)
"""
The name of the target parameter that will contain the tool context (if any).
This field will not be included in serialization.
"""
class ToolOutput(BaseModel):
"""The output of a tool."""
description: Optional[str] = Field(
None, description="A descriptive, human-readable explanation of the output."
)
available_modes: list[str] = Field(
default_factory=lambda: ["value", "error", "null"],
description="The available modes for the output.",
)
value_schema: Optional[ValueSchema] = Field(
None, description="The schema of the value of the output."
)
class OAuth2Requirement(BaseModel):
"""Indicates that the tool requires OAuth 2.0 authorization."""
authority: AnyUrl
"""The URL of the OAuth 2.0 authorization server."""
scope: Optional[list[str]] = None
"""The scope(s) needed for authorization."""
class ToolAuthRequirement(BaseModel):
"""A requirement for authorization to use a tool."""
oauth2: Optional[OAuth2Requirement] = None
class ToolRequirements(BaseModel):
"""The requirements for a tool to run."""
auth: Union[ToolAuthRequirement, None] = None
"""The authorization requirements for the tool, if any."""
class ToolDefinition(BaseModel):
"""The specification of a tool."""
name: str
description: str
version: str
inputs: ToolInputs
output: ToolOutput
requirements: ToolRequirements
class ToolVersion(BaseModel):
"""The name and version of a tool."""
name: str
"""The name of the tool."""
version: str
"""The version of the tool."""
class ToolAuthorizationContext(BaseModel):
"""The context for a tool invocation that requires authorization."""
token: str | None = None
"""The token for the tool invocation."""
class ToolContext(BaseModel):
"""The context for a tool invocation."""
authorization: ToolAuthorizationContext | None = None
"""The authorization context for the tool invocation that requires authorization."""
class InvokeToolRequest(BaseModel):
"""The request to invoke a tool."""
run_id: str
"""The globally-unique run ID provided by the Engine."""
invocation_id: str
"""The globally-unique ID for this tool invocation in the run."""
created_at: str
"""The timestamp when the tool invocation was created."""
tool: ToolVersion
"""The name and version of the tool."""
inputs: dict[str, Any] | None = None
"""The inputs for the tool."""
context: ToolContext
"""The context for the tool invocation."""
class InvokeToolError(BaseModel):
"""The error that occurred during the tool invocation."""
message: str
"""The user-facing error message."""
developer_message: str | None = None
"""The developer-facing error details."""
class InvokeToolOutput(BaseModel):
"""The output of a tool invocation."""
value: Union[str, int, float, bool, dict] | None = None
"""The value returned by the tool."""
error: InvokeToolError | None = None
"""The error that occurred during the tool invocation."""
model_config = {
"json_schema_extra": {
"oneOf": [
{"required": ["value"]},
{"required": ["error"]},
{"required": ["requires_authorization"]},
{"required": ["artifact"]},
]
}
}
class InvokeToolResponse(BaseModel):
"""The response to a tool invocation."""
invocation_id: str
"""The globally-unique ID for this tool invocation."""
finished_at: str
"""The timestamp when the tool invocation finished."""
duration: float
"""The duration of the tool invocation in milliseconds (ms)."""
success: bool
"""Whether the tool invocation was successful."""
output: InvokeToolOutput | None = None
"""The output of the tool invocation."""

View file

@ -1,89 +0,0 @@
from abc import ABC
from typing import Literal, Optional, Union
from pydantic import AnyUrl, BaseModel, Field
class ValueSchema(BaseModel):
"""Value schema for input parameters and outputs."""
val_type: Literal["string", "integer", "float", "boolean", "json"]
"""The type of the value."""
enum: Optional[list[str]] = None
class InputParameter(BaseModel):
"""A parameter that can be passed to a tool."""
name: str = Field(..., description="The human-readable name of this parameter.")
required: bool = Field(
...,
description="Whether this parameter is required (true) or optional (false).",
)
description: Optional[str] = Field(
None, description="A descriptive, human-readable explanation of the parameter."
)
value_schema: ValueSchema = Field(
...,
description="The schema of the value of this parameter.",
)
inferrable: bool = Field(
True,
description="Whether a value for this parameter can be inferred by a model. Defaults to `true`.",
)
class ToolInputs(BaseModel):
"""The inputs that a tool accepts."""
parameters: list[InputParameter]
"""The list of parameters that the tool accepts."""
class ToolOutput(BaseModel):
"""The output of a tool."""
description: Optional[str] = Field(
None, description="A descriptive, human-readable explanation of the output."
)
available_modes: list[str] = Field(
default_factory=lambda: ["value", "error", "null"],
description="The available modes for the output.",
)
value_schema: Optional[ValueSchema] = Field(
None, description="The schema of the value of the output."
)
class ToolAuthorizationRequirement(BaseModel, ABC):
"""A requirement for authorization to use a tool."""
pass
class OAuth2AuthorizationRequirement(ToolAuthorizationRequirement):
"""Specifies OAuth2 requirement for tool execution."""
url: AnyUrl
"""The URL to which the user should be redirected to authorize the tool."""
scope: Optional[list[str]] = None
"""The scope of the authorization."""
class ToolRequirements(BaseModel):
"""The requirements for a tool to run."""
authorization: Union[ToolAuthorizationRequirement, None] = None
class ToolDefinition(BaseModel):
"""The specification of a tool."""
name: str
description: str
version: str
inputs: ToolInputs
output: ToolOutput
requirements: ToolRequirements

View file

@ -0,0 +1,5 @@
from .tool import tool
__all__ = [
"tool",
]

20
arcade/arcade/sdk/auth.py Normal file
View file

@ -0,0 +1,20 @@
from abc import ABC
from typing import Optional
from pydantic import AnyUrl, BaseModel
class ToolAuthorization(BaseModel, ABC):
"""Marks a tool as requiring authorization."""
pass
class OAuth2(ToolAuthorization):
"""Marks a tool as requiring OAuth 2.0 authorization."""
authority: AnyUrl
"""The URL of the OAuth 2.0 authorization server."""
scope: Optional[list[str]] = None
"""The scope(s) needed for the authorized action."""

View file

@ -1,9 +1,8 @@
import inspect
import os
from typing import Any, Callable, Optional, TypeVar, Union
from typing import Callable, TypeVar, Union
from arcade.core.tool import ToolAuthorizationRequirement
from arcade.core.utils import snake_to_pascal_case
from arcade.sdk.auth import ToolAuthorization
T = TypeVar("T")
@ -13,7 +12,7 @@ def tool(
func: Callable | None = None,
desc: str | None = None,
name: str | None = None,
requires_auth: Union[ToolAuthorizationRequirement, None] = None,
requires_auth: Union[ToolAuthorization, None] = None,
) -> Callable:
def decorator(func: Callable) -> Callable:
func_name = str(getattr(func, "__name__", None))
@ -28,12 +27,3 @@ def tool(
if func: # This means the decorator is used without parameters
return decorator(func)
return decorator
def get_secret(name: str, default: Optional[Any] = None) -> Any:
secret = os.getenv(name)
if secret is None:
if default is not None:
return default
raise ValueError(f"Secret {name} is not set.")
return secret

View file

@ -2,8 +2,8 @@ import asyncio
import pytest
from arcade.core.tool import OAuth2AuthorizationRequirement
from arcade.sdk.tool import tool
from arcade.sdk import tool
from arcade.sdk.auth import OAuth2
def test_sync_function():
@ -38,8 +38,8 @@ def test_tool_decorator_with_all_options():
@tool(
name="TestTool",
desc="Test description",
requires_auth=OAuth2AuthorizationRequirement(
url="https://example.com/oauth2/auth",
requires_auth=OAuth2(
authority="https://example.com/oauth2/auth",
scope=["test_scope", "another.scope"],
),
)
@ -48,5 +48,5 @@ def test_tool_decorator_with_all_options():
assert test_tool.__tool_name__ == "TestTool"
assert test_tool.__tool_description__ == "Test description"
assert str(test_tool.__tool_requires_auth__.url) == "https://example.com/oauth2/auth"
assert str(test_tool.__tool_requires_auth__.authority) == "https://example.com/oauth2/auth"
assert test_tool.__tool_requires_auth__.scope == ["test_scope", "another.scope"]

View file

@ -3,16 +3,19 @@ from typing import Annotated, Literal, Optional
import pytest
from arcade.core.catalog import ToolCatalog
from arcade.core.tool import (
from arcade.core.schema import (
InputParameter,
OAuth2AuthorizationRequirement,
OAuth2Requirement,
ToolAuthRequirement,
ToolContext,
ToolInputs,
ToolOutput,
ToolRequirements,
ValueSchema,
)
from arcade.sdk import tool
from arcade.sdk.annotations import Inferrable
from arcade.sdk.tool import tool
from arcade.sdk.auth import OAuth2
### Tests on @tool decorator
@ -43,9 +46,7 @@ def func_with_name_and_description():
@tool(
desc="A function that requires authentication",
requires_auth=OAuth2AuthorizationRequirement(
url="https://example.com/oauth2/auth", scope=["scope1", "scope2"]
),
requires_auth=OAuth2(authority="https://example.com/oauth2/auth", scope=["scope1", "scope2"]),
)
def func_with_auth_requirement():
pass
@ -53,7 +54,7 @@ def func_with_auth_requirement():
### Tests on input params
@tool(desc="A function with an input parameter")
def func_with_param(param1: Annotated[str, "First param"]):
def func_with_param(context: Annotated[str, "First param"]):
pass
@ -114,6 +115,7 @@ def func_with_optional_param_with_default_value(
@tool(desc="A function with multiple parameters, some with default values")
def func_with_mixed_params(
context: ToolContext,
param1: Annotated[str, "First param"],
param2: Annotated[int, "Second param"] = 42,
):
@ -125,6 +127,11 @@ def func_with_complex_param(param1: Annotated[list[str], "A list of strings"]):
pass
@tool(desc="A function that takes a context")
def func_with_context(my_context: ToolContext):
pass
### Tests on output/return values
@tool(desc="A function that performs an action without returning anything")
def func_with_no_return():
@ -184,15 +191,18 @@ def func_with_complex_return() -> list[dict[str, str]]:
),
pytest.param(
func_with_name_and_description,
{"name": "MyCustomTool", "requirements": ToolRequirements(authorization=None)},
{"name": "MyCustomTool", "requirements": ToolRequirements(auth=None)},
id="func_with_no_auth_requirement",
),
pytest.param(
func_with_auth_requirement,
{
"requirements": ToolRequirements(
authorization=OAuth2AuthorizationRequirement(
url="https://example.com/oauth2/auth", scope=["scope1", "scope2"]
auth=ToolAuthRequirement(
oauth2=OAuth2Requirement(
authority="https://example.com/oauth2/auth",
scope=["scope1", "scope2"],
)
)
)
},
@ -212,7 +222,7 @@ def func_with_complex_return() -> list[dict[str, str]]:
"inputs": ToolInputs(
parameters=[
InputParameter(
name="param1",
name="context", # Nothing special about this name, parameters can be named anything
description="First param",
inferrable=True, # Defaults to true
required=True,
@ -428,7 +438,8 @@ def func_with_complex_return() -> list[dict[str, str]]:
required=False, # Because a default value is provided
value_schema=ValueSchema(val_type="integer", enum=None),
),
]
],
tool_context_parameter_name="context",
),
},
id="func_with_mixed_params",
@ -450,6 +461,15 @@ def func_with_complex_return() -> list[dict[str, str]]:
},
id="func_with_complex_param",
),
pytest.param(
func_with_context,
{
"inputs": ToolInputs(
parameters=[], tool_context_parameter_name="my_context"
), # ToolContext type is not an input param, but it's stored in the inputs field
},
id="func_with_context",
),
# Tests on output values
pytest.param(
func_with_no_return,

View file

@ -2,7 +2,8 @@ import pytest
from arcade.core.catalog import ToolCatalog
from arcade.core.errors import ToolDefinitionError
from arcade.sdk.tool import tool
from arcade.core.schema import ToolContext
from arcade.sdk import tool
@tool
@ -30,6 +31,11 @@ def func_with_unsupported_param(param1: complex):
pass
@tool(desc="A function with multiple context parameters (illegal)")
def func_with_multiple_context_params(context: ToolContext, context2: ToolContext):
pass
@pytest.mark.parametrize(
"func_under_test, exception_type",
[
@ -58,6 +64,11 @@ def func_with_unsupported_param(param1: complex):
ToolDefinitionError,
id=func_with_unsupported_param.__name__,
),
pytest.param(
func_with_multiple_context_params,
ToolDefinitionError,
id=func_with_multiple_context_params.__name__,
),
],
)
def test_missing_info_raises_error(func_under_test, exception_type):

View file

@ -4,13 +4,13 @@ import pytest
from pydantic import BaseModel, Field
from arcade.core.catalog import ToolCatalog
from arcade.core.tool import (
from arcade.core.schema import (
InputParameter,
ToolInputs,
ToolOutput,
ValueSchema,
)
from arcade.sdk.tool import tool
from arcade.sdk import tool
class ProductOutput(BaseModel):

View file

@ -5,7 +5,7 @@ from pydantic import Field
from arcade.core.catalog import ToolCatalog
from arcade.core.errors import ToolDefinitionError
from arcade.sdk.tool import tool
from arcade.sdk import tool
@tool

View file

@ -0,0 +1,65 @@
import os
from google.auth.transport.requests import Request
from google.oauth2.credentials import Credentials
from google_auth_oauthlib.flow import InstalledAppFlow
from google.auth.exceptions import RefreshError
from googleapiclient.discovery import build
from typing import Annotated
from arcade.sdk import tool
SECRET_FILE = "/Users/spartee/Dropbox/Arcade/gcp/credentials.json"
DRIVE_SCOPES = ["https://www.googleapis.com/auth/drive.metadata.readonly"]
@tool
async def list_drive_files(
n_files: Annotated[int, "Number of files to search"] = 5,
) -> list[str]:
"""List files from a Google Drive account and return their details."""
creds = None
# The file token.json stores the user's access and refresh tokens, and is
# created automatically when the authorization flow completes for the first time.
# TODO: use context.authorization.token like gmail.py
if os.path.exists("token.json"):
creds = Credentials.from_authorized_user_file("token.json")
# If there are no (valid) credentials available, let the user log in.
if not creds or not creds.valid:
if creds and creds.expired and creds.refresh_token:
try:
creds.refresh(Request())
except RefreshError:
flow = InstalledAppFlow.from_client_secrets_file(
SECRET_FILE, DRIVE_SCOPES
)
creds = flow.run_local_server(port=0)
# Save the credentials for the next run
with open("token.json", "w") as token:
token.write(creds.to_json())
else:
flow = InstalledAppFlow.from_client_secrets_file(SECRET_FILE, DRIVE_SCOPES)
creds = flow.run_local_server(port=0)
# Save the credentials for the next run
with open("token.json", "w") as token:
token.write(creds.to_json())
# Call the Drive v3 API
service = build("drive", "v3", credentials=creds)
# Request a list of all the files
results = (
service.files()
.list(pageSize=n_files, fields="nextPageToken, files(id, name)")
.execute()
)
items = results.get("files", [])
if not items:
print("No files found.")
else:
print("Files:")
for item in items:
print("{0} ({1})".format(item["name"], item["id"]))
return items

View file

@ -1,52 +1,29 @@
import os
import re
from base64 import urlsafe_b64decode
from arcade.core.schema import ToolContext
from arcade.sdk.auth import OAuth2
from bs4 import BeautifulSoup
from google.auth.transport.requests import Request
from google.oauth2.credentials import Credentials
from google_auth_oauthlib.flow import InstalledAppFlow
from google.auth.exceptions import RefreshError
from googleapiclient.discovery import build
from typing import Dict, List, Annotated
from arcade.sdk.tool import tool
from typing import Annotated
from arcade.sdk import tool
SCOPES = ["https://www.googleapis.com/auth/gmail.readonly"]
SECRET_FILE = "/Users/spartee/Dropbox/Arcade/gcp/credentials.json"
@tool
async def oauth_read_email(
@tool(
requires_auth=OAuth2(
authority="https://accounts.google.com",
scope=["https://www.googleapis.com/auth/gmail.readonly"],
)
)
async def get_emails(
context: ToolContext,
n_emails: Annotated[int, "Number of emails to read"] = 5,
) -> List[Dict[str, str]]:
) -> dict:
"""Read emails from a Gmail account and extract plain text content, removing any HTML."""
creds = None
# The file token.json stores the user's access and refresh tokens, and is
# created automatically when the authorization flow completes for the first time.
if os.path.exists("token.json"):
creds = Credentials.from_authorized_user_file("token.json")
# If there are no (valid) credentials available, let the user log in.
if not creds or not creds.valid:
if creds and creds.expired and creds.refresh_token:
try:
creds.refresh(Request())
except RefreshError:
flow = InstalledAppFlow.from_client_secrets_file(SECRET_FILE, SCOPES)
creds = flow.run_local_server(port=0)
# Save the credentials for the next run
with open("token.json", "w") as token:
token.write(creds.to_json())
else:
flow = InstalledAppFlow.from_client_secrets_file(SECRET_FILE, SCOPES)
creds = flow.run_local_server(port=0)
# Save the credentials for the next run
with open("token.json", "w") as token:
token.write(creds.to_json())
# Call the Gmail API
service = build("gmail", "v1", credentials=creds)
service = build("gmail", "v1", credentials=Credentials(context.authorization.token))
# Request a list of all the messages
result = service.users().messages().list(userId="me").execute()
@ -97,7 +74,7 @@ async def oauth_read_email(
print(f"Error reading email {msg['id']}: {e}", "ERROR")
continue
return emails
return {"emails": emails}
def clean_email_body(body: str) -> str:
@ -114,58 +91,3 @@ def clean_email_body(body: str) -> str:
text = " ".join(text.split()) # Remove extra whitespace
return text
DRIVE_SCOPES = ["https://www.googleapis.com/auth/drive.metadata.readonly"]
@tool
async def list_drive_files(
n_files: Annotated[int, "Number of files to search"] = 5,
) -> list[str]:
"""List files from a Google Drive account and return their details."""
creds = None
# The file token.json stores the user's access and refresh tokens, and is
# created automatically when the authorization flow completes for the first time.
if os.path.exists("token.json"):
creds = Credentials.from_authorized_user_file("token.json")
# If there are no (valid) credentials available, let the user log in.
if not creds or not creds.valid:
if creds and creds.expired and creds.refresh_token:
try:
creds.refresh(Request())
except RefreshError:
flow = InstalledAppFlow.from_client_secrets_file(
SECRET_FILE, DRIVE_SCOPES
)
creds = flow.run_local_server(port=0)
# Save the credentials for the next run
with open("token.json", "w") as token:
token.write(creds.to_json())
else:
flow = InstalledAppFlow.from_client_secrets_file(SECRET_FILE, DRIVE_SCOPES)
creds = flow.run_local_server(port=0)
# Save the credentials for the next run
with open("token.json", "w") as token:
token.write(creds.to_json())
# Call the Drive v3 API
service = build("drive", "v3", credentials=creds)
# Request a list of all the files
results = (
service.files()
.list(pageSize=n_files, fields="nextPageToken, files(id, name)")
.execute()
)
items = results.get("files", [])
if not items:
print("No files found.")
else:
print("Files:")
for item in items:
print("{0} ({1})".format(item["name"], item["id"]))
return items

View file

@ -1,7 +1,7 @@
import math
from typing import Annotated
from arcade.sdk.tool import tool
from arcade.sdk import tool
@tool

View file

@ -1,7 +1,8 @@
import json
import os
import serpapi
from typing import Annotated
from arcade.sdk.tool import tool, get_secret
from typing import Annotated, Any, Optional
from arcade.sdk import tool
@tool
@ -23,3 +24,12 @@ async def search_google(
organic_results = results.get("organic_results", [])
return json.dumps(organic_results[:n_results])
def get_secret(name: str, default: Optional[Any] = None) -> Any:
secret = os.getenv(name)
if secret is None:
if default is not None:
return default
raise ValueError(f"Secret {name} is not set.")
return secret

View file

@ -68,6 +68,6 @@
}
}
},
"required": ["run_id", "invocation_id", "created_at", "tool", "input", "context"],
"required": ["run_id", "invocation_id", "created_at", "tool", "inputs", "context"],
"additionalProperties": false
}

View file

@ -129,7 +129,10 @@
"format": "uri"
},
"scope": {
"type": "string"
"type": "array",
"items": {
"type": "string"
}
}
},
"required": ["url"],