Add is_enabled to FunctionTool (#808)

### Summary:
Allows a user to do `function_tool(is_enabled=<some_callable>)`; the
callable is called when the agent runs.

This allows you to dynamically enable/disable a tool based on the
context/env.

The meta-goal is to allow `Agent` to be effectively immutable. That
enables some nice things down the line, and this allows you to
dynamically modify the tools list without mutating the agent.

### Test Plan:
Unit tests
This commit is contained in:
Rohan Mehta 2025-06-03 13:44:16 -04:00 committed by GitHub
parent 995af4d83e
commit 4046fcb3fa
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 102 additions and 24 deletions

View file

@ -1,5 +1,6 @@
from __future__ import annotations
import asyncio
import dataclasses
import inspect
from collections.abc import Awaitable
@ -17,7 +18,7 @@ from .mcp import MCPUtil
from .model_settings import ModelSettings
from .models.interface import Model
from .run_context import RunContextWrapper, TContext
from .tool import FunctionToolResult, Tool, function_tool
from .tool import FunctionTool, FunctionToolResult, Tool, function_tool
from .util import _transforms
from .util._types import MaybeAwaitable
@ -246,7 +247,22 @@ class Agent(Generic[TContext]):
convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
return await MCPUtil.get_all_function_tools(self.mcp_servers, convert_schemas_to_strict)
async def get_all_tools(self) -> list[Tool]:
async def get_all_tools(self, run_context: RunContextWrapper[Any]) -> list[Tool]:
"""All agent tools, including MCP tools and function tools."""
mcp_tools = await self.get_mcp_tools()
return mcp_tools + self.tools
async def _check_tool_enabled(tool: Tool) -> bool:
if not isinstance(tool, FunctionTool):
return True
attr = tool.is_enabled
if isinstance(attr, bool):
return attr
res = attr(run_context, self)
if inspect.isawaitable(res):
return bool(await res)
return bool(res)
results = await asyncio.gather(*(_check_tool_enabled(t) for t in self.tools))
enabled: list[Tool] = [t for t, ok in zip(self.tools, results) if ok]
return [*mcp_tools, *enabled]

View file

@ -181,7 +181,7 @@ class Runner:
try:
while True:
all_tools = await cls._get_all_tools(current_agent)
all_tools = await cls._get_all_tools(current_agent, context_wrapper)
# Start an agent span if we don't have one. This span is ended if the current
# agent changes, or if the agent loop ends.
@ -525,7 +525,7 @@ class Runner:
if streamed_result.is_complete:
break
all_tools = await cls._get_all_tools(current_agent)
all_tools = await cls._get_all_tools(current_agent, context_wrapper)
# Start an agent span if we don't have one. This span is ended if the current
# agent changes, or if the agent loop ends.
@ -980,8 +980,10 @@ class Runner:
return handoffs
@classmethod
async def _get_all_tools(cls, agent: Agent[Any]) -> list[Tool]:
return await agent.get_all_tools()
async def _get_all_tools(
cls, agent: Agent[Any], context_wrapper: RunContextWrapper[Any]
) -> list[Tool]:
return await agent.get_all_tools(context_wrapper)
@classmethod
def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model:

View file

@ -4,7 +4,7 @@ import inspect
import json
from collections.abc import Awaitable
from dataclasses import dataclass
from typing import Any, Callable, Literal, Union, overload
from typing import TYPE_CHECKING, Any, Callable, Literal, Union, overload
from openai.types.responses.file_search_tool_param import Filters, RankingOptions
from openai.types.responses.response_output_item import LocalShellCall, McpApprovalRequest
@ -24,6 +24,9 @@ from .tracing import SpanError
from .util import _error_tracing
from .util._types import MaybeAwaitable
if TYPE_CHECKING:
from .agent import Agent
ToolParams = ParamSpec("ToolParams")
ToolFunctionWithoutContext = Callable[ToolParams, Any]
@ -74,6 +77,11 @@ class FunctionTool:
"""Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True,
as it increases the likelihood of correct JSON input."""
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True
"""Whether the tool is enabled. Either a bool or a Callable that takes the run context and agent
and returns whether the tool is enabled. You can use this to dynamically enable/disable a tool
based on your context/state."""
@dataclass
class FileSearchTool:
@ -262,6 +270,7 @@ def function_tool(
use_docstring_info: bool = True,
failure_error_function: ToolErrorFunction | None = None,
strict_mode: bool = True,
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
) -> FunctionTool:
"""Overload for usage as @function_tool (no parentheses)."""
...
@ -276,6 +285,7 @@ def function_tool(
use_docstring_info: bool = True,
failure_error_function: ToolErrorFunction | None = None,
strict_mode: bool = True,
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
) -> Callable[[ToolFunction[...]], FunctionTool]:
"""Overload for usage as @function_tool(...)."""
...
@ -290,6 +300,7 @@ def function_tool(
use_docstring_info: bool = True,
failure_error_function: ToolErrorFunction | None = default_tool_error_function,
strict_mode: bool = True,
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
) -> FunctionTool | Callable[[ToolFunction[...]], FunctionTool]:
"""
Decorator to create a FunctionTool from a function. By default, we will:
@ -318,6 +329,9 @@ def function_tool(
If False, it allows non-strict JSON schemas. For example, if a parameter has a default
value, it will be optional, additional properties are allowed, etc. See here for more:
https://platform.openai.com/docs/guides/structured-outputs?api-mode=responses#supported-schemas
is_enabled: Whether the tool is enabled. Can be a bool or a callable that takes the run
context and agent and returns whether the tool is enabled. Disabled tools are hidden
from the LLM at runtime.
"""
def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool:
@ -407,6 +421,7 @@ def function_tool(
params_json_schema=schema.params_json_schema,
on_invoke_tool=_on_invoke_tool,
strict_json_schema=strict_mode,
is_enabled=is_enabled,
)
# If func is actually a callable, we were used as @function_tool with no parentheses

View file

@ -5,7 +5,7 @@ import pytest
from pydantic import BaseModel
from typing_extensions import TypedDict
from agents import FunctionTool, ModelBehaviorError, RunContextWrapper, function_tool
from agents import Agent, FunctionTool, ModelBehaviorError, RunContextWrapper, function_tool
from agents.tool import default_tool_error_function
@ -255,3 +255,44 @@ async def test_async_custom_error_function_works():
result = await tool.on_invoke_tool(ctx, '{"a": 1, "b": 2}')
assert result == "error_ValueError"
class BoolCtx(BaseModel):
enable_tools: bool
@pytest.mark.asyncio
async def test_is_enabled_bool_and_callable():
@function_tool(is_enabled=False)
def disabled_tool():
return "nope"
async def cond_enabled(ctx: RunContextWrapper[BoolCtx], agent: Agent[Any]) -> bool:
return ctx.context.enable_tools
@function_tool(is_enabled=cond_enabled)
def another_tool():
return "hi"
async def third_tool_on_invoke_tool(ctx: RunContextWrapper[Any], args: str) -> str:
return "third"
third_tool = FunctionTool(
name="third_tool",
description="third tool",
on_invoke_tool=third_tool_on_invoke_tool,
is_enabled=lambda ctx, agent: ctx.context.enable_tools,
params_json_schema={},
)
agent = Agent(name="t", tools=[disabled_tool, another_tool, third_tool])
context_1 = RunContextWrapper(BoolCtx(enable_tools=False))
context_2 = RunContextWrapper(BoolCtx(enable_tools=True))
tools_with_ctx = await agent.get_all_tools(context_1)
assert tools_with_ctx == []
tools_with_ctx = await agent.get_all_tools(context_2)
assert len(tools_with_ctx) == 2
assert tools_with_ctx[0].name == "another_tool"
assert tools_with_ctx[1].name == "third_tool"

View file

@ -290,7 +290,7 @@ async def get_execute_result(
processed_response = RunImpl.process_model_response(
agent=agent,
all_tools=await agent.get_all_tools(),
all_tools=await agent.get_all_tools(context_wrapper or RunContextWrapper(None)),
response=response,
output_schema=output_schema,
handoffs=handoffs,

View file

@ -34,6 +34,10 @@ from .test_responses import (
)
def _dummy_ctx() -> RunContextWrapper[None]:
return RunContextWrapper(context=None)
def test_empty_response():
agent = Agent(name="test")
response = ModelResponse(
@ -83,7 +87,7 @@ async def test_single_tool_call():
response=response,
output_schema=None,
handoffs=[],
all_tools=await agent.get_all_tools(),
all_tools=await agent.get_all_tools(_dummy_ctx()),
)
assert not result.handoffs
assert result.functions and len(result.functions) == 1
@ -111,7 +115,7 @@ async def test_missing_tool_call_raises_error():
response=response,
output_schema=None,
handoffs=[],
all_tools=await agent.get_all_tools(),
all_tools=await agent.get_all_tools(_dummy_ctx()),
)
@ -140,7 +144,7 @@ async def test_multiple_tool_calls():
response=response,
output_schema=None,
handoffs=[],
all_tools=await agent.get_all_tools(),
all_tools=await agent.get_all_tools(_dummy_ctx()),
)
assert not result.handoffs
assert result.functions and len(result.functions) == 2
@ -169,7 +173,7 @@ async def test_handoffs_parsed_correctly():
response=response,
output_schema=None,
handoffs=[],
all_tools=await agent_3.get_all_tools(),
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
)
assert not result.handoffs, "Shouldn't have a handoff here"
@ -183,7 +187,7 @@ async def test_handoffs_parsed_correctly():
response=response,
output_schema=None,
handoffs=Runner._get_handoffs(agent_3),
all_tools=await agent_3.get_all_tools(),
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
)
assert len(result.handoffs) == 1, "Should have a handoff here"
handoff = result.handoffs[0]
@ -213,7 +217,7 @@ async def test_missing_handoff_fails():
response=response,
output_schema=None,
handoffs=Runner._get_handoffs(agent_3),
all_tools=await agent_3.get_all_tools(),
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
)
@ -236,7 +240,7 @@ async def test_multiple_handoffs_doesnt_error():
response=response,
output_schema=None,
handoffs=Runner._get_handoffs(agent_3),
all_tools=await agent_3.get_all_tools(),
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
)
assert len(result.handoffs) == 2, "Should have multiple handoffs here"
@ -262,7 +266,7 @@ async def test_final_output_parsed_correctly():
response=response,
output_schema=Runner._get_output_schema(agent),
handoffs=[],
all_tools=await agent.get_all_tools(),
all_tools=await agent.get_all_tools(_dummy_ctx()),
)
@ -288,7 +292,7 @@ async def test_file_search_tool_call_parsed_correctly():
response=response,
output_schema=None,
handoffs=[],
all_tools=await agent.get_all_tools(),
all_tools=await agent.get_all_tools(_dummy_ctx()),
)
# The final item should be a ToolCallItem for the file search call
assert any(
@ -313,7 +317,7 @@ async def test_function_web_search_tool_call_parsed_correctly():
response=response,
output_schema=None,
handoffs=[],
all_tools=await agent.get_all_tools(),
all_tools=await agent.get_all_tools(_dummy_ctx()),
)
assert any(
isinstance(item, ToolCallItem) and item.raw_item is web_search_call
@ -340,7 +344,7 @@ async def test_reasoning_item_parsed_correctly():
response=response,
output_schema=None,
handoffs=[],
all_tools=await Agent(name="test").get_all_tools(),
all_tools=await Agent(name="test").get_all_tools(_dummy_ctx()),
)
assert any(
isinstance(item, ReasoningItem) and item.raw_item is reasoning for item in result.new_items
@ -409,7 +413,7 @@ async def test_computer_tool_call_without_computer_tool_raises_error():
response=response,
output_schema=None,
handoffs=[],
all_tools=await Agent(name="test").get_all_tools(),
all_tools=await Agent(name="test").get_all_tools(_dummy_ctx()),
)
@ -437,7 +441,7 @@ async def test_computer_tool_call_with_computer_tool_parsed_correctly():
response=response,
output_schema=None,
handoffs=[],
all_tools=await agent.get_all_tools(),
all_tools=await agent.get_all_tools(_dummy_ctx()),
)
assert any(
isinstance(item, ToolCallItem) and item.raw_item is computer_call
@ -468,7 +472,7 @@ async def test_tool_and_handoff_parsed_correctly():
response=response,
output_schema=None,
handoffs=Runner._get_handoffs(agent_3),
all_tools=await agent_3.get_all_tools(),
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
)
assert result.functions and len(result.functions) == 1
assert len(result.handoffs) == 1, "Should have a handoff here"