feat: Tool secrets (#252)
SDK support for tool secrets (stored and managed by the engine): - [x] New `requires_secrets=` option in the `@tool` decorator - [x] Internal plumbing in the catalog and `ToolContext` - [x] Full test coverage of all added code - [x] Bumped minor version (new feature) This PR can be merged without waiting for Engine changes, because it is additive only (no breaking changes). After this is merged, I will open another PR to update existing toolkits that will benefit from this feature!
This commit is contained in:
parent
7466543bde
commit
3f7226709f
9 changed files with 206 additions and 8 deletions
|
|
@ -41,6 +41,7 @@ from arcade.core.schema import (
|
|||
ToolkitDefinition,
|
||||
ToolOutput,
|
||||
ToolRequirements,
|
||||
ToolSecretRequirement,
|
||||
ValueSchema,
|
||||
)
|
||||
from arcade.core.toolkit import Toolkit
|
||||
|
|
@ -374,6 +375,21 @@ class ToolCatalog(BaseModel):
|
|||
new_auth_requirement.oauth2 = OAuth2Requirement(**auth_requirement.model_dump())
|
||||
auth_requirement = new_auth_requirement
|
||||
|
||||
secrets_requirement = getattr(tool, "__tool_requires_secrets__", None)
|
||||
if isinstance(secrets_requirement, list):
|
||||
if any(not isinstance(secret, str) for secret in secrets_requirement):
|
||||
raise ToolDefinitionError(
|
||||
f"Secret keys must be strings (error in tool {raw_tool_name})."
|
||||
)
|
||||
|
||||
secrets_requirement = to_tool_secret_requirements(secrets_requirement)
|
||||
if any(
|
||||
secret.key is None or secret.key.strip() == "" for secret in secrets_requirement
|
||||
):
|
||||
raise ToolDefinitionError(
|
||||
f"Secrets must have a non-empty key (error in tool {raw_tool_name})."
|
||||
)
|
||||
|
||||
toolkit_definition = ToolkitDefinition(
|
||||
name=snake_to_pascal_case(toolkit_name),
|
||||
description=toolkit_desc,
|
||||
|
|
@ -393,6 +409,7 @@ class ToolCatalog(BaseModel):
|
|||
output=create_output_definition(tool),
|
||||
requirements=ToolRequirements(
|
||||
authorization=auth_requirement,
|
||||
secrets=secrets_requirement,
|
||||
),
|
||||
deprecation_message=deprecation_message,
|
||||
)
|
||||
|
|
@ -798,3 +815,11 @@ def determine_output_model(func: Callable) -> type[BaseModel]:
|
|||
output_model_name,
|
||||
result=(return_annotation, Field(description="No description provided.")),
|
||||
)
|
||||
|
||||
|
||||
def to_tool_secret_requirements(
|
||||
secrets_requirement: list[str],
|
||||
) -> list[ToolSecretRequirement]:
|
||||
# Iterate through the list, de-dupe case-insensitively, and convert each string to a ToolSecretRequirement
|
||||
unique_secrets = {name.lower(): name.lower() for name in secrets_requirement}.values()
|
||||
return [ToolSecretRequirement(key=name) for name in unique_secrets]
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import os
|
|||
from dataclasses import dataclass
|
||||
from typing import Any, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
# allow for custom tool name separator
|
||||
TOOL_NAME_SEPARATOR = os.getenv("ARCADE_TOOL_NAME_SEPARATOR", ".")
|
||||
|
|
@ -102,12 +102,21 @@ class ToolAuthRequirement(BaseModel):
|
|||
"""The OAuth 2.0 requirement, if any."""
|
||||
|
||||
|
||||
class ToolSecretRequirement(BaseModel):
|
||||
"""A requirement for a tool to run."""
|
||||
|
||||
key: str
|
||||
"""The ID of the secret."""
|
||||
|
||||
|
||||
class ToolRequirements(BaseModel):
|
||||
"""The requirements for a tool to run."""
|
||||
|
||||
authorization: Union[ToolAuthRequirement, None] = None
|
||||
"""The authorization requirements for the tool, if any."""
|
||||
|
||||
secrets: Union[list[ToolSecretRequirement], None] = None
|
||||
|
||||
|
||||
class ToolkitDefinition(BaseModel):
|
||||
"""The specification of a toolkit."""
|
||||
|
|
@ -231,19 +240,52 @@ class ToolAuthorizationContext(BaseModel):
|
|||
"""
|
||||
|
||||
|
||||
class ToolSecretItem(BaseModel):
|
||||
"""The context for a tool secret."""
|
||||
|
||||
key: str
|
||||
"""The key of the secret."""
|
||||
|
||||
value: str
|
||||
"""The value of the secret."""
|
||||
|
||||
|
||||
class ToolContext(BaseModel):
|
||||
"""The context for a tool invocation."""
|
||||
|
||||
authorization: ToolAuthorizationContext | None = None
|
||||
"""The authorization context for the tool invocation that requires authorization."""
|
||||
|
||||
secrets: list[ToolSecretItem] | None = None
|
||||
"""The secrets for the tool invocation."""
|
||||
|
||||
user_id: str | None = None
|
||||
"""The user ID for the tool invocation (if any)."""
|
||||
|
||||
@field_validator("secrets", mode="before")
|
||||
def lower_keys(cls, v: dict[str, ToolSecretItem] | None) -> dict[str, ToolSecretItem] | None:
|
||||
if isinstance(v, dict):
|
||||
return {k.lower(): value for k, value in v.items()}
|
||||
return v
|
||||
|
||||
def get_auth_token_or_empty(self) -> str:
|
||||
"""Retrieve the authorization token, or return an empty string if not available."""
|
||||
return self.authorization.token if self.authorization and self.authorization.token else ""
|
||||
|
||||
def get_secret(self, key: str) -> str:
|
||||
"""Retrieve the secret for the tool invocation."""
|
||||
if not key or not key.strip():
|
||||
raise ValueError("Secret key ID passed to get_secret cannot be empty.")
|
||||
|
||||
if not self.secrets:
|
||||
raise ValueError("Secrets not found in context.")
|
||||
|
||||
normalized_key = key.lower()
|
||||
for secret in self.secrets:
|
||||
if secret.key == normalized_key:
|
||||
return secret.value
|
||||
raise ValueError(f"Secret {key} not found in context.")
|
||||
|
||||
|
||||
class ToolCallRequest(BaseModel):
|
||||
"""The request to call (invoke) a tool."""
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ def tool(
|
|||
desc: str | None = None,
|
||||
name: str | None = None,
|
||||
requires_auth: Union[ToolAuthorization, None] = None,
|
||||
requires_secrets: Union[list[str], None] = None,
|
||||
) -> Callable:
|
||||
def decorator(func: Callable) -> Callable:
|
||||
func_name = str(getattr(func, "__name__", None))
|
||||
|
|
@ -22,6 +23,7 @@ def tool(
|
|||
func.__tool_name__ = tool_name # type: ignore[attr-defined]
|
||||
func.__tool_description__ = desc or inspect.cleandoc(func.__doc__ or "") # type: ignore[attr-defined]
|
||||
func.__tool_requires_auth__ = requires_auth # type: ignore[attr-defined]
|
||||
func.__tool_requires_secrets__ = requires_secrets # type: ignore[attr-defined]
|
||||
|
||||
if inspect.iscoroutinefunction(func):
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "arcade-ai"
|
||||
version = "1.1.0"
|
||||
version = "1.2.0"
|
||||
description = "Arcade Python SDK and CLI"
|
||||
readme = "README.md"
|
||||
packages = [
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
from arcade.core.schema import ToolAuthorizationContext, ToolContext
|
||||
import pytest
|
||||
|
||||
from arcade.core.schema import ToolAuthorizationContext, ToolContext, ToolSecretItem
|
||||
|
||||
|
||||
def test_get_auth_token_or_empty_with_token():
|
||||
|
|
@ -22,3 +24,49 @@ def test_get_auth_token_or_empty_no_authorization():
|
|||
tool_context = ToolContext(authorization=None)
|
||||
|
||||
assert tool_context.get_auth_token_or_empty() == ""
|
||||
|
||||
|
||||
def test_get_secret_valid():
|
||||
key = "my_key"
|
||||
val = "secret_value"
|
||||
secrets = [ToolSecretItem(key=key, value=val)]
|
||||
tool_context = ToolContext(secrets=secrets)
|
||||
|
||||
# When the secret exists, get_secret should return its value.
|
||||
actual_secret = tool_context.get_secret(key)
|
||||
assert actual_secret == val
|
||||
|
||||
|
||||
def test_get_secret_with_case_insensitive_key():
|
||||
key = "my_key"
|
||||
val = "secret_value"
|
||||
secrets = [ToolSecretItem(key=key, value=val)]
|
||||
tool_context = ToolContext(secrets=secrets)
|
||||
|
||||
assert tool_context.get_secret(key.upper()) == val
|
||||
assert tool_context.get_secret(key.lower()) == val
|
||||
|
||||
|
||||
def test_get_secret_key_not_found():
|
||||
key = "nonexistent_key"
|
||||
secrets = [ToolSecretItem(key="other_key", value="another_secret")]
|
||||
tool_context = ToolContext(secrets=secrets)
|
||||
|
||||
# When the key is not found, get_secret should raise a ValueError.
|
||||
with pytest.raises(ValueError, match=f"Secret {key} not found in context."):
|
||||
tool_context.get_secret(key)
|
||||
|
||||
|
||||
def test_get_secret_when_secrets_is_none():
|
||||
tool_context = ToolContext(secrets=None)
|
||||
|
||||
# When no secrets dictionary is provided, get_secret should raise a ValueError.
|
||||
with pytest.raises(ValueError, match="Secrets not found in context."):
|
||||
tool_context.get_secret("missing_key")
|
||||
|
||||
|
||||
def test_get_secret_with_empty_key():
|
||||
tool_context = ToolContext(secrets=[])
|
||||
|
||||
with pytest.raises(ValueError, match="Secret key ID passed to get_secret cannot be empty."):
|
||||
tool_context.get_secret("")
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from arcade.core.schema import (
|
|||
ToolInput,
|
||||
ToolOutput,
|
||||
ToolRequirements,
|
||||
ToolSecretRequirement,
|
||||
ValueSchema,
|
||||
)
|
||||
from arcade.core.utils import snake_to_pascal_case
|
||||
|
|
@ -46,6 +47,22 @@ def func_with_name_and_description():
|
|||
pass
|
||||
|
||||
|
||||
@tool(
|
||||
desc="A function that requires a secret",
|
||||
requires_secrets=["my_secret_id"],
|
||||
)
|
||||
def func_with_secret_requirement():
|
||||
pass
|
||||
|
||||
|
||||
@tool(
|
||||
desc="A function that requires multiple secrets, deduped case-insensitively",
|
||||
requires_secrets=["my_secret_id", "my_secret_id2", "MY_SECRET_ID"],
|
||||
)
|
||||
def func_with_multiple_secret_requirement():
|
||||
pass
|
||||
|
||||
|
||||
@tool(
|
||||
desc="A function that requires authentication",
|
||||
requires_auth=OAuth2(
|
||||
|
|
@ -261,6 +278,23 @@ def func_with_complex_return() -> dict[str, str]:
|
|||
{"requirements": ToolRequirements(auth=None)},
|
||||
id="func_with_no_auth_requirement",
|
||||
),
|
||||
pytest.param(
|
||||
func_with_secret_requirement,
|
||||
{"requirements": ToolRequirements(secrets=[ToolSecretRequirement(key="my_secret_id")])},
|
||||
id="func_with_secret_requirement",
|
||||
),
|
||||
pytest.param(
|
||||
func_with_multiple_secret_requirement,
|
||||
{
|
||||
"requirements": ToolRequirements(
|
||||
secrets=[
|
||||
ToolSecretRequirement(key="my_secret_id"),
|
||||
ToolSecretRequirement(key="my_secret_id2"),
|
||||
]
|
||||
)
|
||||
},
|
||||
id="func_with_multiple_secret_requirement",
|
||||
),
|
||||
pytest.param(
|
||||
func_with_auth_requirement,
|
||||
{
|
||||
|
|
@ -269,7 +303,6 @@ def func_with_complex_return() -> dict[str, str]:
|
|||
provider_type="oauth2",
|
||||
id="my_example_provider123",
|
||||
oauth2=OAuth2Requirement(
|
||||
authority="https://example.com/oauth2/auth",
|
||||
scopes=["scope1", "scope2"],
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -41,6 +41,22 @@ def func_with_multiple_context_params(context: ToolContext, context2: ToolContex
|
|||
pass
|
||||
|
||||
|
||||
@tool(
|
||||
desc="A function with a required secret with a missing key (illegal)",
|
||||
requires_secrets=[""],
|
||||
)
|
||||
def func_with_missing_secret_key(context: ToolContext):
|
||||
pass
|
||||
|
||||
|
||||
@tool(
|
||||
desc="A function that requires a secret (invalid type)",
|
||||
requires_secrets=[True],
|
||||
)
|
||||
def func_with_secret_requirement_invalid_type():
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"func_under_test, exception_type",
|
||||
[
|
||||
|
|
@ -79,6 +95,16 @@ def func_with_multiple_context_params(context: ToolContext, context2: ToolContex
|
|||
ToolDefinitionError,
|
||||
id=func_with_multiple_context_params.__name__,
|
||||
),
|
||||
pytest.param(
|
||||
func_with_missing_secret_key,
|
||||
ToolDefinitionError,
|
||||
id=func_with_missing_secret_key.__name__,
|
||||
),
|
||||
pytest.param(
|
||||
func_with_secret_requirement_invalid_type,
|
||||
ToolDefinitionError,
|
||||
id=func_with_secret_requirement_invalid_type.__name__,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_missing_info_raises_error(func_under_test, exception_type):
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@
|
|||
"description": "Version of the toolkit containing the tool to call"
|
||||
}
|
||||
},
|
||||
"required": ["name", "toolkit_name", "version"],
|
||||
"required": ["name", "toolkit", "version"],
|
||||
"additionalProperties": false
|
||||
},
|
||||
"inputs": {
|
||||
|
|
@ -54,6 +54,10 @@
|
|||
"additionalProperties": false
|
||||
}
|
||||
},
|
||||
"secrets": {
|
||||
"type": "object",
|
||||
"additionalProperties": true
|
||||
},
|
||||
"user_id": {
|
||||
"type": "string",
|
||||
"description": "A unique ID that identifies the user (if any)"
|
||||
|
|
@ -65,6 +69,6 @@
|
|||
}
|
||||
}
|
||||
},
|
||||
"required": ["run_id", "invocation_id", "created_at", "tool", "inputs", "context"],
|
||||
"required": ["run_id", "execution_id", "created_at", "tool", "inputs", "context"],
|
||||
"additionalProperties": false
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,8 +14,8 @@
|
|||
"oneOf": [{ "$ref": "#/$defs/primitives" }, { "type": "string", "enum": ["array"] }]
|
||||
},
|
||||
"inner_val_type": {
|
||||
"$ref": "#/$defs/primitives",
|
||||
"description": "If the value type is a list, the type of the list values."
|
||||
"description": "If the value type is a list, the type of the list values.",
|
||||
"oneOf": [{ "$ref": "#/$defs/primitives" }, { "type": "null" }]
|
||||
},
|
||||
"enum": {
|
||||
"oneOf": [
|
||||
|
|
@ -141,6 +141,24 @@
|
|||
"requirements": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"secrets": {
|
||||
"oneOf": [
|
||||
{ "type": "null" }, // Can be unset,
|
||||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"key": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": ["key"],
|
||||
"additionalProperties": false
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
"authorization": {
|
||||
"oneOf": [
|
||||
{ "type": "null" }, // Can be unset
|
||||
|
|
|
|||
Loading…
Reference in a new issue