diff --git a/libs/arcade-core/arcade_core/metadata.py b/libs/arcade-core/arcade_core/metadata.py index 0390c2a6..498ae490 100644 --- a/libs/arcade-core/arcade_core/metadata.py +++ b/libs/arcade-core/arcade_core/metadata.py @@ -13,10 +13,11 @@ Defines the metadata model for Arcade tools. This module provides three layers: - Extras: Arbitrary key/values for custom logic (IDP routing, feature flags, etc.) """ +import math from enum import Enum from typing import Any -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, ValidatorFunctionWrapHandler, field_validator from arcade_core.errors import ToolDefinitionError @@ -273,7 +274,28 @@ class ToolMetadata(BaseModel): """What effects the tool has.""" extras: dict[str, Any] | None = None - """Arbitrary key/values for custom logic.""" + """Arbitrary key/values for custom logic. Must contain only JSON-native types + (str, int, float, bool, None, dict with string keys, list) at all depths.""" + + @field_validator("extras", mode="wrap") + @classmethod + def _validate_extras_top_level_keys( + cls, v: dict[str, Any] | None, handler: ValidatorFunctionWrapHandler + ) -> dict[str, Any] | None: + """Intercept Pydantic's type validation to give a clear error for + non-string top-level keys. Full recursive JSON-safety validation + (nested keys + value types) is deferred to validate_for_tool() + which is called when the tool definition is created for the catalog.""" + if v is not None and isinstance(v, dict): + bad_keys = {k: type(k).__name__ for k in v if not isinstance(k, str)} + if bad_keys: + examples = ", ".join(f"{k!r} ({t})" for k, t in bad_keys.items()) + raise ToolDefinitionError( + f"All keys in ToolMetadata.extras must be strings. " + f"Found non-string key(s): {examples}. " + ) + result: dict[str, Any] | None = handler(v) + return result strict: bool = Field(default=True, exclude=True) """Enable validation for logical contradictions. Set False for edge cases. @@ -283,13 +305,26 @@ class ToolMetadata(BaseModel): def validate_for_tool(self) -> None: """ - Validate consistency between behavior and classification. + Validate metadata consistency and JSON-safety of extras. Called by the catalog when creating a tool definition. Raises: ToolDefinitionError: If strict=True and validation fails """ + # JSON-safety check on extras + if self.extras is not None: + errors = _find_json_violations(self.extras, "extras") + if errors: + formatted = "; ".join(errors) + raise ToolDefinitionError( + f"ToolMetadata.extras must contain only JSON-safe " + f"types (str, int, float, bool, None, dict, list). " + f"Found violations: {formatted}. " + f"All dict keys must be strings, and all values must be " + f"JSON-native types." + ) + if not self.strict: return @@ -335,3 +370,36 @@ class ToolMetadata(BaseModel): "but is marked open_world=False. " "Fix the contradiction, or set strict=False to bypass." ) + + +def _find_json_violations(obj: Any, path: str) -> list[str]: + """Walk a nested structure and return human-readable descriptions of + any non-JSON-native keys or values. + + JSON-native: str, int, float, bool, None, dict (string keys only), list. + """ + errors: list[str] = [] + if isinstance(obj, dict): + for k, v in obj.items(): + key_path = f"{path}[{k!r}]" + if not isinstance(k, str): + errors.append( + f"{key_path} has a non-string key of type {type(k).__name__} — " + f"all dict keys must be strings" + ) + errors.extend(_find_json_violations(v, key_path)) + elif isinstance(obj, list): + for i, item in enumerate(obj): + errors.extend(_find_json_violations(item, f"{path}[{i}]")) + # non-finite floats + elif isinstance(obj, float) and (math.isnan(obj) or math.isinf(obj)): + errors.append( + f"{path} has a non-JSON-safe float value {obj!r} — " + f"NaN and Infinity are not valid JSON numbers" + ) + # json primitive types + elif not isinstance(obj, (str, int, float, bool, type(None))): + errors.append( + f"{path} has a non-JSON-safe value of type {type(obj).__name__} (got {obj!r})" + ) + return errors diff --git a/libs/arcade-core/pyproject.toml b/libs/arcade-core/pyproject.toml index 5264d3a4..f813f102 100644 --- a/libs/arcade-core/pyproject.toml +++ b/libs/arcade-core/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "arcade-core" -version = "4.4.2" +version = "4.5.0" description = "Arcade Core - Core library for Arcade platform" readme = "README.md" license = { text = "MIT" } diff --git a/libs/tests/tool/test_tool_metadata.py b/libs/tests/tool/test_tool_metadata.py index c8193480..1990026b 100644 --- a/libs/tests/tool/test_tool_metadata.py +++ b/libs/tests/tool/test_tool_metadata.py @@ -1,3 +1,5 @@ +import datetime + import pytest from arcade_core.catalog import ToolCatalog from arcade_core.errors import ToolDefinitionError @@ -26,7 +28,9 @@ class TestEnumCoverage: def test_all_operations_are_categorized(self): """Every Operation must be in _READ_ONLY_OPERATIONS, _MUTATING_OPERATIONS, or _INDETERMINATE_OPERATIONS.""" all_operations = set(Operation) - categorized_operations = _READ_ONLY_OPERATIONS | _MUTATING_OPERATIONS | _INDETERMINATE_OPERATIONS + categorized_operations = ( + _READ_ONLY_OPERATIONS | _MUTATING_OPERATIONS | _INDETERMINATE_OPERATIONS + ) # Check that every operation is categorized uncategorized = all_operations - categorized_operations @@ -160,8 +164,8 @@ class TestToolMetadataValidation: ) assert len(metadata.classification.service_domains) == 2 - def test_extras_accepts_arbitrary_dict(self): - """Extras field accepts arbitrary key/value pairs.""" + def test_extras_accepts_json_native_values(self): + """Extras field accepts JSON-native key/value pairs.""" metadata = ToolMetadata( extras={"idp": "entraID", "requires_mfa": True, "max_requests": 100}, ) @@ -170,6 +174,123 @@ class TestToolMetadataValidation: assert metadata.extras["max_requests"] == 100 +class TestExtrasJsonSafety: + """Test that ToolMetadata.extras enforces JSON-native types at all depths. + + JSON-native types: str, int, float, bool, None, dict (str keys), list. + + Top-level non-string keys are caught at construction time (field_validator). + Nested keys and non-JSON-native values are caught at registration time + (validate_for_tool) where the tool name is available for error context. + """ + + @pytest.mark.parametrize( + "extras", + [ + pytest.param(None, id="none"), + pytest.param({}, id="empty_dict"), + pytest.param( + {"string": "hello", "int": 42, "float": 3.14, "bool": True, "null": None}, + id="flat_json_native_values", + ), + pytest.param({"config": {"api_key": "abc", "retries": 3}}, id="nested_dict"), + pytest.param({"tags": ["a", "b"], "counts": [1, 2, 3]}, id="lists"), + pytest.param( + {"l1": {"l2": [{"l3": [1, "two", None, True, 3.0]}]}}, + id="deeply_nested", + ), + pytest.param( + {"empty_dict": {}, "empty_list": [], "nested": {"also_empty": []}}, + id="empty_nested_structures", + ), + ], + ) + def test_valid_json_safe_extras(self, extras: dict | None): + metadata = ToolMetadata(extras=extras) + assert metadata.extras == extras + metadata.validate_for_tool() + + # --- Top-level non-string keys: caught at construction time --- + + @pytest.mark.parametrize( + "extras", + [ + pytest.param({3: "three"}, id="int_key"), + pytest.param({True: "yes"}, id="bool_key"), + pytest.param({None: "null key"}, id="none_key"), + ], + ) + def test_non_string_top_level_key_rejected_at_construction(self, extras: dict): + with pytest.raises(ToolDefinitionError, match="must be strings"): + ToolMetadata(extras=extras) + + # --- Nested non-string keys + non-JSON values: caught by validate_for_tool --- + + @pytest.mark.parametrize( + "extras, match", + [ + # Non-string keys nested in dicts/lists + pytest.param({"o": {42: "v"}}, "must be strings", id="int_key_nested"), + pytest.param({"o": {True: "v"}}, "must be strings", id="bool_key_nested"), + pytest.param({"o": {(1, 2): "v"}}, "must be strings", id="tuple_key_nested"), + pytest.param({"a": {"b": {42: "v"}}}, "must be strings", id="int_key_deep"), + pytest.param({"items": [{True: "v"}]}, "must be strings", id="bool_key_in_list"), + # Non-JSON-native values at top level + pytest.param({"v": datetime.datetime(2023, 1, 1)}, "JSON-safe", id="datetime_value"), + pytest.param({"v": datetime.date(2023, 1, 1)}, "JSON-safe", id="date_value"), + pytest.param({"v": datetime.timedelta(seconds=60)}, "JSON-safe", id="timedelta_value"), + pytest.param({"v": {1, 2, 3}}, "JSON-safe", id="set_value"), + pytest.param({"v": frozenset([1, 2])}, "JSON-safe", id="frozenset_value"), + pytest.param({"v": (1, 2)}, "JSON-safe", id="tuple_value"), + pytest.param({"v": b"hello"}, "JSON-safe", id="bytes_value"), + # Non-finite floats (not valid JSON per RFC 8259) + pytest.param({"v": float("nan")}, "JSON-safe", id="float_nan"), + pytest.param({"v": float("inf")}, "JSON-safe", id="float_inf"), + pytest.param({"v": float("-inf")}, "JSON-safe", id="float_neg_inf"), + # Non-JSON-native values nested + pytest.param( + {"o": {"i": datetime.datetime(2023, 1, 1)}}, + "JSON-safe", + id="datetime_nested_in_dict", + ), + pytest.param({"items": [1, "ok", {3, 4}]}, "JSON-safe", id="set_nested_in_list"), + pytest.param( + {"a": [{"b": [datetime.date(2023, 1, 1)]}]}, + "JSON-safe", + id="date_deeply_nested", + ), + ], + ) + def test_rejects_non_json_safe_extras_at_validation(self, extras: dict, match: str): + metadata = ToolMetadata(extras=extras) + with pytest.raises(ToolDefinitionError, match=match): + metadata.validate_for_tool() + + # --- Error message quality --- + + def test_error_includes_path_for_nested_violations(self): + metadata = ToolMetadata(extras={"outer": {42: "bad"}}) + with pytest.raises(ToolDefinitionError, match=r"extras\['outer'\]"): + metadata.validate_for_tool() + + metadata = ToolMetadata(extras={"outer": datetime.datetime(2023, 1, 1)}) + with pytest.raises(ToolDefinitionError, match=r"extras\['outer'\]"): + metadata.validate_for_tool() + + def test_error_includes_type_name(self): + metadata = ToolMetadata(extras={"ts": datetime.datetime(2023, 1, 1)}) + with pytest.raises(ToolDefinitionError, match="datetime"): + metadata.validate_for_tool() + + def test_error_reports_all_violations(self): + metadata = ToolMetadata(extras={"ok_key": {True: "bool key"}, "bad": (1, 2)}) + with pytest.raises(ToolDefinitionError) as exc_info: + metadata.validate_for_tool() + msg = str(exc_info.value) + assert "True" in msg + assert "tuple" in msg + + class TestToolDecoratorWithMetadata: """Test @tool decorator with metadata parameter."""