feat: Tool secrets (#252)

SDK support for tool secrets (stored and managed by the engine):
- [x] New `requires_secrets=` option in the `@tool` decorator
- [x] Internal plumbing in the catalog and `ToolContext`
- [x] Full test coverage of all added code
- [x] Bumped minor version (new feature)

This PR can be merged without waiting for Engine changes, because it is
additive only (no breaking changes).

After this is merged, I will open another PR to update existing toolkits
that will benefit from this feature!
This commit is contained in:
Nate Barbettini 2025-02-27 15:56:11 -08:00 committed by GitHub
parent 7466543bde
commit 3f7226709f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 206 additions and 8 deletions

View file

@ -41,6 +41,7 @@ from arcade.core.schema import (
ToolkitDefinition,
ToolOutput,
ToolRequirements,
ToolSecretRequirement,
ValueSchema,
)
from arcade.core.toolkit import Toolkit
@ -374,6 +375,21 @@ class ToolCatalog(BaseModel):
new_auth_requirement.oauth2 = OAuth2Requirement(**auth_requirement.model_dump())
auth_requirement = new_auth_requirement
secrets_requirement = getattr(tool, "__tool_requires_secrets__", None)
if isinstance(secrets_requirement, list):
if any(not isinstance(secret, str) for secret in secrets_requirement):
raise ToolDefinitionError(
f"Secret keys must be strings (error in tool {raw_tool_name})."
)
secrets_requirement = to_tool_secret_requirements(secrets_requirement)
if any(
secret.key is None or secret.key.strip() == "" for secret in secrets_requirement
):
raise ToolDefinitionError(
f"Secrets must have a non-empty key (error in tool {raw_tool_name})."
)
toolkit_definition = ToolkitDefinition(
name=snake_to_pascal_case(toolkit_name),
description=toolkit_desc,
@ -393,6 +409,7 @@ class ToolCatalog(BaseModel):
output=create_output_definition(tool),
requirements=ToolRequirements(
authorization=auth_requirement,
secrets=secrets_requirement,
),
deprecation_message=deprecation_message,
)
@ -798,3 +815,11 @@ def determine_output_model(func: Callable) -> type[BaseModel]:
output_model_name,
result=(return_annotation, Field(description="No description provided.")),
)
def to_tool_secret_requirements(
secrets_requirement: list[str],
) -> list[ToolSecretRequirement]:
# Iterate through the list, de-dupe case-insensitively, and convert each string to a ToolSecretRequirement
unique_secrets = {name.lower(): name.lower() for name in secrets_requirement}.values()
return [ToolSecretRequirement(key=name) for name in unique_secrets]

View file

