Add is_enabled to FunctionTool (#808)

### Summary:
Allows a user to do `function_tool(is_enabled=<some_callable>)`; the
callable is called when the agent runs.

This allows you to dynamically enable/disable a tool based on the
context/env.

The meta-goal is to allow `Agent` to be effectively immutable. That
enables some nice things down the line, and this allows you to
dynamically modify the tools list without mutating the agent.

### Test Plan:
Unit tests
This commit is contained in:
Rohan Mehta 2025-06-03 13:44:16 -04:00 committed by GitHub
parent 995af4d83e
commit 4046fcb3fa
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 102 additions and 24 deletions

View file

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
import asyncio
import dataclasses import dataclasses
import inspect import inspect
from collections.abc import Awaitable from collections.abc import Awaitable
@ -17,7 +18,7 @@ from .mcp import MCPUtil
from .model_settings import ModelSettings from .model_settings import ModelSettings
from .models.interface import Model from .models.interface import Model
from .run_context import RunContextWrapper, TContext from .run_context import RunContextWrapper, TContext
from .tool import FunctionToolResult, Tool, function_tool from .tool import FunctionTool, FunctionToolResult, Tool, function_tool
from .util import _transforms from .util import _transforms
from .util._types import MaybeAwaitable from .util._types import MaybeAwaitable
@ -246,7 +247,22 @@ class Agent(Generic[TContext]):
convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False) convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
return await MCPUtil.get_all_function_tools(self.mcp_servers, convert_schemas_to_strict) return await MCPUtil.get_all_function_tools(self.mcp_servers, convert_schemas_to_strict)
async def get_all_tools(self) -> list[Tool]: async def get_all_tools(self, run_context: RunContextWrapper[Any]) -> list[Tool]:
"""All agent tools, including MCP tools and function tools.""" """All agent tools, including MCP tools and function tools."""
mcp_tools = await self.get_mcp_tools() mcp_tools = await self.get_mcp_tools()
return mcp_tools + self.tools
async def _check_tool_enabled(tool: Tool) -> bool:
if not isinstance(tool, FunctionTool):
return True
attr = tool.is_enabled
if isinstance(attr, bool):
return attr
res = attr(run_context, self)
if inspect.isawaitable(res):
return bool(await res)
return bool(res)
results = await asyncio.gather(*(_check_tool_enabled(t) for t in self.tools))
enabled: list[Tool] = [t for t, ok in zip(self.tools, results) if ok]
return [*mcp_tools, *enabled]

View file

@ -181,7 +181,7 @@ class Runner:
try: try:
while True: while True:
all_tools = await cls._get_all_tools(current_agent) all_tools = await cls._get_all_tools(current_agent, context_wrapper)
# Start an agent span if we don't have one. This span is ended if the current # Start an agent span if we don't have one. This span is ended if the current
# agent changes, or if the agent loop ends. # agent changes, or if the agent loop ends.
@ -525,7 +525,7 @@ class Runner:
if streamed_result.is_complete: if streamed_result.is_complete:
break break
all_tools = await cls._get_all_tools(current_agent) all_tools = await cls._get_all_tools(current_agent, context_wrapper)
# Start an agent span if we don't have one. This span is ended if the current # Start an agent span if we don't have one. This span is ended if the current
# agent changes, or if the agent loop ends. # agent changes, or if the agent loop ends.
@ -980,8 +980,10 @@ class Runner:
return handoffs return handoffs
@classmethod @classmethod
async def _get_all_tools(cls, agent: Agent[Any]) -> list[Tool]: async def _get_all_tools(
return await agent.get_all_tools() cls, agent: Agent[Any], context_wrapper: RunContextWrapper[Any]
) -> list[Tool]:
return await agent.get_all_tools(context_wrapper)
@classmethod @classmethod
def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model: def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model:

View file

@ -4,7 +4,7 @@ import inspect
import json import json
from collections.abc import Awaitable from collections.abc import Awaitable
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, Literal, Union, overload from typing import TYPE_CHECKING, Any, Callable, Literal, Union, overload
from openai.types.responses.file_search_tool_param import Filters, RankingOptions from openai.types.responses.file_search_tool_param import Filters, RankingOptions
from openai.types.responses.response_output_item import LocalShellCall, McpApprovalRequest from openai.types.responses.response_output_item import LocalShellCall, McpApprovalRequest
@ -24,6 +24,9 @@ from .tracing import SpanError
from .util import _error_tracing from .util import _error_tracing
from .util._types import MaybeAwaitable from .util._types import MaybeAwaitable
if TYPE_CHECKING:
from .agent import Agent
ToolParams = ParamSpec("ToolParams") ToolParams = ParamSpec("ToolParams")
ToolFunctionWithoutContext = Callable[ToolParams, Any] ToolFunctionWithoutContext = Callable[ToolParams, Any]
@ -74,6 +77,11 @@ class FunctionTool:
"""Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True, """Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True,
as it increases the likelihood of correct JSON input.""" as it increases the likelihood of correct JSON input."""
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True
"""Whether the tool is enabled. Either a bool or a Callable that takes the run context and agent
and returns whether the tool is enabled. You can use this to dynamically enable/disable a tool
based on your context/state."""
@dataclass @dataclass
class FileSearchTool: class FileSearchTool:
@ -262,6 +270,7 @@ def function_tool(
use_docstring_info: bool = True, use_docstring_info: bool = True,
failure_error_function: ToolErrorFunction | None = None, failure_error_function: ToolErrorFunction | None = None,
strict_mode: bool = True, strict_mode: bool = True,
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
) -> FunctionTool: ) -> FunctionTool:
"""Overload for usage as @function_tool (no parentheses).""" """Overload for usage as @function_tool (no parentheses)."""
... ...
@ -276,6 +285,7 @@ def function_tool(
use_docstring_info: bool = True, use_docstring_info: bool = True,
failure_error_function: ToolErrorFunction | None = None, failure_error_function: ToolErrorFunction | None = None,
strict_mode: bool = True, strict_mode: bool = True,
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
) -> Callable[[ToolFunction[...]], FunctionTool]: ) -> Callable[[ToolFunction[...]], FunctionTool]:
"""Overload for usage as @function_tool(...).""" """Overload for usage as @function_tool(...)."""
... ...
@ -290,6 +300,7 @@ def function_tool(
use_docstring_info: bool = True, use_docstring_info: bool = True,
failure_error_function: ToolErrorFunction | None = default_tool_error_function, failure_error_function: ToolErrorFunction | None = default_tool_error_function,
strict_mode: bool = True, strict_mode: bool = True,
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
) -> FunctionTool | Callable[[ToolFunction[...]], FunctionTool]: ) -> FunctionTool | Callable[[ToolFunction[...]], FunctionTool]:
""" """
Decorator to create a FunctionTool from a function. By default, we will: Decorator to create a FunctionTool from a function. By default, we will:
@ -318,6 +329,9 @@ def function_tool(
If False, it allows non-strict JSON schemas. For example, if a parameter has a default If False, it allows non-strict JSON schemas. For example, if a parameter has a default
value, it will be optional, additional properties are allowed, etc. See here for more: value, it will be optional, additional properties are allowed, etc. See here for more:
https://platform.openai.com/docs/guides/structured-outputs?api-mode=responses#supported-schemas https://platform.openai.com/docs/guides/structured-outputs?api-mode=responses#supported-schemas
is_enabled: Whether the tool is enabled. Can be a bool or a callable that takes the run
context and agent and returns whether the tool is enabled. Disabled tools are hidden
from the LLM at runtime.
""" """
def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool: def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool:
@ -407,6 +421,7 @@ def function_tool(
params_json_schema=schema.params_json_schema, params_json_schema=schema.params_json_schema,
on_invoke_tool=_on_invoke_tool, on_invoke_tool=_on_invoke_tool,
strict_json_schema=strict_mode, strict_json_schema=strict_mode,
is_enabled=is_enabled,
) )
# If func is actually a callable, we were used as @function_tool with no parentheses # If func is actually a callable, we were used as @function_tool with no parentheses

View file

@ -5,7 +5,7 @@ import pytest
from pydantic import BaseModel from pydantic import BaseModel
from typing_extensions import TypedDict from typing_extensions import TypedDict
from agents import FunctionTool, ModelBehaviorError, RunContextWrapper, function_tool from agents import Agent, FunctionTool, ModelBehaviorError, RunContextWrapper, function_tool
from agents.tool import default_tool_error_function from agents.tool import default_tool_error_function
@ -255,3 +255,44 @@ async def test_async_custom_error_function_works():
result = await tool.on_invoke_tool(ctx, '{"a": 1, "b": 2}') result = await tool.on_invoke_tool(ctx, '{"a": 1, "b": 2}')
assert result == "error_ValueError" assert result == "error_ValueError"
class BoolCtx(BaseModel):
enable_tools: bool
@pytest.mark.asyncio
async def test_is_enabled_bool_and_callable():
@function_tool(is_enabled=False)
def disabled_tool():
return "nope"
async def cond_enabled(ctx: RunContextWrapper[BoolCtx], agent: Agent[Any]) -> bool:
return ctx.context.enable_tools
@function_tool(is_enabled=cond_enabled)
def another_tool():
return "hi"
async def third_tool_on_invoke_tool(ctx: RunContextWrapper[Any], args: str) -> str:
return "third"
third_tool = FunctionTool(
name="third_tool",
description="third tool",
on_invoke_tool=third_tool_on_invoke_tool,
is_enabled=lambda ctx, agent: ctx.context.enable_tools,
params_json_schema={},
)
agent = Agent(name="t", tools=[disabled_tool, another_tool, third_tool])
context_1 = RunContextWrapper(BoolCtx(enable_tools=False))
context_2 = RunContextWrapper(BoolCtx(enable_tools=True))
tools_with_ctx = await agent.get_all_tools(context_1)
assert tools_with_ctx == []
tools_with_ctx = await agent.get_all_tools(context_2)
assert len(tools_with_ctx) == 2
assert tools_with_ctx[0].name == "another_tool"
assert tools_with_ctx[1].name == "third_tool"

View file

@ -290,7 +290,7 @@ async def get_execute_result(
processed_response = RunImpl.process_model_response( processed_response = RunImpl.process_model_response(
agent=agent, agent=agent,
all_tools=await agent.get_all_tools(), all_tools=await agent.get_all_tools(context_wrapper or RunContextWrapper(None)),
response=response, response=response,
output_schema=output_schema, output_schema=output_schema,
handoffs=handoffs, handoffs=handoffs,

View file

@ -34,6 +34,10 @@ from .test_responses import (
) )
def _dummy_ctx() -> RunContextWrapper[None]:
return RunContextWrapper(context=None)
def test_empty_response(): def test_empty_response():
agent = Agent(name="test") agent = Agent(name="test")
response = ModelResponse( response = ModelResponse(
@ -83,7 +87,7 @@ async def test_single_tool_call():
response=response, response=response,
output_schema=None, output_schema=None,
handoffs=[], handoffs=[],
all_tools=await agent.get_all_tools(), all_tools=await agent.get_all_tools(_dummy_ctx()),
) )
assert not result.handoffs assert not result.handoffs
assert result.functions and len(result.functions) == 1 assert result.functions and len(result.functions) == 1
@ -111,7 +115,7 @@ async def test_missing_tool_call_raises_error():
response=response, response=response,
output_schema=None, output_schema=None,
handoffs=[], handoffs=[],
all_tools=await agent.get_all_tools(), all_tools=await agent.get_all_tools(_dummy_ctx()),
) )
@ -140,7 +144,7 @@ async def test_multiple_tool_calls():
response=response, response=response,
output_schema=None, output_schema=None,
handoffs=[], handoffs=[],
all_tools=await agent.get_all_tools(), all_tools=await agent.get_all_tools(_dummy_ctx()),
) )
assert not result.handoffs assert not result.handoffs
assert result.functions and len(result.functions) == 2 assert result.functions and len(result.functions) == 2
@ -169,7 +173,7 @@ async def test_handoffs_parsed_correctly():
response=response, response=response,
output_schema=None, output_schema=None,
handoffs=[], handoffs=[],
all_tools=await agent_3.get_all_tools(), all_tools=await agent_3.get_all_tools(_dummy_ctx()),
) )
assert not result.handoffs, "Shouldn't have a handoff here" assert not result.handoffs, "Shouldn't have a handoff here"
@ -183,7 +187,7 @@ async def test_handoffs_parsed_correctly():
response=response, response=response,
output_schema=None, output_schema=None,
handoffs=Runner._get_handoffs(agent_3), handoffs=Runner._get_handoffs(agent_3),
all_tools=await agent_3.get_all_tools(), all_tools=await agent_3.get_all_tools(_dummy_ctx()),
) )
assert len(result.handoffs) == 1, "Should have a handoff here" assert len(result.handoffs) == 1, "Should have a handoff here"
handoff = result.handoffs[0] handoff = result.handoffs[0]
@ -213,7 +217,7 @@ async def test_missing_handoff_fails():
response=response, response=response,
output_schema=None, output_schema=None,
handoffs=Runner._get_handoffs(agent_3), handoffs=Runner._get_handoffs(agent_3),
all_tools=await agent_3.get_all_tools(), all_tools=await agent_3.get_all_tools(_dummy_ctx()),
) )
@ -236,7 +240,7 @@ async def test_multiple_handoffs_doesnt_error():
response=response, response=response,
output_schema=None, output_schema=None,
handoffs=Runner._get_handoffs(agent_3), handoffs=Runner._get_handoffs(agent_3),
all_tools=await agent_3.get_all_tools(), all_tools=await agent_3.get_all_tools(_dummy_ctx()),
) )
assert len(result.handoffs) == 2, "Should have multiple handoffs here" assert len(result.handoffs) == 2, "Should have multiple handoffs here"
@ -262,7 +266,7 @@ async def test_final_output_parsed_correctly():
response=response, response=response,
output_schema=Runner._get_output_schema(agent), output_schema=Runner._get_output_schema(agent),
handoffs=[], handoffs=[],
all_tools=await agent.get_all_tools(), all_tools=await agent.get_all_tools(_dummy_ctx()),
) )
@ -288,7 +292,7 @@ async def test_file_search_tool_call_parsed_correctly():
response=response, response=response,
output_schema=None, output_schema=None,
handoffs=[], handoffs=[],
all_tools=await agent.get_all_tools(), all_tools=await agent.get_all_tools(_dummy_ctx()),
) )
# The final item should be a ToolCallItem for the file search call # The final item should be a ToolCallItem for the file search call
assert any( assert any(
@ -313,7 +317,7 @@ async def test_function_web_search_tool_call_parsed_correctly():
response=response, response=response,
output_schema=None, output_schema=None,
handoffs=[], handoffs=[],
all_tools=await agent.get_all_tools(), all_tools=await agent.get_all_tools(_dummy_ctx()),
) )
assert any( assert any(
isinstance(item, ToolCallItem) and item.raw_item is web_search_call isinstance(item, ToolCallItem) and item.raw_item is web_search_call
@ -340,7 +344,7 @@ async def test_reasoning_item_parsed_correctly():
response=response, response=response,
output_schema=None, output_schema=None,
handoffs=[], handoffs=[],
all_tools=await Agent(name="test").get_all_tools(), all_tools=await Agent(name="test").get_all_tools(_dummy_ctx()),
) )
assert any( assert any(
isinstance(item, ReasoningItem) and item.raw_item is reasoning for item in result.new_items isinstance(item, ReasoningItem) and item.raw_item is reasoning for item in result.new_items
@ -409,7 +413,7 @@ async def test_computer_tool_call_without_computer_tool_raises_error():
response=response, response=response,
output_schema=None, output_schema=None,
handoffs=[], handoffs=[],
all_tools=await Agent(name="test").get_all_tools(), all_tools=await Agent(name="test").get_all_tools(_dummy_ctx()),
) )
@ -437,7 +441,7 @@ async def test_computer_tool_call_with_computer_tool_parsed_correctly():
response=response, response=response,
output_schema=None, output_schema=None,
handoffs=[], handoffs=[],
all_tools=await agent.get_all_tools(), all_tools=await agent.get_all_tools(_dummy_ctx()),
) )
assert any( assert any(
isinstance(item, ToolCallItem) and item.raw_item is computer_call isinstance(item, ToolCallItem) and item.raw_item is computer_call
@ -468,7 +472,7 @@ async def test_tool_and_handoff_parsed_correctly():
response=response, response=response,
output_schema=None, output_schema=None,
handoffs=Runner._get_handoffs(agent_3), handoffs=Runner._get_handoffs(agent_3),
all_tools=await agent_3.get_all_tools(), all_tools=await agent_3.get_all_tools(_dummy_ctx()),
) )
assert result.functions and len(result.functions) == 1 assert result.functions and len(result.functions) == 1
assert len(result.handoffs) == 1, "Should have a handoff here" assert len(result.handoffs) == 1, "Should have a handoff here"