Make the reset behavior on tool use configurable

## Summary:

#263 added this behavior. The goal was to prevent infinite loops when tool choice was set. The key change I'm making is:
1. Making it configurable on the agent.
2. Doing bookkeeping in the Runner to track this, to prevent mutating agents.
3. Not resetting the global tool choice in RunConfig.

## Test Plan:
Unit tests.
.
This commit is contained in:
Rohan Mehta 2025-03-25 13:29:17 -04:00
parent 362a9dc078
commit 6fb5792b77
7 changed files with 172 additions and 96 deletions

View file

@ -142,11 +142,6 @@ Supplying a list of tools doesn't always mean the LLM will use a tool. You can f
!!! note
To prevent infinite loops, the framework automatically resets `tool_choice` to "auto" after a tool call in the following scenarios:
1. When `tool_choice` is set to a specific function name (any string that's not "auto", "required", or "none")
2. When `tool_choice` is set to "required" AND there is only one tool available
This targeted reset mechanism allows the model to decide whether to make additional tool calls in subsequent turns while avoiding infinite loops in these specific cases.
To prevent infinite loops, the framework automatically resets `tool_choice` to "auto" after a tool call. This behavior is configurable via [`agent.reset_tool_choice`][agents.agent.Agent.reset_tool_choice]. The infinite loop is because tool results are sent to the LLM, which then generates another tool call because of `tool_choice`, ad infinitum.
If you want the Agent to completely stop after a tool call (rather than continuing with auto mode), you can set [`Agent.tool_use_behavior="stop_on_first_tool"`] which will directly use the tool output as the final response without further LLM processing.

View file

@ -4,7 +4,7 @@ import asyncio
import dataclasses
import inspect
from collections.abc import Awaitable
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, cast
from openai.types.responses import (
@ -77,6 +77,23 @@ QUEUE_COMPLETE_SENTINEL = QueueCompleteSentinel()
_NOT_FINAL_OUTPUT = ToolsToFinalOutputResult(is_final_output=False, final_output=None)
@dataclass
class AgentToolUseTracker:
agent_to_tools: list[tuple[Agent, list[str]]] = field(default_factory=list)
"""Tuple of (agent, list of tools used). Can't use a dict because agents aren't hashable."""
def add_tool_use(self, agent: Agent[Any], tool_names: list[str]) -> None:
existing_data = next((item for item in self.agent_to_tools if item[0] == agent), None)
if existing_data:
existing_data[1].extend(tool_names)
else:
self.agent_to_tools.append((agent, tool_names))
def has_used_tools(self, agent: Agent[Any]) -> bool:
existing_data = next((item for item in self.agent_to_tools if item[0] == agent), None)
return existing_data is not None and len(existing_data[1]) > 0
@dataclass
class ToolRunHandoff:
handoff: Handoff
@ -101,6 +118,7 @@ class ProcessedResponse:
handoffs: list[ToolRunHandoff]
functions: list[ToolRunFunction]
computer_actions: list[ToolRunComputerAction]
tools_used: list[str] # Names of all tools used, including hosted tools
def has_tools_to_run(self) -> bool:
# Handoffs, functions and computer actions need local processing
@ -208,29 +226,6 @@ class RunImpl:
new_step_items.extend([result.run_item for result in function_results])
new_step_items.extend(computer_results)
# Reset tool_choice to "auto" after tool execution to prevent infinite loops
if processed_response.functions or processed_response.computer_actions:
tools = agent.tools
if (
run_config.model_settings and
cls._should_reset_tool_choice(run_config.model_settings, tools)
):
# update the run_config model settings with a copy
new_run_config_settings = dataclasses.replace(
run_config.model_settings,
tool_choice="auto"
)
run_config = dataclasses.replace(run_config, model_settings=new_run_config_settings)
if cls._should_reset_tool_choice(agent.model_settings, tools):
# Create a modified copy instead of modifying the original agent
new_model_settings = dataclasses.replace(
agent.model_settings,
tool_choice="auto"
)
agent = dataclasses.replace(agent, model_settings=new_model_settings)
# Second, check if there are any handoffs
if run_handoffs := processed_response.handoffs:
return await cls.execute_handoffs(
@ -322,22 +317,16 @@ class RunImpl:
)
@classmethod
def _should_reset_tool_choice(cls, model_settings: ModelSettings, tools: list[Tool]) -> bool:
if model_settings is None or model_settings.tool_choice is None:
return False
def maybe_reset_tool_choice(
cls, agent: Agent[Any], tool_use_tracker: AgentToolUseTracker, model_settings: ModelSettings
) -> ModelSettings:
"""Resets tool choice to None if the agent has used tools and the agent's reset_tool_choice
flag is True."""
# for specific tool choices
if (
isinstance(model_settings.tool_choice, str) and
model_settings.tool_choice not in ["auto", "required", "none"]
):
return True
if agent.reset_tool_choice is True and tool_use_tracker.has_used_tools(agent):
return dataclasses.replace(model_settings, tool_choice=None)
# for one tool and required tool choice
if model_settings.tool_choice == "required":
return len(tools) == 1
return False
return model_settings
@classmethod
def process_model_response(
@ -354,7 +343,7 @@ class RunImpl:
run_handoffs = []
functions = []
computer_actions = []
tools_used: list[str] = []
handoff_map = {handoff.tool_name: handoff for handoff in handoffs}
function_map = {tool.name: tool for tool in all_tools if isinstance(tool, FunctionTool)}
computer_tool = next((tool for tool in all_tools if isinstance(tool, ComputerTool)), None)
@ -364,12 +353,15 @@ class RunImpl:
items.append(MessageOutputItem(raw_item=output, agent=agent))
elif isinstance(output, ResponseFileSearchToolCall):
items.append(ToolCallItem(raw_item=output, agent=agent))
tools_used.append("file_search")
elif isinstance(output, ResponseFunctionWebSearch):
items.append(ToolCallItem(raw_item=output, agent=agent))
tools_used.append("web_search")
elif isinstance(output, ResponseReasoningItem):
items.append(ReasoningItem(raw_item=output, agent=agent))
elif isinstance(output, ResponseComputerToolCall):
items.append(ToolCallItem(raw_item=output, agent=agent))
tools_used.append("computer_use")
if not computer_tool:
_error_tracing.attach_error_to_current_span(
SpanError(
@ -391,6 +383,8 @@ class RunImpl:
if not isinstance(output, ResponseFunctionToolCall):
continue
tools_used.append(output.name)
# Handoffs
if output.name in handoff_map:
items.append(HandoffCallItem(raw_item=output, agent=agent))
@ -422,6 +416,7 @@ class RunImpl:
handoffs=run_handoffs,
functions=functions,
computer_actions=computer_actions,
tools_used=tools_used,
)
@classmethod

View file

@ -155,6 +155,10 @@ class Agent(Generic[TContext]):
web search, etc are always processed by the LLM.
"""
reset_tool_choice: bool = True
"""Whether to reset the tool choice to the default value after a tool has been called. Defaults
to True. This ensures that the agent doesn't enter an infinite loop of tool usage."""
def clone(self, **kwargs: Any) -> Agent[TContext]:
"""Make a copy of the agent, with the given arguments changed. For example, you could do:
```

View file

@ -208,8 +208,10 @@ class OpenAIResponsesModel(Model):
list_input = ItemHelpers.input_to_new_input_list(input)
parallel_tool_calls = (
True if model_settings.parallel_tool_calls and tools and len(tools) > 0
else False if model_settings.parallel_tool_calls is False
True
if model_settings.parallel_tool_calls and tools and len(tools) > 0
else False
if model_settings.parallel_tool_calls is False
else NOT_GIVEN
)

View file

@ -10,6 +10,7 @@ from openai.types.responses import ResponseCompletedEvent
from agents.tool import Tool
from ._run_impl import (
AgentToolUseTracker,
NextStepFinalOutput,
NextStepHandoff,
NextStepRunAgain,
@ -151,6 +152,8 @@ class Runner:
if run_config is None:
run_config = RunConfig()
tool_use_tracker = AgentToolUseTracker()
with TraceCtxManager(
workflow_name=run_config.workflow_name,
trace_id=run_config.trace_id,
@ -227,6 +230,7 @@ class Runner:
context_wrapper=context_wrapper,
run_config=run_config,
should_run_agent_start_hooks=should_run_agent_start_hooks,
tool_use_tracker=tool_use_tracker,
),
)
else:
@ -239,6 +243,7 @@ class Runner:
context_wrapper=context_wrapper,
run_config=run_config,
should_run_agent_start_hooks=should_run_agent_start_hooks,
tool_use_tracker=tool_use_tracker,
)
should_run_agent_start_hooks = False
@ -486,6 +491,7 @@ class Runner:
current_agent = starting_agent
current_turn = 0
should_run_agent_start_hooks = True
tool_use_tracker = AgentToolUseTracker()
streamed_result._event_queue.put_nowait(AgentUpdatedStreamEvent(new_agent=current_agent))
@ -546,6 +552,7 @@ class Runner:
context_wrapper,
run_config,
should_run_agent_start_hooks,
tool_use_tracker,
)
should_run_agent_start_hooks = False
@ -613,6 +620,7 @@ class Runner:
context_wrapper: RunContextWrapper[TContext],
run_config: RunConfig,
should_run_agent_start_hooks: bool,
tool_use_tracker: AgentToolUseTracker,
) -> SingleStepResult:
if should_run_agent_start_hooks:
await asyncio.gather(
@ -635,6 +643,8 @@ class Runner:
all_tools = await cls._get_all_tools(agent)
model = cls._get_model(agent, run_config)
model_settings = agent.model_settings.resolve(run_config.model_settings)
model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings)
final_response: ModelResponse | None = None
input = ItemHelpers.input_to_new_input_list(streamed_result.input)
@ -687,6 +697,7 @@ class Runner:
hooks=hooks,
context_wrapper=context_wrapper,
run_config=run_config,
tool_use_tracker=tool_use_tracker,
)
RunImpl.stream_step_result_to_queue(single_step_result, streamed_result._event_queue)
@ -704,6 +715,7 @@ class Runner:
context_wrapper: RunContextWrapper[TContext],
run_config: RunConfig,
should_run_agent_start_hooks: bool,
tool_use_tracker: AgentToolUseTracker,
) -> SingleStepResult:
# Ensure we run the hooks before anything else
if should_run_agent_start_hooks:
@ -732,6 +744,7 @@ class Runner:
handoffs,
context_wrapper,
run_config,
tool_use_tracker,
)
return await cls._get_single_step_result_from_response(
@ -745,6 +758,7 @@ class Runner:
hooks=hooks,
context_wrapper=context_wrapper,
run_config=run_config,
tool_use_tracker=tool_use_tracker,
)
@classmethod
@ -761,6 +775,7 @@ class Runner:
hooks: RunHooks[TContext],
context_wrapper: RunContextWrapper[TContext],
run_config: RunConfig,
tool_use_tracker: AgentToolUseTracker,
) -> SingleStepResult:
processed_response = RunImpl.process_model_response(
agent=agent,
@ -769,6 +784,9 @@ class Runner:
output_schema=output_schema,
handoffs=handoffs,
)
tool_use_tracker.add_tool_use(agent, processed_response.tools_used)
return await RunImpl.execute_tools_and_side_effects(
agent=agent,
original_input=original_input,
@ -868,9 +886,12 @@ class Runner:
handoffs: list[Handoff],
context_wrapper: RunContextWrapper[TContext],
run_config: RunConfig,
tool_use_tracker: AgentToolUseTracker,
) -> ModelResponse:
model = cls._get_model(agent, run_config)
model_settings = agent.model_settings.resolve(run_config.model_settings)
model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings)
new_response = await model.get_response(
system_instructions=system_prompt,
input=input,

View file

@ -1,6 +1,7 @@
from __future__ import annotations
from collections.abc import AsyncIterator
from typing import Any
from openai.types.responses import Response, ResponseCompletedEvent
@ -31,6 +32,7 @@ class FakeModel(Model):
[initial_output] if initial_output else []
)
self.tracing_enabled = tracing_enabled
self.last_turn_args: dict[str, Any] = {}
def set_next_output(self, output: list[TResponseOutputItem] | Exception):
self.turn_outputs.append(output)
@ -53,6 +55,14 @@ class FakeModel(Model):
handoffs: list[Handoff],
tracing: ModelTracing,
) -> ModelResponse:
self.last_turn_args = {
"system_instructions": system_instructions,
"input": input,
"model_settings": model_settings,
"tools": tools,
"output_schema": output_schema,
}
with generation_span(disabled=not self.tracing_enabled) as span:
output = self.get_next_output()

View file

@ -1,63 +1,78 @@
import pytest
from agents import Agent, ModelSettings, Runner, Tool
from agents._run_impl import RunImpl
from agents import Agent, ModelSettings, Runner
from agents._run_impl import AgentToolUseTracker, RunImpl
from .fake_model import FakeModel
from .test_responses import (
get_function_tool,
get_function_tool_call,
get_text_message,
)
from .test_responses import get_function_tool, get_function_tool_call, get_text_message
class TestToolChoiceReset:
def test_should_reset_tool_choice_direct(self):
"""
Test the _should_reset_tool_choice method directly with various inputs
to ensure it correctly identifies cases where reset is needed.
"""
# Case 1: tool_choice = None should not reset
agent = Agent(name="test_agent")
# Case 1: Empty tool use tracker should not change the "None" tool choice
model_settings = ModelSettings(tool_choice=None)
tools1: list[Tool] = [get_function_tool("tool1")]
# Cast to list[Tool] to fix type checking issues
assert not RunImpl._should_reset_tool_choice(model_settings, tools1)
tracker = AgentToolUseTracker()
new_settings = RunImpl.maybe_reset_tool_choice(agent, tracker, model_settings)
assert new_settings.tool_choice == model_settings.tool_choice
# Case 2: tool_choice = "auto" should not reset
# Case 2: Empty tool use tracker should not change the "auto" tool choice
model_settings = ModelSettings(tool_choice="auto")
assert not RunImpl._should_reset_tool_choice(model_settings, tools1)
tracker = AgentToolUseTracker()
new_settings = RunImpl.maybe_reset_tool_choice(agent, tracker, model_settings)
assert model_settings.tool_choice == new_settings.tool_choice
# Case 3: tool_choice = "none" should not reset
model_settings = ModelSettings(tool_choice="none")
assert not RunImpl._should_reset_tool_choice(model_settings, tools1)
# Case 3: Empty tool use tracker should not change the "required" tool choice
model_settings = ModelSettings(tool_choice="required")
tracker = AgentToolUseTracker()
new_settings = RunImpl.maybe_reset_tool_choice(agent, tracker, model_settings)
assert model_settings.tool_choice == new_settings.tool_choice
# Case 4: tool_choice = "required" with one tool should reset
model_settings = ModelSettings(tool_choice="required")
assert RunImpl._should_reset_tool_choice(model_settings, tools1)
tracker = AgentToolUseTracker()
tracker.add_tool_use(agent, ["tool1"])
new_settings = RunImpl.maybe_reset_tool_choice(agent, tracker, model_settings)
assert new_settings.tool_choice is None
# Case 5: tool_choice = "required" with multiple tools should not reset
# Case 5: tool_choice = "required" with multiple tools should reset
model_settings = ModelSettings(tool_choice="required")
tools2: list[Tool] = [get_function_tool("tool1"), get_function_tool("tool2")]
assert not RunImpl._should_reset_tool_choice(model_settings, tools2)
tracker = AgentToolUseTracker()
tracker.add_tool_use(agent, ["tool1", "tool2"])
new_settings = RunImpl.maybe_reset_tool_choice(agent, tracker, model_settings)
assert new_settings.tool_choice is None
# Case 6: Specific tool choice should reset
model_settings = ModelSettings(tool_choice="specific_tool")
assert RunImpl._should_reset_tool_choice(model_settings, tools1)
# Case 6: Tool usage on a different agent should not affect the tool choice
model_settings = ModelSettings(tool_choice="foo_bar")
tracker = AgentToolUseTracker()
tracker.add_tool_use(Agent(name="other_agent"), ["foo_bar", "baz"])
new_settings = RunImpl.maybe_reset_tool_choice(agent, tracker, model_settings)
assert new_settings.tool_choice == model_settings.tool_choice
# Case 7: tool_choice = "foo_bar" with multiple tools should reset
model_settings = ModelSettings(tool_choice="foo_bar")
tracker = AgentToolUseTracker()
tracker.add_tool_use(agent, ["foo_bar", "baz"])
new_settings = RunImpl.maybe_reset_tool_choice(agent, tracker, model_settings)
assert new_settings.tool_choice is None
@pytest.mark.asyncio
async def test_required_tool_choice_with_multiple_runs(self):
"""
Test scenario 1: When multiple runs are executed with tool_choice="required"
Ensure each run works correctly and doesn't get stuck in infinite loop
Also verify that tool_choice remains "required" between runs
Test scenario 1: When multiple runs are executed with tool_choice="required", ensure each
run works correctly and doesn't get stuck in an infinite loop. Also verify that tool_choice
remains "required" between runs.
"""
# Set up our fake model with responses for two runs
fake_model = FakeModel()
fake_model.add_multiple_turn_outputs([
[get_text_message("First run response")],
[get_text_message("Second run response")]
])
fake_model.add_multiple_turn_outputs(
[[get_text_message("First run response")], [get_text_message("Second run response")]]
)
# Create agent with a custom tool and tool_choice="required"
custom_tool = get_function_tool("custom_tool")
@ -71,24 +86,26 @@ class TestToolChoiceReset:
# First run should work correctly and preserve tool_choice
result1 = await Runner.run(agent, "first run")
assert result1.final_output == "First run response"
assert agent.model_settings.tool_choice == "required", "tool_choice should stay required"
assert fake_model.last_turn_args["model_settings"].tool_choice == "required", (
"tool_choice should stay required"
)
# Second run should also work correctly with tool_choice still required
result2 = await Runner.run(agent, "second run")
assert result2.final_output == "Second run response"
assert agent.model_settings.tool_choice == "required", "tool_choice should stay required"
assert fake_model.last_turn_args["model_settings"].tool_choice == "required", (
"tool_choice should stay required"
)
@pytest.mark.asyncio
async def test_required_with_stop_at_tool_name(self):
"""
Test scenario 2: When using required tool_choice with stop_at_tool_names behavior
Ensure it correctly stops at the specified tool
Test scenario 2: When using required tool_choice with stop_at_tool_names behavior, ensure
it correctly stops at the specified tool
"""
# Set up fake model to return a tool call for second_tool
fake_model = FakeModel()
fake_model.set_next_output([
get_function_tool_call("second_tool", "{}")
])
fake_model.set_next_output([get_function_tool_call("second_tool", "{}")])
# Create agent with two tools and tool_choice="required" and stop_at_tool behavior
first_tool = get_function_tool("first_tool", return_value="first tool result")
@ -109,8 +126,8 @@ class TestToolChoiceReset:
@pytest.mark.asyncio
async def test_specific_tool_choice(self):
"""
Test scenario 3: When using a specific tool choice name
Ensure it doesn't cause infinite loops
Test scenario 3: When using a specific tool choice name, ensure it doesn't cause infinite
loops.
"""
# Set up fake model to return a text message
fake_model = FakeModel()
@ -135,17 +152,19 @@ class TestToolChoiceReset:
@pytest.mark.asyncio
async def test_required_with_single_tool(self):
"""
Test scenario 4: When using required tool_choice with only one tool
Ensure it doesn't cause infinite loops
Test scenario 4: When using required tool_choice with only one tool, ensure it doesn't cause
infinite loops.
"""
# Set up fake model to return a tool call followed by a text message
fake_model = FakeModel()
fake_model.add_multiple_turn_outputs([
# First call returns a tool call
[get_function_tool_call("custom_tool", "{}")],
# Second call returns a text message
[get_text_message("Final response")]
])
fake_model.add_multiple_turn_outputs(
[
# First call returns a tool call
[get_function_tool_call("custom_tool", "{}")],
# Second call returns a text message
[get_text_message("Final response")],
]
)
# Create agent with a single tool and tool_choice="required"
custom_tool = get_function_tool("custom_tool", return_value="tool result")
@ -159,3 +178,33 @@ class TestToolChoiceReset:
# Run should complete without infinite loops
result = await Runner.run(agent, "first run")
assert result.final_output == "Final response"
@pytest.mark.asyncio
async def test_dont_reset_tool_choice_if_not_required(self):
"""
Test scenario 5: When agent.reset_tool_choice is False, ensure tool_choice is not reset.
"""
# Set up fake model to return a tool call followed by a text message
fake_model = FakeModel()
fake_model.add_multiple_turn_outputs(
[
# First call returns a tool call
[get_function_tool_call("custom_tool", "{}")],
# Second call returns a text message
[get_text_message("Final response")],
]
)
# Create agent with a single tool and tool_choice="required" and reset_tool_choice=False
custom_tool = get_function_tool("custom_tool", return_value="tool result")
agent = Agent(
name="test_agent",
model=fake_model,
tools=[custom_tool],
model_settings=ModelSettings(tool_choice="required"),
reset_tool_choice=False,
)
await Runner.run(agent, "test")
assert fake_model.last_turn_args["model_settings"].tool_choice == "required"