Add is_enabled to FunctionTool (#808)
### Summary: Allows a user to do `function_tool(is_enabled=<some_callable>)`; the callable is called when the agent runs. This allows you to dynamically enable/disable a tool based on the context/env. The meta-goal is to allow `Agent` to be effectively immutable. That enables some nice things down the line, and this allows you to dynamically modify the tools list without mutating the agent. ### Test Plan: Unit tests
This commit is contained in:
parent
995af4d83e
commit
4046fcb3fa
6 changed files with 102 additions and 24 deletions
|
|
@ -1,5 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import inspect
|
||||
from collections.abc import Awaitable
|
||||
|
|
@ -17,7 +18,7 @@ from .mcp import MCPUtil
|
|||
from .model_settings import ModelSettings
|
||||
from .models.interface import Model
|
||||
from .run_context import RunContextWrapper, TContext
|
||||
from .tool import FunctionToolResult, Tool, function_tool
|
||||
from .tool import FunctionTool, FunctionToolResult, Tool, function_tool
|
||||
from .util import _transforms
|
||||
from .util._types import MaybeAwaitable
|
||||
|
||||
|
|
@ -246,7 +247,22 @@ class Agent(Generic[TContext]):
|
|||
convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
|
||||
return await MCPUtil.get_all_function_tools(self.mcp_servers, convert_schemas_to_strict)
|
||||
|
||||
async def get_all_tools(self) -> list[Tool]:
|
||||
async def get_all_tools(self, run_context: RunContextWrapper[Any]) -> list[Tool]:
|
||||
"""All agent tools, including MCP tools and function tools."""
|
||||
mcp_tools = await self.get_mcp_tools()
|
||||
return mcp_tools + self.tools
|
||||
|
||||
async def _check_tool_enabled(tool: Tool) -> bool:
|
||||
if not isinstance(tool, FunctionTool):
|
||||
return True
|
||||
|
||||
attr = tool.is_enabled
|
||||
if isinstance(attr, bool):
|
||||
return attr
|
||||
res = attr(run_context, self)
|
||||
if inspect.isawaitable(res):
|
||||
return bool(await res)
|
||||
return bool(res)
|
||||
|
||||
results = await asyncio.gather(*(_check_tool_enabled(t) for t in self.tools))
|
||||
enabled: list[Tool] = [t for t, ok in zip(self.tools, results) if ok]
|
||||
return [*mcp_tools, *enabled]
|
||||
|
|
|
|||
|
|
@ -181,7 +181,7 @@ class Runner:
|
|||
|
||||
try:
|
||||
while True:
|
||||
all_tools = await cls._get_all_tools(current_agent)
|
||||
all_tools = await cls._get_all_tools(current_agent, context_wrapper)
|
||||
|
||||
# Start an agent span if we don't have one. This span is ended if the current
|
||||
# agent changes, or if the agent loop ends.
|
||||
|
|
@ -525,7 +525,7 @@ class Runner:
|
|||
if streamed_result.is_complete:
|
||||
break
|
||||
|
||||
all_tools = await cls._get_all_tools(current_agent)
|
||||
all_tools = await cls._get_all_tools(current_agent, context_wrapper)
|
||||
|
||||
# Start an agent span if we don't have one. This span is ended if the current
|
||||
# agent changes, or if the agent loop ends.
|
||||
|
|
@ -980,8 +980,10 @@ class Runner:
|
|||
return handoffs
|
||||
|
||||
@classmethod
|
||||
async def _get_all_tools(cls, agent: Agent[Any]) -> list[Tool]:
|
||||
return await agent.get_all_tools()
|
||||
async def _get_all_tools(
|
||||
cls, agent: Agent[Any], context_wrapper: RunContextWrapper[Any]
|
||||
) -> list[Tool]:
|
||||
return await agent.get_all_tools(context_wrapper)
|
||||
|
||||
@classmethod
|
||||
def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model:
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import inspect
|
|||
import json
|
||||
from collections.abc import Awaitable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Literal, Union, overload
|
||||
from typing import TYPE_CHECKING, Any, Callable, Literal, Union, overload
|
||||
|
||||
from openai.types.responses.file_search_tool_param import Filters, RankingOptions
|
||||
from openai.types.responses.response_output_item import LocalShellCall, McpApprovalRequest
|
||||
|
|
@ -24,6 +24,9 @@ from .tracing import SpanError
|
|||
from .util import _error_tracing
|
||||
from .util._types import MaybeAwaitable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .agent import Agent
|
||||
|
||||
ToolParams = ParamSpec("ToolParams")
|
||||
|
||||
ToolFunctionWithoutContext = Callable[ToolParams, Any]
|
||||
|
|
@ -74,6 +77,11 @@ class FunctionTool:
|
|||
"""Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True,
|
||||
as it increases the likelihood of correct JSON input."""
|
||||
|
||||
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True
|
||||
"""Whether the tool is enabled. Either a bool or a Callable that takes the run context and agent
|
||||
and returns whether the tool is enabled. You can use this to dynamically enable/disable a tool
|
||||
based on your context/state."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileSearchTool:
|
||||
|
|
@ -262,6 +270,7 @@ def function_tool(
|
|||
use_docstring_info: bool = True,
|
||||
failure_error_function: ToolErrorFunction | None = None,
|
||||
strict_mode: bool = True,
|
||||
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
|
||||
) -> FunctionTool:
|
||||
"""Overload for usage as @function_tool (no parentheses)."""
|
||||
...
|
||||
|
|
@ -276,6 +285,7 @@ def function_tool(
|
|||
use_docstring_info: bool = True,
|
||||
failure_error_function: ToolErrorFunction | None = None,
|
||||
strict_mode: bool = True,
|
||||
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
|
||||
) -> Callable[[ToolFunction[...]], FunctionTool]:
|
||||
"""Overload for usage as @function_tool(...)."""
|
||||
...
|
||||
|
|
@ -290,6 +300,7 @@ def function_tool(
|
|||
use_docstring_info: bool = True,
|
||||
failure_error_function: ToolErrorFunction | None = default_tool_error_function,
|
||||
strict_mode: bool = True,
|
||||
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
|
||||
) -> FunctionTool | Callable[[ToolFunction[...]], FunctionTool]:
|
||||
"""
|
||||
Decorator to create a FunctionTool from a function. By default, we will:
|
||||
|
|
@ -318,6 +329,9 @@ def function_tool(
|
|||
If False, it allows non-strict JSON schemas. For example, if a parameter has a default
|
||||
value, it will be optional, additional properties are allowed, etc. See here for more:
|
||||
https://platform.openai.com/docs/guides/structured-outputs?api-mode=responses#supported-schemas
|
||||
is_enabled: Whether the tool is enabled. Can be a bool or a callable that takes the run
|
||||
context and agent and returns whether the tool is enabled. Disabled tools are hidden
|
||||
from the LLM at runtime.
|
||||
"""
|
||||
|
||||
def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool:
|
||||
|
|
@ -407,6 +421,7 @@ def function_tool(
|
|||
params_json_schema=schema.params_json_schema,
|
||||
on_invoke_tool=_on_invoke_tool,
|
||||
strict_json_schema=strict_mode,
|
||||
is_enabled=is_enabled,
|
||||
)
|
||||
|
||||
# If func is actually a callable, we were used as @function_tool with no parentheses
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ import pytest
|
|||
from pydantic import BaseModel
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from agents import FunctionTool, ModelBehaviorError, RunContextWrapper, function_tool
|
||||
from agents import Agent, FunctionTool, ModelBehaviorError, RunContextWrapper, function_tool
|
||||
from agents.tool import default_tool_error_function
|
||||
|
||||
|
||||
|
|
@ -255,3 +255,44 @@ async def test_async_custom_error_function_works():
|
|||
|
||||
result = await tool.on_invoke_tool(ctx, '{"a": 1, "b": 2}')
|
||||
assert result == "error_ValueError"
|
||||
|
||||
|
||||
class BoolCtx(BaseModel):
|
||||
enable_tools: bool
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_enabled_bool_and_callable():
|
||||
@function_tool(is_enabled=False)
|
||||
def disabled_tool():
|
||||
return "nope"
|
||||
|
||||
async def cond_enabled(ctx: RunContextWrapper[BoolCtx], agent: Agent[Any]) -> bool:
|
||||
return ctx.context.enable_tools
|
||||
|
||||
@function_tool(is_enabled=cond_enabled)
|
||||
def another_tool():
|
||||
return "hi"
|
||||
|
||||
async def third_tool_on_invoke_tool(ctx: RunContextWrapper[Any], args: str) -> str:
|
||||
return "third"
|
||||
|
||||
third_tool = FunctionTool(
|
||||
name="third_tool",
|
||||
description="third tool",
|
||||
on_invoke_tool=third_tool_on_invoke_tool,
|
||||
is_enabled=lambda ctx, agent: ctx.context.enable_tools,
|
||||
params_json_schema={},
|
||||
)
|
||||
|
||||
agent = Agent(name="t", tools=[disabled_tool, another_tool, third_tool])
|
||||
context_1 = RunContextWrapper(BoolCtx(enable_tools=False))
|
||||
context_2 = RunContextWrapper(BoolCtx(enable_tools=True))
|
||||
|
||||
tools_with_ctx = await agent.get_all_tools(context_1)
|
||||
assert tools_with_ctx == []
|
||||
|
||||
tools_with_ctx = await agent.get_all_tools(context_2)
|
||||
assert len(tools_with_ctx) == 2
|
||||
assert tools_with_ctx[0].name == "another_tool"
|
||||
assert tools_with_ctx[1].name == "third_tool"
|
||||
|
|
|
|||
|
|
@ -290,7 +290,7 @@ async def get_execute_result(
|
|||
|
||||
processed_response = RunImpl.process_model_response(
|
||||
agent=agent,
|
||||
all_tools=await agent.get_all_tools(),
|
||||
all_tools=await agent.get_all_tools(context_wrapper or RunContextWrapper(None)),
|
||||
response=response,
|
||||
output_schema=output_schema,
|
||||
handoffs=handoffs,
|
||||
|
|
|
|||
|
|
@ -34,6 +34,10 @@ from .test_responses import (
|
|||
)
|
||||
|
||||
|
||||
def _dummy_ctx() -> RunContextWrapper[None]:
|
||||
return RunContextWrapper(context=None)
|
||||
|
||||
|
||||
def test_empty_response():
|
||||
agent = Agent(name="test")
|
||||
response = ModelResponse(
|
||||
|
|
@ -83,7 +87,7 @@ async def test_single_tool_call():
|
|||
response=response,
|
||||
output_schema=None,
|
||||
handoffs=[],
|
||||
all_tools=await agent.get_all_tools(),
|
||||
all_tools=await agent.get_all_tools(_dummy_ctx()),
|
||||
)
|
||||
assert not result.handoffs
|
||||
assert result.functions and len(result.functions) == 1
|
||||
|
|
@ -111,7 +115,7 @@ async def test_missing_tool_call_raises_error():
|
|||
response=response,
|
||||
output_schema=None,
|
||||
handoffs=[],
|
||||
all_tools=await agent.get_all_tools(),
|
||||
all_tools=await agent.get_all_tools(_dummy_ctx()),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -140,7 +144,7 @@ async def test_multiple_tool_calls():
|
|||
response=response,
|
||||
output_schema=None,
|
||||
handoffs=[],
|
||||
all_tools=await agent.get_all_tools(),
|
||||
all_tools=await agent.get_all_tools(_dummy_ctx()),
|
||||
)
|
||||
assert not result.handoffs
|
||||
assert result.functions and len(result.functions) == 2
|
||||
|
|
@ -169,7 +173,7 @@ async def test_handoffs_parsed_correctly():
|
|||
response=response,
|
||||
output_schema=None,
|
||||
handoffs=[],
|
||||
all_tools=await agent_3.get_all_tools(),
|
||||
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
|
||||
)
|
||||
assert not result.handoffs, "Shouldn't have a handoff here"
|
||||
|
||||
|
|
@ -183,7 +187,7 @@ async def test_handoffs_parsed_correctly():
|
|||
response=response,
|
||||
output_schema=None,
|
||||
handoffs=Runner._get_handoffs(agent_3),
|
||||
all_tools=await agent_3.get_all_tools(),
|
||||
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
|
||||
)
|
||||
assert len(result.handoffs) == 1, "Should have a handoff here"
|
||||
handoff = result.handoffs[0]
|
||||
|
|
@ -213,7 +217,7 @@ async def test_missing_handoff_fails():
|
|||
response=response,
|
||||
output_schema=None,
|
||||
handoffs=Runner._get_handoffs(agent_3),
|
||||
all_tools=await agent_3.get_all_tools(),
|
||||
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -236,7 +240,7 @@ async def test_multiple_handoffs_doesnt_error():
|
|||
response=response,
|
||||
output_schema=None,
|
||||
handoffs=Runner._get_handoffs(agent_3),
|
||||
all_tools=await agent_3.get_all_tools(),
|
||||
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
|
||||
)
|
||||
assert len(result.handoffs) == 2, "Should have multiple handoffs here"
|
||||
|
||||
|
|
@ -262,7 +266,7 @@ async def test_final_output_parsed_correctly():
|
|||
response=response,
|
||||
output_schema=Runner._get_output_schema(agent),
|
||||
handoffs=[],
|
||||
all_tools=await agent.get_all_tools(),
|
||||
all_tools=await agent.get_all_tools(_dummy_ctx()),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -288,7 +292,7 @@ async def test_file_search_tool_call_parsed_correctly():
|
|||
response=response,
|
||||
output_schema=None,
|
||||
handoffs=[],
|
||||
all_tools=await agent.get_all_tools(),
|
||||
all_tools=await agent.get_all_tools(_dummy_ctx()),
|
||||
)
|
||||
# The final item should be a ToolCallItem for the file search call
|
||||
assert any(
|
||||
|
|
@ -313,7 +317,7 @@ async def test_function_web_search_tool_call_parsed_correctly():
|
|||
response=response,
|
||||
output_schema=None,
|
||||
handoffs=[],
|
||||
all_tools=await agent.get_all_tools(),
|
||||
all_tools=await agent.get_all_tools(_dummy_ctx()),
|
||||
)
|
||||
assert any(
|
||||
isinstance(item, ToolCallItem) and item.raw_item is web_search_call
|
||||
|
|
@ -340,7 +344,7 @@ async def test_reasoning_item_parsed_correctly():
|
|||
response=response,
|
||||
output_schema=None,
|
||||
handoffs=[],
|
||||
all_tools=await Agent(name="test").get_all_tools(),
|
||||
all_tools=await Agent(name="test").get_all_tools(_dummy_ctx()),
|
||||
)
|
||||
assert any(
|
||||
isinstance(item, ReasoningItem) and item.raw_item is reasoning for item in result.new_items
|
||||
|
|
@ -409,7 +413,7 @@ async def test_computer_tool_call_without_computer_tool_raises_error():
|
|||
response=response,
|
||||
output_schema=None,
|
||||
handoffs=[],
|
||||
all_tools=await Agent(name="test").get_all_tools(),
|
||||
all_tools=await Agent(name="test").get_all_tools(_dummy_ctx()),
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -437,7 +441,7 @@ async def test_computer_tool_call_with_computer_tool_parsed_correctly():
|
|||
response=response,
|
||||
output_schema=None,
|
||||
handoffs=[],
|
||||
all_tools=await agent.get_all_tools(),
|
||||
all_tools=await agent.get_all_tools(_dummy_ctx()),
|
||||
)
|
||||
assert any(
|
||||
isinstance(item, ToolCallItem) and item.raw_item is computer_call
|
||||
|
|
@ -468,7 +472,7 @@ async def test_tool_and_handoff_parsed_correctly():
|
|||
response=response,
|
||||
output_schema=None,
|
||||
handoffs=Runner._get_handoffs(agent_3),
|
||||
all_tools=await agent_3.get_all_tools(),
|
||||
all_tools=await agent_3.get_all_tools(_dummy_ctx()),
|
||||
)
|
||||
assert result.functions and len(result.functions) == 1
|
||||
assert len(result.handoffs) == 1, "Should have a handoff here"
|
||||
|
|
|
|||
Loading…
Reference in a new issue