From ad713e4939b8c36d9c9e1e9e0f50054581f9b052 Mon Sep 17 00:00:00 2001 From: Eric Gustin <34000337+EricGustin@users.noreply.github.com> Date: Wed, 16 Apr 2025 19:17:36 -0800 Subject: [PATCH] Tool Metadata (#357) --- arcade/arcade/core/catalog.py | 110 ++++++++++++++---- arcade/arcade/core/schema.py | 64 ++++++++-- arcade/arcade/sdk/__init__.py | 3 +- arcade/arcade/sdk/tool.py | 2 + arcade/tests/core/test_schema.py | 51 +++++++- .../tests/tool/test_create_tool_definition.py | 71 +++++++++++ .../test_create_tool_definition_errors.py | 41 ++++++- schemas/preview/tool_definition.schema.jsonc | 18 +++ 8 files changed, 323 insertions(+), 37 deletions(-) diff --git a/arcade/arcade/core/catalog.py b/arcade/arcade/core/catalog.py index 1fc6e6b4..15e97d00 100644 --- a/arcade/arcade/core/catalog.py +++ b/arcade/arcade/core/catalog.py @@ -39,6 +39,8 @@ from arcade.core.schema import ( ToolDefinition, ToolInput, ToolkitDefinition, + ToolMetadataKey, + ToolMetadataRequirement, ToolOutput, ToolRequirements, ToolSecretRequirement, @@ -369,31 +371,9 @@ class ToolCatalog(BaseModel): if does_function_return_value(tool) and tool.__annotations__.get("return") is None: raise ToolDefinitionError(f"Tool {raw_tool_name} must have a return type annotation") - auth_requirement = getattr(tool, "__tool_requires_auth__", None) - if isinstance(auth_requirement, ToolAuthorization): - new_auth_requirement = ToolAuthRequirement( - provider_id=auth_requirement.provider_id, - provider_type=auth_requirement.provider_type, - id=auth_requirement.id, - ) - if isinstance(auth_requirement, OAuth2): - new_auth_requirement.oauth2 = OAuth2Requirement(**auth_requirement.model_dump()) - auth_requirement = new_auth_requirement - - secrets_requirement = getattr(tool, "__tool_requires_secrets__", None) - if isinstance(secrets_requirement, list): - if any(not isinstance(secret, str) for secret in secrets_requirement): - raise ToolDefinitionError( - f"Secret keys must be strings (error in tool {raw_tool_name})." - ) - - secrets_requirement = to_tool_secret_requirements(secrets_requirement) - if any( - secret.key is None or secret.key.strip() == "" for secret in secrets_requirement - ): - raise ToolDefinitionError( - f"Secrets must have a non-empty key (error in tool {raw_tool_name})." - ) + auth_requirement = create_auth_requirement(tool) + secrets_requirement = create_secrets_requirement(tool) + metadata_requirement = create_metadata_requirement(tool, auth_requirement) toolkit_definition = ToolkitDefinition( name=snake_to_pascal_case(toolkit_name), @@ -415,6 +395,7 @@ class ToolCatalog(BaseModel): requirements=ToolRequirements( authorization=auth_requirement, secrets=secrets_requirement, + metadata=metadata_requirement, ), deprecation_message=deprecation_message, ) @@ -505,6 +486,77 @@ def create_output_definition(func: Callable) -> ToolOutput: ) +def create_auth_requirement(tool: Callable) -> ToolAuthRequirement | None: + """ + Create an auth requirement for a tool. + """ + auth_requirement = getattr(tool, "__tool_requires_auth__", None) + if isinstance(auth_requirement, ToolAuthorization): + new_auth_requirement = ToolAuthRequirement( + provider_id=auth_requirement.provider_id, + provider_type=auth_requirement.provider_type, + id=auth_requirement.id, + ) + if isinstance(auth_requirement, OAuth2): + new_auth_requirement.oauth2 = OAuth2Requirement(**auth_requirement.model_dump()) + auth_requirement = new_auth_requirement + + return auth_requirement + + +def create_secrets_requirement(tool: Callable) -> list[ToolSecretRequirement] | None: + """ + Create a secrets requirement for a tool. + """ + raw_tool_name = getattr(tool, "__tool_name__", tool.__name__) + secrets_requirement = getattr(tool, "__tool_requires_secrets__", None) + if isinstance(secrets_requirement, list): + if any(not isinstance(secret, str) for secret in secrets_requirement): + raise ToolDefinitionError( + f"Secret keys must be strings (error in tool {raw_tool_name})." + ) + + secrets_requirement = to_tool_secret_requirements(secrets_requirement) + if any(secret.key is None or secret.key.strip() == "" for secret in secrets_requirement): + raise ToolDefinitionError( + f"Secrets must have a non-empty key (error in tool {raw_tool_name})." + ) + + return secrets_requirement + + +def create_metadata_requirement( + tool: Callable, auth_requirement: ToolAuthRequirement | None +) -> list[ToolMetadataRequirement] | None: + """ + Create a metadata requirement for a tool. + """ + raw_tool_name = getattr(tool, "__tool_name__", tool.__name__) + metadata_requirement = getattr(tool, "__tool_requires_metadata__", None) + if isinstance(metadata_requirement, list): + for metadata in metadata_requirement: + if not isinstance(metadata, str): + raise ToolDefinitionError( + f"Metadata must be strings (error in tool {raw_tool_name})." + ) + if ToolMetadataKey.requires_auth(metadata) and auth_requirement is None: + raise ToolDefinitionError( + f"Tool {raw_tool_name} declares metadata key '{metadata}', " + "which requires that the tool has an auth requirement, " + "but no auth requirement was provided. Please specify an auth requirement." + ) + + metadata_requirement = to_tool_metadata_requirements(metadata_requirement) + if any( + metadata.key is None or metadata.key.strip() == "" for metadata in metadata_requirement + ): + raise ToolDefinitionError( + f"Metadata must have a non-empty key (error in tool {raw_tool_name})." + ) + + return metadata_requirement + + @dataclass class ParamInfo: """ @@ -832,3 +884,11 @@ def to_tool_secret_requirements( # Iterate through the list, de-dupe case-insensitively, and convert each string to a ToolSecretRequirement unique_secrets = {name.lower(): name.lower() for name in secrets_requirement}.values() return [ToolSecretRequirement(key=name) for name in unique_secrets] + + +def to_tool_metadata_requirements( + metadata_requirement: list[str], +) -> list[ToolMetadataRequirement]: + # Iterate through the list, de-dupe case-insensitively, and convert each string to a ToolMetadataRequirement + unique_metadata = {name.lower(): name.lower() for name in metadata_requirement}.values() + return [ToolMetadataRequirement(key=name) for name in unique_metadata] diff --git a/arcade/arcade/core/schema.py b/arcade/arcade/core/schema.py index 4867f6d6..af21b3ba 100644 --- a/arcade/arcade/core/schema.py +++ b/arcade/arcade/core/schema.py @@ -1,5 +1,6 @@ import os from dataclasses import dataclass +from enum import Enum from typing import Any, Literal, Optional, Union from pydantic import BaseModel, Field @@ -109,6 +110,26 @@ class ToolSecretRequirement(BaseModel): """The ID of the secret.""" +class ToolMetadataKey(str, Enum): + """Convience enum for commonly used metadata keys.""" + + CLIENT_ID = "client_id" + COORDINATOR_URL = "coordinator_url" + + @staticmethod + def requires_auth(key: str) -> bool: + """Whether the key depends on the tool having an authorization requirement.""" + keys_that_require_auth = [ToolMetadataKey.CLIENT_ID] + return key.strip().lower() in keys_that_require_auth + + +class ToolMetadataRequirement(BaseModel): + """A requirement for a tool to run.""" + + key: str + """The ID of the metadata.""" + + class ToolRequirements(BaseModel): """The requirements for a tool to run.""" @@ -116,6 +137,10 @@ class ToolRequirements(BaseModel): """The authorization requirements for the tool, if any.""" secrets: Union[list[ToolSecretRequirement], None] = None + """The secret requirements for the tool, if any.""" + + metadata: Union[list[ToolMetadataRequirement], None] = None + """The metadata requirements for the tool, if any.""" class ToolkitDefinition(BaseModel): @@ -250,6 +275,16 @@ class ToolSecretItem(BaseModel): """The value of the secret.""" +class ToolMetadataItem(BaseModel): + """The context for a tool metadata.""" + + key: str + """The key of the metadata.""" + + value: str + """The value of the metadata.""" + + class ToolContext(BaseModel): """The context for a tool invocation.""" @@ -259,6 +294,9 @@ class ToolContext(BaseModel): secrets: list[ToolSecretItem] | None = None """The secrets for the tool invocation.""" + metadata: list[ToolMetadataItem] | None = None + """The metadata for the tool invocation.""" + user_id: str | None = None """The user ID for the tool invocation (if any).""" @@ -268,17 +306,27 @@ class ToolContext(BaseModel): def get_secret(self, key: str) -> str: """Retrieve the secret for the tool invocation.""" - if not key or not key.strip(): - raise ValueError("Secret key ID passed to get_secret cannot be empty.") + return self._get_item(key, self.secrets, "secret") - if not self.secrets: - raise ValueError("Secrets not found in context.") + def get_metadata(self, key: str) -> str: + """Retrieve the metadata for the tool invocation.""" + return self._get_item(key, self.metadata, "metadata") + + def _get_item( + self, key: str, items: list[ToolMetadataItem] | list[ToolSecretItem] | None, item_name: str + ) -> str: + if not key or not key.strip(): + raise ValueError( + f"{item_name.capitalize()} key passed to get_{item_name} cannot be empty." + ) + if not items: + raise ValueError(f"{item_name.capitalize()}s not found in context.") normalized_key = key.lower() - for secret in self.secrets: - if secret.key.lower() == normalized_key: - return secret.value - raise ValueError(f"Secret {key} not found in context.") + for item in items: + if item.key.lower() == normalized_key: + return item.value + raise ValueError(f"{item_name.capitalize()} {key} not found in context.") class ToolCallRequest(BaseModel): diff --git a/arcade/arcade/sdk/__init__.py b/arcade/arcade/sdk/__init__.py index 24bd83b4..35132e9e 100644 --- a/arcade/arcade/sdk/__init__.py +++ b/arcade/arcade/sdk/__init__.py @@ -1,5 +1,5 @@ from arcade.core.catalog import ToolCatalog -from arcade.core.schema import ToolAuthorizationContext, ToolContext +from arcade.core.schema import ToolAuthorizationContext, ToolContext, ToolMetadataKey from arcade.core.toolkit import Toolkit from .tool import tool @@ -8,6 +8,7 @@ __all__ = [ "ToolAuthorizationContext", "ToolCatalog", "ToolContext", + "ToolMetadataKey", "Toolkit", "tool", ] diff --git a/arcade/arcade/sdk/tool.py b/arcade/arcade/sdk/tool.py index ac9d6b3c..18bc1c69 100644 --- a/arcade/arcade/sdk/tool.py +++ b/arcade/arcade/sdk/tool.py @@ -15,6 +15,7 @@ def tool( name: str | None = None, requires_auth: Union[ToolAuthorization, None] = None, requires_secrets: Union[list[str], None] = None, + requires_metadata: Union[list[str], None] = None, ) -> Callable: def decorator(func: Callable) -> Callable: func_name = str(getattr(func, "__name__", None)) @@ -24,6 +25,7 @@ def tool( func.__tool_description__ = desc or inspect.cleandoc(func.__doc__ or "") # type: ignore[attr-defined] func.__tool_requires_auth__ = requires_auth # type: ignore[attr-defined] func.__tool_requires_secrets__ = requires_secrets # type: ignore[attr-defined] + func.__tool_requires_metadata__ = requires_metadata # type: ignore[attr-defined] if inspect.iscoroutinefunction(func): diff --git a/arcade/tests/core/test_schema.py b/arcade/tests/core/test_schema.py index 4623882a..cde95d07 100644 --- a/arcade/tests/core/test_schema.py +++ b/arcade/tests/core/test_schema.py @@ -1,6 +1,11 @@ import pytest -from arcade.core.schema import ToolAuthorizationContext, ToolContext, ToolSecretItem +from arcade.core.schema import ( + ToolAuthorizationContext, + ToolContext, + ToolMetadataItem, + ToolSecretItem, +) def test_get_auth_token_or_empty_with_token(): @@ -68,5 +73,47 @@ def test_get_secret_when_secrets_is_none(): def test_get_secret_with_empty_key(): tool_context = ToolContext(secrets=[]) - with pytest.raises(ValueError, match="Secret key ID passed to get_secret cannot be empty."): + with pytest.raises(ValueError, match="Secret key passed to get_secret cannot be empty."): tool_context.get_secret("") + + +def test_get_metadata_valid(): + key = "my_key" + val = "metadata_value" + metadata = [ToolMetadataItem(key=key, value=val)] + tool_context = ToolContext(metadata=metadata) + + assert tool_context.get_metadata(key) == val + + +def test_get_metadata_with_case_insensitive_key(): + key = "My_key" + val = "metadata_value" + metadata = [ToolMetadataItem(key=key, value=val)] + tool_context = ToolContext(metadata=metadata) + + assert tool_context.get_metadata(key.upper()) == val + assert tool_context.get_metadata(key.lower()) == val + + +def test_get_metadata_key_not_found(): + key = "nonexistent_key" + metadata = [ToolMetadataItem(key="other_key", value="another_metadata")] + tool_context = ToolContext(metadata=metadata) + + with pytest.raises(ValueError, match=f"Metadata {key} not found in context."): + tool_context.get_metadata(key) + + +def test_get_metadata_when_metadata_is_none(): + tool_context = ToolContext(metadata=None) + + with pytest.raises(ValueError, match="Metadatas not found in context."): + tool_context.get_metadata("missing_key") + + +def test_get_metadata_with_empty_key(): + tool_context = ToolContext(metadata=[]) + + with pytest.raises(ValueError, match="Metadata key passed to get_metadata cannot be empty."): + tool_context.get_metadata("") diff --git a/arcade/tests/tool/test_create_tool_definition.py b/arcade/tests/tool/test_create_tool_definition.py index 539a3727..ee4c07d2 100644 --- a/arcade/tests/tool/test_create_tool_definition.py +++ b/arcade/tests/tool/test_create_tool_definition.py @@ -10,6 +10,8 @@ from arcade.core.schema import ( ToolAuthRequirement, ToolContext, ToolInput, + ToolMetadataKey, + ToolMetadataRequirement, ToolOutput, ToolRequirements, ToolSecretRequirement, @@ -63,6 +65,38 @@ def func_with_multiple_secret_requirement(): pass +@tool( + desc="A function that requires metadata", + requires_metadata=[ToolMetadataKey.COORDINATOR_URL], +) +def func_with_metadata_requirement(): + pass + + +@tool( + desc="A function that requires multiple metadata fields, deduped case-insensitively", + requires_metadata=[ + ToolMetadataKey.COORDINATOR_URL, + "my_other_metadata_key", + "MY_OTHER_METADATA_KEY", + ], +) +def func_with_multiple_metadata_requirement(): + pass + + +@tool( + desc="A function that requires a metadata field that depends on the tool having an auth requirement", + requires_auth=OAuth2( + id="my_example_provider123", + scopes=["scope1", "scope2"], + ), + requires_metadata=[ToolMetadataKey.CLIENT_ID], +) +def func_with_metadata_and_auth_dependency(): + pass + + @tool( desc="A function that requires authentication", requires_auth=OAuth2( @@ -336,6 +370,43 @@ def func_with_complex_return() -> dict[str, str]: }, id="func_with_multiple_secret_requirement", ), + pytest.param( + func_with_metadata_requirement, + { + "requirements": ToolRequirements( + metadata=[ToolMetadataRequirement(key=ToolMetadataKey.COORDINATOR_URL)] + ) + }, + id="func_with_metadata_requirement", + ), + pytest.param( + func_with_multiple_metadata_requirement, + { + "requirements": ToolRequirements( + metadata=[ + ToolMetadataRequirement(key=ToolMetadataKey.COORDINATOR_URL), + ToolMetadataRequirement(key="my_other_metadata_key"), + ] + ) + }, + id="func_with_multiple_metadata_requirement", + ), + pytest.param( + func_with_metadata_and_auth_dependency, + { + "requirements": ToolRequirements( + metadata=[ToolMetadataRequirement(key=ToolMetadataKey.CLIENT_ID)], + authorization=ToolAuthRequirement( + provider_type="oauth2", + id="my_example_provider123", + oauth2=OAuth2Requirement( + scopes=["scope1", "scope2"], + ), + ), + ) + }, + id="func_with_metadata_and_auth_dependency", + ), pytest.param( func_with_auth_requirement, { diff --git a/arcade/tests/tool/test_create_tool_definition_errors.py b/arcade/tests/tool/test_create_tool_definition_errors.py index d2848b1c..3a3de328 100644 --- a/arcade/tests/tool/test_create_tool_definition_errors.py +++ b/arcade/tests/tool/test_create_tool_definition_errors.py @@ -4,7 +4,7 @@ import pytest from arcade.core.catalog import ToolCatalog from arcade.core.errors import ToolDefinitionError -from arcade.core.schema import ToolContext +from arcade.core.schema import ToolContext, ToolMetadataKey from arcade.sdk import tool @@ -81,6 +81,30 @@ def func_with_secret_requirement_invalid_type(): pass +@tool( + desc="A function with a required metadata with a missing key (illegal)", + requires_metadata=[""], +) +def func_with_missing_metadata_key(context: ToolContext): + pass + + +@tool( + desc="A function that requires metadata with an invalid type (illegal)", + requires_metadata=[True], +) +def func_with_metadata_requirement_invalid_type(): + pass + + +@tool( + desc="A function with a required metadata key that depends on the tool having an auth requirement, but the tool does not have an auth requirement (illegal)", + requires_metadata=[ToolMetadataKey.CLIENT_ID], +) +def func_with_metadata_and_auth_dependency(): + pass + + @pytest.mark.parametrize( "func_under_test, exception_type", [ @@ -139,6 +163,21 @@ def func_with_secret_requirement_invalid_type(): ToolDefinitionError, id=func_with_secret_requirement_invalid_type.__name__, ), + pytest.param( + func_with_missing_metadata_key, + ToolDefinitionError, + id=func_with_missing_metadata_key.__name__, + ), + pytest.param( + func_with_metadata_requirement_invalid_type, + ToolDefinitionError, + id=func_with_metadata_requirement_invalid_type.__name__, + ), + pytest.param( + func_with_metadata_and_auth_dependency, + ToolDefinitionError, + id=func_with_metadata_and_auth_dependency.__name__, + ), pytest.param( func_with_union_return_type_1, ToolDefinitionError, diff --git a/schemas/preview/tool_definition.schema.jsonc b/schemas/preview/tool_definition.schema.jsonc index 410c39cc..ee0dfabe 100644 --- a/schemas/preview/tool_definition.schema.jsonc +++ b/schemas/preview/tool_definition.schema.jsonc @@ -159,6 +159,24 @@ } ] }, + "metadata": { + "oneOf": [ + { "type": "null" }, // Can be unset, + { + "type": "array", + "items": { + "type": "object", + "properties": { + "key": { + "type": "string" + } + }, + "required": ["key"], + "additionalProperties": false + } + } + ] + }, "authorization": { "oneOf": [ { "type": "null" }, // Can be unset