fix: optimize tool_choice reset logic and fix lint errors

- Refactor tool_choice reset to target only problematic edge cases
- Replace manual ModelSettings recreation with dataclasses.replace
- Fix line length and error handling lint issues in tests
This commit is contained in:
xianghuijin 2025-03-22 14:10:09 +08:00
parent d169d79288
commit bbcda753df
2 changed files with 137 additions and 147 deletions

View file

@ -1,6 +1,7 @@
from __future__ import annotations
import asyncio
import dataclasses
import inspect
from collections.abc import Awaitable
from dataclasses import dataclass
@ -51,7 +52,7 @@ from .model_settings import ModelSettings
from .models.interface import ModelTracing
from .run_context import RunContextWrapper, TContext
from .stream_events import RunItemStreamEvent, StreamEvent
from .tool import ComputerTool, FunctionTool, FunctionToolResult
from .tool import ComputerTool, FunctionTool, FunctionToolResult, Tool
from .tracing import (
SpanError,
Trace,
@ -208,34 +209,22 @@ class RunImpl:
new_step_items.extend(computer_results)
# Reset tool_choice to "auto" after tool execution to prevent infinite loops
if (processed_response.functions or processed_response.computer_actions):
# Reset agent's model_settings
if agent.model_settings.tool_choice == "required" or isinstance(agent.model_settings.tool_choice, str):
# Create a new model_settings to avoid modifying the original shared instance
agent.model_settings = ModelSettings(
temperature=agent.model_settings.temperature,
top_p=agent.model_settings.top_p,
frequency_penalty=agent.model_settings.frequency_penalty,
presence_penalty=agent.model_settings.presence_penalty,
tool_choice="auto", # Reset to auto
parallel_tool_calls=agent.model_settings.parallel_tool_calls,
truncation=agent.model_settings.truncation,
max_tokens=agent.model_settings.max_tokens,
if processed_response.functions or processed_response.computer_actions:
tools = agent.tools
# Only reset in the problematic scenarios where loops are likely unintentional
if cls._should_reset_tool_choice(agent.model_settings, tools):
agent.model_settings = dataclasses.replace(
agent.model_settings,
tool_choice="auto"
)
# Also reset run_config's model_settings if it exists
if run_config.model_settings and (run_config.model_settings.tool_choice == "required" or
isinstance(run_config.model_settings.tool_choice, str)):
# Create a new model_settings for run_config
run_config.model_settings = ModelSettings(
temperature=run_config.model_settings.temperature,
top_p=run_config.model_settings.top_p,
frequency_penalty=run_config.model_settings.frequency_penalty,
presence_penalty=run_config.model_settings.presence_penalty,
tool_choice="auto", # Reset to auto
parallel_tool_calls=run_config.model_settings.parallel_tool_calls,
truncation=run_config.model_settings.truncation,
max_tokens=run_config.model_settings.max_tokens,
if (
run_config.model_settings and
cls._should_reset_tool_choice(run_config.model_settings, tools)
):
run_config.model_settings = dataclasses.replace(
run_config.model_settings,
tool_choice="auto"
)
# Second, check if there are any handoffs
@ -328,6 +317,24 @@ class RunImpl:
next_step=NextStepRunAgain(),
)
@classmethod
def _should_reset_tool_choice(cls, model_settings: ModelSettings, tools: list[Tool]) -> bool:
if model_settings is None or model_settings.tool_choice is None:
return False
# for specific tool choices
if (
isinstance(model_settings.tool_choice, str) and
model_settings.tool_choice not in ["auto", "required", "none"]
):
return True
# for one tool and required tool choice
if model_settings.tool_choice == "required":
return len(tools) == 1
return False
@classmethod
def process_model_response(
cls,

View file

@ -1,13 +1,15 @@
from unittest import mock
import asyncio
import dataclasses
import json
from typing import List
from unittest import mock
from agents import Agent, ModelSettings, RunConfig, function_tool, Runner
from agents.models.interface import ModelResponse
from agents.items import Usage
from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall
from agents import Agent, ModelSettings, RunConfig, Runner, function_tool
from agents.items import Usage
from agents.models.interface import ModelResponse
from agents.tool import Tool
@function_tool
def echo(text: str) -> str:
@ -15,22 +17,39 @@ def echo(text: str) -> str:
return text
def should_reset_tool_choice(model_settings: ModelSettings, tools: list[Tool]) -> bool:
if model_settings is None or model_settings.tool_choice is None:
return False
# for specific tool choices
if (
isinstance(model_settings.tool_choice, str) and
model_settings.tool_choice not in ["auto", "required", "none"]
):
return True
# for one tool and required tool choice
if model_settings.tool_choice == "required":
return len(tools) == 1
return False
# Mock model implementation that always calls tools when tool_choice is set
class MockModel:
def __init__(self, tool_call_counter):
self.tool_call_counter = tool_call_counter
async def get_response(self, **kwargs):
tools = kwargs.get("tools", [])
model_settings = kwargs.get("model_settings")
# Increment the counter to track how many times this model is called
self.tool_call_counter["count"] += 1
# If we've been called many times, we're likely in an infinite loop
if self.tool_call_counter["count"] > 5:
self.tool_call_counter["potential_infinite_loop"] = True
# Always create a tool call if tool_choice is required/specific
tool_calls = []
if model_settings and model_settings.tool_choice:
@ -46,7 +65,7 @@ class MockModel:
type="function_call",
)
)
return ModelResponse(
output=tool_calls,
referenceable_id="123",
@ -60,7 +79,7 @@ class TestToolChoiceReset:
# Create an agent with tool_choice="required"
agent = Agent(
name="Test agent",
tools=[echo],
tools=[echo], # Only one tool
model_settings=ModelSettings(tool_choice="required"),
)
@ -77,31 +96,22 @@ class TestToolChoiceReset:
# Execute our code under test
if processed_response.functions:
# Reset agent's model_settings
if agent.model_settings.tool_choice == "required" or isinstance(agent.model_settings.tool_choice, str):
agent.model_settings = ModelSettings(
temperature=agent.model_settings.temperature,
top_p=agent.model_settings.top_p,
frequency_penalty=agent.model_settings.frequency_penalty,
presence_penalty=agent.model_settings.presence_penalty,
tool_choice="auto", # Reset to auto
parallel_tool_calls=agent.model_settings.parallel_tool_calls,
truncation=agent.model_settings.truncation,
max_tokens=agent.model_settings.max_tokens,
# Apply the targeted reset logic
tools = agent.tools
if should_reset_tool_choice(agent.model_settings, tools):
agent.model_settings = dataclasses.replace(
agent.model_settings,
tool_choice="auto" # Reset to auto
)
# Also reset run_config's model_settings if it exists
if run_config.model_settings and (run_config.model_settings.tool_choice == "required" or
isinstance(run_config.model_settings.tool_choice, str)):
run_config.model_settings = ModelSettings(
temperature=run_config.model_settings.temperature,
top_p=run_config.model_settings.top_p,
frequency_penalty=run_config.model_settings.frequency_penalty,
presence_penalty=run_config.model_settings.presence_penalty,
tool_choice="auto", # Reset to auto
parallel_tool_calls=run_config.model_settings.parallel_tool_calls,
truncation=run_config.model_settings.truncation,
max_tokens=run_config.model_settings.max_tokens,
if (
run_config.model_settings and
should_reset_tool_choice(run_config.model_settings, tools)
):
run_config.model_settings = dataclasses.replace(
run_config.model_settings,
tool_choice="auto" # Reset to auto
)
# Check that tool_choice was reset to "auto"
@ -115,7 +125,7 @@ class TestToolChoiceReset:
instructions="You are a test agent",
tools=[echo],
model="gpt-4-0125-preview",
model_settings=ModelSettings(tool_choice="echo"),
model_settings=ModelSettings(tool_choice="echo"), # Specific function name
)
# Execute our code under test
@ -129,31 +139,22 @@ class TestToolChoiceReset:
# Execute our code under test
if processed_response.functions:
# Reset agent's model_settings
if agent.model_settings.tool_choice == "required" or isinstance(agent.model_settings.tool_choice, str):
agent.model_settings = ModelSettings(
temperature=agent.model_settings.temperature,
top_p=agent.model_settings.top_p,
frequency_penalty=agent.model_settings.frequency_penalty,
presence_penalty=agent.model_settings.presence_penalty,
tool_choice="auto", # Reset to auto
parallel_tool_calls=agent.model_settings.parallel_tool_calls,
truncation=agent.model_settings.truncation,
max_tokens=agent.model_settings.max_tokens,
# Apply the targeted reset logic
tools = agent.tools
if should_reset_tool_choice(agent.model_settings, tools):
agent.model_settings = dataclasses.replace(
agent.model_settings,
tool_choice="auto" # Reset to auto
)
# Also reset run_config's model_settings if it exists
if run_config.model_settings and (run_config.model_settings.tool_choice == "required" or
isinstance(run_config.model_settings.tool_choice, str)):
run_config.model_settings = ModelSettings(
temperature=run_config.model_settings.temperature,
top_p=run_config.model_settings.top_p,
frequency_penalty=run_config.model_settings.frequency_penalty,
presence_penalty=run_config.model_settings.presence_penalty,
tool_choice="auto", # Reset to auto
parallel_tool_calls=run_config.model_settings.parallel_tool_calls,
truncation=run_config.model_settings.truncation,
max_tokens=run_config.model_settings.max_tokens,
if (
run_config.model_settings and
should_reset_tool_choice(run_config.model_settings, tools)
):
run_config.model_settings = dataclasses.replace(
run_config.model_settings,
tool_choice="auto" # Reset to auto
)
# Check that tool_choice was reset to "auto"
@ -179,49 +180,40 @@ class TestToolChoiceReset:
# Execute our code under test
if processed_response.functions:
# Reset agent's model_settings
if agent.model_settings.tool_choice == "required" or isinstance(agent.model_settings.tool_choice, str):
agent.model_settings = ModelSettings(
temperature=agent.model_settings.temperature,
top_p=agent.model_settings.top_p,
frequency_penalty=agent.model_settings.frequency_penalty,
presence_penalty=agent.model_settings.presence_penalty,
tool_choice="auto", # Reset to auto
parallel_tool_calls=agent.model_settings.parallel_tool_calls,
truncation=agent.model_settings.truncation,
max_tokens=agent.model_settings.max_tokens,
# Apply the targeted reset logic
tools = agent.tools
if should_reset_tool_choice(agent.model_settings, tools):
agent.model_settings = dataclasses.replace(
agent.model_settings,
tool_choice="auto" # Reset to auto
)
# Also reset run_config's model_settings if it exists
if run_config.model_settings and (run_config.model_settings.tool_choice == "required" or
isinstance(run_config.model_settings.tool_choice, str)):
run_config.model_settings = ModelSettings(
temperature=run_config.model_settings.temperature,
top_p=run_config.model_settings.top_p,
frequency_penalty=run_config.model_settings.frequency_penalty,
presence_penalty=run_config.model_settings.presence_penalty,
tool_choice="auto", # Reset to auto
parallel_tool_calls=run_config.model_settings.parallel_tool_calls,
truncation=run_config.model_settings.truncation,
max_tokens=run_config.model_settings.max_tokens,
if (
run_config.model_settings and
should_reset_tool_choice(run_config.model_settings, tools)
):
run_config.model_settings = dataclasses.replace(
run_config.model_settings,
tool_choice="auto" # Reset to auto
)
# Check that tool_choice remains "auto"
assert agent.model_settings.tool_choice == "auto"
async def test_run_config_tool_choice_reset(self):
"""Test that run_config.model_settings.tool_choice is reset to 'auto'"""
# Create an agent with default model_settings
agent = Agent(
name="Test agent",
tools=[echo],
tools=[echo], # Only one tool
model_settings=ModelSettings(tool_choice=None),
)
# Create a run_config with tool_choice="required"
run_config = RunConfig()
run_config.model_settings = ModelSettings(tool_choice="required")
# Execute our code under test
processed_response = mock.MagicMock()
processed_response.functions = [mock.MagicMock()] # At least one function call
@ -229,47 +221,38 @@ class TestToolChoiceReset:
# Execute our code under test
if processed_response.functions:
# Reset agent's model_settings
if agent.model_settings.tool_choice == "required" or isinstance(agent.model_settings.tool_choice, str):
agent.model_settings = ModelSettings(
temperature=agent.model_settings.temperature,
top_p=agent.model_settings.top_p,
frequency_penalty=agent.model_settings.frequency_penalty,
presence_penalty=agent.model_settings.presence_penalty,
tool_choice="auto", # Reset to auto
parallel_tool_calls=agent.model_settings.parallel_tool_calls,
truncation=agent.model_settings.truncation,
max_tokens=agent.model_settings.max_tokens,
# Apply the targeted reset logic
tools = agent.tools
if should_reset_tool_choice(agent.model_settings, tools):
agent.model_settings = dataclasses.replace(
agent.model_settings,
tool_choice="auto" # Reset to auto
)
# Also reset run_config's model_settings if it exists
if run_config.model_settings and (run_config.model_settings.tool_choice == "required" or
isinstance(run_config.model_settings.tool_choice, str)):
run_config.model_settings = ModelSettings(
temperature=run_config.model_settings.temperature,
top_p=run_config.model_settings.top_p,
frequency_penalty=run_config.model_settings.frequency_penalty,
presence_penalty=run_config.model_settings.presence_penalty,
tool_choice="auto", # Reset to auto
parallel_tool_calls=run_config.model_settings.parallel_tool_calls,
truncation=run_config.model_settings.truncation,
max_tokens=run_config.model_settings.max_tokens,
if (
run_config.model_settings and
should_reset_tool_choice(run_config.model_settings, tools)
):
run_config.model_settings = dataclasses.replace(
run_config.model_settings,
tool_choice="auto" # Reset to auto
)
# Check that run_config's tool_choice was reset to "auto"
assert run_config.model_settings.tool_choice == "auto"
@mock.patch("agents.run.Runner._get_model")
async def test_integration_prevents_infinite_loop(self, mock_get_model):
"""Integration test to verify that tool_choice reset prevents infinite loops"""
# Create a counter to track model calls and detect potential infinite loops
tool_call_counter = {"count": 0, "potential_infinite_loop": False}
# Set up our mock model that will always use tools when tool_choice is set
mock_model_instance = MockModel(tool_call_counter)
# Return our mock model directly
mock_get_model.return_value = mock_model_instance
# Create an agent with tool_choice="required" to force tool usage
agent = Agent(
name="Test agent",
@ -280,24 +263,24 @@ class TestToolChoiceReset:
# This would cause infinite loops without the tool_choice reset
tool_use_behavior="run_llm_again",
)
# Set a timeout to catch potential infinite loops that our fix doesn't address
try:
# Run the agent with a timeout
async def run_with_timeout():
return await Runner.run(agent, input="Test input")
result = await asyncio.wait_for(run_with_timeout(), timeout=2.0)
# Verify the agent ran successfully
assert result is not None
# Verify the tool was called at least once but not too many times
# (indicating no infinite loop)
assert tool_call_counter["count"] >= 1
assert tool_call_counter["count"] < 5
assert not tool_call_counter["potential_infinite_loop"]
except asyncio.TimeoutError:
except asyncio.TimeoutError as err:
# If we hit a timeout, the test failed - we likely have an infinite loop
assert False, "Timeout occurred, potential infinite loop detected"
raise AssertionError("Timeout occurred, potential infinite loop detected") from err