Tool Metadata (#357)

This commit is contained in:
Eric Gustin 2025-04-16 19:17:36 -08:00 committed by GitHub
parent c7fba25488
commit ad713e4939
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 323 additions and 37 deletions

View file

@ -39,6 +39,8 @@ from arcade.core.schema import (
ToolDefinition,
ToolInput,
ToolkitDefinition,
ToolMetadataKey,
ToolMetadataRequirement,
ToolOutput,
ToolRequirements,
ToolSecretRequirement,
@ -369,31 +371,9 @@ class ToolCatalog(BaseModel):
if does_function_return_value(tool) and tool.__annotations__.get("return") is None:
raise ToolDefinitionError(f"Tool {raw_tool_name} must have a return type annotation")
auth_requirement = getattr(tool, "__tool_requires_auth__", None)
if isinstance(auth_requirement, ToolAuthorization):
new_auth_requirement = ToolAuthRequirement(
provider_id=auth_requirement.provider_id,
provider_type=auth_requirement.provider_type,
id=auth_requirement.id,
)
if isinstance(auth_requirement, OAuth2):
new_auth_requirement.oauth2 = OAuth2Requirement(**auth_requirement.model_dump())
auth_requirement = new_auth_requirement
secrets_requirement = getattr(tool, "__tool_requires_secrets__", None)
if isinstance(secrets_requirement, list):
if any(not isinstance(secret, str) for secret in secrets_requirement):
raise ToolDefinitionError(
f"Secret keys must be strings (error in tool {raw_tool_name})."
)
secrets_requirement = to_tool_secret_requirements(secrets_requirement)
if any(
secret.key is None or secret.key.strip() == "" for secret in secrets_requirement
):
raise ToolDefinitionError(
f"Secrets must have a non-empty key (error in tool {raw_tool_name})."
)
auth_requirement = create_auth_requirement(tool)
secrets_requirement = create_secrets_requirement(tool)
metadata_requirement = create_metadata_requirement(tool, auth_requirement)
toolkit_definition = ToolkitDefinition(
name=snake_to_pascal_case(toolkit_name),
@ -415,6 +395,7 @@ class ToolCatalog(BaseModel):
requirements=ToolRequirements(
authorization=auth_requirement,
secrets=secrets_requirement,
metadata=metadata_requirement,
),
deprecation_message=deprecation_message,
)
@ -505,6 +486,77 @@ def create_output_definition(func: Callable) -> ToolOutput:
)
def create_auth_requirement(tool: Callable) -> ToolAuthRequirement | None:
"""
Create an auth requirement for a tool.
"""
auth_requirement = getattr(tool, "__tool_requires_auth__", None)
if isinstance(auth_requirement, ToolAuthorization):
new_auth_requirement = ToolAuthRequirement(
provider_id=auth_requirement.provider_id,
provider_type=auth_requirement.provider_type,
id=auth_requirement.id,
)
if isinstance(auth_requirement, OAuth2):
new_auth_requirement.oauth2 = OAuth2Requirement(**auth_requirement.model_dump())
auth_requirement = new_auth_requirement
return auth_requirement
def create_secrets_requirement(tool: Callable) -> list[ToolSecretRequirement] | None:
"""
Create a secrets requirement for a tool.
"""
raw_tool_name = getattr(tool, "__tool_name__", tool.__name__)
secrets_requirement = getattr(tool, "__tool_requires_secrets__", None)
if isinstance(secrets_requirement, list):
if any(not isinstance(secret, str) for secret in secrets_requirement):
raise ToolDefinitionError(
f"Secret keys must be strings (error in tool {raw_tool_name})."
)
secrets_requirement = to_tool_secret_requirements(secrets_requirement)
if any(secret.key is None or secret.key.strip() == "" for secret in secrets_requirement):
raise ToolDefinitionError(
f"Secrets must have a non-empty key (error in tool {raw_tool_name})."
)
return secrets_requirement
def create_metadata_requirement(
tool: Callable, auth_requirement: ToolAuthRequirement | None
) -> list[ToolMetadataRequirement] | None:
"""
Create a metadata requirement for a tool.
"""
raw_tool_name = getattr(tool, "__tool_name__", tool.__name__)
metadata_requirement = getattr(tool, "__tool_requires_metadata__", None)
if isinstance(metadata_requirement, list):
for metadata in metadata_requirement:
if not isinstance(metadata, str):
raise ToolDefinitionError(
f"Metadata must be strings (error in tool {raw_tool_name})."
)
if ToolMetadataKey.requires_auth(metadata) and auth_requirement is None:
raise ToolDefinitionError(
f"Tool {raw_tool_name} declares metadata key '{metadata}', "
"which requires that the tool has an auth requirement, "
"but no auth requirement was provided. Please specify an auth requirement."
)
metadata_requirement = to_tool_metadata_requirements(metadata_requirement)
if any(
metadata.key is None or metadata.key.strip() == "" for metadata in metadata_requirement
):
raise ToolDefinitionError(
f"Metadata must have a non-empty key (error in tool {raw_tool_name})."
)
return metadata_requirement
@dataclass
class ParamInfo:
"""
@ -832,3 +884,11 @@ def to_tool_secret_requirements(
# Iterate through the list, de-dupe case-insensitively, and convert each string to a ToolSecretRequirement
unique_secrets = {name.lower(): name.lower() for name in secrets_requirement}.values()
return [ToolSecretRequirement(key=name) for name in unique_secrets]
def to_tool_metadata_requirements(
metadata_requirement: list[str],
) -> list[ToolMetadataRequirement]:
# Iterate through the list, de-dupe case-insensitively, and convert each string to a ToolMetadataRequirement
unique_metadata = {name.lower(): name.lower() for name in metadata_requirement}.values()
return [ToolMetadataRequirement(key=name) for name in unique_metadata]

View file

@ -1,5 +1,6 @@
import os
from dataclasses import dataclass
from enum import Enum
from typing import Any, Literal, Optional, Union
from pydantic import BaseModel, Field
@ -109,6 +110,26 @@ class ToolSecretRequirement(BaseModel):
"""The ID of the secret."""
class ToolMetadataKey(str, Enum):
"""Convience enum for commonly used metadata keys."""
CLIENT_ID = "client_id"
COORDINATOR_URL = "coordinator_url"
@staticmethod
def requires_auth(key: str) -> bool:
"""Whether the key depends on the tool having an authorization requirement."""
keys_that_require_auth = [ToolMetadataKey.CLIENT_ID]
return key.strip().lower() in keys_that_require_auth
class ToolMetadataRequirement(BaseModel):
"""A requirement for a tool to run."""
key: str
"""The ID of the metadata."""
class ToolRequirements(BaseModel):
"""The requirements for a tool to run."""
@ -116,6 +137,10 @@ class ToolRequirements(BaseModel):
"""The authorization requirements for the tool, if any."""
secrets: Union[list[ToolSecretRequirement], None] = None
"""The secret requirements for the tool, if any."""
metadata: Union[list[ToolMetadataRequirement], None] = None
"""The metadata requirements for the tool, if any."""
class ToolkitDefinition(BaseModel):
@ -250,6 +275,16 @@ class ToolSecretItem(BaseModel):
"""The value of the secret."""
class ToolMetadataItem(BaseModel):
"""The context for a tool metadata."""
key: str
"""The key of the metadata."""
value: str
"""The value of the metadata."""
class ToolContext(BaseModel):
"""The context for a tool invocation."""
@ -259,6 +294,9 @@ class ToolContext(BaseModel):
secrets: list[ToolSecretItem] | None = None
"""The secrets for the tool invocation."""
metadata: list[ToolMetadataItem] | None = None
"""The metadata for the tool invocation."""
user_id: str | None = None
"""The user ID for the tool invocation (if any)."""
@ -268,17 +306,27 @@ class ToolContext(BaseModel):
def get_secret(self, key: str) -> str:
"""Retrieve the secret for the tool invocation."""
if not key or not key.strip():
raise ValueError("Secret key ID passed to get_secret cannot be empty.")
return self._get_item(key, self.secrets, "secret")
if not self.secrets:
raise ValueError("Secrets not found in context.")
def get_metadata(self, key: str) -> str:
"""Retrieve the metadata for the tool invocation."""
return self._get_item(key, self.metadata, "metadata")
def _get_item(
self, key: str, items: list[ToolMetadataItem] | list[ToolSecretItem] | None, item_name: str
) -> str:
if not key or not key.strip():
raise ValueError(
f"{item_name.capitalize()} key passed to get_{item_name} cannot be empty."
)
if not items:
raise ValueError(f"{item_name.capitalize()}s not found in context.")
normalized_key = key.lower()
for secret in self.secrets:
if secret.key.lower() == normalized_key:
return secret.value
raise ValueError(f"Secret {key} not found in context.")
for item in items:
if item.key.lower() == normalized_key:
return item.value
raise ValueError(f"{item_name.capitalize()} {key} not found in context.")
class ToolCallRequest(BaseModel):

View file

@ -1,5 +1,5 @@
from arcade.core.catalog import ToolCatalog
from arcade.core.schema import ToolAuthorizationContext, ToolContext
from arcade.core.schema import ToolAuthorizationContext, ToolContext, ToolMetadataKey
from arcade.core.toolkit import Toolkit
from .tool import tool
@ -8,6 +8,7 @@ __all__ = [
"ToolAuthorizationContext",
"ToolCatalog",
"ToolContext",
"ToolMetadataKey",
"Toolkit",
"tool",
]

View file

@ -15,6 +15,7 @@ def tool(
name: str | None = None,
requires_auth: Union[ToolAuthorization, None] = None,
requires_secrets: Union[list[str], None] = None,
requires_metadata: Union[list[str], None] = None,
) -> Callable:
def decorator(func: Callable) -> Callable:
func_name = str(getattr(func, "__name__", None))
@ -24,6 +25,7 @@ def tool(
func.__tool_description__ = desc or inspect.cleandoc(func.__doc__ or "") # type: ignore[attr-defined]
func.__tool_requires_auth__ = requires_auth # type: ignore[attr-defined]
func.__tool_requires_secrets__ = requires_secrets # type: ignore[attr-defined]
func.__tool_requires_metadata__ = requires_metadata # type: ignore[attr-defined]
if inspect.iscoroutinefunction(func):

View file

@ -1,6 +1,11 @@
import pytest
from arcade.core.schema import ToolAuthorizationContext, ToolContext, ToolSecretItem
from arcade.core.schema import (
ToolAuthorizationContext,
ToolContext,
ToolMetadataItem,
ToolSecretItem,
)
def test_get_auth_token_or_empty_with_token():
@ -68,5 +73,47 @@ def test_get_secret_when_secrets_is_none():
def test_get_secret_with_empty_key():
tool_context = ToolContext(secrets=[])
with pytest.raises(ValueError, match="Secret key ID passed to get_secret cannot be empty."):
with pytest.raises(ValueError, match="Secret key passed to get_secret cannot be empty."):
tool_context.get_secret("")
def test_get_metadata_valid():
key = "my_key"
val = "metadata_value"
metadata = [ToolMetadataItem(key=key, value=val)]
tool_context = ToolContext(metadata=metadata)
assert tool_context.get_metadata(key) == val
def test_get_metadata_with_case_insensitive_key():
key = "My_key"
val = "metadata_value"
metadata = [ToolMetadataItem(key=key, value=val)]
tool_context = ToolContext(metadata=metadata)
assert tool_context.get_metadata(key.upper()) == val
assert tool_context.get_metadata(key.lower()) == val
def test_get_metadata_key_not_found():
key = "nonexistent_key"
metadata = [ToolMetadataItem(key="other_key", value="another_metadata")]
tool_context = ToolContext(metadata=metadata)
with pytest.raises(ValueError, match=f"Metadata {key} not found in context."):
tool_context.get_metadata(key)
def test_get_metadata_when_metadata_is_none():
tool_context = ToolContext(metadata=None)
with pytest.raises(ValueError, match="Metadatas not found in context."):
tool_context.get_metadata("missing_key")
def test_get_metadata_with_empty_key():
tool_context = ToolContext(metadata=[])
with pytest.raises(ValueError, match="Metadata key passed to get_metadata cannot be empty."):
tool_context.get_metadata("")

View file

@ -10,6 +10,8 @@ from arcade.core.schema import (
ToolAuthRequirement,
ToolContext,
ToolInput,
ToolMetadataKey,
ToolMetadataRequirement,
ToolOutput,
ToolRequirements,
ToolSecretRequirement,
@ -63,6 +65,38 @@ def func_with_multiple_secret_requirement():
pass
@tool(
desc="A function that requires metadata",
requires_metadata=[ToolMetadataKey.COORDINATOR_URL],
)
def func_with_metadata_requirement():
pass
@tool(
desc="A function that requires multiple metadata fields, deduped case-insensitively",
requires_metadata=[
ToolMetadataKey.COORDINATOR_URL,
"my_other_metadata_key",
"MY_OTHER_METADATA_KEY",
],
)
def func_with_multiple_metadata_requirement():
pass
@tool(
desc="A function that requires a metadata field that depends on the tool having an auth requirement",
requires_auth=OAuth2(
id="my_example_provider123",
scopes=["scope1", "scope2"],
),
requires_metadata=[ToolMetadataKey.CLIENT_ID],
)
def func_with_metadata_and_auth_dependency():
pass
@tool(
desc="A function that requires authentication",
requires_auth=OAuth2(
@ -336,6 +370,43 @@ def func_with_complex_return() -> dict[str, str]:
},
id="func_with_multiple_secret_requirement",
),
pytest.param(
func_with_metadata_requirement,
{
"requirements": ToolRequirements(
metadata=[ToolMetadataRequirement(key=ToolMetadataKey.COORDINATOR_URL)]
)
},
id="func_with_metadata_requirement",
),
pytest.param(
func_with_multiple_metadata_requirement,
{
"requirements": ToolRequirements(
metadata=[
ToolMetadataRequirement(key=ToolMetadataKey.COORDINATOR_URL),
ToolMetadataRequirement(key="my_other_metadata_key"),
]
)
},
id="func_with_multiple_metadata_requirement",
),
pytest.param(
func_with_metadata_and_auth_dependency,
{
"requirements": ToolRequirements(
metadata=[ToolMetadataRequirement(key=ToolMetadataKey.CLIENT_ID)],
authorization=ToolAuthRequirement(
provider_type="oauth2",
id="my_example_provider123",
oauth2=OAuth2Requirement(
scopes=["scope1", "scope2"],
),
),
)
},
id="func_with_metadata_and_auth_dependency",
),
pytest.param(
func_with_auth_requirement,
{

View file

@ -4,7 +4,7 @@ import pytest
from arcade.core.catalog import ToolCatalog
from arcade.core.errors import ToolDefinitionError
from arcade.core.schema import ToolContext
from arcade.core.schema import ToolContext, ToolMetadataKey
from arcade.sdk import tool
@ -81,6 +81,30 @@ def func_with_secret_requirement_invalid_type():
pass
@tool(
desc="A function with a required metadata with a missing key (illegal)",
requires_metadata=[""],
)
def func_with_missing_metadata_key(context: ToolContext):
pass
@tool(
desc="A function that requires metadata with an invalid type (illegal)",
requires_metadata=[True],
)
def func_with_metadata_requirement_invalid_type():
pass
@tool(
desc="A function with a required metadata key that depends on the tool having an auth requirement, but the tool does not have an auth requirement (illegal)",
requires_metadata=[ToolMetadataKey.CLIENT_ID],
)
def func_with_metadata_and_auth_dependency():
pass
@pytest.mark.parametrize(
"func_under_test, exception_type",
[
@ -139,6 +163,21 @@ def func_with_secret_requirement_invalid_type():
ToolDefinitionError,
id=func_with_secret_requirement_invalid_type.__name__,
),
pytest.param(
func_with_missing_metadata_key,
ToolDefinitionError,
id=func_with_missing_metadata_key.__name__,
),
pytest.param(
func_with_metadata_requirement_invalid_type,
ToolDefinitionError,
id=func_with_metadata_requirement_invalid_type.__name__,
),
pytest.param(
func_with_metadata_and_auth_dependency,
ToolDefinitionError,
id=func_with_metadata_and_auth_dependency.__name__,
),
pytest.param(
func_with_union_return_type_1,
ToolDefinitionError,

View file

@ -159,6 +159,24 @@
}
]
},
"metadata": {
"oneOf": [
{ "type": "null" }, // Can be unset,
{
"type": "array",
"items": {
"type": "object",
"properties": {
"key": {
"type": "string"
}
},
"required": ["key"],
"additionalProperties": false
}
}
]
},
"authorization": {
"oneOf": [
{ "type": "null" }, // Can be unset