diff --git a/arcade/arcade/core/catalog.py b/arcade/arcade/core/catalog.py index 383eefbf..e5ffbdc0 100644 --- a/arcade/arcade/core/catalog.py +++ b/arcade/arcade/core/catalog.py @@ -41,6 +41,7 @@ from arcade.core.schema import ( ToolkitDefinition, ToolOutput, ToolRequirements, + ToolSecretRequirement, ValueSchema, ) from arcade.core.toolkit import Toolkit @@ -374,6 +375,21 @@ class ToolCatalog(BaseModel): new_auth_requirement.oauth2 = OAuth2Requirement(**auth_requirement.model_dump()) auth_requirement = new_auth_requirement + secrets_requirement = getattr(tool, "__tool_requires_secrets__", None) + if isinstance(secrets_requirement, list): + if any(not isinstance(secret, str) for secret in secrets_requirement): + raise ToolDefinitionError( + f"Secret keys must be strings (error in tool {raw_tool_name})." + ) + + secrets_requirement = to_tool_secret_requirements(secrets_requirement) + if any( + secret.key is None or secret.key.strip() == "" for secret in secrets_requirement + ): + raise ToolDefinitionError( + f"Secrets must have a non-empty key (error in tool {raw_tool_name})." + ) + toolkit_definition = ToolkitDefinition( name=snake_to_pascal_case(toolkit_name), description=toolkit_desc, @@ -393,6 +409,7 @@ class ToolCatalog(BaseModel): output=create_output_definition(tool), requirements=ToolRequirements( authorization=auth_requirement, + secrets=secrets_requirement, ), deprecation_message=deprecation_message, ) @@ -798,3 +815,11 @@ def determine_output_model(func: Callable) -> type[BaseModel]: output_model_name, result=(return_annotation, Field(description="No description provided.")), ) + + +def to_tool_secret_requirements( + secrets_requirement: list[str], +) -> list[ToolSecretRequirement]: + # Iterate through the list, de-dupe case-insensitively, and convert each string to a ToolSecretRequirement + unique_secrets = {name.lower(): name.lower() for name in secrets_requirement}.values() + return [ToolSecretRequirement(key=name) for name in unique_secrets] diff --git a/arcade/arcade/core/schema.py b/arcade/arcade/core/schema.py index daa8087b..6834751d 100644 --- a/arcade/arcade/core/schema.py +++ b/arcade/arcade/core/schema.py @@ -2,7 +2,7 @@ import os from dataclasses import dataclass from typing import Any, Literal, Optional, Union -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, field_validator # allow for custom tool name separator TOOL_NAME_SEPARATOR = os.getenv("ARCADE_TOOL_NAME_SEPARATOR", ".") @@ -102,12 +102,21 @@ class ToolAuthRequirement(BaseModel): """The OAuth 2.0 requirement, if any.""" +class ToolSecretRequirement(BaseModel): + """A requirement for a tool to run.""" + + key: str + """The ID of the secret.""" + + class ToolRequirements(BaseModel): """The requirements for a tool to run.""" authorization: Union[ToolAuthRequirement, None] = None """The authorization requirements for the tool, if any.""" + secrets: Union[list[ToolSecretRequirement], None] = None + class ToolkitDefinition(BaseModel): """The specification of a toolkit.""" @@ -231,19 +240,52 @@ class ToolAuthorizationContext(BaseModel): """ +class ToolSecretItem(BaseModel): + """The context for a tool secret.""" + + key: str + """The key of the secret.""" + + value: str + """The value of the secret.""" + + class ToolContext(BaseModel): """The context for a tool invocation.""" authorization: ToolAuthorizationContext | None = None """The authorization context for the tool invocation that requires authorization.""" + secrets: list[ToolSecretItem] | None = None + """The secrets for the tool invocation.""" + user_id: str | None = None """The user ID for the tool invocation (if any).""" + @field_validator("secrets", mode="before") + def lower_keys(cls, v: dict[str, ToolSecretItem] | None) -> dict[str, ToolSecretItem] | None: + if isinstance(v, dict): + return {k.lower(): value for k, value in v.items()} + return v + def get_auth_token_or_empty(self) -> str: """Retrieve the authorization token, or return an empty string if not available.""" return self.authorization.token if self.authorization and self.authorization.token else "" + def get_secret(self, key: str) -> str: + """Retrieve the secret for the tool invocation.""" + if not key or not key.strip(): + raise ValueError("Secret key ID passed to get_secret cannot be empty.") + + if not self.secrets: + raise ValueError("Secrets not found in context.") + + normalized_key = key.lower() + for secret in self.secrets: + if secret.key == normalized_key: + return secret.value + raise ValueError(f"Secret {key} not found in context.") + class ToolCallRequest(BaseModel): """The request to call (invoke) a tool.""" diff --git a/arcade/arcade/sdk/tool.py b/arcade/arcade/sdk/tool.py index d2a7f0dd..ac9d6b3c 100644 --- a/arcade/arcade/sdk/tool.py +++ b/arcade/arcade/sdk/tool.py @@ -14,6 +14,7 @@ def tool( desc: str | None = None, name: str | None = None, requires_auth: Union[ToolAuthorization, None] = None, + requires_secrets: Union[list[str], None] = None, ) -> Callable: def decorator(func: Callable) -> Callable: func_name = str(getattr(func, "__name__", None)) @@ -22,6 +23,7 @@ def tool( func.__tool_name__ = tool_name # type: ignore[attr-defined] func.__tool_description__ = desc or inspect.cleandoc(func.__doc__ or "") # type: ignore[attr-defined] func.__tool_requires_auth__ = requires_auth # type: ignore[attr-defined] + func.__tool_requires_secrets__ = requires_secrets # type: ignore[attr-defined] if inspect.iscoroutinefunction(func): diff --git a/arcade/pyproject.toml b/arcade/pyproject.toml index 447370b5..91690fb6 100644 --- a/arcade/pyproject.toml +++ b/arcade/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "arcade-ai" -version = "1.1.0" +version = "1.2.0" description = "Arcade Python SDK and CLI" readme = "README.md" packages = [ diff --git a/arcade/tests/core/test_schema.py b/arcade/tests/core/test_schema.py index 5fd7de53..8a91e2f9 100644 --- a/arcade/tests/core/test_schema.py +++ b/arcade/tests/core/test_schema.py @@ -1,4 +1,6 @@ -from arcade.core.schema import ToolAuthorizationContext, ToolContext +import pytest + +from arcade.core.schema import ToolAuthorizationContext, ToolContext, ToolSecretItem def test_get_auth_token_or_empty_with_token(): @@ -22,3 +24,49 @@ def test_get_auth_token_or_empty_no_authorization(): tool_context = ToolContext(authorization=None) assert tool_context.get_auth_token_or_empty() == "" + + +def test_get_secret_valid(): + key = "my_key" + val = "secret_value" + secrets = [ToolSecretItem(key=key, value=val)] + tool_context = ToolContext(secrets=secrets) + + # When the secret exists, get_secret should return its value. + actual_secret = tool_context.get_secret(key) + assert actual_secret == val + + +def test_get_secret_with_case_insensitive_key(): + key = "my_key" + val = "secret_value" + secrets = [ToolSecretItem(key=key, value=val)] + tool_context = ToolContext(secrets=secrets) + + assert tool_context.get_secret(key.upper()) == val + assert tool_context.get_secret(key.lower()) == val + + +def test_get_secret_key_not_found(): + key = "nonexistent_key" + secrets = [ToolSecretItem(key="other_key", value="another_secret")] + tool_context = ToolContext(secrets=secrets) + + # When the key is not found, get_secret should raise a ValueError. + with pytest.raises(ValueError, match=f"Secret {key} not found in context."): + tool_context.get_secret(key) + + +def test_get_secret_when_secrets_is_none(): + tool_context = ToolContext(secrets=None) + + # When no secrets dictionary is provided, get_secret should raise a ValueError. + with pytest.raises(ValueError, match="Secrets not found in context."): + tool_context.get_secret("missing_key") + + +def test_get_secret_with_empty_key(): + tool_context = ToolContext(secrets=[]) + + with pytest.raises(ValueError, match="Secret key ID passed to get_secret cannot be empty."): + tool_context.get_secret("") diff --git a/arcade/tests/tool/test_create_tool_definition.py b/arcade/tests/tool/test_create_tool_definition.py index 27a2ee9f..bc94bed0 100644 --- a/arcade/tests/tool/test_create_tool_definition.py +++ b/arcade/tests/tool/test_create_tool_definition.py @@ -12,6 +12,7 @@ from arcade.core.schema import ( ToolInput, ToolOutput, ToolRequirements, + ToolSecretRequirement, ValueSchema, ) from arcade.core.utils import snake_to_pascal_case @@ -46,6 +47,22 @@ def func_with_name_and_description(): pass +@tool( + desc="A function that requires a secret", + requires_secrets=["my_secret_id"], +) +def func_with_secret_requirement(): + pass + + +@tool( + desc="A function that requires multiple secrets, deduped case-insensitively", + requires_secrets=["my_secret_id", "my_secret_id2", "MY_SECRET_ID"], +) +def func_with_multiple_secret_requirement(): + pass + + @tool( desc="A function that requires authentication", requires_auth=OAuth2( @@ -261,6 +278,23 @@ def func_with_complex_return() -> dict[str, str]: {"requirements": ToolRequirements(auth=None)}, id="func_with_no_auth_requirement", ), + pytest.param( + func_with_secret_requirement, + {"requirements": ToolRequirements(secrets=[ToolSecretRequirement(key="my_secret_id")])}, + id="func_with_secret_requirement", + ), + pytest.param( + func_with_multiple_secret_requirement, + { + "requirements": ToolRequirements( + secrets=[ + ToolSecretRequirement(key="my_secret_id"), + ToolSecretRequirement(key="my_secret_id2"), + ] + ) + }, + id="func_with_multiple_secret_requirement", + ), pytest.param( func_with_auth_requirement, { @@ -269,7 +303,6 @@ def func_with_complex_return() -> dict[str, str]: provider_type="oauth2", id="my_example_provider123", oauth2=OAuth2Requirement( - authority="https://example.com/oauth2/auth", scopes=["scope1", "scope2"], ), ) diff --git a/arcade/tests/tool/test_create_tool_definition_errors.py b/arcade/tests/tool/test_create_tool_definition_errors.py index 959ea184..fa40d712 100644 --- a/arcade/tests/tool/test_create_tool_definition_errors.py +++ b/arcade/tests/tool/test_create_tool_definition_errors.py @@ -41,6 +41,22 @@ def func_with_multiple_context_params(context: ToolContext, context2: ToolContex pass +@tool( + desc="A function with a required secret with a missing key (illegal)", + requires_secrets=[""], +) +def func_with_missing_secret_key(context: ToolContext): + pass + + +@tool( + desc="A function that requires a secret (invalid type)", + requires_secrets=[True], +) +def func_with_secret_requirement_invalid_type(): + pass + + @pytest.mark.parametrize( "func_under_test, exception_type", [ @@ -79,6 +95,16 @@ def func_with_multiple_context_params(context: ToolContext, context2: ToolContex ToolDefinitionError, id=func_with_multiple_context_params.__name__, ), + pytest.param( + func_with_missing_secret_key, + ToolDefinitionError, + id=func_with_missing_secret_key.__name__, + ), + pytest.param( + func_with_secret_requirement_invalid_type, + ToolDefinitionError, + id=func_with_secret_requirement_invalid_type.__name__, + ), ], ) def test_missing_info_raises_error(func_under_test, exception_type): diff --git a/schemas/preview/invoke_tool_request.schema.jsonc b/schemas/preview/invoke_tool_request.schema.jsonc index 4cf21e03..7cfbcd8f 100644 --- a/schemas/preview/invoke_tool_request.schema.jsonc +++ b/schemas/preview/invoke_tool_request.schema.jsonc @@ -35,7 +35,7 @@ "description": "Version of the toolkit containing the tool to call" } }, - "required": ["name", "toolkit_name", "version"], + "required": ["name", "toolkit", "version"], "additionalProperties": false }, "inputs": { @@ -54,6 +54,10 @@ "additionalProperties": false } }, + "secrets": { + "type": "object", + "additionalProperties": true + }, "user_id": { "type": "string", "description": "A unique ID that identifies the user (if any)" @@ -65,6 +69,6 @@ } } }, - "required": ["run_id", "invocation_id", "created_at", "tool", "inputs", "context"], + "required": ["run_id", "execution_id", "created_at", "tool", "inputs", "context"], "additionalProperties": false } diff --git a/schemas/preview/tool_definition.schema.jsonc b/schemas/preview/tool_definition.schema.jsonc index de4da4d9..410c39cc 100644 --- a/schemas/preview/tool_definition.schema.jsonc +++ b/schemas/preview/tool_definition.schema.jsonc @@ -14,8 +14,8 @@ "oneOf": [{ "$ref": "#/$defs/primitives" }, { "type": "string", "enum": ["array"] }] }, "inner_val_type": { - "$ref": "#/$defs/primitives", - "description": "If the value type is a list, the type of the list values." + "description": "If the value type is a list, the type of the list values.", + "oneOf": [{ "$ref": "#/$defs/primitives" }, { "type": "null" }] }, "enum": { "oneOf": [ @@ -141,6 +141,24 @@ "requirements": { "type": "object", "properties": { + "secrets": { + "oneOf": [ + { "type": "null" }, // Can be unset, + { + "type": "array", + "items": { + "type": "object", + "properties": { + "key": { + "type": "string" + } + }, + "required": ["key"], + "additionalProperties": false + } + } + ] + }, "authorization": { "oneOf": [ { "type": "null" }, // Can be unset