Tool Metadata (#357)

This commit is contained in:
Eric Gustin 2025-04-16 19:17:36 -08:00 committed by GitHub
parent c7fba25488
commit ad713e4939
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 323 additions and 37 deletions

View file

@ -39,6 +39,8 @@ from arcade.core.schema import (
ToolDefinition, ToolDefinition,
ToolInput, ToolInput,
ToolkitDefinition, ToolkitDefinition,
ToolMetadataKey,
ToolMetadataRequirement,
ToolOutput, ToolOutput,
ToolRequirements, ToolRequirements,
ToolSecretRequirement, ToolSecretRequirement,
@ -369,31 +371,9 @@ class ToolCatalog(BaseModel):
if does_function_return_value(tool) and tool.__annotations__.get("return") is None: if does_function_return_value(tool) and tool.__annotations__.get("return") is None:
raise ToolDefinitionError(f"Tool {raw_tool_name} must have a return type annotation") raise ToolDefinitionError(f"Tool {raw_tool_name} must have a return type annotation")
auth_requirement = getattr(tool, "__tool_requires_auth__", None) auth_requirement = create_auth_requirement(tool)
if isinstance(auth_requirement, ToolAuthorization): secrets_requirement = create_secrets_requirement(tool)
new_auth_requirement = ToolAuthRequirement( metadata_requirement = create_metadata_requirement(tool, auth_requirement)
provider_id=auth_requirement.provider_id,
provider_type=auth_requirement.provider_type,
id=auth_requirement.id,
)
if isinstance(auth_requirement, OAuth2):
new_auth_requirement.oauth2 = OAuth2Requirement(**auth_requirement.model_dump())
auth_requirement = new_auth_requirement
secrets_requirement = getattr(tool, "__tool_requires_secrets__", None)
if isinstance(secrets_requirement, list):
if any(not isinstance(secret, str) for secret in secrets_requirement):
raise ToolDefinitionError(
f"Secret keys must be strings (error in tool {raw_tool_name})."
)
secrets_requirement = to_tool_secret_requirements(secrets_requirement)
if any(
secret.key is None or secret.key.strip() == "" for secret in secrets_requirement
):
raise ToolDefinitionError(
f"Secrets must have a non-empty key (error in tool {raw_tool_name})."
)
toolkit_definition = ToolkitDefinition( toolkit_definition = ToolkitDefinition(
name=snake_to_pascal_case(toolkit_name), name=snake_to_pascal_case(toolkit_name),
@ -415,6 +395,7 @@ class ToolCatalog(BaseModel):
requirements=ToolRequirements( requirements=ToolRequirements(
authorization=auth_requirement, authorization=auth_requirement,
secrets=secrets_requirement, secrets=secrets_requirement,
metadata=metadata_requirement,
), ),
deprecation_message=deprecation_message, deprecation_message=deprecation_message,
) )
@ -505,6 +486,77 @@ def create_output_definition(func: Callable) -> ToolOutput:
) )
def create_auth_requirement(tool: Callable) -> ToolAuthRequirement | None:
"""
Create an auth requirement for a tool.
"""
auth_requirement = getattr(tool, "__tool_requires_auth__", None)
if isinstance(auth_requirement, ToolAuthorization):
new_auth_requirement = ToolAuthRequirement(
provider_id=auth_requirement.provider_id,
provider_type=auth_requirement.provider_type,
id=auth_requirement.id,
)
if isinstance(auth_requirement, OAuth2):
new_auth_requirement.oauth2 = OAuth2Requirement(**auth_requirement.model_dump())
auth_requirement = new_auth_requirement
return auth_requirement
def create_secrets_requirement(tool: Callable) -> list[ToolSecretRequirement] | None:
"""
Create a secrets requirement for a tool.
"""
raw_tool_name = getattr(tool, "__tool_name__", tool.__name__)
secrets_requirement = getattr(tool, "__tool_requires_secrets__", None)
if isinstance(secrets_requirement, list):
if any(not isinstance(secret, str) for secret in secrets_requirement):
raise ToolDefinitionError(
f"Secret keys must be strings (error in tool {raw_tool_name})."
)
secrets_requirement = to_tool_secret_requirements(secrets_requirement)
if any(secret.key is None or secret.key.strip() == "" for secret in secrets_requirement):
raise ToolDefinitionError(
f"Secrets must have a non-empty key (error in tool {raw_tool_name})."
)
return secrets_requirement
def create_metadata_requirement(
tool: Callable, auth_requirement: ToolAuthRequirement | None
) -> list[ToolMetadataRequirement] | None:
"""
Create a metadata requirement for a tool.
"""
raw_tool_name = getattr(tool, "__tool_name__", tool.__name__)
metadata_requirement = getattr(tool, "__tool_requires_metadata__", None)
if isinstance(metadata_requirement, list):
for metadata in metadata_requirement:
if not isinstance(metadata, str):
raise ToolDefinitionError(
f"Metadata must be strings (error in tool {raw_tool_name})."
)
if ToolMetadataKey.requires_auth(metadata) and auth_requirement is None:
raise ToolDefinitionError(
f"Tool {raw_tool_name} declares metadata key '{metadata}', "
"which requires that the tool has an auth requirement, "
"but no auth requirement was provided. Please specify an auth requirement."
)
metadata_requirement = to_tool_metadata_requirements(metadata_requirement)
if any(
metadata.key is None or metadata.key.strip() == "" for metadata in metadata_requirement
):
raise ToolDefinitionError(
f"Metadata must have a non-empty key (error in tool {raw_tool_name})."
)
return metadata_requirement
@dataclass @dataclass
class ParamInfo: class ParamInfo:
""" """
@ -832,3 +884,11 @@ def to_tool_secret_requirements(
# Iterate through the list, de-dupe case-insensitively, and convert each string to a ToolSecretRequirement # Iterate through the list, de-dupe case-insensitively, and convert each string to a ToolSecretRequirement
unique_secrets = {name.lower(): name.lower() for name in secrets_requirement}.values() unique_secrets = {name.lower(): name.lower() for name in secrets_requirement}.values()
return [ToolSecretRequirement(key=name) for name in unique_secrets] return [ToolSecretRequirement(key=name) for name in unique_secrets]
def to_tool_metadata_requirements(
metadata_requirement: list[str],
) -> list[ToolMetadataRequirement]:
# Iterate through the list, de-dupe case-insensitively, and convert each string to a ToolMetadataRequirement
unique_metadata = {name.lower(): name.lower() for name in metadata_requirement}.values()
return [ToolMetadataRequirement(key=name) for name in unique_metadata]

View file

@ -1,5 +1,6 @@
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum
from typing import Any, Literal, Optional, Union from typing import Any, Literal, Optional, Union
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -109,6 +110,26 @@ class ToolSecretRequirement(BaseModel):
"""The ID of the secret.""" """The ID of the secret."""
class ToolMetadataKey(str, Enum):
"""Convience enum for commonly used metadata keys."""
CLIENT_ID = "client_id"
COORDINATOR_URL = "coordinator_url"
@staticmethod
def requires_auth(key: str) -> bool:
"""Whether the key depends on the tool having an authorization requirement."""
keys_that_require_auth = [ToolMetadataKey.CLIENT_ID]
return key.strip().lower() in keys_that_require_auth
class ToolMetadataRequirement(BaseModel):
"""A requirement for a tool to run."""
key: str
"""The ID of the metadata."""
class ToolRequirements(BaseModel): class ToolRequirements(BaseModel):
"""The requirements for a tool to run.""" """The requirements for a tool to run."""
@ -116,6 +137,10 @@ class ToolRequirements(BaseModel):
"""The authorization requirements for the tool, if any.""" """The authorization requirements for the tool, if any."""
secrets: Union[list[ToolSecretRequirement], None] = None secrets: Union[list[ToolSecretRequirement], None] = None
"""The secret requirements for the tool, if any."""
metadata: Union[list[ToolMetadataRequirement], None] = None
"""The metadata requirements for the tool, if any."""
class ToolkitDefinition(BaseModel): class ToolkitDefinition(BaseModel):
@ -250,6 +275,16 @@ class ToolSecretItem(BaseModel):
"""The value of the secret.""" """The value of the secret."""
class ToolMetadataItem(BaseModel):
"""The context for a tool metadata."""
key: str
"""The key of the metadata."""
value: str
"""The value of the metadata."""
class ToolContext(BaseModel): class ToolContext(BaseModel):
"""The context for a tool invocation.""" """The context for a tool invocation."""
@ -259,6 +294,9 @@ class ToolContext(BaseModel):
secrets: list[ToolSecretItem] | None = None secrets: list[ToolSecretItem] | None = None
"""The secrets for the tool invocation.""" """The secrets for the tool invocation."""
metadata: list[ToolMetadataItem] | None = None
"""The metadata for the tool invocation."""
user_id: str | None = None user_id: str | None = None
"""The user ID for the tool invocation (if any).""" """The user ID for the tool invocation (if any)."""
@ -268,17 +306,27 @@ class ToolContext(BaseModel):
def get_secret(self, key: str) -> str: def get_secret(self, key: str) -> str:
"""Retrieve the secret for the tool invocation.""" """Retrieve the secret for the tool invocation."""
if not key or not key.strip(): return self._get_item(key, self.secrets, "secret")
raise ValueError("Secret key ID passed to get_secret cannot be empty.")
if not self.secrets: def get_metadata(self, key: str) -> str:
raise ValueError("Secrets not found in context.") """Retrieve the metadata for the tool invocation."""
return self._get_item(key, self.metadata, "metadata")
def _get_item(
self, key: str, items: list[ToolMetadataItem] | list[ToolSecretItem] | None, item_name: str
) -> str:
if not key or not key.strip():
raise ValueError(
f"{item_name.capitalize()} key passed to get_{item_name} cannot be empty."
)
if not items:
raise ValueError(f"{item_name.capitalize()}s not found in context.")
normalized_key = key.lower() normalized_key = key.lower()
for secret in self.secrets: for item in items:
if secret.key.lower() == normalized_key: if item.key.lower() == normalized_key:
return secret.value return item.value
raise ValueError(f"Secret {key} not found in context.") raise ValueError(f"{item_name.capitalize()} {key} not found in context.")
class ToolCallRequest(BaseModel): class ToolCallRequest(BaseModel):

View file

@ -1,5 +1,5 @@
from arcade.core.catalog import ToolCatalog from arcade.core.catalog import ToolCatalog
from arcade.core.schema import ToolAuthorizationContext, ToolContext from arcade.core.schema import ToolAuthorizationContext, ToolContext, ToolMetadataKey
from arcade.core.toolkit import Toolkit from arcade.core.toolkit import Toolkit
from .tool import tool from .tool import tool
@ -8,6 +8,7 @@ __all__ = [
"ToolAuthorizationContext", "ToolAuthorizationContext",
"ToolCatalog", "ToolCatalog",
"ToolContext", "ToolContext",
"ToolMetadataKey",
"Toolkit", "Toolkit",
"tool", "tool",
] ]

View file

@ -15,6 +15,7 @@ def tool(
name: str | None = None, name: str | None = None,
requires_auth: Union[ToolAuthorization, None] = None, requires_auth: Union[ToolAuthorization, None] = None,
requires_secrets: Union[list[str], None] = None, requires_secrets: Union[list[str], None] = None,
requires_metadata: Union[list[str], None] = None,
) -> Callable: ) -> Callable:
def decorator(func: Callable) -> Callable: def decorator(func: Callable) -> Callable:
func_name = str(getattr(func, "__name__", None)) func_name = str(getattr(func, "__name__", None))
@ -24,6 +25,7 @@ def tool(
func.__tool_description__ = desc or inspect.cleandoc(func.__doc__ or "") # type: ignore[attr-defined] func.__tool_description__ = desc or inspect.cleandoc(func.__doc__ or "") # type: ignore[attr-defined]
func.__tool_requires_auth__ = requires_auth # type: ignore[attr-defined] func.__tool_requires_auth__ = requires_auth # type: ignore[attr-defined]
func.__tool_requires_secrets__ = requires_secrets # type: ignore[attr-defined] func.__tool_requires_secrets__ = requires_secrets # type: ignore[attr-defined]
func.__tool_requires_metadata__ = requires_metadata # type: ignore[attr-defined]
if inspect.iscoroutinefunction(func): if inspect.iscoroutinefunction(func):

View file

@ -1,6 +1,11 @@
import pytest import pytest
from arcade.core.schema import ToolAuthorizationContext, ToolContext, ToolSecretItem from arcade.core.schema import (
ToolAuthorizationContext,
ToolContext,
ToolMetadataItem,
ToolSecretItem,
)
def test_get_auth_token_or_empty_with_token(): def test_get_auth_token_or_empty_with_token():
@ -68,5 +73,47 @@ def test_get_secret_when_secrets_is_none():
def test_get_secret_with_empty_key(): def test_get_secret_with_empty_key():
tool_context = ToolContext(secrets=[]) tool_context = ToolContext(secrets=[])
with pytest.raises(ValueError, match="Secret key ID passed to get_secret cannot be empty."): with pytest.raises(ValueError, match="Secret key passed to get_secret cannot be empty."):
tool_context.get_secret("") tool_context.get_secret("")
def test_get_metadata_valid():
key = "my_key"
val = "metadata_value"
metadata = [ToolMetadataItem(key=key, value=val)]
tool_context = ToolContext(metadata=metadata)
assert tool_context.get_metadata(key) == val
def test_get_metadata_with_case_insensitive_key():
key = "My_key"
val = "metadata_value"
metadata = [ToolMetadataItem(key=key, value=val)]
tool_context = ToolContext(metadata=metadata)
assert tool_context.get_metadata(key.upper()) == val
assert tool_context.get_metadata(key.lower()) == val
def test_get_metadata_key_not_found():
key = "nonexistent_key"
metadata = [ToolMetadataItem(key="other_key", value="another_metadata")]
tool_context = ToolContext(metadata=metadata)
with pytest.raises(ValueError, match=f"Metadata {key} not found in context."):
tool_context.get_metadata(key)
def test_get_metadata_when_metadata_is_none():
tool_context = ToolContext(metadata=None)
with pytest.raises(ValueError, match="Metadatas not found in context."):
tool_context.get_metadata("missing_key")
def test_get_metadata_with_empty_key():
tool_context = ToolContext(metadata=[])
with pytest.raises(ValueError, match="Metadata key passed to get_metadata cannot be empty."):
tool_context.get_metadata("")

View file

@ -10,6 +10,8 @@ from arcade.core.schema import (
ToolAuthRequirement, ToolAuthRequirement,
ToolContext, ToolContext,
ToolInput, ToolInput,
ToolMetadataKey,
ToolMetadataRequirement,
ToolOutput, ToolOutput,
ToolRequirements, ToolRequirements,
ToolSecretRequirement, ToolSecretRequirement,
@ -63,6 +65,38 @@ def func_with_multiple_secret_requirement():
pass pass
@tool(
desc="A function that requires metadata",
requires_metadata=[ToolMetadataKey.COORDINATOR_URL],
)
def func_with_metadata_requirement():
pass
@tool(
desc="A function that requires multiple metadata fields, deduped case-insensitively",
requires_metadata=[
ToolMetadataKey.COORDINATOR_URL,
"my_other_metadata_key",
"MY_OTHER_METADATA_KEY",
],
)
def func_with_multiple_metadata_requirement():
pass
@tool(
desc="A function that requires a metadata field that depends on the tool having an auth requirement",
requires_auth=OAuth2(
id="my_example_provider123",
scopes=["scope1", "scope2"],
),
requires_metadata=[ToolMetadataKey.CLIENT_ID],
)
def func_with_metadata_and_auth_dependency():
pass
@tool( @tool(
desc="A function that requires authentication", desc="A function that requires authentication",
requires_auth=OAuth2( requires_auth=OAuth2(
@ -336,6 +370,43 @@ def func_with_complex_return() -> dict[str, str]:
}, },
id="func_with_multiple_secret_requirement", id="func_with_multiple_secret_requirement",
), ),
pytest.param(
func_with_metadata_requirement,
{
"requirements": ToolRequirements(
metadata=[ToolMetadataRequirement(key=ToolMetadataKey.COORDINATOR_URL)]
)
},
id="func_with_metadata_requirement",
),
pytest.param(
func_with_multiple_metadata_requirement,
{
"requirements": ToolRequirements(
metadata=[
ToolMetadataRequirement(key=ToolMetadataKey.COORDINATOR_URL),
ToolMetadataRequirement(key="my_other_metadata_key"),
]
)
},
id="func_with_multiple_metadata_requirement",
),
pytest.param(
func_with_metadata_and_auth_dependency,
{
"requirements": ToolRequirements(
metadata=[ToolMetadataRequirement(key=ToolMetadataKey.CLIENT_ID)],
authorization=ToolAuthRequirement(
provider_type="oauth2",
id="my_example_provider123",
oauth2=OAuth2Requirement(
scopes=["scope1", "scope2"],
),
),
)
},
id="func_with_metadata_and_auth_dependency",
),
pytest.param( pytest.param(
func_with_auth_requirement, func_with_auth_requirement,
{ {

View file

@ -4,7 +4,7 @@ import pytest
from arcade.core.catalog import ToolCatalog from arcade.core.catalog import ToolCatalog
from arcade.core.errors import ToolDefinitionError from arcade.core.errors import ToolDefinitionError
from arcade.core.schema import ToolContext from arcade.core.schema import ToolContext, ToolMetadataKey
from arcade.sdk import tool from arcade.sdk import tool
@ -81,6 +81,30 @@ def func_with_secret_requirement_invalid_type():
pass pass
@tool(
desc="A function with a required metadata with a missing key (illegal)",
requires_metadata=[""],
)
def func_with_missing_metadata_key(context: ToolContext):
pass
@tool(
desc="A function that requires metadata with an invalid type (illegal)",
requires_metadata=[True],
)
def func_with_metadata_requirement_invalid_type():
pass
@tool(
desc="A function with a required metadata key that depends on the tool having an auth requirement, but the tool does not have an auth requirement (illegal)",
requires_metadata=[ToolMetadataKey.CLIENT_ID],
)
def func_with_metadata_and_auth_dependency():
pass
@pytest.mark.parametrize( @pytest.mark.parametrize(
"func_under_test, exception_type", "func_under_test, exception_type",
[ [
@ -139,6 +163,21 @@ def func_with_secret_requirement_invalid_type():
ToolDefinitionError, ToolDefinitionError,
id=func_with_secret_requirement_invalid_type.__name__, id=func_with_secret_requirement_invalid_type.__name__,
), ),
pytest.param(
func_with_missing_metadata_key,
ToolDefinitionError,
id=func_with_missing_metadata_key.__name__,
),
pytest.param(
func_with_metadata_requirement_invalid_type,
ToolDefinitionError,
id=func_with_metadata_requirement_invalid_type.__name__,
),
pytest.param(
func_with_metadata_and_auth_dependency,
ToolDefinitionError,
id=func_with_metadata_and_auth_dependency.__name__,
),
pytest.param( pytest.param(
func_with_union_return_type_1, func_with_union_return_type_1,
ToolDefinitionError, ToolDefinitionError,

View file

@ -159,6 +159,24 @@
} }
] ]
}, },
"metadata": {
"oneOf": [
{ "type": "null" }, // Can be unset,
{
"type": "array",
"items": {
"type": "object",
"properties": {
"key": {
"type": "string"
}
},
"required": ["key"],
"additionalProperties": false
}
}
]
},
"authorization": { "authorization": {
"oneOf": [ "oneOf": [
{ "type": "null" }, // Can be unset { "type": "null" }, // Can be unset