Convert MCP schemas to strict where possible (#414)
## Summary: Towards #404. I made this configurable because it's not clear this is always a good thing to do. I also made it default to False because I'm not sure if this could cause errors. If it works out well, we can switch the default in the future as a small breaking changes ## Test Plan: Unit tests
This commit is contained in:
parent
45c25f8ab0
commit
01f5e86ea5
5 changed files with 202 additions and 24 deletions
|
|
@ -6,7 +6,7 @@ from collections.abc import Awaitable
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, cast
|
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, cast
|
||||||
|
|
||||||
from typing_extensions import TypeAlias, TypedDict
|
from typing_extensions import NotRequired, TypeAlias, TypedDict
|
||||||
|
|
||||||
from .guardrail import InputGuardrail, OutputGuardrail
|
from .guardrail import InputGuardrail, OutputGuardrail
|
||||||
from .handoffs import Handoff
|
from .handoffs import Handoff
|
||||||
|
|
@ -53,6 +53,15 @@ class StopAtTools(TypedDict):
|
||||||
"""A list of tool names, any of which will stop the agent from running further."""
|
"""A list of tool names, any of which will stop the agent from running further."""
|
||||||
|
|
||||||
|
|
||||||
|
class MCPConfig(TypedDict):
|
||||||
|
"""Configuration for MCP servers."""
|
||||||
|
|
||||||
|
convert_schemas_to_strict: NotRequired[bool]
|
||||||
|
"""If True, we will attempt to convert the MCP schemas to strict-mode schemas. This is a
|
||||||
|
best-effort conversion, so some schemas may not be convertible. Defaults to False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Agent(Generic[TContext]):
|
class Agent(Generic[TContext]):
|
||||||
"""An agent is an AI model configured with instructions, tools, guardrails, handoffs and more.
|
"""An agent is an AI model configured with instructions, tools, guardrails, handoffs and more.
|
||||||
|
|
@ -119,6 +128,9 @@ class Agent(Generic[TContext]):
|
||||||
longer needed.
|
longer needed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
mcp_config: MCPConfig = field(default_factory=lambda: MCPConfig())
|
||||||
|
"""Configuration for MCP servers."""
|
||||||
|
|
||||||
input_guardrails: list[InputGuardrail[TContext]] = field(default_factory=list)
|
input_guardrails: list[InputGuardrail[TContext]] = field(default_factory=list)
|
||||||
"""A list of checks that run in parallel to the agent's execution, before generating a
|
"""A list of checks that run in parallel to the agent's execution, before generating a
|
||||||
response. Runs only if the agent is the first agent in the chain.
|
response. Runs only if the agent is the first agent in the chain.
|
||||||
|
|
@ -224,7 +236,8 @@ class Agent(Generic[TContext]):
|
||||||
|
|
||||||
async def get_mcp_tools(self) -> list[Tool]:
|
async def get_mcp_tools(self) -> list[Tool]:
|
||||||
"""Fetches the available tools from the MCP servers."""
|
"""Fetches the available tools from the MCP servers."""
|
||||||
return await MCPUtil.get_all_function_tools(self.mcp_servers)
|
convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
|
||||||
|
return await MCPUtil.get_all_function_tools(self.mcp_servers, convert_schemas_to_strict)
|
||||||
|
|
||||||
async def get_all_tools(self) -> list[Tool]:
|
async def get_all_tools(self) -> list[Tool]:
|
||||||
"""All agent tools, including MCP tools and function tools."""
|
"""All agent tools, including MCP tools and function tools."""
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,8 @@ import functools
|
||||||
import json
|
import json
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
|
from agents.strict_schema import ensure_strict_json_schema
|
||||||
|
|
||||||
from .. import _debug
|
from .. import _debug
|
||||||
from ..exceptions import AgentsException, ModelBehaviorError, UserError
|
from ..exceptions import AgentsException, ModelBehaviorError, UserError
|
||||||
from ..logger import logger
|
from ..logger import logger
|
||||||
|
|
@ -19,12 +21,14 @@ class MCPUtil:
|
||||||
"""Set of utilities for interop between MCP and Agents SDK tools."""
|
"""Set of utilities for interop between MCP and Agents SDK tools."""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_all_function_tools(cls, servers: list["MCPServer"]) -> list[Tool]:
|
async def get_all_function_tools(
|
||||||
|
cls, servers: list["MCPServer"], convert_schemas_to_strict: bool
|
||||||
|
) -> list[Tool]:
|
||||||
"""Get all function tools from a list of MCP servers."""
|
"""Get all function tools from a list of MCP servers."""
|
||||||
tools = []
|
tools = []
|
||||||
tool_names: set[str] = set()
|
tool_names: set[str] = set()
|
||||||
for server in servers:
|
for server in servers:
|
||||||
server_tools = await cls.get_function_tools(server)
|
server_tools = await cls.get_function_tools(server, convert_schemas_to_strict)
|
||||||
server_tool_names = {tool.name for tool in server_tools}
|
server_tool_names = {tool.name for tool in server_tools}
|
||||||
if len(server_tool_names & tool_names) > 0:
|
if len(server_tool_names & tool_names) > 0:
|
||||||
raise UserError(
|
raise UserError(
|
||||||
|
|
@ -37,25 +41,37 @@ class MCPUtil:
|
||||||
return tools
|
return tools
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_function_tools(cls, server: "MCPServer") -> list[Tool]:
|
async def get_function_tools(
|
||||||
|
cls, server: "MCPServer", convert_schemas_to_strict: bool
|
||||||
|
) -> list[Tool]:
|
||||||
"""Get all function tools from a single MCP server."""
|
"""Get all function tools from a single MCP server."""
|
||||||
|
|
||||||
with mcp_tools_span(server=server.name) as span:
|
with mcp_tools_span(server=server.name) as span:
|
||||||
tools = await server.list_tools()
|
tools = await server.list_tools()
|
||||||
span.span_data.result = [tool.name for tool in tools]
|
span.span_data.result = [tool.name for tool in tools]
|
||||||
|
|
||||||
return [cls.to_function_tool(tool, server) for tool in tools]
|
return [cls.to_function_tool(tool, server, convert_schemas_to_strict) for tool in tools]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def to_function_tool(cls, tool: "MCPTool", server: "MCPServer") -> FunctionTool:
|
def to_function_tool(
|
||||||
|
cls, tool: "MCPTool", server: "MCPServer", convert_schemas_to_strict: bool
|
||||||
|
) -> FunctionTool:
|
||||||
"""Convert an MCP tool to an Agents SDK function tool."""
|
"""Convert an MCP tool to an Agents SDK function tool."""
|
||||||
invoke_func = functools.partial(cls.invoke_mcp_tool, server, tool)
|
invoke_func = functools.partial(cls.invoke_mcp_tool, server, tool)
|
||||||
|
schema, is_strict = tool.inputSchema, False
|
||||||
|
if convert_schemas_to_strict:
|
||||||
|
try:
|
||||||
|
schema = ensure_strict_json_schema(schema)
|
||||||
|
is_strict = True
|
||||||
|
except Exception as e:
|
||||||
|
logger.info(f"Error converting MCP schema to strict mode: {e}")
|
||||||
|
|
||||||
return FunctionTool(
|
return FunctionTool(
|
||||||
name=tool.name,
|
name=tool.name,
|
||||||
description=tool.description or "",
|
description=tool.description or "",
|
||||||
params_json_schema=tool.inputSchema,
|
params_json_schema=schema,
|
||||||
on_invoke_tool=invoke_func,
|
on_invoke_tool=invoke_func,
|
||||||
strict_json_schema=False,
|
strict_json_schema=is_strict,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
||||||
|
|
@ -2,10 +2,11 @@ import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from inline_snapshot import snapshot
|
||||||
from mcp.types import Tool as MCPTool
|
from mcp.types import Tool as MCPTool
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, TypeAdapter
|
||||||
|
|
||||||
from agents import FunctionTool, RunContextWrapper
|
from agents import Agent, FunctionTool, RunContextWrapper
|
||||||
from agents.exceptions import AgentsException, ModelBehaviorError
|
from agents.exceptions import AgentsException, ModelBehaviorError
|
||||||
from agents.mcp import MCPServer, MCPUtil
|
from agents.mcp import MCPServer, MCPUtil
|
||||||
|
|
||||||
|
|
@ -18,7 +19,16 @@ class Foo(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class Bar(BaseModel):
|
class Bar(BaseModel):
|
||||||
qux: str
|
qux: dict[str, str]
|
||||||
|
|
||||||
|
|
||||||
|
Baz = TypeAdapter(dict[str, str])
|
||||||
|
|
||||||
|
|
||||||
|
def _convertible_schema() -> dict[str, Any]:
|
||||||
|
schema = Foo.model_json_schema()
|
||||||
|
schema["additionalProperties"] = False
|
||||||
|
return schema
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
@ -47,7 +57,7 @@ async def test_get_all_function_tools():
|
||||||
server3.add_tool(names[4], schemas[4])
|
server3.add_tool(names[4], schemas[4])
|
||||||
|
|
||||||
servers: list[MCPServer] = [server1, server2, server3]
|
servers: list[MCPServer] = [server1, server2, server3]
|
||||||
tools = await MCPUtil.get_all_function_tools(servers)
|
tools = await MCPUtil.get_all_function_tools(servers, convert_schemas_to_strict=False)
|
||||||
assert len(tools) == 5
|
assert len(tools) == 5
|
||||||
assert all(tool.name in names for tool in tools)
|
assert all(tool.name in names for tool in tools)
|
||||||
|
|
||||||
|
|
@ -56,6 +66,11 @@ async def test_get_all_function_tools():
|
||||||
assert tool.params_json_schema == schemas[idx]
|
assert tool.params_json_schema == schemas[idx]
|
||||||
assert tool.name == names[idx]
|
assert tool.name == names[idx]
|
||||||
|
|
||||||
|
# Also make sure it works with strict schemas
|
||||||
|
tools = await MCPUtil.get_all_function_tools(servers, convert_schemas_to_strict=True)
|
||||||
|
assert len(tools) == 5
|
||||||
|
assert all(tool.name in names for tool in tools)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_invoke_mcp_tool():
|
async def test_invoke_mcp_tool():
|
||||||
|
|
@ -107,3 +122,141 @@ async def test_mcp_invocation_crash_causes_error(caplog: pytest.LogCaptureFixtur
|
||||||
await MCPUtil.invoke_mcp_tool(server, tool, ctx, "")
|
await MCPUtil.invoke_mcp_tool(server, tool, ctx, "")
|
||||||
|
|
||||||
assert "Error invoking MCP tool test_tool_1" in caplog.text
|
assert "Error invoking MCP tool test_tool_1" in caplog.text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_agent_convert_schemas_true():
|
||||||
|
"""Test that setting convert_schemas_to_strict to True converts non-strict schemas to strict.
|
||||||
|
- 'foo' tool is already strict and remains strict.
|
||||||
|
- 'bar' tool is non-strict and becomes strict (additionalProperties set to False, etc).
|
||||||
|
"""
|
||||||
|
strict_schema = Foo.model_json_schema()
|
||||||
|
non_strict_schema = Baz.json_schema()
|
||||||
|
possible_to_convert_schema = _convertible_schema()
|
||||||
|
|
||||||
|
server = FakeMCPServer()
|
||||||
|
server.add_tool("foo", strict_schema)
|
||||||
|
server.add_tool("bar", non_strict_schema)
|
||||||
|
server.add_tool("baz", possible_to_convert_schema)
|
||||||
|
agent = Agent(
|
||||||
|
name="test_agent", mcp_servers=[server], mcp_config={"convert_schemas_to_strict": True}
|
||||||
|
)
|
||||||
|
tools = await agent.get_mcp_tools()
|
||||||
|
|
||||||
|
foo_tool = next(tool for tool in tools if tool.name == "foo")
|
||||||
|
assert isinstance(foo_tool, FunctionTool)
|
||||||
|
bar_tool = next(tool for tool in tools if tool.name == "bar")
|
||||||
|
assert isinstance(bar_tool, FunctionTool)
|
||||||
|
baz_tool = next(tool for tool in tools if tool.name == "baz")
|
||||||
|
assert isinstance(baz_tool, FunctionTool)
|
||||||
|
|
||||||
|
# Checks that additionalProperties is set to False
|
||||||
|
assert foo_tool.params_json_schema == snapshot(
|
||||||
|
{
|
||||||
|
"properties": {
|
||||||
|
"bar": {"title": "Bar", "type": "string"},
|
||||||
|
"baz": {"title": "Baz", "type": "integer"},
|
||||||
|
},
|
||||||
|
"required": ["bar", "baz"],
|
||||||
|
"title": "Foo",
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert foo_tool.strict_json_schema is True, "foo_tool should be strict"
|
||||||
|
|
||||||
|
# Checks that additionalProperties is set to False
|
||||||
|
assert bar_tool.params_json_schema == snapshot(
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": {"type": "string"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert bar_tool.strict_json_schema is False, "bar_tool should not be strict"
|
||||||
|
|
||||||
|
# Checks that additionalProperties is set to False
|
||||||
|
assert baz_tool.params_json_schema == snapshot(
|
||||||
|
{
|
||||||
|
"properties": {
|
||||||
|
"bar": {"title": "Bar", "type": "string"},
|
||||||
|
"baz": {"title": "Baz", "type": "integer"},
|
||||||
|
},
|
||||||
|
"required": ["bar", "baz"],
|
||||||
|
"title": "Foo",
|
||||||
|
"type": "object",
|
||||||
|
"additionalProperties": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert baz_tool.strict_json_schema is True, "baz_tool should be strict"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_agent_convert_schemas_false():
|
||||||
|
"""Test that setting convert_schemas_to_strict to False leaves tool schemas as non-strict.
|
||||||
|
- 'foo' tool remains strict.
|
||||||
|
- 'bar' tool remains non-strict (additionalProperties remains True).
|
||||||
|
"""
|
||||||
|
strict_schema = Foo.model_json_schema()
|
||||||
|
non_strict_schema = Baz.json_schema()
|
||||||
|
possible_to_convert_schema = _convertible_schema()
|
||||||
|
|
||||||
|
server = FakeMCPServer()
|
||||||
|
server.add_tool("foo", strict_schema)
|
||||||
|
server.add_tool("bar", non_strict_schema)
|
||||||
|
server.add_tool("baz", possible_to_convert_schema)
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
name="test_agent", mcp_servers=[server], mcp_config={"convert_schemas_to_strict": False}
|
||||||
|
)
|
||||||
|
tools = await agent.get_mcp_tools()
|
||||||
|
|
||||||
|
foo_tool = next(tool for tool in tools if tool.name == "foo")
|
||||||
|
assert isinstance(foo_tool, FunctionTool)
|
||||||
|
bar_tool = next(tool for tool in tools if tool.name == "bar")
|
||||||
|
assert isinstance(bar_tool, FunctionTool)
|
||||||
|
baz_tool = next(tool for tool in tools if tool.name == "baz")
|
||||||
|
assert isinstance(baz_tool, FunctionTool)
|
||||||
|
|
||||||
|
assert foo_tool.params_json_schema == strict_schema
|
||||||
|
assert foo_tool.strict_json_schema is False, "Shouldn't be converted unless specified"
|
||||||
|
|
||||||
|
assert bar_tool.params_json_schema == non_strict_schema
|
||||||
|
assert bar_tool.strict_json_schema is False
|
||||||
|
|
||||||
|
assert baz_tool.params_json_schema == possible_to_convert_schema
|
||||||
|
assert baz_tool.strict_json_schema is False, "Shouldn't be converted unless specified"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_agent_convert_schemas_unset():
|
||||||
|
"""Test that leaving convert_schemas_to_strict unset (defaulting to False) leaves tool schemas
|
||||||
|
as non-strict.
|
||||||
|
- 'foo' tool remains strict.
|
||||||
|
- 'bar' tool remains non-strict.
|
||||||
|
"""
|
||||||
|
strict_schema = Foo.model_json_schema()
|
||||||
|
non_strict_schema = Baz.json_schema()
|
||||||
|
possible_to_convert_schema = _convertible_schema()
|
||||||
|
|
||||||
|
server = FakeMCPServer()
|
||||||
|
server.add_tool("foo", strict_schema)
|
||||||
|
server.add_tool("bar", non_strict_schema)
|
||||||
|
server.add_tool("baz", possible_to_convert_schema)
|
||||||
|
agent = Agent(name="test_agent", mcp_servers=[server])
|
||||||
|
tools = await agent.get_mcp_tools()
|
||||||
|
|
||||||
|
foo_tool = next(tool for tool in tools if tool.name == "foo")
|
||||||
|
assert isinstance(foo_tool, FunctionTool)
|
||||||
|
bar_tool = next(tool for tool in tools if tool.name == "bar")
|
||||||
|
assert isinstance(bar_tool, FunctionTool)
|
||||||
|
baz_tool = next(tool for tool in tools if tool.name == "baz")
|
||||||
|
assert isinstance(baz_tool, FunctionTool)
|
||||||
|
|
||||||
|
assert foo_tool.params_json_schema == strict_schema
|
||||||
|
assert foo_tool.strict_json_schema is False, "Shouldn't be converted unless specified"
|
||||||
|
|
||||||
|
assert bar_tool.params_json_schema == non_strict_schema
|
||||||
|
assert bar_tool.strict_json_schema is False
|
||||||
|
|
||||||
|
assert baz_tool.params_json_schema == possible_to_convert_schema
|
||||||
|
assert baz_tool.strict_json_schema is False, "Shouldn't be converted unless specified"
|
||||||
|
|
|
||||||
|
|
@ -642,9 +642,7 @@ async def test_tool_use_behavior_custom_function():
|
||||||
async def test_model_settings_override():
|
async def test_model_settings_override():
|
||||||
model = FakeModel()
|
model = FakeModel()
|
||||||
agent = Agent(
|
agent = Agent(
|
||||||
name="test",
|
name="test", model=model, model_settings=ModelSettings(temperature=1.0, max_tokens=1000)
|
||||||
model=model,
|
|
||||||
model_settings=ModelSettings(temperature=1.0, max_tokens=1000)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
model.add_multiple_turn_outputs(
|
model.add_multiple_turn_outputs(
|
||||||
|
|
|
||||||
|
|
@ -244,9 +244,10 @@ async def test_multiple_handoff_doesnt_error():
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{"type": "generation"},
|
{"type": "generation"},
|
||||||
{"type": "handoff",
|
{
|
||||||
"data": {"from_agent": "test", "to_agent": "test"},
|
"type": "handoff",
|
||||||
"error": {
|
"data": {"from_agent": "test", "to_agent": "test"},
|
||||||
|
"error": {
|
||||||
"data": {
|
"data": {
|
||||||
"requested_agents": [
|
"requested_agents": [
|
||||||
"test",
|
"test",
|
||||||
|
|
@ -255,7 +256,7 @@ async def test_multiple_handoff_doesnt_error():
|
||||||
},
|
},
|
||||||
"message": "Multiple handoffs requested",
|
"message": "Multiple handoffs requested",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
|
@ -383,10 +384,7 @@ async def test_handoffs_lead_to_correct_agent_spans():
|
||||||
{"type": "generation"},
|
{"type": "generation"},
|
||||||
{
|
{
|
||||||
"type": "handoff",
|
"type": "handoff",
|
||||||
"data": {
|
"data": {"from_agent": "test_agent_3", "to_agent": "test_agent_1"},
|
||||||
"from_agent": "test_agent_3",
|
|
||||||
"to_agent": "test_agent_1"
|
|
||||||
},
|
|
||||||
"error": {
|
"error": {
|
||||||
"data": {
|
"data": {
|
||||||
"requested_agents": [
|
"requested_agents": [
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue