Add is_enabled to FunctionTool (#808)
### Summary: Allows a user to do `function_tool(is_enabled=<some_callable>)`; the callable is called when the agent runs. This allows you to dynamically enable/disable a tool based on the context/env. The meta-goal is to allow `Agent` to be effectively immutable. That enables some nice things down the line, and this allows you to dynamically modify the tools list without mutating the agent. ### Test Plan: Unit tests
This commit is contained in:
parent
995af4d83e
commit
4046fcb3fa
6 changed files with 102 additions and 24 deletions
|
|
@ -1,5 +1,6 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import inspect
|
import inspect
|
||||||
from collections.abc import Awaitable
|
from collections.abc import Awaitable
|
||||||
|
|
@ -17,7 +18,7 @@ from .mcp import MCPUtil
|
||||||
from .model_settings import ModelSettings
|
from .model_settings import ModelSettings
|
||||||
from .models.interface import Model
|
from .models.interface import Model
|
||||||
from .run_context import RunContextWrapper, TContext
|
from .run_context import RunContextWrapper, TContext
|
||||||
from .tool import FunctionToolResult, Tool, function_tool
|
from .tool import FunctionTool, FunctionToolResult, Tool, function_tool
|
||||||
from .util import _transforms
|
from .util import _transforms
|
||||||
from .util._types import MaybeAwaitable
|
from .util._types import MaybeAwaitable
|
||||||
|
|
||||||
|
|
@ -246,7 +247,22 @@ class Agent(Generic[TContext]):
|
||||||
convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
|
convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
|
||||||
return await MCPUtil.get_all_function_tools(self.mcp_servers, convert_schemas_to_strict)
|
return await MCPUtil.get_all_function_tools(self.mcp_servers, convert_schemas_to_strict)
|
||||||
|
|
||||||
async def get_all_tools(self) -> list[Tool]:
|
async def get_all_tools(self, run_context: RunContextWrapper[Any]) -> list[Tool]:
|
||||||
"""All agent tools, including MCP tools and function tools."""
|
"""All agent tools, including MCP tools and function tools."""
|
||||||
mcp_tools = await self.get_mcp_tools()
|
mcp_tools = await self.get_mcp_tools()
|
||||||
return mcp_tools + self.tools
|
|
||||||
|
async def _check_tool_enabled(tool: Tool) -> bool:
|
||||||
|
if not isinstance(tool, FunctionTool):
|
||||||
|
return True
|
||||||
|
|
||||||
|
attr = tool.is_enabled
|
||||||
|
if isinstance(attr, bool):
|
||||||
|
return attr
|
||||||
|
res = attr(run_context, self)
|
||||||
|
if inspect.isawaitable(res):
|
||||||
|
return bool(await res)
|
||||||
|
return bool(res)
|
||||||
|
|
||||||
|
results = await asyncio.gather(*(_check_tool_enabled(t) for t in self.tools))
|
||||||
|
enabled: list[Tool] = [t for t, ok in zip(self.tools, results) if ok]
|
||||||
|
return [*mcp_tools, *enabled]
|
||||||
|
|
|
||||||
|
|
@ -181,7 +181,7 @@ class Runner:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
all_tools = await cls._get_all_tools(current_agent)
|
all_tools = await cls._get_all_tools(current_agent, context_wrapper)
|
||||||
|
|
||||||
# Start an agent span if we don't have one. This span is ended if the current
|
# Start an agent span if we don't have one. This span is ended if the current
|
||||||
# agent changes, or if the agent loop ends.
|
# agent changes, or if the agent loop ends.
|
||||||
|
|
@ -525,7 +525,7 @@ class Runner:
|
||||||
if streamed_result.is_complete:
|
if streamed_result.is_complete:
|
||||||
break
|
break
|
||||||
|
|
||||||
all_tools = await cls._get_all_tools(current_agent)
|
all_tools = await cls._get_all_tools(current_agent, context_wrapper)
|
||||||
|
|
||||||
# Start an agent span if we don't have one. This span is ended if the current
|
# Start an agent span if we don't have one. This span is ended if the current
|
||||||
# agent changes, or if the agent loop ends.
|
# agent changes, or if the agent loop ends.
|
||||||
|
|
@ -980,8 +980,10 @@ class Runner:
|
||||||
return handoffs
|
return handoffs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def _get_all_tools(cls, agent: Agent[Any]) -> list[Tool]:
|
async def _get_all_tools(
|
||||||
return await agent.get_all_tools()
|
cls, agent: Agent[Any], context_wrapper: RunContextWrapper[Any]
|
||||||
|
) -> list[Tool]:
|
||||||
|
return await agent.get_all_tools(context_wrapper)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model:
|
def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model:
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import inspect
|
||||||
import json
|
import json
|
||||||
from collections.abc import Awaitable
|
from collections.abc import Awaitable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable, Literal, Union, overload
|
from typing import TYPE_CHECKING, Any, Callable, Literal, Union, overload
|
||||||
|
|
||||||
from openai.types.responses.file_search_tool_param import Filters, RankingOptions
|
from openai.types.responses.file_search_tool_param import Filters, RankingOptions
|
||||||
from openai.types.responses.response_output_item import LocalShellCall, McpApprovalRequest
|
from openai.types.responses.response_output_item import LocalShellCall, McpApprovalRequest
|
||||||
|
|
@ -24,6 +24,9 @@ from .tracing import SpanError
|
||||||
from .util import _error_tracing
|
from .util import _error_tracing
|
||||||
from .util._types import MaybeAwaitable
|
from .util._types import MaybeAwaitable
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .agent import Agent
|
||||||
|
|
||||||
ToolParams = ParamSpec("ToolParams")
|
ToolParams = ParamSpec("ToolParams")
|
||||||
|
|
||||||
ToolFunctionWithoutContext = Callable[ToolParams, Any]
|
ToolFunctionWithoutContext = Callable[ToolParams, Any]
|
||||||
|
|
@ -74,6 +77,11 @@ class FunctionTool:
|
||||||
"""Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True,
|
"""Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True,
|
||||||
as it increases the likelihood of correct JSON input."""
|
as it increases the likelihood of correct JSON input."""
|
||||||
|
|
||||||
|
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True
|
||||||
|
"""Whether the tool is enabled. Either a bool or a Callable that takes the run context and agent
|
||||||
|
and returns whether the tool is enabled. You can use this to dynamically enable/disable a tool
|
||||||
|
based on your context/state."""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FileSearchTool:
|
class FileSearchTool:
|
||||||
|
|
@ -262,6 +270,7 @@ def function_tool(
|
||||||
use_docstring_info: bool = True,
|
use_docstring_info: bool = True,
|
||||||
failure_error_function: ToolErrorFunction | None = None,
|
failure_error_function: ToolErrorFunction | None = None,
|
||||||
strict_mode: bool = True,
|
strict_mode: bool = True,
|
||||||
|
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
|
||||||
) -> FunctionTool:
|
) -> FunctionTool:
|
||||||
"""Overload for usage as @function_tool (no parentheses)."""
|
"""Overload for usage as @function_tool (no parentheses)."""
|
||||||
...
|
...
|
||||||
|
|
@ -276,6 +285,7 @@ def function_tool(
|
||||||
use_docstring_info: bool = True,
|
use_docstring_info: bool = True,
|
||||||
failure_error_function: ToolErrorFunction | None = None,
|
failure_error_function: ToolErrorFunction | None = None,
|
||||||
strict_mode: bool = True,
|
strict_mode: bool = True,
|
||||||
|
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
|
||||||
) -> Callable[[ToolFunction[...]], FunctionTool]:
|
) -> Callable[[ToolFunction[...]], FunctionTool]:
|
||||||
"""Overload for usage as @function_tool(...)."""
|
"""Overload for usage as @function_tool(...)."""
|
||||||
...
|
...
|
||||||
|
|
@ -290,6 +300,7 @@ def function_tool(
|
||||||
use_docstring_info: bool = True,
|
use_docstring_info: bool = True,
|
||||||
failure_error_function: ToolErrorFunction | None = default_tool_error_function,
|
failure_error_function: ToolErrorFunction | None = default_tool_error_function,
|
||||||
strict_mode: bool = True,
|
strict_mode: bool = True,
|
||||||
|
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
|
||||||
) -> FunctionTool | Callable[[ToolFunction[...]], FunctionTool]:
|
) -> FunctionTool | Callable[[ToolFunction[...]], FunctionTool]:
|
||||||
"""
|
"""
|
||||||
Decorator to create a FunctionTool from a function. By default, we will:
|
Decorator to create a FunctionTool from a function. By default, we will:
|
||||||
|
|
@ -318,6 +329,9 @@ def function_tool(
|
||||||
If False, it allows non-strict JSON schemas. For example, if a parameter has a default
|
If False, it allows non-strict JSON schemas. For example, if a parameter has a default
|
||||||
value, it will be optional, additional properties are allowed, etc. See here for more:
|
value, it will be optional, additional properties are allowed, etc. See here for more:
|
||||||
https://platform.openai.com/docs/guides/structured-outputs?api-mode=responses#supported-schemas
|
https://platform.openai.com/docs/guides/structured-outputs?api-mode=responses#supported-schemas
|
||||||
|
is_enabled: Whether the tool is enabled. Can be a bool or a callable that takes the run
|
||||||
|
context and agent and returns whether the tool is enabled. Disabled tools are hidden
|
||||||
|
from the LLM at runtime.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool:
|
def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool:
|
||||||
|
|
@ -407,6 +421,7 @@ def function_tool(
|
||||||
params_json_schema=schema.params_json_schema,
|
params_json_schema=schema.params_json_schema,
|
||||||
on_invoke_tool=_on_invoke_tool,
|
on_invoke_tool=_on_invoke_tool,
|
||||||
strict_json_schema=strict_mode,
|
strict_json_schema=strict_mode,
|
||||||
|
is_enabled=is_enabled,
|
||||||
)
|
)
|
||||||
|
|
||||||
# If func is actually a callable, we were used as @function_tool with no parentheses
|
# If func is actually a callable, we were used as @function_tool with no parentheses
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import pytest
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from agents import FunctionTool, ModelBehaviorError, RunContextWrapper, function_tool
|
from agents import Agent, FunctionTool, ModelBehaviorError, RunContextWrapper, function_tool
|
||||||
from agents.tool import default_tool_error_function
|
from agents.tool import default_tool_error_function
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -255,3 +255,44 @@ async def test_async_custom_error_function_works():
|
||||||
|
|
||||||
result = await tool.on_invoke_tool(ctx, '{"a": 1, "b": 2}')
|
result = await tool.on_invoke_tool(ctx, '{"a": 1, "b": 2}')
|
||||||
assert result == "error_ValueError"
|
assert result == "error_ValueError"
|
||||||
|
|
||||||
|
|
||||||
|
class BoolCtx(BaseModel):
|
||||||
|
enable_tools: bool
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_is_enabled_bool_and_callable():
|
||||||
|
@function_tool(is_enabled=False)
|
||||||
|
def disabled_tool():
|
||||||
|
return "nope"
|
||||||
|
|
||||||
|
async def cond_enabled(ctx: RunContextWrapper[BoolCtx], agent: Agent[Any]) -> bool:
|
||||||
|
return ctx.context.enable_tools
|
||||||
|
|
||||||
|
@function_tool(is_enabled=cond_enabled)
|
||||||
|
def another_tool():
|
||||||
|
return "hi"
|
||||||
|
|
||||||
|
async def third_tool_on_invoke_tool(ctx: RunContextWrapper[Any], args: str) -> str:
|
||||||
|
return "third"
|
||||||
|
|
||||||
|
third_tool = FunctionTool(
|
||||||
|
name="third_tool",
|
||||||
|
description="third tool",
|
||||||
|
on_invoke_tool=third_tool_on_invoke_tool,
|
||||||
|
is_enabled=lambda ctx, agent: ctx.context.enable_tools,
|
||||||
|
params_json_schema={},
|
||||||
|
)
|
||||||
|
|
||||||
|
agent = Agent(name="t", tools=[disabled_tool, another_tool, third_tool])
|
||||||
|
context_1 = RunContextWrapper(BoolCtx(enable_tools=False))
|
||||||
|
context_2 = RunContextWrapper(BoolCtx(enable_tools=True))
|
||||||
|
|
||||||
|
tools_with_ctx = await agent.get_all_tools(context_1)
|
||||||
|
assert tools_with_ctx == []
|
||||||
|
|
||||||
|
tools_with_ctx = await agent.get_all_tools(context_2)
|
||||||
|
assert len(tools_with_ctx) == 2
|
||||||
|
assert tools_with_ctx[0].name == "another_tool"
|
||||||
|
assert tools_with_ctx[1].name == "third_tool"
|
||||||
|
|
|
||||||
|
|
@ -290,7 +290,7 @@ async def get_execute_result(
|
||||||
|
|
||||||
processed_response = RunImpl.process_model_response(
|
processed_response = RunImpl.process_model_response(
|
||||||
agent=agent,
|
agent=agent,
|
||||||
all_tools=await agent.get_all_tools(),
|
all_tools=await agent.get_all_tools(context_wrapper or RunContextWrapper(None)),
|
||||||
response=response,
|
response=response,
|
||||||
output_schema=output_schema,
|
output_schema=output_schema,
|
||||||
handoffs=handoffs,
|
handoffs=handoffs,
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,10 @@ from .test_responses import (
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _dummy_ctx() -> RunContextWrapper[None]:
|
||||||
|
return RunContextWrapper(context=None)
|
||||||
|
|
||||||
|
|
||||||
def test_empty_response():
|
def test_empty_response():
|
||||||
agent = Agent(name="test")
|
agent = Agent(name="test")
|
||||||
response = ModelResponse(
|
response = ModelResponse(
|
||||||
|
|
@ -83,7 +87,7 @@ async def test_single_tool_call():
|
||||||
response=response,
|
response=response,
|
||||||
output_schema=None,
|
output_schema=None,
|
||||||
handoffs=[],
|
handoffs=[],
|
||||||
all_tools=await agent.get_all_tools(),
|
all_tools=await agent.get_all_tools(_dummy_ctx()),
|
||||||
)
|
)
|
||||||
assert not result.handoffs
|
assert not result.handoffs
|
||||||
assert result.functions and len(result.functions) == 1
|
assert result.functions and len(result.functions) == 1
|
||||||
|
|
@ -111,7 +115,7 @@ async def test_missing_tool_call_raises_error():
|
||||||
response=response,
|
response=response,
|
||||||
output_schema=None,
|
output_schema=None,
|
||||||
handoffs=[],
|
handoffs=[],
|
||||||
all_tools=await agent.get_all_tools(),
|
all_tools=await agent.get_all_tools(_dummy_ctx()),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -140,7 +144,7 @@ async def test_multiple_tool_calls():
|
||||||
response=response,
|
response=response,
|
||||||
output_schema=None,
|
output_schema=None,
|
||||||
handoffs=[],
|
handoffs=[],
|
||||||
all_tools=await agent.get_all_tools(),
|
all_tools=await agent.get_all_tools(_dummy_ctx()),
|
||||||
)
|
)
|
||||||
assert not result.handoffs
|
assert not result.handoffs
|
||||||
assert result.functions and len(result.functions) == 2
|
assert result.functions and len(result.functions) == 2
|
||||||
|
|
@ -169,7 +173,7 @@ async def test_handoffs_parsed_correctly():
|
||||||
response=response,
|
response=response,
|
||||||
output_schema=None,
|
output_schema=None,
|
||||||
handoffs=[],
|
handoffs=[],
|
||||||
all_tools=await agent_3.get_all_tools(),
|
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
|
||||||
)
|
)
|
||||||
assert not result.handoffs, "Shouldn't have a handoff here"
|
assert not result.handoffs, "Shouldn't have a handoff here"
|
||||||
|
|
||||||
|
|
@ -183,7 +187,7 @@ async def test_handoffs_parsed_correctly():
|
||||||
response=response,
|
response=response,
|
||||||
output_schema=None,
|
output_schema=None,
|
||||||
handoffs=Runner._get_handoffs(agent_3),
|
handoffs=Runner._get_handoffs(agent_3),
|
||||||
all_tools=await agent_3.get_all_tools(),
|
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
|
||||||
)
|
)
|
||||||
assert len(result.handoffs) == 1, "Should have a handoff here"
|
assert len(result.handoffs) == 1, "Should have a handoff here"
|
||||||
handoff = result.handoffs[0]
|
handoff = result.handoffs[0]
|
||||||
|
|
@ -213,7 +217,7 @@ async def test_missing_handoff_fails():
|
||||||
response=response,
|
response=response,
|
||||||
output_schema=None,
|
output_schema=None,
|
||||||
handoffs=Runner._get_handoffs(agent_3),
|
handoffs=Runner._get_handoffs(agent_3),
|
||||||
all_tools=await agent_3.get_all_tools(),
|
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -236,7 +240,7 @@ async def test_multiple_handoffs_doesnt_error():
|
||||||
response=response,
|
response=response,
|
||||||
output_schema=None,
|
output_schema=None,
|
||||||
handoffs=Runner._get_handoffs(agent_3),
|
handoffs=Runner._get_handoffs(agent_3),
|
||||||
all_tools=await agent_3.get_all_tools(),
|
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
|
||||||
)
|
)
|
||||||
assert len(result.handoffs) == 2, "Should have multiple handoffs here"
|
assert len(result.handoffs) == 2, "Should have multiple handoffs here"
|
||||||
|
|
||||||
|
|
@ -262,7 +266,7 @@ async def test_final_output_parsed_correctly():
|
||||||
response=response,
|
response=response,
|
||||||
output_schema=Runner._get_output_schema(agent),
|
output_schema=Runner._get_output_schema(agent),
|
||||||
handoffs=[],
|
handoffs=[],
|
||||||
all_tools=await agent.get_all_tools(),
|
all_tools=await agent.get_all_tools(_dummy_ctx()),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -288,7 +292,7 @@ async def test_file_search_tool_call_parsed_correctly():
|
||||||
response=response,
|
response=response,
|
||||||
output_schema=None,
|
output_schema=None,
|
||||||
handoffs=[],
|
handoffs=[],
|
||||||
all_tools=await agent.get_all_tools(),
|
all_tools=await agent.get_all_tools(_dummy_ctx()),
|
||||||
)
|
)
|
||||||
# The final item should be a ToolCallItem for the file search call
|
# The final item should be a ToolCallItem for the file search call
|
||||||
assert any(
|
assert any(
|
||||||
|
|
@ -313,7 +317,7 @@ async def test_function_web_search_tool_call_parsed_correctly():
|
||||||
response=response,
|
response=response,
|
||||||
output_schema=None,
|
output_schema=None,
|
||||||
handoffs=[],
|
handoffs=[],
|
||||||
all_tools=await agent.get_all_tools(),
|
all_tools=await agent.get_all_tools(_dummy_ctx()),
|
||||||
)
|
)
|
||||||
assert any(
|
assert any(
|
||||||
isinstance(item, ToolCallItem) and item.raw_item is web_search_call
|
isinstance(item, ToolCallItem) and item.raw_item is web_search_call
|
||||||
|
|
@ -340,7 +344,7 @@ async def test_reasoning_item_parsed_correctly():
|
||||||
response=response,
|
response=response,
|
||||||
output_schema=None,
|
output_schema=None,
|
||||||
handoffs=[],
|
handoffs=[],
|
||||||
all_tools=await Agent(name="test").get_all_tools(),
|
all_tools=await Agent(name="test").get_all_tools(_dummy_ctx()),
|
||||||
)
|
)
|
||||||
assert any(
|
assert any(
|
||||||
isinstance(item, ReasoningItem) and item.raw_item is reasoning for item in result.new_items
|
isinstance(item, ReasoningItem) and item.raw_item is reasoning for item in result.new_items
|
||||||
|
|
@ -409,7 +413,7 @@ async def test_computer_tool_call_without_computer_tool_raises_error():
|
||||||
response=response,
|
response=response,
|
||||||
output_schema=None,
|
output_schema=None,
|
||||||
handoffs=[],
|
handoffs=[],
|
||||||
all_tools=await Agent(name="test").get_all_tools(),
|
all_tools=await Agent(name="test").get_all_tools(_dummy_ctx()),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -437,7 +441,7 @@ async def test_computer_tool_call_with_computer_tool_parsed_correctly():
|
||||||
response=response,
|
response=response,
|
||||||
output_schema=None,
|
output_schema=None,
|
||||||
handoffs=[],
|
handoffs=[],
|
||||||
all_tools=await agent.get_all_tools(),
|
all_tools=await agent.get_all_tools(_dummy_ctx()),
|
||||||
)
|
)
|
||||||
assert any(
|
assert any(
|
||||||
isinstance(item, ToolCallItem) and item.raw_item is computer_call
|
isinstance(item, ToolCallItem) and item.raw_item is computer_call
|
||||||
|
|
@ -468,7 +472,7 @@ async def test_tool_and_handoff_parsed_correctly():
|
||||||
response=response,
|
response=response,
|
||||||
output_schema=None,
|
output_schema=None,
|
||||||
handoffs=Runner._get_handoffs(agent_3),
|
handoffs=Runner._get_handoffs(agent_3),
|
||||||
all_tools=await agent_3.get_all_tools(),
|
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
|
||||||
)
|
)
|
||||||
assert result.functions and len(result.functions) == 1
|
assert result.functions and len(result.functions) == 1
|
||||||
assert len(result.handoffs) == 1, "Should have a handoff here"
|
assert len(result.handoffs) == 1, "Should have a handoff here"
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue