Make the reset behavior on tool use configurable
## Summary: #263 added this behavior. The goal was to prevent infinite loops when tool choice was set. The key change I'm making is: 1. Making it configurable on the agent. 2. Doing bookkeeping in the Runner to track this, to prevent mutating agents. 3. Not resetting the global tool choice in RunConfig. ## Test Plan: Unit tests. .
This commit is contained in:
parent
362a9dc078
commit
6fb5792b77
7 changed files with 172 additions and 96 deletions
|
|
@ -142,11 +142,6 @@ Supplying a list of tools doesn't always mean the LLM will use a tool. You can f
|
|||
|
||||
!!! note
|
||||
|
||||
To prevent infinite loops, the framework automatically resets `tool_choice` to "auto" after a tool call in the following scenarios:
|
||||
|
||||
1. When `tool_choice` is set to a specific function name (any string that's not "auto", "required", or "none")
|
||||
2. When `tool_choice` is set to "required" AND there is only one tool available
|
||||
|
||||
This targeted reset mechanism allows the model to decide whether to make additional tool calls in subsequent turns while avoiding infinite loops in these specific cases.
|
||||
|
||||
To prevent infinite loops, the framework automatically resets `tool_choice` to "auto" after a tool call. This behavior is configurable via [`agent.reset_tool_choice`][agents.agent.Agent.reset_tool_choice]. The infinite loop is because tool results are sent to the LLM, which then generates another tool call because of `tool_choice`, ad infinitum.
|
||||
|
||||
If you want the Agent to completely stop after a tool call (rather than continuing with auto mode), you can set [`Agent.tool_use_behavior="stop_on_first_tool"`] which will directly use the tool output as the final response without further LLM processing.
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import asyncio
|
|||
import dataclasses
|
||||
import inspect
|
||||
from collections.abc import Awaitable
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from openai.types.responses import (
|
||||
|
|
@ -77,6 +77,23 @@ QUEUE_COMPLETE_SENTINEL = QueueCompleteSentinel()
|
|||
_NOT_FINAL_OUTPUT = ToolsToFinalOutputResult(is_final_output=False, final_output=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentToolUseTracker:
|
||||
agent_to_tools: list[tuple[Agent, list[str]]] = field(default_factory=list)
|
||||
"""Tuple of (agent, list of tools used). Can't use a dict because agents aren't hashable."""
|
||||
|
||||
def add_tool_use(self, agent: Agent[Any], tool_names: list[str]) -> None:
|
||||
existing_data = next((item for item in self.agent_to_tools if item[0] == agent), None)
|
||||
if existing_data:
|
||||
existing_data[1].extend(tool_names)
|
||||
else:
|
||||
self.agent_to_tools.append((agent, tool_names))
|
||||
|
||||
def has_used_tools(self, agent: Agent[Any]) -> bool:
|
||||
existing_data = next((item for item in self.agent_to_tools if item[0] == agent), None)
|
||||
return existing_data is not None and len(existing_data[1]) > 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolRunHandoff:
|
||||
handoff: Handoff
|
||||
|
|
@ -101,6 +118,7 @@ class ProcessedResponse:
|
|||
handoffs: list[ToolRunHandoff]
|
||||
functions: list[ToolRunFunction]
|
||||
computer_actions: list[ToolRunComputerAction]
|
||||
tools_used: list[str] # Names of all tools used, including hosted tools
|
||||
|
||||
def has_tools_to_run(self) -> bool:
|
||||
# Handoffs, functions and computer actions need local processing
|
||||
|
|
@ -208,29 +226,6 @@ class RunImpl:
|
|||
new_step_items.extend([result.run_item for result in function_results])
|
||||
new_step_items.extend(computer_results)
|
||||
|
||||
# Reset tool_choice to "auto" after tool execution to prevent infinite loops
|
||||
if processed_response.functions or processed_response.computer_actions:
|
||||
tools = agent.tools
|
||||
|
||||
if (
|
||||
run_config.model_settings and
|
||||
cls._should_reset_tool_choice(run_config.model_settings, tools)
|
||||
):
|
||||
# update the run_config model settings with a copy
|
||||
new_run_config_settings = dataclasses.replace(
|
||||
run_config.model_settings,
|
||||
tool_choice="auto"
|
||||
)
|
||||
run_config = dataclasses.replace(run_config, model_settings=new_run_config_settings)
|
||||
|
||||
if cls._should_reset_tool_choice(agent.model_settings, tools):
|
||||
# Create a modified copy instead of modifying the original agent
|
||||
new_model_settings = dataclasses.replace(
|
||||
agent.model_settings,
|
||||
tool_choice="auto"
|
||||
)
|
||||
agent = dataclasses.replace(agent, model_settings=new_model_settings)
|
||||
|
||||
# Second, check if there are any handoffs
|
||||
if run_handoffs := processed_response.handoffs:
|
||||
return await cls.execute_handoffs(
|
||||
|
|
@ -322,22 +317,16 @@ class RunImpl:
|
|||
)
|
||||
|
||||
@classmethod
|
||||
def _should_reset_tool_choice(cls, model_settings: ModelSettings, tools: list[Tool]) -> bool:
|
||||
if model_settings is None or model_settings.tool_choice is None:
|
||||
return False
|
||||
def maybe_reset_tool_choice(
|
||||
cls, agent: Agent[Any], tool_use_tracker: AgentToolUseTracker, model_settings: ModelSettings
|
||||
) -> ModelSettings:
|
||||
"""Resets tool choice to None if the agent has used tools and the agent's reset_tool_choice
|
||||
flag is True."""
|
||||
|
||||
# for specific tool choices
|
||||
if (
|
||||
isinstance(model_settings.tool_choice, str) and
|
||||
model_settings.tool_choice not in ["auto", "required", "none"]
|
||||
):
|
||||
return True
|
||||
if agent.reset_tool_choice is True and tool_use_tracker.has_used_tools(agent):
|
||||
return dataclasses.replace(model_settings, tool_choice=None)
|
||||
|
||||
# for one tool and required tool choice
|
||||
if model_settings.tool_choice == "required":
|
||||
return len(tools) == 1
|
||||
|
||||
return False
|
||||
return model_settings
|
||||
|
||||
@classmethod
|
||||
def process_model_response(
|
||||
|
|
@ -354,7 +343,7 @@ class RunImpl:
|
|||
run_handoffs = []
|
||||
functions = []
|
||||
computer_actions = []
|
||||
|
||||
tools_used: list[str] = []
|
||||
handoff_map = {handoff.tool_name: handoff for handoff in handoffs}
|
||||
function_map = {tool.name: tool for tool in all_tools if isinstance(tool, FunctionTool)}
|
||||
computer_tool = next((tool for tool in all_tools if isinstance(tool, ComputerTool)), None)
|
||||
|
|
@ -364,12 +353,15 @@ class RunImpl:
|
|||
items.append(MessageOutputItem(raw_item=output, agent=agent))
|
||||
elif isinstance(output, ResponseFileSearchToolCall):
|
||||
items.append(ToolCallItem(raw_item=output, agent=agent))
|
||||
tools_used.append("file_search")
|
||||
elif isinstance(output, ResponseFunctionWebSearch):
|
||||
items.append(ToolCallItem(raw_item=output, agent=agent))
|
||||
tools_used.append("web_search")
|
||||
elif isinstance(output, ResponseReasoningItem):
|
||||
items.append(ReasoningItem(raw_item=output, agent=agent))
|
||||
elif isinstance(output, ResponseComputerToolCall):
|
||||
items.append(ToolCallItem(raw_item=output, agent=agent))
|
||||
tools_used.append("computer_use")
|
||||
if not computer_tool:
|
||||
_error_tracing.attach_error_to_current_span(
|
||||
SpanError(
|
||||
|
|
@ -391,6 +383,8 @@ class RunImpl:
|
|||
if not isinstance(output, ResponseFunctionToolCall):
|
||||
continue
|
||||
|
||||
tools_used.append(output.name)
|
||||
|
||||
# Handoffs
|
||||
if output.name in handoff_map:
|
||||
items.append(HandoffCallItem(raw_item=output, agent=agent))
|
||||
|
|
@ -422,6 +416,7 @@ class RunImpl:
|
|||
handoffs=run_handoffs,
|
||||
functions=functions,
|
||||
computer_actions=computer_actions,
|
||||
tools_used=tools_used,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -155,6 +155,10 @@ class Agent(Generic[TContext]):
|
|||
web search, etc are always processed by the LLM.
|
||||
"""
|
||||
|
||||
reset_tool_choice: bool = True
|
||||
"""Whether to reset the tool choice to the default value after a tool has been called. Defaults
|
||||
to True. This ensures that the agent doesn't enter an infinite loop of tool usage."""
|
||||
|
||||
def clone(self, **kwargs: Any) -> Agent[TContext]:
|
||||
"""Make a copy of the agent, with the given arguments changed. For example, you could do:
|
||||
```
|
||||
|
|
|
|||
|
|
@ -208,8 +208,10 @@ class OpenAIResponsesModel(Model):
|
|||
list_input = ItemHelpers.input_to_new_input_list(input)
|
||||
|
||||
parallel_tool_calls = (
|
||||
True if model_settings.parallel_tool_calls and tools and len(tools) > 0
|
||||
else False if model_settings.parallel_tool_calls is False
|
||||
True
|
||||
if model_settings.parallel_tool_calls and tools and len(tools) > 0
|
||||
else False
|
||||
if model_settings.parallel_tool_calls is False
|
||||
else NOT_GIVEN
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from openai.types.responses import ResponseCompletedEvent
|
|||
from agents.tool import Tool
|
||||
|
||||
from ._run_impl import (
|
||||
AgentToolUseTracker,
|
||||
NextStepFinalOutput,
|
||||
NextStepHandoff,
|
||||
NextStepRunAgain,
|
||||
|
|
@ -151,6 +152,8 @@ class Runner:
|
|||
if run_config is None:
|
||||
run_config = RunConfig()
|
||||
|
||||
tool_use_tracker = AgentToolUseTracker()
|
||||
|
||||
with TraceCtxManager(
|
||||
workflow_name=run_config.workflow_name,
|
||||
trace_id=run_config.trace_id,
|
||||
|
|
@ -227,6 +230,7 @@ class Runner:
|
|||
context_wrapper=context_wrapper,
|
||||
run_config=run_config,
|
||||
should_run_agent_start_hooks=should_run_agent_start_hooks,
|
||||
tool_use_tracker=tool_use_tracker,
|
||||
),
|
||||
)
|
||||
else:
|
||||
|
|
@ -239,6 +243,7 @@ class Runner:
|
|||
context_wrapper=context_wrapper,
|
||||
run_config=run_config,
|
||||
should_run_agent_start_hooks=should_run_agent_start_hooks,
|
||||
tool_use_tracker=tool_use_tracker,
|
||||
)
|
||||
should_run_agent_start_hooks = False
|
||||
|
||||
|
|
@ -486,6 +491,7 @@ class Runner:
|
|||
current_agent = starting_agent
|
||||
current_turn = 0
|
||||
should_run_agent_start_hooks = True
|
||||
tool_use_tracker = AgentToolUseTracker()
|
||||
|
||||
streamed_result._event_queue.put_nowait(AgentUpdatedStreamEvent(new_agent=current_agent))
|
||||
|
||||
|
|
@ -546,6 +552,7 @@ class Runner:
|
|||
context_wrapper,
|
||||
run_config,
|
||||
should_run_agent_start_hooks,
|
||||
tool_use_tracker,
|
||||
)
|
||||
should_run_agent_start_hooks = False
|
||||
|
||||
|
|
@ -613,6 +620,7 @@ class Runner:
|
|||
context_wrapper: RunContextWrapper[TContext],
|
||||
run_config: RunConfig,
|
||||
should_run_agent_start_hooks: bool,
|
||||
tool_use_tracker: AgentToolUseTracker,
|
||||
) -> SingleStepResult:
|
||||
if should_run_agent_start_hooks:
|
||||
await asyncio.gather(
|
||||
|
|
@ -635,6 +643,8 @@ class Runner:
|
|||
all_tools = await cls._get_all_tools(agent)
|
||||
model = cls._get_model(agent, run_config)
|
||||
model_settings = agent.model_settings.resolve(run_config.model_settings)
|
||||
model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings)
|
||||
|
||||
final_response: ModelResponse | None = None
|
||||
|
||||
input = ItemHelpers.input_to_new_input_list(streamed_result.input)
|
||||
|
|
@ -687,6 +697,7 @@ class Runner:
|
|||
hooks=hooks,
|
||||
context_wrapper=context_wrapper,
|
||||
run_config=run_config,
|
||||
tool_use_tracker=tool_use_tracker,
|
||||
)
|
||||
|
||||
RunImpl.stream_step_result_to_queue(single_step_result, streamed_result._event_queue)
|
||||
|
|
@ -704,6 +715,7 @@ class Runner:
|
|||
context_wrapper: RunContextWrapper[TContext],
|
||||
run_config: RunConfig,
|
||||
should_run_agent_start_hooks: bool,
|
||||
tool_use_tracker: AgentToolUseTracker,
|
||||
) -> SingleStepResult:
|
||||
# Ensure we run the hooks before anything else
|
||||
if should_run_agent_start_hooks:
|
||||
|
|
@ -732,6 +744,7 @@ class Runner:
|
|||
handoffs,
|
||||
context_wrapper,
|
||||
run_config,
|
||||
tool_use_tracker,
|
||||
)
|
||||
|
||||
return await cls._get_single_step_result_from_response(
|
||||
|
|
@ -745,6 +758,7 @@ class Runner:
|
|||
hooks=hooks,
|
||||
context_wrapper=context_wrapper,
|
||||
run_config=run_config,
|
||||
tool_use_tracker=tool_use_tracker,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
|
@ -761,6 +775,7 @@ class Runner:
|
|||
hooks: RunHooks[TContext],
|
||||
context_wrapper: RunContextWrapper[TContext],
|
||||
run_config: RunConfig,
|
||||
tool_use_tracker: AgentToolUseTracker,
|
||||
) -> SingleStepResult:
|
||||
processed_response = RunImpl.process_model_response(
|
||||
agent=agent,
|
||||
|
|
@ -769,6 +784,9 @@ class Runner:
|
|||
output_schema=output_schema,
|
||||
handoffs=handoffs,
|
||||
)
|
||||
|
||||
tool_use_tracker.add_tool_use(agent, processed_response.tools_used)
|
||||
|
||||
return await RunImpl.execute_tools_and_side_effects(
|
||||
agent=agent,
|
||||
original_input=original_input,
|
||||
|
|
@ -868,9 +886,12 @@ class Runner:
|
|||
handoffs: list[Handoff],
|
||||
context_wrapper: RunContextWrapper[TContext],
|
||||
run_config: RunConfig,
|
||||
tool_use_tracker: AgentToolUseTracker,
|
||||
) -> ModelResponse:
|
||||
model = cls._get_model(agent, run_config)
|
||||
model_settings = agent.model_settings.resolve(run_config.model_settings)
|
||||
model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings)
|
||||
|
||||
new_response = await model.get_response(
|
||||
system_instructions=system_prompt,
|
||||
input=input,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from openai.types.responses import Response, ResponseCompletedEvent
|
||||
|
||||
|
|
@ -31,6 +32,7 @@ class FakeModel(Model):
|
|||
[initial_output] if initial_output else []
|
||||
)
|
||||
self.tracing_enabled = tracing_enabled
|
||||
self.last_turn_args: dict[str, Any] = {}
|
||||
|
||||
def set_next_output(self, output: list[TResponseOutputItem] | Exception):
|
||||
self.turn_outputs.append(output)
|
||||
|
|
@ -53,6 +55,14 @@ class FakeModel(Model):
|
|||
handoffs: list[Handoff],
|
||||
tracing: ModelTracing,
|
||||
) -> ModelResponse:
|
||||
self.last_turn_args = {
|
||||
"system_instructions": system_instructions,
|
||||
"input": input,
|
||||
"model_settings": model_settings,
|
||||
"tools": tools,
|
||||
"output_schema": output_schema,
|
||||
}
|
||||
|
||||
with generation_span(disabled=not self.tracing_enabled) as span:
|
||||
output = self.get_next_output()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,63 +1,78 @@
|
|||
import pytest
|
||||
|
||||
from agents import Agent, ModelSettings, Runner, Tool
|
||||
from agents._run_impl import RunImpl
|
||||
from agents import Agent, ModelSettings, Runner
|
||||
from agents._run_impl import AgentToolUseTracker, RunImpl
|
||||
|
||||
from .fake_model import FakeModel
|
||||
from .test_responses import (
|
||||
get_function_tool,
|
||||
get_function_tool_call,
|
||||
get_text_message,
|
||||
)
|
||||
from .test_responses import get_function_tool, get_function_tool_call, get_text_message
|
||||
|
||||
|
||||
class TestToolChoiceReset:
|
||||
|
||||
def test_should_reset_tool_choice_direct(self):
|
||||
"""
|
||||
Test the _should_reset_tool_choice method directly with various inputs
|
||||
to ensure it correctly identifies cases where reset is needed.
|
||||
"""
|
||||
# Case 1: tool_choice = None should not reset
|
||||
agent = Agent(name="test_agent")
|
||||
|
||||
# Case 1: Empty tool use tracker should not change the "None" tool choice
|
||||
model_settings = ModelSettings(tool_choice=None)
|
||||
tools1: list[Tool] = [get_function_tool("tool1")]
|
||||
# Cast to list[Tool] to fix type checking issues
|
||||
assert not RunImpl._should_reset_tool_choice(model_settings, tools1)
|
||||
tracker = AgentToolUseTracker()
|
||||
new_settings = RunImpl.maybe_reset_tool_choice(agent, tracker, model_settings)
|
||||
assert new_settings.tool_choice == model_settings.tool_choice
|
||||
|
||||
# Case 2: tool_choice = "auto" should not reset
|
||||
# Case 2: Empty tool use tracker should not change the "auto" tool choice
|
||||
model_settings = ModelSettings(tool_choice="auto")
|
||||
assert not RunImpl._should_reset_tool_choice(model_settings, tools1)
|
||||
tracker = AgentToolUseTracker()
|
||||
new_settings = RunImpl.maybe_reset_tool_choice(agent, tracker, model_settings)
|
||||
assert model_settings.tool_choice == new_settings.tool_choice
|
||||
|
||||
# Case 3: tool_choice = "none" should not reset
|
||||
model_settings = ModelSettings(tool_choice="none")
|
||||
assert not RunImpl._should_reset_tool_choice(model_settings, tools1)
|
||||
# Case 3: Empty tool use tracker should not change the "required" tool choice
|
||||
model_settings = ModelSettings(tool_choice="required")
|
||||
tracker = AgentToolUseTracker()
|
||||
new_settings = RunImpl.maybe_reset_tool_choice(agent, tracker, model_settings)
|
||||
assert model_settings.tool_choice == new_settings.tool_choice
|
||||
|
||||
# Case 4: tool_choice = "required" with one tool should reset
|
||||
model_settings = ModelSettings(tool_choice="required")
|
||||
assert RunImpl._should_reset_tool_choice(model_settings, tools1)
|
||||
tracker = AgentToolUseTracker()
|
||||
tracker.add_tool_use(agent, ["tool1"])
|
||||
new_settings = RunImpl.maybe_reset_tool_choice(agent, tracker, model_settings)
|
||||
assert new_settings.tool_choice is None
|
||||
|
||||
# Case 5: tool_choice = "required" with multiple tools should not reset
|
||||
# Case 5: tool_choice = "required" with multiple tools should reset
|
||||
model_settings = ModelSettings(tool_choice="required")
|
||||
tools2: list[Tool] = [get_function_tool("tool1"), get_function_tool("tool2")]
|
||||
assert not RunImpl._should_reset_tool_choice(model_settings, tools2)
|
||||
tracker = AgentToolUseTracker()
|
||||
tracker.add_tool_use(agent, ["tool1", "tool2"])
|
||||
new_settings = RunImpl.maybe_reset_tool_choice(agent, tracker, model_settings)
|
||||
assert new_settings.tool_choice is None
|
||||
|
||||
# Case 6: Specific tool choice should reset
|
||||
model_settings = ModelSettings(tool_choice="specific_tool")
|
||||
assert RunImpl._should_reset_tool_choice(model_settings, tools1)
|
||||
# Case 6: Tool usage on a different agent should not affect the tool choice
|
||||
model_settings = ModelSettings(tool_choice="foo_bar")
|
||||
tracker = AgentToolUseTracker()
|
||||
tracker.add_tool_use(Agent(name="other_agent"), ["foo_bar", "baz"])
|
||||
new_settings = RunImpl.maybe_reset_tool_choice(agent, tracker, model_settings)
|
||||
assert new_settings.tool_choice == model_settings.tool_choice
|
||||
|
||||
# Case 7: tool_choice = "foo_bar" with multiple tools should reset
|
||||
model_settings = ModelSettings(tool_choice="foo_bar")
|
||||
tracker = AgentToolUseTracker()
|
||||
tracker.add_tool_use(agent, ["foo_bar", "baz"])
|
||||
new_settings = RunImpl.maybe_reset_tool_choice(agent, tracker, model_settings)
|
||||
assert new_settings.tool_choice is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_required_tool_choice_with_multiple_runs(self):
|
||||
"""
|
||||
Test scenario 1: When multiple runs are executed with tool_choice="required"
|
||||
Ensure each run works correctly and doesn't get stuck in infinite loop
|
||||
Also verify that tool_choice remains "required" between runs
|
||||
Test scenario 1: When multiple runs are executed with tool_choice="required", ensure each
|
||||
run works correctly and doesn't get stuck in an infinite loop. Also verify that tool_choice
|
||||
remains "required" between runs.
|
||||
"""
|
||||
# Set up our fake model with responses for two runs
|
||||
fake_model = FakeModel()
|
||||
fake_model.add_multiple_turn_outputs([
|
||||
[get_text_message("First run response")],
|
||||
[get_text_message("Second run response")]
|
||||
])
|
||||
fake_model.add_multiple_turn_outputs(
|
||||
[[get_text_message("First run response")], [get_text_message("Second run response")]]
|
||||
)
|
||||
|
||||
# Create agent with a custom tool and tool_choice="required"
|
||||
custom_tool = get_function_tool("custom_tool")
|
||||
|
|
@ -71,24 +86,26 @@ class TestToolChoiceReset:
|
|||
# First run should work correctly and preserve tool_choice
|
||||
result1 = await Runner.run(agent, "first run")
|
||||
assert result1.final_output == "First run response"
|
||||
assert agent.model_settings.tool_choice == "required", "tool_choice should stay required"
|
||||
assert fake_model.last_turn_args["model_settings"].tool_choice == "required", (
|
||||
"tool_choice should stay required"
|
||||
)
|
||||
|
||||
# Second run should also work correctly with tool_choice still required
|
||||
result2 = await Runner.run(agent, "second run")
|
||||
assert result2.final_output == "Second run response"
|
||||
assert agent.model_settings.tool_choice == "required", "tool_choice should stay required"
|
||||
assert fake_model.last_turn_args["model_settings"].tool_choice == "required", (
|
||||
"tool_choice should stay required"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_required_with_stop_at_tool_name(self):
|
||||
"""
|
||||
Test scenario 2: When using required tool_choice with stop_at_tool_names behavior
|
||||
Ensure it correctly stops at the specified tool
|
||||
Test scenario 2: When using required tool_choice with stop_at_tool_names behavior, ensure
|
||||
it correctly stops at the specified tool
|
||||
"""
|
||||
# Set up fake model to return a tool call for second_tool
|
||||
fake_model = FakeModel()
|
||||
fake_model.set_next_output([
|
||||
get_function_tool_call("second_tool", "{}")
|
||||
])
|
||||
fake_model.set_next_output([get_function_tool_call("second_tool", "{}")])
|
||||
|
||||
# Create agent with two tools and tool_choice="required" and stop_at_tool behavior
|
||||
first_tool = get_function_tool("first_tool", return_value="first tool result")
|
||||
|
|
@ -109,8 +126,8 @@ class TestToolChoiceReset:
|
|||
@pytest.mark.asyncio
|
||||
async def test_specific_tool_choice(self):
|
||||
"""
|
||||
Test scenario 3: When using a specific tool choice name
|
||||
Ensure it doesn't cause infinite loops
|
||||
Test scenario 3: When using a specific tool choice name, ensure it doesn't cause infinite
|
||||
loops.
|
||||
"""
|
||||
# Set up fake model to return a text message
|
||||
fake_model = FakeModel()
|
||||
|
|
@ -135,17 +152,19 @@ class TestToolChoiceReset:
|
|||
@pytest.mark.asyncio
|
||||
async def test_required_with_single_tool(self):
|
||||
"""
|
||||
Test scenario 4: When using required tool_choice with only one tool
|
||||
Ensure it doesn't cause infinite loops
|
||||
Test scenario 4: When using required tool_choice with only one tool, ensure it doesn't cause
|
||||
infinite loops.
|
||||
"""
|
||||
# Set up fake model to return a tool call followed by a text message
|
||||
fake_model = FakeModel()
|
||||
fake_model.add_multiple_turn_outputs([
|
||||
# First call returns a tool call
|
||||
[get_function_tool_call("custom_tool", "{}")],
|
||||
# Second call returns a text message
|
||||
[get_text_message("Final response")]
|
||||
])
|
||||
fake_model.add_multiple_turn_outputs(
|
||||
[
|
||||
# First call returns a tool call
|
||||
[get_function_tool_call("custom_tool", "{}")],
|
||||
# Second call returns a text message
|
||||
[get_text_message("Final response")],
|
||||
]
|
||||
)
|
||||
|
||||
# Create agent with a single tool and tool_choice="required"
|
||||
custom_tool = get_function_tool("custom_tool", return_value="tool result")
|
||||
|
|
@ -159,3 +178,33 @@ class TestToolChoiceReset:
|
|||
# Run should complete without infinite loops
|
||||
result = await Runner.run(agent, "first run")
|
||||
assert result.final_output == "Final response"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dont_reset_tool_choice_if_not_required(self):
|
||||
"""
|
||||
Test scenario 5: When agent.reset_tool_choice is False, ensure tool_choice is not reset.
|
||||
"""
|
||||
# Set up fake model to return a tool call followed by a text message
|
||||
fake_model = FakeModel()
|
||||
fake_model.add_multiple_turn_outputs(
|
||||
[
|
||||
# First call returns a tool call
|
||||
[get_function_tool_call("custom_tool", "{}")],
|
||||
# Second call returns a text message
|
||||
[get_text_message("Final response")],
|
||||
]
|
||||
)
|
||||
|
||||
# Create agent with a single tool and tool_choice="required" and reset_tool_choice=False
|
||||
custom_tool = get_function_tool("custom_tool", return_value="tool result")
|
||||
agent = Agent(
|
||||
name="test_agent",
|
||||
model=fake_model,
|
||||
tools=[custom_tool],
|
||||
model_settings=ModelSettings(tool_choice="required"),
|
||||
reset_tool_choice=False,
|
||||
)
|
||||
|
||||
await Runner.run(agent, "test")
|
||||
|
||||
assert fake_model.last_turn_args["model_settings"].tool_choice == "required"
|
||||
|
|
|
|||
Loading…
Reference in a new issue