Convert MCP schemas to strict where possible (#414)
## Summary: Towards #404. I made this configurable because it's not clear this is always a good thing to do. I also made it default to False because I'm not sure if this could cause errors. If it works out well, we can switch the default in the future as a small breaking changes ## Test Plan: Unit tests
This commit is contained in:
parent
45c25f8ab0
commit
01f5e86ea5
5 changed files with 202 additions and 24 deletions
|
|
@ -6,7 +6,7 @@ from collections.abc import Awaitable
|
|||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, cast
|
||||
|
||||
from typing_extensions import TypeAlias, TypedDict
|
||||
from typing_extensions import NotRequired, TypeAlias, TypedDict
|
||||
|
||||
from .guardrail import InputGuardrail, OutputGuardrail
|
||||
from .handoffs import Handoff
|
||||
|
|
@ -53,6 +53,15 @@ class StopAtTools(TypedDict):
|
|||
"""A list of tool names, any of which will stop the agent from running further."""
|
||||
|
||||
|
||||
class MCPConfig(TypedDict):
|
||||
"""Configuration for MCP servers."""
|
||||
|
||||
convert_schemas_to_strict: NotRequired[bool]
|
||||
"""If True, we will attempt to convert the MCP schemas to strict-mode schemas. This is a
|
||||
best-effort conversion, so some schemas may not be convertible. Defaults to False.
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class Agent(Generic[TContext]):
|
||||
"""An agent is an AI model configured with instructions, tools, guardrails, handoffs and more.
|
||||
|
|
@ -119,6 +128,9 @@ class Agent(Generic[TContext]):
|
|||
longer needed.
|
||||
"""
|
||||
|
||||
mcp_config: MCPConfig = field(default_factory=lambda: MCPConfig())
|
||||
"""Configuration for MCP servers."""
|
||||
|
||||
input_guardrails: list[InputGuardrail[TContext]] = field(default_factory=list)
|
||||
"""A list of checks that run in parallel to the agent's execution, before generating a
|
||||
response. Runs only if the agent is the first agent in the chain.
|
||||
|
|
@ -224,7 +236,8 @@ class Agent(Generic[TContext]):
|
|||
|
||||
async def get_mcp_tools(self) -> list[Tool]:
|
||||
"""Fetches the available tools from the MCP servers."""
|
||||
return await MCPUtil.get_all_function_tools(self.mcp_servers)
|
||||
convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
|
||||
return await MCPUtil.get_all_function_tools(self.mcp_servers, convert_schemas_to_strict)
|
||||
|
||||
async def get_all_tools(self) -> list[Tool]:
|
||||
"""All agent tools, including MCP tools and function tools."""
|
||||
|
|
|
|||
|
|
@ -2,6 +2,8 @@ import functools
|
|||
import json
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from agents.strict_schema import ensure_strict_json_schema
|
||||
|
||||
from .. import _debug
|
||||
from ..exceptions import AgentsException, ModelBehaviorError, UserError
|
||||
from ..logger import logger
|
||||
|
|
@ -19,12 +21,14 @@ class MCPUtil:
|
|||
"""Set of utilities for interop between MCP and Agents SDK tools."""
|
||||
|
||||
@classmethod
|
||||
async def get_all_function_tools(cls, servers: list["MCPServer"]) -> list[Tool]:
|
||||
async def get_all_function_tools(
|
||||
cls, servers: list["MCPServer"], convert_schemas_to_strict: bool
|
||||
) -> list[Tool]:
|
||||
"""Get all function tools from a list of MCP servers."""
|
||||
tools = []
|
||||
tool_names: set[str] = set()
|
||||
for server in servers:
|
||||
server_tools = await cls.get_function_tools(server)
|
||||
server_tools = await cls.get_function_tools(server, convert_schemas_to_strict)
|
||||
server_tool_names = {tool.name for tool in server_tools}
|
||||
if len(server_tool_names & tool_names) > 0:
|
||||
raise UserError(
|
||||
|
|
@ -37,25 +41,37 @@ class MCPUtil:
|
|||
return tools
|
||||
|
||||
@classmethod
|
||||
async def get_function_tools(cls, server: "MCPServer") -> list[Tool]:
|
||||
async def get_function_tools(
|
||||
cls, server: "MCPServer", convert_schemas_to_strict: bool
|
||||
) -> list[Tool]:
|
||||
"""Get all function tools from a single MCP server."""
|
||||
|
||||
with mcp_tools_span(server=server.name) as span:
|
||||
tools = await server.list_tools()
|
||||
span.span_data.result = [tool.name for tool in tools]
|
||||
|
||||
return [cls.to_function_tool(tool, server) for tool in tools]
|
||||
return [cls.to_function_tool(tool, server, convert_schemas_to_strict) for tool in tools]
|
||||
|
||||
@classmethod
|
||||
def to_function_tool(cls, tool: "MCPTool", server: "MCPServer") -> FunctionTool:
|
||||
def to_function_tool(
|
||||
cls, tool: "MCPTool", server: "MCPServer", convert_schemas_to_strict: bool
|
||||
) -> FunctionTool:
|
||||
"""Convert an MCP tool to an Agents SDK function tool."""
|
||||
invoke_func = functools.partial(cls.invoke_mcp_tool, server, tool)
|
||||
schema, is_strict = tool.inputSchema, False
|
||||
if convert_schemas_to_strict:
|
||||
try:
|
||||
schema = ensure_strict_json_schema(schema)
|
||||
is_strict = True
|
||||
except Exception as e:
|
||||
logger.info(f"Error converting MCP schema to strict mode: {e}")
|
||||
|
||||
return FunctionTool(
|
||||
name=tool.name,
|
||||
description=tool.description or "",
|
||||
params_json_schema=tool.inputSchema,
|
||||
params_json_schema=schema,
|
||||
on_invoke_tool=invoke_func,
|
||||
strict_json_schema=False,
|
||||
strict_json_schema=is_strict,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -2,10 +2,11 @@ import logging
|
|||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from inline_snapshot import snapshot
|
||||
from mcp.types import Tool as MCPTool
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
|
||||
from agents import FunctionTool, RunContextWrapper
|
||||
from agents import Agent, FunctionTool, RunContextWrapper
|
||||
from agents.exceptions import AgentsException, ModelBehaviorError
|
||||
from agents.mcp import MCPServer, MCPUtil
|
||||
|
||||
|
|
@ -18,7 +19,16 @@ class Foo(BaseModel):
|
|||
|
||||
|
||||
class Bar(BaseModel):
|
||||
qux: str
|
||||
qux: dict[str, str]
|
||||
|
||||
|
||||
Baz = TypeAdapter(dict[str, str])
|
||||
|
||||
|
||||
def _convertible_schema() -> dict[str, Any]:
|
||||
schema = Foo.model_json_schema()
|
||||
schema["additionalProperties"] = False
|
||||
return schema
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
@ -47,7 +57,7 @@ async def test_get_all_function_tools():
|
|||
server3.add_tool(names[4], schemas[4])
|
||||
|
||||
servers: list[MCPServer] = [server1, server2, server3]
|
||||
tools = await MCPUtil.get_all_function_tools(servers)
|
||||
tools = await MCPUtil.get_all_function_tools(servers, convert_schemas_to_strict=False)
|
||||
assert len(tools) == 5
|
||||
assert all(tool.name in names for tool in tools)
|
||||
|
||||
|
|
@ -56,6 +66,11 @@ async def test_get_all_function_tools():
|
|||
assert tool.params_json_schema == schemas[idx]
|
||||
assert tool.name == names[idx]
|
||||
|
||||
# Also make sure it works with strict schemas
|
||||
tools = await MCPUtil.get_all_function_tools(servers, convert_schemas_to_strict=True)
|
||||
assert len(tools) == 5
|
||||
assert all(tool.name in names for tool in tools)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_mcp_tool():
|
||||
|
|
@ -107,3 +122,141 @@ async def test_mcp_invocation_crash_causes_error(caplog: pytest.LogCaptureFixtur
|
|||
await MCPUtil.invoke_mcp_tool(server, tool, ctx, "")
|
||||
|
||||
assert "Error invoking MCP tool test_tool_1" in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_convert_schemas_true():
|
||||
"""Test that setting convert_schemas_to_strict to True converts non-strict schemas to strict.
|
||||
- 'foo' tool is already strict and remains strict.
|
||||
- 'bar' tool is non-strict and becomes strict (additionalProperties set to False, etc).
|
||||
"""
|
||||
strict_schema = Foo.model_json_schema()
|
||||
non_strict_schema = Baz.json_schema()
|
||||
possible_to_convert_schema = _convertible_schema()
|
||||
|
||||
server = FakeMCPServer()
|
||||
server.add_tool("foo", strict_schema)
|
||||
server.add_tool("bar", non_strict_schema)
|
||||
server.add_tool("baz", possible_to_convert_schema)
|
||||
agent = Agent(
|
||||
name="test_agent", mcp_servers=[server], mcp_config={"convert_schemas_to_strict": True}
|
||||
)
|
||||
tools = await agent.get_mcp_tools()
|
||||
|
||||
foo_tool = next(tool for tool in tools if tool.name == "foo")
|
||||
assert isinstance(foo_tool, FunctionTool)
|
||||
bar_tool = next(tool for tool in tools if tool.name == "bar")
|
||||
assert isinstance(bar_tool, FunctionTool)
|
||||
baz_tool = next(tool for tool in tools if tool.name == "baz")
|
||||
assert isinstance(baz_tool, FunctionTool)
|
||||
|
||||
# Checks that additionalProperties is set to False
|
||||
assert foo_tool.params_json_schema == snapshot(
|
||||
{
|
||||
"properties": {
|
||||
"bar": {"title": "Bar", "type": "string"},
|
||||
"baz": {"title": "Baz", "type": "integer"},
|
||||
},
|
||||
"required": ["bar", "baz"],
|
||||
"title": "Foo",
|
||||
"type": "object",
|
||||
"additionalProperties": False,
|
||||
}
|
||||
)
|
||||
assert foo_tool.strict_json_schema is True, "foo_tool should be strict"
|
||||
|
||||
# Checks that additionalProperties is set to False
|
||||
assert bar_tool.params_json_schema == snapshot(
|
||||
{
|
||||
"type": "object",
|
||||
"additionalProperties": {"type": "string"},
|
||||
}
|
||||
)
|
||||
assert bar_tool.strict_json_schema is False, "bar_tool should not be strict"
|
||||
|
||||
# Checks that additionalProperties is set to False
|
||||
assert baz_tool.params_json_schema == snapshot(
|
||||
{
|
||||
"properties": {
|
||||
"bar": {"title": "Bar", "type": "string"},
|
||||
"baz": {"title": "Baz", "type": "integer"},
|
||||
},
|
||||
"required": ["bar", "baz"],
|
||||
"title": "Foo",
|
||||
"type": "object",
|
||||
"additionalProperties": False,
|
||||
}
|
||||
)
|
||||
assert baz_tool.strict_json_schema is True, "baz_tool should be strict"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_convert_schemas_false():
|
||||
"""Test that setting convert_schemas_to_strict to False leaves tool schemas as non-strict.
|
||||
- 'foo' tool remains strict.
|
||||
- 'bar' tool remains non-strict (additionalProperties remains True).
|
||||
"""
|
||||
strict_schema = Foo.model_json_schema()
|
||||
non_strict_schema = Baz.json_schema()
|
||||
possible_to_convert_schema = _convertible_schema()
|
||||
|
||||
server = FakeMCPServer()
|
||||
server.add_tool("foo", strict_schema)
|
||||
server.add_tool("bar", non_strict_schema)
|
||||
server.add_tool("baz", possible_to_convert_schema)
|
||||
|
||||
agent = Agent(
|
||||
name="test_agent", mcp_servers=[server], mcp_config={"convert_schemas_to_strict": False}
|
||||
)
|
||||
tools = await agent.get_mcp_tools()
|
||||
|
||||
foo_tool = next(tool for tool in tools if tool.name == "foo")
|
||||
assert isinstance(foo_tool, FunctionTool)
|
||||
bar_tool = next(tool for tool in tools if tool.name == "bar")
|
||||
assert isinstance(bar_tool, FunctionTool)
|
||||
baz_tool = next(tool for tool in tools if tool.name == "baz")
|
||||
assert isinstance(baz_tool, FunctionTool)
|
||||
|
||||
assert foo_tool.params_json_schema == strict_schema
|
||||
assert foo_tool.strict_json_schema is False, "Shouldn't be converted unless specified"
|
||||
|
||||
assert bar_tool.params_json_schema == non_strict_schema
|
||||
assert bar_tool.strict_json_schema is False
|
||||
|
||||
assert baz_tool.params_json_schema == possible_to_convert_schema
|
||||
assert baz_tool.strict_json_schema is False, "Shouldn't be converted unless specified"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_convert_schemas_unset():
|
||||
"""Test that leaving convert_schemas_to_strict unset (defaulting to False) leaves tool schemas
|
||||
as non-strict.
|
||||
- 'foo' tool remains strict.
|
||||
- 'bar' tool remains non-strict.
|
||||
"""
|
||||
strict_schema = Foo.model_json_schema()
|
||||
non_strict_schema = Baz.json_schema()
|
||||
possible_to_convert_schema = _convertible_schema()
|
||||
|
||||
server = FakeMCPServer()
|
||||
server.add_tool("foo", strict_schema)
|
||||
server.add_tool("bar", non_strict_schema)
|
||||
server.add_tool("baz", possible_to_convert_schema)
|
||||
agent = Agent(name="test_agent", mcp_servers=[server])
|
||||
tools = await agent.get_mcp_tools()
|
||||
|
||||
foo_tool = next(tool for tool in tools if tool.name == "foo")
|
||||
assert isinstance(foo_tool, FunctionTool)
|
||||
bar_tool = next(tool for tool in tools if tool.name == "bar")
|
||||
assert isinstance(bar_tool, FunctionTool)
|
||||
baz_tool = next(tool for tool in tools if tool.name == "baz")
|
||||
assert isinstance(baz_tool, FunctionTool)
|
||||
|
||||
assert foo_tool.params_json_schema == strict_schema
|
||||
assert foo_tool.strict_json_schema is False, "Shouldn't be converted unless specified"
|
||||
|
||||
assert bar_tool.params_json_schema == non_strict_schema
|
||||
assert bar_tool.strict_json_schema is False
|
||||
|
||||
assert baz_tool.params_json_schema == possible_to_convert_schema
|
||||
assert baz_tool.strict_json_schema is False, "Shouldn't be converted unless specified"
|
||||
|
|
|
|||
|
|
@ -642,9 +642,7 @@ async def test_tool_use_behavior_custom_function():
|
|||
async def test_model_settings_override():
|
||||
model = FakeModel()
|
||||
agent = Agent(
|
||||
name="test",
|
||||
model=model,
|
||||
model_settings=ModelSettings(temperature=1.0, max_tokens=1000)
|
||||
name="test", model=model, model_settings=ModelSettings(temperature=1.0, max_tokens=1000)
|
||||
)
|
||||
|
||||
model.add_multiple_turn_outputs(
|
||||
|
|
|
|||
|
|
@ -244,9 +244,10 @@ async def test_multiple_handoff_doesnt_error():
|
|||
},
|
||||
},
|
||||
{"type": "generation"},
|
||||
{"type": "handoff",
|
||||
"data": {"from_agent": "test", "to_agent": "test"},
|
||||
"error": {
|
||||
{
|
||||
"type": "handoff",
|
||||
"data": {"from_agent": "test", "to_agent": "test"},
|
||||
"error": {
|
||||
"data": {
|
||||
"requested_agents": [
|
||||
"test",
|
||||
|
|
@ -255,7 +256,7 @@ async def test_multiple_handoff_doesnt_error():
|
|||
},
|
||||
"message": "Multiple handoffs requested",
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
|
|
@ -383,10 +384,7 @@ async def test_handoffs_lead_to_correct_agent_spans():
|
|||
{"type": "generation"},
|
||||
{
|
||||
"type": "handoff",
|
||||
"data": {
|
||||
"from_agent": "test_agent_3",
|
||||
"to_agent": "test_agent_1"
|
||||
},
|
||||
"data": {"from_agent": "test_agent_3", "to_agent": "test_agent_1"},
|
||||
"error": {
|
||||
"data": {
|
||||
"requested_agents": [
|
||||
|
|
|
|||
Loading…
Reference in a new issue