@ -2,7 +2,7 @@ import os
from dataclasses import dataclass
from typing import Any, Literal, Optional, Union
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
# allow for custom tool name separator
TOOL_NAME_SEPARATOR = os.getenv("ARCADE_TOOL_NAME_SEPARATOR", ".")
@ -102,12 +102,21 @@ class ToolAuthRequirement(BaseModel):
"""The OAuth 2.0 requirement, if any."""
class ToolSecretRequirement(BaseModel):
"""A requirement for a tool to run."""
key: str
"""The ID of the secret."""
class ToolRequirements(BaseModel):
"""The requirements for a tool to run."""
authorization: Union[ToolAuthRequirement, None] = None
"""The authorization requirements for the tool, if any."""
secrets: Union[list[ToolSecretRequirement], None] = None
class ToolkitDefinition(BaseModel):
"""The specification of a toolkit."""
@ -231,19 +240,52 @@ class ToolAuthorizationContext(BaseModel):
"""
class ToolSecretItem(BaseModel):
"""The context for a tool secret."""
key: str
"""The key of the secret."""
value: str
"""The value of the secret."""
class ToolContext(BaseModel):
"""The context for a tool invocation."""
authorization: ToolAuthorizationContext | None = None
"""The authorization context for the tool invocation that requires authorization."""
secrets: list[ToolSecretItem] | None = None
"""The secrets for the tool invocation."""
user_id: str | None = None
"""The user ID for the tool invocation (if any)."""
@field_validator("secrets", mode="before")
def lower_keys(cls, v: dict[str, ToolSecretItem] | None) -> dict[str, ToolSecretItem] | None:
if isinstance(v, dict):
return {k.lower(): value for k, value in v.items()}
return v
def get_auth_token_or_empty(self) -> str:
"""Retrieve the authorization token, or return an empty string if not available."""
return self.authorization.token if self.authorization and self.authorization.token else ""
def get_secret(self, key: str) -> str:
"""Retrieve the secret for the tool invocation."""
if not key or not key.strip():
raise ValueError("Secret key ID passed to get_secret cannot be empty.")
if not self.secrets:
raise ValueError("Secrets not found in context.")
normalized_key = key.lower()
for secret in self.secrets:
if secret.key == normalized_key:
return secret.value
raise ValueError(f"Secret {key} not found in context.")
class ToolCallRequest(BaseModel):
"""The request to call (invoke) a tool."""

View file

@ -14,6 +14,7 @@ def tool(
desc: str | None = None,
name: str | None = None,
requires_auth: Union[ToolAuthorization, None] = None,
requires_secrets: Union[list[str], None] = None,
) -> Callable:
def decorator(func: Callable) -> Callable:
func_name = str(getattr(func, "__name__", None))
@ -22,6 +23,7 @@ def tool(
func.__tool_name__ = tool_name # type: ignore[attr-defined]
func.__tool_description__ = desc or inspect.cleandoc(func.__doc__ or "") # type: ignore[attr-defined]
func.__tool_requires_auth__ = requires_auth # type: ignore[attr-defined]
func.__tool_requires_secrets__ = requires_secrets # type: ignore[attr-defined]
if inspect.iscoroutinefunction(func):

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "arcade-ai"
version = "1.1.0"
version = "1.2.0"
description = "Arcade Python SDK and CLI"
readme = "README.md"
packages = [

View file

@ -1,4 +1,6 @@
from arcade.core.schema import ToolAuthorizationContext, ToolContext
import pytest
from arcade.core.schema import ToolAuthorizationContext, ToolContext, ToolSecretItem
def test_get_auth_token_or_empty_with_token():
@ -22,3 +24,49 @@ def test_get_auth_token_or_empty_no_authorization():
tool_context = ToolContext(authorization=None)
assert tool_context.get_auth_token_or_empty() == ""
def test_get_secret_valid():
key = "my_key"
val = "secret_value"
secrets = [ToolSecretItem(key=key, value=val)]
tool_context = ToolContext(secrets=secrets)
# When the secret exists, get_secret should return its value.
actual_secret = tool_context.get_secret(key)
assert actual_secret == val
def test_get_secret_with_case_insensitive_key():
key = "my_key"
val = "secret_value"
secrets = [ToolSecretItem(key=key, value=val)]
tool_context = ToolContext(secrets=secrets)
assert tool_context.get_secret(key.upper()) == val
assert tool_context.get_secret(key.lower()) == val
def test_get_secret_key_not_found():
key = "nonexistent_key"
secrets = [ToolSecretItem(key="other_key", value="another_secret")]
tool_context = ToolContext(secrets=secrets)
# When the key is not found, get_secret should raise a ValueError.
with pytest.raises(ValueError, match=f"Secret {key} not found in context."):
tool_context.get_secret(key)
def test_get_secret_when_secrets_is_none():
tool_context = ToolContext(secrets=None)
# When no secrets dictionary is provided, get_secret should raise a ValueError.
with pytest.raises(ValueError, match="Secrets not found in context."):
tool_context.get_secret("missing_key")
def test_get_secret_with_empty_key():
tool_context = ToolContext(secrets=[])
with pytest.raises(ValueError, match="Secret key ID passed to get_secret cannot be empty."):
tool_context.get_secret("")

View file

@ -12,6 +12,7 @@ from arcade.core.schema import (
ToolInput,
ToolOutput,
ToolRequirements,
ToolSecretRequirement,
ValueSchema,
)
from arcade.core.utils import snake_to_pascal_case
@ -46,6 +47,22 @@ def func_with_name_and_description():
pass
@tool(
desc="A function that requires a secret",
requires_secrets=["my_secret_id"],
)
def func_with_secret_requirement():
pass
@tool(
desc="A function that requires multiple secrets, deduped case-insensitively",
requires_secrets=["my_secret_id", "my_secret_id2", "MY_SECRET_ID"],
)
def func_with_multiple_secret_requirement():
pass
@tool(
desc="A function that requires authentication",
requires_auth=OAuth2(
@ -261,6 +278,23 @@ def func_with_complex_return() -> dict[str, str]:
{"requirements": ToolRequirements(auth=None)},
id="func_with_no_auth_requirement",
),
pytest.param(
func_with_secret_requirement,
{"requirements": ToolRequirements(secrets=[ToolSecretRequirement(key="my_secret_id")])},
id="func_with_secret_requirement",
),
pytest.param(
func_with_multiple_secret_requirement,
{
"requirements": ToolRequirements(
secrets=[
ToolSecretRequirement(key="my_secret_id"),
ToolSecretRequirement(key="my_secret_id2"),
]
)
},
id="func_with_multiple_secret_requirement",
),
pytest.param(
func_with_auth_requirement,
{
@ -269,7 +303,6 @@ def func_with_complex_return() -> dict[str, str]:
provider_type="oauth2",
id="my_example_provider123",
oauth2=OAuth2Requirement(
authority="https://example.com/oauth2/auth",
scopes=["scope1", "scope2"],
),
)

View file

@ -41,6 +41,22 @@ def func_with_multiple_context_params(context: ToolContext, context2: ToolContex
pass
@tool(
desc="A function with a required secret with a missing key (illegal)",
requires_secrets=[""],
)
def func_with_missing_secret_key(context: ToolContext):
pass
@tool(
desc="A function that requires a secret (invalid type)",
requires_secrets=[True],
)
def func_with_secret_requirement_invalid_type():
pass
@pytest.mark.parametrize(
"func_under_test, exception_type",
[
@ -79,6 +95,16 @@ def func_with_multiple_context_params(context: ToolContext, context2: ToolContex
ToolDefinitionError,
id=func_with_multiple_context_params.__name__,
),
pytest.param(
func_with_missing_secret_key,
ToolDefinitionError,
id=func_with_missing_secret_key.__name__,
),
pytest.param(
func_with_secret_requirement_invalid_type,
ToolDefinitionError,
id=func_with_secret_requirement_invalid_type.__name__,
),
],
)
def test_missing_info_raises_error(func_under_test, exception_type):

View file

@ -35,7 +35,7 @@
"description": "Version of the toolkit containing the tool to call"
}
},
"required": ["name", "toolkit_name", "version"],
"required": ["name", "toolkit", "version"],
"additionalProperties": false
},
"inputs": {
@ -54,6 +54,10 @@
"additionalProperties": false
}
},
"secrets": {
"type": "object",
"additionalProperties": true
},
"user_id": {
"type": "string",
"description": "A unique ID that identifies the user (if any)"
@ -65,6 +69,6 @@
}
}
},
"required": ["run_id", "invocation_id", "created_at", "tool", "inputs", "context"],
"required": ["run_id", "execution_id", "created_at", "tool", "inputs", "context"],
"additionalProperties": false
}

View file

@ -14,8 +14,8 @@
"oneOf": [{ "$ref": "#/$defs/primitives" }, { "type": "string", "enum": ["array"] }]
},
"inner_val_type": {
"$ref": "#/$defs/primitives",
"description": "If the value type is a list, the type of the list values."
"description": "If the value type is a list, the type of the list values.",
"oneOf": [{ "$ref": "#/$defs/primitives" }, { "type": "null" }]
},
"enum": {
"oneOf": [
@ -141,6 +141,24 @@
"requirements": {
"type": "object",
"properties": {
"secrets": {
"oneOf": [
{ "type": "null" }, // Can be unset,
{
"type": "array",
"items": {
"type": "object",
"properties": {
"key": {
"type": "string"
}
},
"required": ["key"],
"additionalProperties": false
}
}
]
},
"authorization": {
"oneOf": [
{ "type": "null" }, // Can be unset