Tool Metadata (#357)
This commit is contained in:
parent
c7fba25488
commit
ad713e4939
8 changed files with 323 additions and 37 deletions
|
|
@ -39,6 +39,8 @@ from arcade.core.schema import (
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolInput,
|
ToolInput,
|
||||||
ToolkitDefinition,
|
ToolkitDefinition,
|
||||||
|
ToolMetadataKey,
|
||||||
|
ToolMetadataRequirement,
|
||||||
ToolOutput,
|
ToolOutput,
|
||||||
ToolRequirements,
|
ToolRequirements,
|
||||||
ToolSecretRequirement,
|
ToolSecretRequirement,
|
||||||
|
|
@ -369,31 +371,9 @@ class ToolCatalog(BaseModel):
|
||||||
if does_function_return_value(tool) and tool.__annotations__.get("return") is None:
|
if does_function_return_value(tool) and tool.__annotations__.get("return") is None:
|
||||||
raise ToolDefinitionError(f"Tool {raw_tool_name} must have a return type annotation")
|
raise ToolDefinitionError(f"Tool {raw_tool_name} must have a return type annotation")
|
||||||
|
|
||||||
auth_requirement = getattr(tool, "__tool_requires_auth__", None)
|
auth_requirement = create_auth_requirement(tool)
|
||||||
if isinstance(auth_requirement, ToolAuthorization):
|
secrets_requirement = create_secrets_requirement(tool)
|
||||||
new_auth_requirement = ToolAuthRequirement(
|
metadata_requirement = create_metadata_requirement(tool, auth_requirement)
|
||||||
provider_id=auth_requirement.provider_id,
|
|
||||||
provider_type=auth_requirement.provider_type,
|
|
||||||
id=auth_requirement.id,
|
|
||||||
)
|
|
||||||
if isinstance(auth_requirement, OAuth2):
|
|
||||||
new_auth_requirement.oauth2 = OAuth2Requirement(**auth_requirement.model_dump())
|
|
||||||
auth_requirement = new_auth_requirement
|
|
||||||
|
|
||||||
secrets_requirement = getattr(tool, "__tool_requires_secrets__", None)
|
|
||||||
if isinstance(secrets_requirement, list):
|
|
||||||
if any(not isinstance(secret, str) for secret in secrets_requirement):
|
|
||||||
raise ToolDefinitionError(
|
|
||||||
f"Secret keys must be strings (error in tool {raw_tool_name})."
|
|
||||||
)
|
|
||||||
|
|
||||||
secrets_requirement = to_tool_secret_requirements(secrets_requirement)
|
|
||||||
if any(
|
|
||||||
secret.key is None or secret.key.strip() == "" for secret in secrets_requirement
|
|
||||||
):
|
|
||||||
raise ToolDefinitionError(
|
|
||||||
f"Secrets must have a non-empty key (error in tool {raw_tool_name})."
|
|
||||||
)
|
|
||||||
|
|
||||||
toolkit_definition = ToolkitDefinition(
|
toolkit_definition = ToolkitDefinition(
|
||||||
name=snake_to_pascal_case(toolkit_name),
|
name=snake_to_pascal_case(toolkit_name),
|
||||||
|
|
@ -415,6 +395,7 @@ class ToolCatalog(BaseModel):
|
||||||
requirements=ToolRequirements(
|
requirements=ToolRequirements(
|
||||||
authorization=auth_requirement,
|
authorization=auth_requirement,
|
||||||
secrets=secrets_requirement,
|
secrets=secrets_requirement,
|
||||||
|
metadata=metadata_requirement,
|
||||||
),
|
),
|
||||||
deprecation_message=deprecation_message,
|
deprecation_message=deprecation_message,
|
||||||
)
|
)
|
||||||
|
|
@ -505,6 +486,77 @@ def create_output_definition(func: Callable) -> ToolOutput:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_auth_requirement(tool: Callable) -> ToolAuthRequirement | None:
|
||||||
|
"""
|
||||||
|
Create an auth requirement for a tool.
|
||||||
|
"""
|
||||||
|
auth_requirement = getattr(tool, "__tool_requires_auth__", None)
|
||||||
|
if isinstance(auth_requirement, ToolAuthorization):
|
||||||
|
new_auth_requirement = ToolAuthRequirement(
|
||||||
|
provider_id=auth_requirement.provider_id,
|
||||||
|
provider_type=auth_requirement.provider_type,
|
||||||
|
id=auth_requirement.id,
|
||||||
|
)
|
||||||
|
if isinstance(auth_requirement, OAuth2):
|
||||||
|
new_auth_requirement.oauth2 = OAuth2Requirement(**auth_requirement.model_dump())
|
||||||
|
auth_requirement = new_auth_requirement
|
||||||
|
|
||||||
|
return auth_requirement
|
||||||
|
|
||||||
|
|
||||||
|
def create_secrets_requirement(tool: Callable) -> list[ToolSecretRequirement] | None:
|
||||||
|
"""
|
||||||
|
Create a secrets requirement for a tool.
|
||||||
|
"""
|
||||||
|
raw_tool_name = getattr(tool, "__tool_name__", tool.__name__)
|
||||||
|
secrets_requirement = getattr(tool, "__tool_requires_secrets__", None)
|
||||||
|
if isinstance(secrets_requirement, list):
|
||||||
|
if any(not isinstance(secret, str) for secret in secrets_requirement):
|
||||||
|
raise ToolDefinitionError(
|
||||||
|
f"Secret keys must be strings (error in tool {raw_tool_name})."
|
||||||
|
)
|
||||||
|
|
||||||
|
secrets_requirement = to_tool_secret_requirements(secrets_requirement)
|
||||||
|
if any(secret.key is None or secret.key.strip() == "" for secret in secrets_requirement):
|
||||||
|
raise ToolDefinitionError(
|
||||||
|
f"Secrets must have a non-empty key (error in tool {raw_tool_name})."
|
||||||
|
)
|
||||||
|
|
||||||
|
return secrets_requirement
|
||||||
|
|
||||||
|
|
||||||
|
def create_metadata_requirement(
|
||||||
|
tool: Callable, auth_requirement: ToolAuthRequirement | None
|
||||||
|
) -> list[ToolMetadataRequirement] | None:
|
||||||
|
"""
|
||||||
|
Create a metadata requirement for a tool.
|
||||||
|
"""
|
||||||
|
raw_tool_name = getattr(tool, "__tool_name__", tool.__name__)
|
||||||
|
metadata_requirement = getattr(tool, "__tool_requires_metadata__", None)
|
||||||
|
if isinstance(metadata_requirement, list):
|
||||||
|
for metadata in metadata_requirement:
|
||||||
|
if not isinstance(metadata, str):
|
||||||
|
raise ToolDefinitionError(
|
||||||
|
f"Metadata must be strings (error in tool {raw_tool_name})."
|
||||||
|
)
|
||||||
|
if ToolMetadataKey.requires_auth(metadata) and auth_requirement is None:
|
||||||
|
raise ToolDefinitionError(
|
||||||
|
f"Tool {raw_tool_name} declares metadata key '{metadata}', "
|
||||||
|
"which requires that the tool has an auth requirement, "
|
||||||
|
"but no auth requirement was provided. Please specify an auth requirement."
|
||||||
|
)
|
||||||
|
|
||||||
|
metadata_requirement = to_tool_metadata_requirements(metadata_requirement)
|
||||||
|
if any(
|
||||||
|
metadata.key is None or metadata.key.strip() == "" for metadata in metadata_requirement
|
||||||
|
):
|
||||||
|
raise ToolDefinitionError(
|
||||||
|
f"Metadata must have a non-empty key (error in tool {raw_tool_name})."
|
||||||
|
)
|
||||||
|
|
||||||
|
return metadata_requirement
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ParamInfo:
|
class ParamInfo:
|
||||||
"""
|
"""
|
||||||
|
|
@ -832,3 +884,11 @@ def to_tool_secret_requirements(
|
||||||
# Iterate through the list, de-dupe case-insensitively, and convert each string to a ToolSecretRequirement
|
# Iterate through the list, de-dupe case-insensitively, and convert each string to a ToolSecretRequirement
|
||||||
unique_secrets = {name.lower(): name.lower() for name in secrets_requirement}.values()
|
unique_secrets = {name.lower(): name.lower() for name in secrets_requirement}.values()
|
||||||
return [ToolSecretRequirement(key=name) for name in unique_secrets]
|
return [ToolSecretRequirement(key=name) for name in unique_secrets]
|
||||||
|
|
||||||
|
|
||||||
|
def to_tool_metadata_requirements(
|
||||||
|
metadata_requirement: list[str],
|
||||||
|
) -> list[ToolMetadataRequirement]:
|
||||||
|
# Iterate through the list, de-dupe case-insensitively, and convert each string to a ToolMetadataRequirement
|
||||||
|
unique_metadata = {name.lower(): name.lower() for name in metadata_requirement}.values()
|
||||||
|
return [ToolMetadataRequirement(key=name) for name in unique_metadata]
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
from typing import Any, Literal, Optional, Union
|
from typing import Any, Literal, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
@ -109,6 +110,26 @@ class ToolSecretRequirement(BaseModel):
|
||||||
"""The ID of the secret."""
|
"""The ID of the secret."""
|
||||||
|
|
||||||
|
|
||||||
|
class ToolMetadataKey(str, Enum):
|
||||||
|
"""Convience enum for commonly used metadata keys."""
|
||||||
|
|
||||||
|
CLIENT_ID = "client_id"
|
||||||
|
COORDINATOR_URL = "coordinator_url"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def requires_auth(key: str) -> bool:
|
||||||
|
"""Whether the key depends on the tool having an authorization requirement."""
|
||||||
|
keys_that_require_auth = [ToolMetadataKey.CLIENT_ID]
|
||||||
|
return key.strip().lower() in keys_that_require_auth
|
||||||
|
|
||||||
|
|
||||||
|
class ToolMetadataRequirement(BaseModel):
|
||||||
|
"""A requirement for a tool to run."""
|
||||||
|
|
||||||
|
key: str
|
||||||
|
"""The ID of the metadata."""
|
||||||
|
|
||||||
|
|
||||||
class ToolRequirements(BaseModel):
|
class ToolRequirements(BaseModel):
|
||||||
"""The requirements for a tool to run."""
|
"""The requirements for a tool to run."""
|
||||||
|
|
||||||
|
|
@ -116,6 +137,10 @@ class ToolRequirements(BaseModel):
|
||||||
"""The authorization requirements for the tool, if any."""
|
"""The authorization requirements for the tool, if any."""
|
||||||
|
|
||||||
secrets: Union[list[ToolSecretRequirement], None] = None
|
secrets: Union[list[ToolSecretRequirement], None] = None
|
||||||
|
"""The secret requirements for the tool, if any."""
|
||||||
|
|
||||||
|
metadata: Union[list[ToolMetadataRequirement], None] = None
|
||||||
|
"""The metadata requirements for the tool, if any."""
|
||||||
|
|
||||||
|
|
||||||
class ToolkitDefinition(BaseModel):
|
class ToolkitDefinition(BaseModel):
|
||||||
|
|
@ -250,6 +275,16 @@ class ToolSecretItem(BaseModel):
|
||||||
"""The value of the secret."""
|
"""The value of the secret."""
|
||||||
|
|
||||||
|
|
||||||
|
class ToolMetadataItem(BaseModel):
|
||||||
|
"""The context for a tool metadata."""
|
||||||
|
|
||||||
|
key: str
|
||||||
|
"""The key of the metadata."""
|
||||||
|
|
||||||
|
value: str
|
||||||
|
"""The value of the metadata."""
|
||||||
|
|
||||||
|
|
||||||
class ToolContext(BaseModel):
|
class ToolContext(BaseModel):
|
||||||
"""The context for a tool invocation."""
|
"""The context for a tool invocation."""
|
||||||
|
|
||||||
|
|
@ -259,6 +294,9 @@ class ToolContext(BaseModel):
|
||||||
secrets: list[ToolSecretItem] | None = None
|
secrets: list[ToolSecretItem] | None = None
|
||||||
"""The secrets for the tool invocation."""
|
"""The secrets for the tool invocation."""
|
||||||
|
|
||||||
|
metadata: list[ToolMetadataItem] | None = None
|
||||||
|
"""The metadata for the tool invocation."""
|
||||||
|
|
||||||
user_id: str | None = None
|
user_id: str | None = None
|
||||||
"""The user ID for the tool invocation (if any)."""
|
"""The user ID for the tool invocation (if any)."""
|
||||||
|
|
||||||
|
|
@ -268,17 +306,27 @@ class ToolContext(BaseModel):
|
||||||
|
|
||||||
def get_secret(self, key: str) -> str:
|
def get_secret(self, key: str) -> str:
|
||||||
"""Retrieve the secret for the tool invocation."""
|
"""Retrieve the secret for the tool invocation."""
|
||||||
if not key or not key.strip():
|
return self._get_item(key, self.secrets, "secret")
|
||||||
raise ValueError("Secret key ID passed to get_secret cannot be empty.")
|
|
||||||
|
|
||||||
if not self.secrets:
|
def get_metadata(self, key: str) -> str:
|
||||||
raise ValueError("Secrets not found in context.")
|
"""Retrieve the metadata for the tool invocation."""
|
||||||
|
return self._get_item(key, self.metadata, "metadata")
|
||||||
|
|
||||||
|
def _get_item(
|
||||||
|
self, key: str, items: list[ToolMetadataItem] | list[ToolSecretItem] | None, item_name: str
|
||||||
|
) -> str:
|
||||||
|
if not key or not key.strip():
|
||||||
|
raise ValueError(
|
||||||
|
f"{item_name.capitalize()} key passed to get_{item_name} cannot be empty."
|
||||||
|
)
|
||||||
|
if not items:
|
||||||
|
raise ValueError(f"{item_name.capitalize()}s not found in context.")
|
||||||
|
|
||||||
normalized_key = key.lower()
|
normalized_key = key.lower()
|
||||||
for secret in self.secrets:
|
for item in items:
|
||||||
if secret.key.lower() == normalized_key:
|
if item.key.lower() == normalized_key:
|
||||||
return secret.value
|
return item.value
|
||||||
raise ValueError(f"Secret {key} not found in context.")
|
raise ValueError(f"{item_name.capitalize()} {key} not found in context.")
|
||||||
|
|
||||||
|
|
||||||
class ToolCallRequest(BaseModel):
|
class ToolCallRequest(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from arcade.core.catalog import ToolCatalog
|
from arcade.core.catalog import ToolCatalog
|
||||||
from arcade.core.schema import ToolAuthorizationContext, ToolContext
|
from arcade.core.schema import ToolAuthorizationContext, ToolContext, ToolMetadataKey
|
||||||
from arcade.core.toolkit import Toolkit
|
from arcade.core.toolkit import Toolkit
|
||||||
|
|
||||||
from .tool import tool
|
from .tool import tool
|
||||||
|
|
@ -8,6 +8,7 @@ __all__ = [
|
||||||
"ToolAuthorizationContext",
|
"ToolAuthorizationContext",
|
||||||
"ToolCatalog",
|
"ToolCatalog",
|
||||||
"ToolContext",
|
"ToolContext",
|
||||||
|
"ToolMetadataKey",
|
||||||
"Toolkit",
|
"Toolkit",
|
||||||
"tool",
|
"tool",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ def tool(
|
||||||
name: str | None = None,
|
name: str | None = None,
|
||||||
requires_auth: Union[ToolAuthorization, None] = None,
|
requires_auth: Union[ToolAuthorization, None] = None,
|
||||||
requires_secrets: Union[list[str], None] = None,
|
requires_secrets: Union[list[str], None] = None,
|
||||||
|
requires_metadata: Union[list[str], None] = None,
|
||||||
) -> Callable:
|
) -> Callable:
|
||||||
def decorator(func: Callable) -> Callable:
|
def decorator(func: Callable) -> Callable:
|
||||||
func_name = str(getattr(func, "__name__", None))
|
func_name = str(getattr(func, "__name__", None))
|
||||||
|
|
@ -24,6 +25,7 @@ def tool(
|
||||||
func.__tool_description__ = desc or inspect.cleandoc(func.__doc__ or "") # type: ignore[attr-defined]
|
func.__tool_description__ = desc or inspect.cleandoc(func.__doc__ or "") # type: ignore[attr-defined]
|
||||||
func.__tool_requires_auth__ = requires_auth # type: ignore[attr-defined]
|
func.__tool_requires_auth__ = requires_auth # type: ignore[attr-defined]
|
||||||
func.__tool_requires_secrets__ = requires_secrets # type: ignore[attr-defined]
|
func.__tool_requires_secrets__ = requires_secrets # type: ignore[attr-defined]
|
||||||
|
func.__tool_requires_metadata__ = requires_metadata # type: ignore[attr-defined]
|
||||||
|
|
||||||
if inspect.iscoroutinefunction(func):
|
if inspect.iscoroutinefunction(func):
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,11 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from arcade.core.schema import ToolAuthorizationContext, ToolContext, ToolSecretItem
|
from arcade.core.schema import (
|
||||||
|
ToolAuthorizationContext,
|
||||||
|
ToolContext,
|
||||||
|
ToolMetadataItem,
|
||||||
|
ToolSecretItem,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_get_auth_token_or_empty_with_token():
|
def test_get_auth_token_or_empty_with_token():
|
||||||
|
|
@ -68,5 +73,47 @@ def test_get_secret_when_secrets_is_none():
|
||||||
def test_get_secret_with_empty_key():
|
def test_get_secret_with_empty_key():
|
||||||
tool_context = ToolContext(secrets=[])
|
tool_context = ToolContext(secrets=[])
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Secret key ID passed to get_secret cannot be empty."):
|
with pytest.raises(ValueError, match="Secret key passed to get_secret cannot be empty."):
|
||||||
tool_context.get_secret("")
|
tool_context.get_secret("")
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_metadata_valid():
|
||||||
|
key = "my_key"
|
||||||
|
val = "metadata_value"
|
||||||
|
metadata = [ToolMetadataItem(key=key, value=val)]
|
||||||
|
tool_context = ToolContext(metadata=metadata)
|
||||||
|
|
||||||
|
assert tool_context.get_metadata(key) == val
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_metadata_with_case_insensitive_key():
|
||||||
|
key = "My_key"
|
||||||
|
val = "metadata_value"
|
||||||
|
metadata = [ToolMetadataItem(key=key, value=val)]
|
||||||
|
tool_context = ToolContext(metadata=metadata)
|
||||||
|
|
||||||
|
assert tool_context.get_metadata(key.upper()) == val
|
||||||
|
assert tool_context.get_metadata(key.lower()) == val
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_metadata_key_not_found():
|
||||||
|
key = "nonexistent_key"
|
||||||
|
metadata = [ToolMetadataItem(key="other_key", value="another_metadata")]
|
||||||
|
tool_context = ToolContext(metadata=metadata)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match=f"Metadata {key} not found in context."):
|
||||||
|
tool_context.get_metadata(key)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_metadata_when_metadata_is_none():
|
||||||
|
tool_context = ToolContext(metadata=None)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Metadatas not found in context."):
|
||||||
|
tool_context.get_metadata("missing_key")
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_metadata_with_empty_key():
|
||||||
|
tool_context = ToolContext(metadata=[])
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Metadata key passed to get_metadata cannot be empty."):
|
||||||
|
tool_context.get_metadata("")
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,8 @@ from arcade.core.schema import (
|
||||||
ToolAuthRequirement,
|
ToolAuthRequirement,
|
||||||
ToolContext,
|
ToolContext,
|
||||||
ToolInput,
|
ToolInput,
|
||||||
|
ToolMetadataKey,
|
||||||
|
ToolMetadataRequirement,
|
||||||
ToolOutput,
|
ToolOutput,
|
||||||
ToolRequirements,
|
ToolRequirements,
|
||||||
ToolSecretRequirement,
|
ToolSecretRequirement,
|
||||||
|
|
@ -63,6 +65,38 @@ def func_with_multiple_secret_requirement():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@tool(
|
||||||
|
desc="A function that requires metadata",
|
||||||
|
requires_metadata=[ToolMetadataKey.COORDINATOR_URL],
|
||||||
|
)
|
||||||
|
def func_with_metadata_requirement():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@tool(
|
||||||
|
desc="A function that requires multiple metadata fields, deduped case-insensitively",
|
||||||
|
requires_metadata=[
|
||||||
|
ToolMetadataKey.COORDINATOR_URL,
|
||||||
|
"my_other_metadata_key",
|
||||||
|
"MY_OTHER_METADATA_KEY",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def func_with_multiple_metadata_requirement():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@tool(
|
||||||
|
desc="A function that requires a metadata field that depends on the tool having an auth requirement",
|
||||||
|
requires_auth=OAuth2(
|
||||||
|
id="my_example_provider123",
|
||||||
|
scopes=["scope1", "scope2"],
|
||||||
|
),
|
||||||
|
requires_metadata=[ToolMetadataKey.CLIENT_ID],
|
||||||
|
)
|
||||||
|
def func_with_metadata_and_auth_dependency():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@tool(
|
@tool(
|
||||||
desc="A function that requires authentication",
|
desc="A function that requires authentication",
|
||||||
requires_auth=OAuth2(
|
requires_auth=OAuth2(
|
||||||
|
|
@ -336,6 +370,43 @@ def func_with_complex_return() -> dict[str, str]:
|
||||||
},
|
},
|
||||||
id="func_with_multiple_secret_requirement",
|
id="func_with_multiple_secret_requirement",
|
||||||
),
|
),
|
||||||
|
pytest.param(
|
||||||
|
func_with_metadata_requirement,
|
||||||
|
{
|
||||||
|
"requirements": ToolRequirements(
|
||||||
|
metadata=[ToolMetadataRequirement(key=ToolMetadataKey.COORDINATOR_URL)]
|
||||||
|
)
|
||||||
|
},
|
||||||
|
id="func_with_metadata_requirement",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
func_with_multiple_metadata_requirement,
|
||||||
|
{
|
||||||
|
"requirements": ToolRequirements(
|
||||||
|
metadata=[
|
||||||
|
ToolMetadataRequirement(key=ToolMetadataKey.COORDINATOR_URL),
|
||||||
|
ToolMetadataRequirement(key="my_other_metadata_key"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
},
|
||||||
|
id="func_with_multiple_metadata_requirement",
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
func_with_metadata_and_auth_dependency,
|
||||||
|
{
|
||||||
|
"requirements": ToolRequirements(
|
||||||
|
metadata=[ToolMetadataRequirement(key=ToolMetadataKey.CLIENT_ID)],
|
||||||
|
authorization=ToolAuthRequirement(
|
||||||
|
provider_type="oauth2",
|
||||||
|
id="my_example_provider123",
|
||||||
|
oauth2=OAuth2Requirement(
|
||||||
|
scopes=["scope1", "scope2"],
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
},
|
||||||
|
id="func_with_metadata_and_auth_dependency",
|
||||||
|
),
|
||||||
pytest.param(
|
pytest.param(
|
||||||
func_with_auth_requirement,
|
func_with_auth_requirement,
|
||||||
{
|
{
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import pytest
|
||||||
|
|
||||||
from arcade.core.catalog import ToolCatalog
|
from arcade.core.catalog import ToolCatalog
|
||||||
from arcade.core.errors import ToolDefinitionError
|
from arcade.core.errors import ToolDefinitionError
|
||||||
from arcade.core.schema import ToolContext
|
from arcade.core.schema import ToolContext, ToolMetadataKey
|
||||||
from arcade.sdk import tool
|
from arcade.sdk import tool
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -81,6 +81,30 @@ def func_with_secret_requirement_invalid_type():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@tool(
|
||||||
|
desc="A function with a required metadata with a missing key (illegal)",
|
||||||
|
requires_metadata=[""],
|
||||||
|
)
|
||||||
|
def func_with_missing_metadata_key(context: ToolContext):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@tool(
|
||||||
|
desc="A function that requires metadata with an invalid type (illegal)",
|
||||||
|
requires_metadata=[True],
|
||||||
|
)
|
||||||
|
def func_with_metadata_requirement_invalid_type():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@tool(
|
||||||
|
desc="A function with a required metadata key that depends on the tool having an auth requirement, but the tool does not have an auth requirement (illegal)",
|
||||||
|
requires_metadata=[ToolMetadataKey.CLIENT_ID],
|
||||||
|
)
|
||||||
|
def func_with_metadata_and_auth_dependency():
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"func_under_test, exception_type",
|
"func_under_test, exception_type",
|
||||||
[
|
[
|
||||||
|
|
@ -139,6 +163,21 @@ def func_with_secret_requirement_invalid_type():
|
||||||
ToolDefinitionError,
|
ToolDefinitionError,
|
||||||
id=func_with_secret_requirement_invalid_type.__name__,
|
id=func_with_secret_requirement_invalid_type.__name__,
|
||||||
),
|
),
|
||||||
|
pytest.param(
|
||||||
|
func_with_missing_metadata_key,
|
||||||
|
ToolDefinitionError,
|
||||||
|
id=func_with_missing_metadata_key.__name__,
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
func_with_metadata_requirement_invalid_type,
|
||||||
|
ToolDefinitionError,
|
||||||
|
id=func_with_metadata_requirement_invalid_type.__name__,
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
func_with_metadata_and_auth_dependency,
|
||||||
|
ToolDefinitionError,
|
||||||
|
id=func_with_metadata_and_auth_dependency.__name__,
|
||||||
|
),
|
||||||
pytest.param(
|
pytest.param(
|
||||||
func_with_union_return_type_1,
|
func_with_union_return_type_1,
|
||||||
ToolDefinitionError,
|
ToolDefinitionError,
|
||||||
|
|
|
||||||
|
|
@ -159,6 +159,24 @@
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
|
"metadata": {
|
||||||
|
"oneOf": [
|
||||||
|
{ "type": "null" }, // Can be unset,
|
||||||
|
{
|
||||||
|
"type": "array",
|
||||||
|
"items": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"key": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["key"],
|
||||||
|
"additionalProperties": false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
"authorization": {
|
"authorization": {
|
||||||
"oneOf": [
|
"oneOf": [
|
||||||
{ "type": "null" }, // Can be unset
|
{ "type": "null" }, // Can be unset
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue