arcade-mcp/libs/tests/tool/test_create_tool_definition.py
Sam Partee 27a6cd31a3
Support Tool Output in ValueSchema of ToolDefinition (#487)
## Before

### Tool: ``GoogleNews.SearchNewsStories``

```python
@tool(requires_secrets=["SERP_API_KEY"])
async def search_news_stories(
    context: ToolContext,
    keywords: Annotated[
        str,
        "Keywords to search for news articles. E.g. 'Apple launches new iPhone'.",
    ],
    country_code: Annotated[
        CountryCode | None,
        "2-character country code to search for news articles. "
        "E.g. 'us' (United States). "
        f"Defaults to '{DEFAULT_GOOGLE_NEWS_COUNTRY}'.",
    ] = None,
    language_code: Annotated[
        LanguageCode,
        "2-character language code to search for news articles. E.g. 'en' (English). "
        f"Defaults to '{DEFAULT_GOOGLE_NEWS_LANGUAGE}'.",
    ] = DEFAULT_GOOGLE_NEWS_LANGUAGE,
    limit: Annotated[
        int | None,
        "Maximum number of news articles to return. Defaults to None "
        "(returns all results found by the API).",
    ] = None,
) -> Annotated[dict[str, Any]]:
    """Search for news articles related to a given query."""
    ...
```


### Tool Definition: ``GoogleNews.SearchNewsStories``
```
  {
    "name": "SearchNewsStories",
    "fully_qualified_name": "GoogleNews.SearchNewsStories",
    "description": "Search for news articles related to a given query.",
    "toolkit": {
      "name": "GoogleNews",
      "description": "Arcade.dev LLM tools for getting new via Google News",
      "version": "2.0.0"
    },
    "input": {
      "parameters": [
        {
          "name": "keywords",
          "required": true,
          "description": "Keywords to search for news articles. E.g. 'Apple launches new iPhone'.",
          "value_schema": {
            "val_type": "string",
            "inner_val_type": null,
            "enum": null,
          },
          "inferrable": true
        },
        {
          "name": "country_code",
          "required": false,
          "description": "2-character country code to search for news articles. E.g. 'us' (United States). Defaults to 'None'.",
          "value_schema": {
            "val_type": "string",
            "inner_val_type": null,
            "enum": null,
          },
          "inferrable": true
        },
        {
          "name": "language_code",
          "required": false,
          "description": "2-character language code to search for news articles. E.g. 'en' (English). Defaults to 'en'.",
          "value_schema": {
            "val_type": "string",
            "inner_val_type": null,
            "enum": null,
          },
          "inferrable": true
        },
        {
          "name": "limit",
          "required": false,
          "description": "Maximum number of news articles to return. Defaults to None (returns all results found by the API).",
          "value_schema": {
            "val_type": "integer",
            "inner_val_type": null,
            "enum": null,

          },
          "inferrable": true
        }
      ]
    },
    "output": {
      "description": "News search results with article details.",
      "available_modes": [
        "value",
        "error"
      ],
      "value_schema": {
        "val_type": "json"
      }
    },
    "requirements": {
      "authorization": null,
      "secrets": [
        {
          "key": "serp_api_key"
        }
      ],
      "metadata": null
    },
    "deprecation_message": null
  },
```

## After

### Enhanced Tool: ``GoogleNews.SearchNewsStories``

```python

"""Type definitions for Google News API responses and parameters."""

from typing_extensions import TypedDict

CountryCode = str
LanguageCode = str


class SearchNewsParams(TypedDict):
    """Input parameters for searching news articles."""

    keywords: str
    """Search query terms to find relevant news articles \
    (e.g., 'Apple launches new iPhone')."""

    country_code: CountryCode | None
    """Optional 2-letter country code to filter news by region \
    (e.g., 'us' for United States, 'uk' for United Kingdom)."""

    language_code: LanguageCode | None
    """Optional 2-letter language code to filter news by language \
    (e.g., 'en' for English, 'es' for Spanish)."""

    limit: int | None
    """Optional maximum number of news articles to return. \
    If not specified, returns all results from the API."""


class SourceInfo(TypedDict, total=False):
    """Information about the news source/publication."""

    name: str
    """Name of the publication (e.g., 'CNN', 'BBC News', 'The New York Times')."""

    icon: str
    """URL to the source's favicon or logo image."""

    authors: list[str]
    """List of author names for the article, if available."""


class NewsResult(TypedDict, total=False):
    """Individual news article from the Google News API response."""

    position: int
    """Ranking position of this result in the search results."""

    title: str
    """Headline or title of the news article."""

    link: str
    """Full URL to the original news article."""

    source: SourceInfo
    """Information about the publication source."""

    date: str
    """Publication date and time (e.g., '2 hours ago', 'Dec 15, 2023')."""

    snippet: str
    """Brief excerpt or summary from the article content."""

    thumbnail: str
    """URL to a high-resolution thumbnail image for the article."""

    thumbnail_small: str
    """URL to a low-resolution thumbnail image for the article."""

    story_token: str
    """Token for accessing full coverage of this news story across multiple sources."""

    stories: list["NewsResult"]
    """Related news stories from other sources covering the same topic."""

    highlight: dict
    """Additional highlighted information about the story."""


class SearchMetadata(TypedDict, total=False):
    """Metadata about the search request and processing."""

    id: str
    """Unique identifier for this search request within SerpApi."""

    status: str
    """Current processing status ('Processing', 'Success', or 'Error')."""

    json_endpoint: str
    """URL to retrieve the JSON results for this search."""

    created_at: str
    """Timestamp when the search request was created."""

    processed_at: str
    """Timestamp when the search request was processed."""

    google_news_url: str
    """Original Google News URL that would return these results."""

    total_time_taken: float
    """Total time in seconds taken to process this search."""


class SearchParameters(TypedDict, total=False):
    """Parameters used for the search request."""

    engine: str
    """Search engine used (always 'google_news' for this API)."""

    q: str
    """Search query string."""

    gl: str
    """Country code used for geographic filtering."""

    hl: str
    """Language code used for language filtering."""

    topic_token: str
    """Token for accessing specific news topics (e.g., 'World', 'Business', 'Technology')."""

    publication_token: str
    """Token for accessing news from specific publishers."""


class MenuLink(TypedDict):
    """Navigation link for news categories or topics."""

    title: str
    """Display text for the menu item (e.g., 'Technology', 'Sports', 'Business')."""

    topic_token: str
    """Token to access this specific topic or category."""

    serpapi_link: str
    """SerpApi URL to search within this topic."""


class TopStoriesLink(TypedDict):
    """Link to top stories section."""

    topic_token: str
    """Token to access top stories."""

    serpapi_link: str
    """SerpApi URL to retrieve top stories."""


class GoogleNewsResponse(TypedDict, total=False):
    """Complete response from the Google News API."""

    search_metadata: SearchMetadata
    """Metadata about the search request and processing."""

    search_parameters: SearchParameters
    """Parameters that were used for this search."""

    news_results: list[NewsResult]
    """List of news articles matching the search criteria."""

    menu_links: list[MenuLink]
    """Navigation links to different news categories and topics."""

    top_stories_link: TopStoriesLink
    """Link to access top stories."""

    title: str
    """Title of the page or topic being displayed."""


class SimplifiedNewsResult(TypedDict):
    """Simplified news article format for tool output."""

    title: str
    """Headline of the news article."""

    link: str
    """URL to the full article."""

    source: str | None
    """Name of the publication source."""

    date: str | None
    """When the article was published."""

    snippet: str | None
    """Brief excerpt from the article."""


class SearchNewsOutput(TypedDict):
    """Output format for the search_news_stories tool."""

    news_results: list[SimplifiedNewsResult]
    """List of news articles in simplified format."""

@tool(requires_secrets=["SERP_API_KEY"])
async def search_news_stories(
    context: ToolContext,
    keywords: Annotated[
        str,
        "Keywords to search for news articles. E.g. 'Apple launches new iPhone'.",
    ],
    country_code: Annotated[
        CountryCode | None,
        "2-character country code to search for news articles. "
        "E.g. 'us' (United States). "
        f"Defaults to '{DEFAULT_GOOGLE_NEWS_COUNTRY}'.",
    ] = None,
    language_code: Annotated[
        LanguageCode,
        "2-character language code to search for news articles. E.g. 'en' (English). "
        f"Defaults to '{DEFAULT_GOOGLE_NEWS_LANGUAGE}'.",
    ] = DEFAULT_GOOGLE_NEWS_LANGUAGE,
    limit: Annotated[
        int | None,
        "Maximum number of news articles to return. Defaults to None "
        "(returns all results found by the API).",
    ] = None,
) -> Annotated[SearchNewsOutput, "News search results with article details."]:
    """Search for news articles related to a given query."""
    ...

```

### Enhanced Tool Definition: ``GoogleNews.SearchNewsStories`` 

```json

  {
    "name": "SearchNewsStories",
    "fully_qualified_name": "GoogleNews.SearchNewsStories",
    "description": "Search for news articles related to a given query.",
    "toolkit": {
      "name": "GoogleNews",
      "description": "Arcade.dev LLM tools for getting new via Google News",
      "version": "2.0.0"
    },
    "input": {
      "parameters": [
        {
          "name": "keywords",
          "required": true,
          "description": "Keywords to search for news articles. E.g. 'Apple launches new iPhone'.",
          "value_schema": {
            "val_type": "string",
            "inner_val_type": null,
            "enum": null,
            "properties": null,
            "inner_properties": null,
            "description": null
          },
          "inferrable": true
        },
        {
          "name": "country_code",
          "required": false,
          "description": "2-character country code to search for news articles. E.g. 'us' (United States). Defaults to 'None'.",
          "value_schema": {
            "val_type": "string",
            "inner_val_type": null,
            "enum": null,
            "properties": null,
            "inner_properties": null,
            "description": null
          },
          "inferrable": true
        },
        {
          "name": "language_code",
          "required": false,
          "description": "2-character language code to search for news articles. E.g. 'en' (English). Defaults to 'en'.",
          "value_schema": {
            "val_type": "string",
            "inner_val_type": null,
            "enum": null,
            "properties": null,
            "inner_properties": null,
            "description": null
          },
          "inferrable": true
        },
        {
          "name": "limit",
          "required": false,
          "description": "Maximum number of news articles to return. Defaults to None (returns all results found by the API).",
          "value_schema": {
            "val_type": "integer",
            "inner_val_type": null,
            "enum": null,
            "properties": null,
            "inner_properties": null,
            "description": null
          },
          "inferrable": true
        }
      ]
    },
    "output": {
      "description": "News search results with article details.",
      "available_modes": [
        "value",
        "error"
      ],
      "value_schema": {
        "val_type": "json",
        "inner_val_type": null,
        "enum": null,
        "properties": {
          "news_results": {
            "val_type": "array",
            "inner_val_type": "json",
            "enum": null,
            "properties": null,
            "inner_properties": {
              "title": {
                "val_type": "string",
                "inner_val_type": null,
                "enum": null,
                "properties": null,
                "inner_properties": null,
                "description": "Headline of the news article."
              },
              "link": {
                "val_type": "string",
                "inner_val_type": null,
                "enum": null,
                "properties": null,
                "inner_properties": null,
                "description": "URL to the full article."
              },
              "source": {
                "val_type": "string",
                "inner_val_type": null,
                "enum": null,
                "properties": null,
                "inner_properties": null,
                "description": "Name of the publication source."
              },
              "date": {
                "val_type": "string",
                "inner_val_type": null,
                "enum": null,
                "properties": null,
                "inner_properties": null,
                "description": "When the article was published."
              },
              "snippet": {
                "val_type": "string",
                "inner_val_type": null,
                "enum": null,
                "properties": null,
                "inner_properties": null,
                "description": "Brief excerpt from the article."
              }
            },
            "description": "List of news articles in simplified format."
          }
        },
        "inner_properties": null,
        "description": null
      }
    },
    "requirements": {
      "authorization": null,
      "secrets": [
        {
          "key": "serp_api_key"
        }
      ],
      "metadata": null
    },
    "deprecation_message": null
  },

```

---------

Co-authored-by: Eric Gustin <eric@arcade.dev>
2025-07-24 15:32:35 -07:00

981 lines
32 KiB
Python

from enum import Enum
from typing import Annotated, Literal
import pytest
from arcade_core.catalog import ToolCatalog
from arcade_core.schema import (
InputParameter,
OAuth2Requirement,
ToolAuthRequirement,
ToolContext,
ToolInput,
ToolMetadataKey,
ToolMetadataRequirement,
ToolOutput,
ToolRequirements,
ToolSecretRequirement,
ValueSchema,
)
from arcade_core.utils import snake_to_pascal_case
from arcade_tdk import tool
from arcade_tdk.annotations import Inferrable
from arcade_tdk.auth import GitHub, Google, OAuth2, Slack, X
### Tests on @tool decorator
@tool(desc="A function with a description")
def func_with_description():
pass
@tool
def func_with_docstring_description():
"""Docstring description"""
pass
@tool
def func_with_multiline_docstring_description():
"""
Docstring description
on multiple lines
"""
pass
@tool(name="MyCustomTool", desc="A function with a very cool description")
def func_with_name_and_description():
pass
@tool(
desc="A function that requires a secret",
requires_secrets=["my_secret_id"],
)
def func_with_secret_requirement():
pass
@tool(
desc="A function that requires multiple secrets, deduped case-insensitively",
requires_secrets=["my_secret_id", "my_secret_id2", "MY_SECRET_ID"],
)
def func_with_multiple_secret_requirement():
pass
@tool(
desc="A function that requires metadata",
requires_metadata=[ToolMetadataKey.COORDINATOR_URL],
)
def func_with_metadata_requirement():
pass
@tool(
desc="A function that requires multiple metadata fields, deduped case-insensitively",
requires_metadata=[
ToolMetadataKey.COORDINATOR_URL,
"my_other_metadata_key",
"MY_OTHER_METADATA_KEY",
],
)
def func_with_multiple_metadata_requirement():
pass
@tool(
desc="A function that requires a metadata field that depends on the tool having an auth requirement",
requires_auth=OAuth2(
id="my_example_provider123",
scopes=["scope1", "scope2"],
),
requires_metadata=[ToolMetadataKey.CLIENT_ID],
)
def func_with_metadata_and_auth_dependency():
pass
@tool(
desc="A function that requires authentication",
requires_auth=OAuth2(
id="my_example_provider123",
scopes=["scope1", "scope2"],
),
)
def func_with_auth_requirement():
pass
@tool(
desc="A function that requires Google authorization",
requires_auth=Google(
id="my_google_provider123",
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
),
)
def func_with_google_auth_requirement():
pass
@tool(
desc="A function that requires GitHub authorization",
requires_auth=GitHub(
id="my_github_provider123",
),
)
def func_with_github_auth_requirement():
pass
@tool(
desc="A function that requires Slack user authorization",
requires_auth=Slack(
scopes=["chat:write", "channels:history"],
),
)
def func_with_slack_user_auth_requirement():
pass
@tool(
desc="A function that requires X (Twitter) authorization",
requires_auth=X(
scopes=["tweet.write"],
),
)
def func_with_x_requirement():
pass
### Tests on input params
@tool(desc="A function with a non-inferrable input parameter")
def func_with_non_inferrable_param(param1: Annotated[str, "First param", Inferrable(False)]):
pass
# Two string annotations on an input parameter is understood to be name, description
@tool(desc="A function with a renamed input parameter")
def func_with_renamed_param(param1: Annotated[str, "ParamOne", "First param"]):
pass
class MyEnum(Enum):
FOO_BAR = "foo bar"
BAZ = "baz"
@tool(desc="A function that takes an enum")
def func_with_enum_param(param1: Annotated[MyEnum, "an enum"]):
pass
@tool(desc="A function that takes a dictionary")
def func_with_dict_param(param1: Annotated[dict, "a cool dictionary"]):
pass
@tool(desc="A function that takes a string literal")
def func_with_string_literal_param(param1: Annotated[Literal["value1", "value2"], "a few choices"]):
pass
@tool(desc="A function with an input parameter with a default value (considered optional)")
def func_with_param_with_default(param1: Annotated[str, "First param"] = "default"):
pass
@tool(desc="A function with an optional input parameter")
def func_with_optional_param(param1: Annotated[str | None, "First param"]):
pass
@tool(desc="A function with an optional input parameter (default: None)")
def func_with_optional_param_with_default_None(
param1: Annotated[str | None, "First param"] = None,
):
pass
@tool(desc="A function with an optional input parameter with default value")
def func_with_optional_param_with_default_value(
param1: Annotated[str | None, "First param"] = "default",
):
pass
@tool(desc="A function with an optional input parameter with bar syntax")
def func_with_optional_param_with_bar_syntax_1(
param1: Annotated[str | None, "First param"] = None,
):
pass
@tool(desc="A function with an optional input parameter with bar syntax")
def func_with_optional_param_with_bar_syntax_2(
param1: Annotated[None | str, "First param"] = None,
):
pass
@tool(desc="A function with an optional input parameter with union syntax")
def func_with_optional_param_with_union_syntax_1(
param1: Annotated[str | None, "First param"] = None,
):
pass
@tool(desc="A function with an optional input parameter with union syntax")
def func_with_optional_param_with_union_syntax_2(
param1: Annotated[None | str, "First param"] = None,
):
pass
@tool(desc="A function with multiple parameters, some with default values")
def func_with_mixed_params(
context: ToolContext,
param1: Annotated[str, "First param"],
param2: Annotated[int, "Second param"] = 42,
):
pass
@tool(desc="A function with a list[str] parameter")
def func_with_list_param(param1: Annotated[list[str], "A list of strings"]):
pass
@tool(desc="A function with a list[float] parameter")
def func_with_list_float_param(param1: Annotated[list[float], "A list of floats"]):
pass
@tool(desc="A function with a list of enums parameter")
def func_with_list_of_enums_param(param1: Annotated[list[MyEnum], "A list of enums"]):
pass
@tool(desc="A function with a complex parameter type")
def func_with_complex_param(
param1: Annotated[dict[str, list[int]], "A dictionary with lists of integers"],
):
pass
@tool(desc="A function that takes a context")
def func_with_context(my_context: ToolContext):
pass
### Tests on output/return values
@tool(desc="A function that returns a list of strings")
def func_with_list_return() -> list[str]:
return ["output1", "output2"]
@tool(desc="A function that returns a known list of string literals")
def func_with_known_list_return() -> Literal["value1", "value2"]:
return "value1"
@tool(desc="A function that returns an enum")
def func_with_enum_return() -> MyEnum:
return MyEnum.FOO_BAR
@tool(desc="A function with an annotated return type")
def func_with_annotated_return() -> Annotated[str, "Annotated return description"]:
return "output"
@tool(desc="A function with an optional return type")
def func_with_optional_return() -> str | None:
return "maybe output"
@tool(desc="A function with an optional return type that uses bar syntax")
def func_with_optional_return_with_bar_syntax_1() -> str | None:
return "maybe output"
@tool(desc="A function with an optional return type that uses bar syntax")
def func_with_optional_return_with_bar_syntax_2() -> None | str:
return "maybe output"
@tool(desc="A function with an optional return type that uses union syntax")
def func_with_optional_return_with_union_syntax_1() -> str | None:
return "maybe output"
@tool(desc="A function with an optional return type that uses union syntax")
def func_with_optional_return_with_union_syntax_2() -> None | str:
return "maybe output"
@tool(desc="A function with a complex return type")
def func_with_complex_return() -> dict[str, str]:
return [{"key": "value"}]
@pytest.mark.parametrize(
"func_under_test, expected_tool_def_fields",
[
# Tests on @tool decorator
pytest.param(
func_with_description,
{"description": "A function with a description"},
id="func_with_description",
),
pytest.param(
func_with_docstring_description,
{"description": "Docstring description"},
id="func_with_docstring_description",
),
pytest.param(
func_with_multiline_docstring_description,
{"description": "Docstring description\non multiple lines"},
id="func_with_multiline_docstring_description",
),
pytest.param(
func_with_name_and_description,
{
"name": "MyCustomTool",
"fully_qualified_name": "TestToolkit.MyCustomTool",
"description": "A function with a very cool description",
},
id="func_with_description_and_name",
),
pytest.param(
func_with_name_and_description,
{"requirements": ToolRequirements(auth=None)},
id="func_with_no_auth_requirement",
),
pytest.param(
func_with_secret_requirement,
{"requirements": ToolRequirements(secrets=[ToolSecretRequirement(key="my_secret_id")])},
id="func_with_secret_requirement",
),
pytest.param(
func_with_multiple_secret_requirement,
{
"requirements": ToolRequirements(
secrets=[
ToolSecretRequirement(key="my_secret_id"),
ToolSecretRequirement(key="my_secret_id2"),
]
)
},
id="func_with_multiple_secret_requirement",
),
pytest.param(
func_with_metadata_requirement,
{
"requirements": ToolRequirements(
metadata=[ToolMetadataRequirement(key=ToolMetadataKey.COORDINATOR_URL)]
)
},
id="func_with_metadata_requirement",
),
pytest.param(
func_with_multiple_metadata_requirement,
{
"requirements": ToolRequirements(
metadata=[
ToolMetadataRequirement(key=ToolMetadataKey.COORDINATOR_URL),
ToolMetadataRequirement(key="my_other_metadata_key"),
]
)
},
id="func_with_multiple_metadata_requirement",
),
pytest.param(
func_with_metadata_and_auth_dependency,
{
"requirements": ToolRequirements(
metadata=[ToolMetadataRequirement(key=ToolMetadataKey.CLIENT_ID)],
authorization=ToolAuthRequirement(
provider_type="oauth2",
id="my_example_provider123",
oauth2=OAuth2Requirement(
scopes=["scope1", "scope2"],
),
),
)
},
id="func_with_metadata_and_auth_dependency",
),
pytest.param(
func_with_auth_requirement,
{
"requirements": ToolRequirements(
authorization=ToolAuthRequirement(
provider_type="oauth2",
id="my_example_provider123",
oauth2=OAuth2Requirement(
scopes=["scope1", "scope2"],
),
)
)
},
id="func_with_auth_requirement",
),
pytest.param(
func_with_google_auth_requirement,
{
"requirements": ToolRequirements(
authorization=ToolAuthRequirement(
provider_id="google",
provider_type="oauth2",
id="my_google_provider123",
oauth2=OAuth2Requirement(
scopes=["https://www.googleapis.com/auth/gmail.readonly"],
),
)
)
},
id="func_with_google_auth_requirement",
),
pytest.param(
func_with_github_auth_requirement,
{
"requirements": ToolRequirements(
authorization=ToolAuthRequirement(
provider_id="github",
provider_type="oauth2",
id="my_github_provider123",
oauth2=OAuth2Requirement(),
)
)
},
id="func_with_github_auth_requirement",
),
pytest.param(
func_with_slack_user_auth_requirement,
{
"requirements": ToolRequirements(
authorization=ToolAuthRequirement(
provider_id="slack",
provider_type="oauth2",
oauth2=OAuth2Requirement(
scopes=["chat:write", "channels:history"],
),
)
)
},
id="func_with_slack_user_auth_requirement",
),
pytest.param(
func_with_x_requirement,
{
"requirements": ToolRequirements(
authorization=ToolAuthRequirement(
provider_id="x",
provider_type="oauth2",
oauth2=OAuth2Requirement(
scopes=["tweet.write"],
),
)
)
},
),
# Tests on input params
pytest.param(
func_with_non_inferrable_param,
{
"input": ToolInput(
parameters=[
InputParameter(
name="param1",
description="First param",
inferrable=False, # Set using Inferrable(False)
required=True,
value_schema=ValueSchema(val_type="string", enum=None),
)
]
),
},
id="func_with_non_inferrable_param",
),
pytest.param(
func_with_renamed_param,
{
"input": ToolInput(
parameters=[
InputParameter(
name="ParamOne",
description="First param",
inferrable=True,
required=True,
value_schema=ValueSchema(val_type="string", enum=None),
)
]
),
},
id="func_with_renamed_param",
),
pytest.param(
func_with_enum_param,
{
"input": ToolInput(
parameters=[
InputParameter(
name="param1",
description="an enum",
inferrable=True,
required=True,
value_schema=ValueSchema(val_type="string", enum=["foo bar", "baz"]),
)
]
),
},
id="func_with_enum_param",
),
pytest.param(
func_with_dict_param,
{
"input": ToolInput(
parameters=[
InputParameter(
name="param1",
description="a cool dictionary",
inferrable=True,
required=True,
value_schema=ValueSchema(val_type="json", enum=None),
)
]
),
},
id="func_with_dict_param",
),
pytest.param(
func_with_string_literal_param,
{
"input": ToolInput(
parameters=[
InputParameter(
name="param1",
description="a few choices",
inferrable=True,
required=True,
value_schema=ValueSchema(val_type="string", enum=["value1", "value2"]),
)
]
),
},
id="func_with_string_enum_param",
),
pytest.param(
func_with_param_with_default,
{
"input": ToolInput(
parameters=[
InputParameter(
name="param1",
description="First param",
inferrable=True,
required=False, # Because a default value is provided
value_schema=ValueSchema(val_type="string", enum=None),
)
]
),
"output": ToolOutput(
available_modes=["null"], description="No description provided."
),
},
id="func_with_param_with_default",
),
pytest.param(
func_with_optional_param,
{
"input": ToolInput(
parameters=[
InputParameter(
name="param1",
description="First param",
inferrable=True,
required=False, # Because of Optional[str]
value_schema=ValueSchema(val_type="string", enum=None),
)
]
),
"output": ToolOutput(
available_modes=["null"], description="No description provided."
),
},
id="func_with_optional_param",
),
pytest.param(
func_with_optional_param_with_default_None,
{
"input": ToolInput(
parameters=[
InputParameter(
name="param1",
description="First param",
inferrable=True,
required=False, # Because of Optional[str]
value_schema=ValueSchema(val_type="string", enum=None),
)
]
),
"output": ToolOutput(
available_modes=["null"], description="No description provided."
),
},
id="func_with_optional_param_with_default_None",
),
pytest.param(
func_with_optional_param_with_default_value,
{
"input": ToolInput(
parameters=[
InputParameter(
name="param1",
description="First param",
inferrable=True,
required=False, # Because of Optional[str] and default value
value_schema=ValueSchema(val_type="string", enum=None),
)
]
),
"output": ToolOutput(
available_modes=["null"], description="No description provided."
),
},
id="func_with_optional_param_with_default_value",
),
pytest.param(
func_with_optional_param_with_bar_syntax_1,
{
"input": ToolInput(
parameters=[
InputParameter(
name="param1",
description="First param",
inferrable=True,
required=False, # Because of Optional[str]
value_schema=ValueSchema(val_type="string", enum=None),
)
]
),
"output": ToolOutput(
available_modes=["null"], description="No description provided."
),
},
id="func_with_optional_param_with_bar_syntax",
),
pytest.param(
func_with_optional_param_with_bar_syntax_2,
{
"input": ToolInput(
parameters=[
InputParameter(
name="param1",
description="First param",
inferrable=True,
required=False,
value_schema=ValueSchema(val_type="string", enum=None),
)
]
),
},
id="func_with_optional_param_with_bar_syntax_2",
),
pytest.param(
func_with_optional_param_with_union_syntax_1,
{
"input": ToolInput(
parameters=[
InputParameter(
name="param1",
description="First param",
inferrable=True,
required=False,
value_schema=ValueSchema(val_type="string", enum=None),
)
]
),
},
id="func_with_optional_param_with_union_syntax_1",
),
pytest.param(
func_with_optional_param_with_union_syntax_2,
{
"input": ToolInput(
parameters=[
InputParameter(
name="param1",
description="First param",
inferrable=True,
required=False,
value_schema=ValueSchema(val_type="string", enum=None),
)
]
),
},
id="func_with_optional_param_with_union_syntax_2",
),
pytest.param(
func_with_mixed_params,
{
"input": ToolInput(
parameters=[
InputParameter(
name="param1",
description="First param",
inferrable=True,
required=True,
value_schema=ValueSchema(val_type="string", enum=None),
),
InputParameter(
name="param2",
description="Second param",
inferrable=True,
required=False, # Because a default value is provided
value_schema=ValueSchema(val_type="integer", enum=None),
),
],
tool_context_parameter_name="context",
),
},
id="func_with_mixed_params",
),
pytest.param(
func_with_list_param,
{
"input": ToolInput(
parameters=[
InputParameter(
name="param1",
description="A list of strings",
inferrable=True,
required=True,
value_schema=ValueSchema(
val_type="array", inner_val_type="string", enum=None
),
)
]
),
},
id="func_with_list_param",
),
pytest.param(
func_with_list_float_param,
{
"input": ToolInput(
parameters=[
InputParameter(
name="param1",
description="A list of floats",
inferrable=True,
required=True,
value_schema=ValueSchema(
val_type="array", inner_val_type="number", enum=None
),
)
]
),
},
id="func_with_list_float_param",
),
pytest.param(
func_with_list_of_enums_param,
{
"input": ToolInput(
parameters=[
InputParameter(
name="param1",
description="A list of enums",
inferrable=True,
required=True,
value_schema=ValueSchema(
val_type="array", inner_val_type="string", enum=["foo bar", "baz"]
),
)
]
),
},
id="func_with_list_of_enums_param",
),
pytest.param(
func_with_complex_param,
{
"input": ToolInput(
parameters=[
InputParameter(
name="param1",
description="A dictionary with lists of integers",
inferrable=True,
required=True,
value_schema=ValueSchema(val_type="json", enum=None),
)
]
),
},
id="func_with_complex_param",
),
pytest.param(
func_with_context,
{
"input": ToolInput(
parameters=[], tool_context_parameter_name="my_context"
), # ToolContext type is not an input param, but it's stored in the input field
},
id="func_with_context",
),
# Tests on output values
pytest.param(
func_with_list_return,
{
"input": ToolInput(parameters=[]),
"output": ToolOutput(
value_schema=ValueSchema(val_type="array", inner_val_type="string", enum=None),
available_modes=["value", "error"],
description="No description provided.",
),
},
id="func_with_list_return",
),
pytest.param(
func_with_known_list_return,
{
"input": ToolInput(parameters=[]),
"output": ToolOutput(
value_schema=ValueSchema(val_type="string", enum=["value1", "value2"]),
available_modes=["value", "error"],
description="No description provided.",
),
},
id="func_with_known_list_return",
),
pytest.param(
func_with_enum_return,
{
"input": ToolInput(parameters=[]),
"output": ToolOutput(
value_schema=ValueSchema(val_type="string", enum=["foo bar", "baz"]),
available_modes=["value", "error"],
description="No description provided.",
),
},
id="func_with_enum_return",
),
pytest.param(
func_with_annotated_return,
{
"input": ToolInput(parameters=[]),
"output": ToolOutput(
value_schema=ValueSchema(val_type="string", enum=None),
available_modes=["value", "error"],
description="Annotated return description",
),
},
id="func_with_annotated_return",
),
pytest.param(
func_with_optional_return,
{
"input": ToolInput(parameters=[]),
"output": ToolOutput(
value_schema=ValueSchema(val_type="string", enum=None),
available_modes=["value", "error", "null"],
description="No description provided.",
),
},
id="func_with_optional_return",
),
pytest.param(
func_with_optional_return_with_bar_syntax_1,
{
"input": ToolInput(parameters=[]),
"output": ToolOutput(
value_schema=ValueSchema(val_type="string", enum=None),
available_modes=["value", "error", "null"],
description="No description provided.",
),
},
id="func_with_optional_return_with_bar_syntax_1",
),
pytest.param(
func_with_optional_return_with_bar_syntax_2,
{
"input": ToolInput(parameters=[]),
"output": ToolOutput(
value_schema=ValueSchema(val_type="string", enum=None),
available_modes=["value", "error", "null"],
description="No description provided.",
),
},
id="func_with_optional_return_with_bar_syntax_2",
),
pytest.param(
func_with_optional_return_with_union_syntax_1,
{
"input": ToolInput(parameters=[]),
"output": ToolOutput(
value_schema=ValueSchema(val_type="string", enum=None),
available_modes=["value", "error", "null"],
description="No description provided.",
),
},
id="func_with_optional_return_with_union_syntax_1",
),
pytest.param(
func_with_optional_return_with_union_syntax_2,
{
"input": ToolInput(parameters=[]),
"output": ToolOutput(
value_schema=ValueSchema(val_type="string", enum=None),
available_modes=["value", "error", "null"],
description="No description provided.",
),
},
id="func_with_optional_return_with_union_syntax_2",
),
pytest.param(
func_with_complex_return,
{
"input": ToolInput(parameters=[]),
"output": ToolOutput(
value_schema=ValueSchema(val_type="json", enum=None),
available_modes=["value", "error"],
description="No description provided.",
),
},
id="func_with_complex_return",
),
],
)
def test_create_tool_def(func_under_test, expected_tool_def_fields):
tool_def = ToolCatalog.create_tool_definition(func_under_test, "test_toolkit", "1.0.0")
for field, expected_value in expected_tool_def_fields.items():
assert getattr(tool_def, field) == expected_value
def test_tool_name_is_set_correctly():
tool_def = ToolCatalog.create_tool_definition(func_with_description, "test_toolkit", "1.0.0")
assert tool_def.name == snake_to_pascal_case(func_with_description.__name__)
assert tool_def.fully_qualified_name == "TestToolkit.FuncWithDescription"
@pytest.mark.parametrize(
"toolkit_name, toolkit_version, toolkit_desc",
[
("test_toolkit", "1.0.0", "test_toolkit_desc"), # Both specified
("test_toolkit", None, "test_toolkit_desc"), # Version optional
("test_toolkit", "latest", None), # Description optional
("test_toolkit", None, None), # Both optional
],
)
def test_toolkit_info_is_set_correctly(toolkit_name, toolkit_version, toolkit_desc):
tool_def = ToolCatalog.create_tool_definition(
func_with_description, toolkit_name, toolkit_version, toolkit_desc
)
assert tool_def.toolkit.name == snake_to_pascal_case(toolkit_name)
assert tool_def.toolkit.description == toolkit_desc
assert tool_def.toolkit.version == toolkit_version