diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index a60ae1d..1f896d7 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import dataclasses import inspect from collections.abc import Awaitable from dataclasses import dataclass @@ -51,7 +52,7 @@ from .model_settings import ModelSettings from .models.interface import ModelTracing from .run_context import RunContextWrapper, TContext from .stream_events import RunItemStreamEvent, StreamEvent -from .tool import ComputerTool, FunctionTool, FunctionToolResult +from .tool import ComputerTool, FunctionTool, FunctionToolResult, Tool from .tracing import ( SpanError, Trace, @@ -208,34 +209,22 @@ class RunImpl: new_step_items.extend(computer_results) # Reset tool_choice to "auto" after tool execution to prevent infinite loops - if (processed_response.functions or processed_response.computer_actions): - # Reset agent's model_settings - if agent.model_settings.tool_choice == "required" or isinstance(agent.model_settings.tool_choice, str): - # Create a new model_settings to avoid modifying the original shared instance - agent.model_settings = ModelSettings( - temperature=agent.model_settings.temperature, - top_p=agent.model_settings.top_p, - frequency_penalty=agent.model_settings.frequency_penalty, - presence_penalty=agent.model_settings.presence_penalty, - tool_choice="auto", # Reset to auto - parallel_tool_calls=agent.model_settings.parallel_tool_calls, - truncation=agent.model_settings.truncation, - max_tokens=agent.model_settings.max_tokens, + if processed_response.functions or processed_response.computer_actions: + tools = agent.tools + # Only reset in the problematic scenarios where loops are likely unintentional + if cls._should_reset_tool_choice(agent.model_settings, tools): + agent.model_settings = dataclasses.replace( + agent.model_settings, + tool_choice="auto" ) - - # Also reset run_config's model_settings if it exists - if run_config.model_settings and (run_config.model_settings.tool_choice == "required" or - isinstance(run_config.model_settings.tool_choice, str)): - # Create a new model_settings for run_config - run_config.model_settings = ModelSettings( - temperature=run_config.model_settings.temperature, - top_p=run_config.model_settings.top_p, - frequency_penalty=run_config.model_settings.frequency_penalty, - presence_penalty=run_config.model_settings.presence_penalty, - tool_choice="auto", # Reset to auto - parallel_tool_calls=run_config.model_settings.parallel_tool_calls, - truncation=run_config.model_settings.truncation, - max_tokens=run_config.model_settings.max_tokens, + + if ( + run_config.model_settings and + cls._should_reset_tool_choice(run_config.model_settings, tools) + ): + run_config.model_settings = dataclasses.replace( + run_config.model_settings, + tool_choice="auto" ) # Second, check if there are any handoffs @@ -328,6 +317,24 @@ class RunImpl: next_step=NextStepRunAgain(), ) + @classmethod + def _should_reset_tool_choice(cls, model_settings: ModelSettings, tools: list[Tool]) -> bool: + if model_settings is None or model_settings.tool_choice is None: + return False + + # for specific tool choices + if ( + isinstance(model_settings.tool_choice, str) and + model_settings.tool_choice not in ["auto", "required", "none"] + ): + return True + + # for one tool and required tool choice + if model_settings.tool_choice == "required": + return len(tools) == 1 + + return False + @classmethod def process_model_response( cls, diff --git a/tests/test_tool_choice_reset.py b/tests/test_tool_choice_reset.py index e01a5f0..b47c4d9 100644 --- a/tests/test_tool_choice_reset.py +++ b/tests/test_tool_choice_reset.py @@ -1,13 +1,15 @@ -from unittest import mock import asyncio +import dataclasses import json -from typing import List +from unittest import mock -from agents import Agent, ModelSettings, RunConfig, function_tool, Runner -from agents.models.interface import ModelResponse -from agents.items import Usage from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall +from agents import Agent, ModelSettings, RunConfig, Runner, function_tool +from agents.items import Usage +from agents.models.interface import ModelResponse +from agents.tool import Tool + @function_tool def echo(text: str) -> str: @@ -15,22 +17,39 @@ def echo(text: str) -> str: return text +def should_reset_tool_choice(model_settings: ModelSettings, tools: list[Tool]) -> bool: + if model_settings is None or model_settings.tool_choice is None: + return False + + # for specific tool choices + if ( + isinstance(model_settings.tool_choice, str) and + model_settings.tool_choice not in ["auto", "required", "none"] + ): + return True + + # for one tool and required tool choice + if model_settings.tool_choice == "required": + return len(tools) == 1 + + return False + # Mock model implementation that always calls tools when tool_choice is set class MockModel: def __init__(self, tool_call_counter): self.tool_call_counter = tool_call_counter - + async def get_response(self, **kwargs): tools = kwargs.get("tools", []) model_settings = kwargs.get("model_settings") - + # Increment the counter to track how many times this model is called self.tool_call_counter["count"] += 1 - + # If we've been called many times, we're likely in an infinite loop if self.tool_call_counter["count"] > 5: self.tool_call_counter["potential_infinite_loop"] = True - + # Always create a tool call if tool_choice is required/specific tool_calls = [] if model_settings and model_settings.tool_choice: @@ -46,7 +65,7 @@ class MockModel: type="function_call", ) ) - + return ModelResponse( output=tool_calls, referenceable_id="123", @@ -60,7 +79,7 @@ class TestToolChoiceReset: # Create an agent with tool_choice="required" agent = Agent( name="Test agent", - tools=[echo], + tools=[echo], # Only one tool model_settings=ModelSettings(tool_choice="required"), ) @@ -77,31 +96,22 @@ class TestToolChoiceReset: # Execute our code under test if processed_response.functions: - # Reset agent's model_settings - if agent.model_settings.tool_choice == "required" or isinstance(agent.model_settings.tool_choice, str): - agent.model_settings = ModelSettings( - temperature=agent.model_settings.temperature, - top_p=agent.model_settings.top_p, - frequency_penalty=agent.model_settings.frequency_penalty, - presence_penalty=agent.model_settings.presence_penalty, - tool_choice="auto", # Reset to auto - parallel_tool_calls=agent.model_settings.parallel_tool_calls, - truncation=agent.model_settings.truncation, - max_tokens=agent.model_settings.max_tokens, + # Apply the targeted reset logic + tools = agent.tools + if should_reset_tool_choice(agent.model_settings, tools): + agent.model_settings = dataclasses.replace( + agent.model_settings, + tool_choice="auto" # Reset to auto ) - + # Also reset run_config's model_settings if it exists - if run_config.model_settings and (run_config.model_settings.tool_choice == "required" or - isinstance(run_config.model_settings.tool_choice, str)): - run_config.model_settings = ModelSettings( - temperature=run_config.model_settings.temperature, - top_p=run_config.model_settings.top_p, - frequency_penalty=run_config.model_settings.frequency_penalty, - presence_penalty=run_config.model_settings.presence_penalty, - tool_choice="auto", # Reset to auto - parallel_tool_calls=run_config.model_settings.parallel_tool_calls, - truncation=run_config.model_settings.truncation, - max_tokens=run_config.model_settings.max_tokens, + if ( + run_config.model_settings and + should_reset_tool_choice(run_config.model_settings, tools) + ): + run_config.model_settings = dataclasses.replace( + run_config.model_settings, + tool_choice="auto" # Reset to auto ) # Check that tool_choice was reset to "auto" @@ -115,7 +125,7 @@ class TestToolChoiceReset: instructions="You are a test agent", tools=[echo], model="gpt-4-0125-preview", - model_settings=ModelSettings(tool_choice="echo"), + model_settings=ModelSettings(tool_choice="echo"), # Specific function name ) # Execute our code under test @@ -129,31 +139,22 @@ class TestToolChoiceReset: # Execute our code under test if processed_response.functions: - # Reset agent's model_settings - if agent.model_settings.tool_choice == "required" or isinstance(agent.model_settings.tool_choice, str): - agent.model_settings = ModelSettings( - temperature=agent.model_settings.temperature, - top_p=agent.model_settings.top_p, - frequency_penalty=agent.model_settings.frequency_penalty, - presence_penalty=agent.model_settings.presence_penalty, - tool_choice="auto", # Reset to auto - parallel_tool_calls=agent.model_settings.parallel_tool_calls, - truncation=agent.model_settings.truncation, - max_tokens=agent.model_settings.max_tokens, + # Apply the targeted reset logic + tools = agent.tools + if should_reset_tool_choice(agent.model_settings, tools): + agent.model_settings = dataclasses.replace( + agent.model_settings, + tool_choice="auto" # Reset to auto ) - + # Also reset run_config's model_settings if it exists - if run_config.model_settings and (run_config.model_settings.tool_choice == "required" or - isinstance(run_config.model_settings.tool_choice, str)): - run_config.model_settings = ModelSettings( - temperature=run_config.model_settings.temperature, - top_p=run_config.model_settings.top_p, - frequency_penalty=run_config.model_settings.frequency_penalty, - presence_penalty=run_config.model_settings.presence_penalty, - tool_choice="auto", # Reset to auto - parallel_tool_calls=run_config.model_settings.parallel_tool_calls, - truncation=run_config.model_settings.truncation, - max_tokens=run_config.model_settings.max_tokens, + if ( + run_config.model_settings and + should_reset_tool_choice(run_config.model_settings, tools) + ): + run_config.model_settings = dataclasses.replace( + run_config.model_settings, + tool_choice="auto" # Reset to auto ) # Check that tool_choice was reset to "auto" @@ -179,49 +180,40 @@ class TestToolChoiceReset: # Execute our code under test if processed_response.functions: - # Reset agent's model_settings - if agent.model_settings.tool_choice == "required" or isinstance(agent.model_settings.tool_choice, str): - agent.model_settings = ModelSettings( - temperature=agent.model_settings.temperature, - top_p=agent.model_settings.top_p, - frequency_penalty=agent.model_settings.frequency_penalty, - presence_penalty=agent.model_settings.presence_penalty, - tool_choice="auto", # Reset to auto - parallel_tool_calls=agent.model_settings.parallel_tool_calls, - truncation=agent.model_settings.truncation, - max_tokens=agent.model_settings.max_tokens, + # Apply the targeted reset logic + tools = agent.tools + if should_reset_tool_choice(agent.model_settings, tools): + agent.model_settings = dataclasses.replace( + agent.model_settings, + tool_choice="auto" # Reset to auto ) - + # Also reset run_config's model_settings if it exists - if run_config.model_settings and (run_config.model_settings.tool_choice == "required" or - isinstance(run_config.model_settings.tool_choice, str)): - run_config.model_settings = ModelSettings( - temperature=run_config.model_settings.temperature, - top_p=run_config.model_settings.top_p, - frequency_penalty=run_config.model_settings.frequency_penalty, - presence_penalty=run_config.model_settings.presence_penalty, - tool_choice="auto", # Reset to auto - parallel_tool_calls=run_config.model_settings.parallel_tool_calls, - truncation=run_config.model_settings.truncation, - max_tokens=run_config.model_settings.max_tokens, + if ( + run_config.model_settings and + should_reset_tool_choice(run_config.model_settings, tools) + ): + run_config.model_settings = dataclasses.replace( + run_config.model_settings, + tool_choice="auto" # Reset to auto ) # Check that tool_choice remains "auto" assert agent.model_settings.tool_choice == "auto" - + async def test_run_config_tool_choice_reset(self): """Test that run_config.model_settings.tool_choice is reset to 'auto'""" # Create an agent with default model_settings agent = Agent( name="Test agent", - tools=[echo], + tools=[echo], # Only one tool model_settings=ModelSettings(tool_choice=None), ) - + # Create a run_config with tool_choice="required" run_config = RunConfig() run_config.model_settings = ModelSettings(tool_choice="required") - + # Execute our code under test processed_response = mock.MagicMock() processed_response.functions = [mock.MagicMock()] # At least one function call @@ -229,47 +221,38 @@ class TestToolChoiceReset: # Execute our code under test if processed_response.functions: - # Reset agent's model_settings - if agent.model_settings.tool_choice == "required" or isinstance(agent.model_settings.tool_choice, str): - agent.model_settings = ModelSettings( - temperature=agent.model_settings.temperature, - top_p=agent.model_settings.top_p, - frequency_penalty=agent.model_settings.frequency_penalty, - presence_penalty=agent.model_settings.presence_penalty, - tool_choice="auto", # Reset to auto - parallel_tool_calls=agent.model_settings.parallel_tool_calls, - truncation=agent.model_settings.truncation, - max_tokens=agent.model_settings.max_tokens, + # Apply the targeted reset logic + tools = agent.tools + if should_reset_tool_choice(agent.model_settings, tools): + agent.model_settings = dataclasses.replace( + agent.model_settings, + tool_choice="auto" # Reset to auto ) - + # Also reset run_config's model_settings if it exists - if run_config.model_settings and (run_config.model_settings.tool_choice == "required" or - isinstance(run_config.model_settings.tool_choice, str)): - run_config.model_settings = ModelSettings( - temperature=run_config.model_settings.temperature, - top_p=run_config.model_settings.top_p, - frequency_penalty=run_config.model_settings.frequency_penalty, - presence_penalty=run_config.model_settings.presence_penalty, - tool_choice="auto", # Reset to auto - parallel_tool_calls=run_config.model_settings.parallel_tool_calls, - truncation=run_config.model_settings.truncation, - max_tokens=run_config.model_settings.max_tokens, + if ( + run_config.model_settings and + should_reset_tool_choice(run_config.model_settings, tools) + ): + run_config.model_settings = dataclasses.replace( + run_config.model_settings, + tool_choice="auto" # Reset to auto ) - + # Check that run_config's tool_choice was reset to "auto" assert run_config.model_settings.tool_choice == "auto" - + @mock.patch("agents.run.Runner._get_model") async def test_integration_prevents_infinite_loop(self, mock_get_model): """Integration test to verify that tool_choice reset prevents infinite loops""" # Create a counter to track model calls and detect potential infinite loops tool_call_counter = {"count": 0, "potential_infinite_loop": False} - + # Set up our mock model that will always use tools when tool_choice is set mock_model_instance = MockModel(tool_call_counter) # Return our mock model directly mock_get_model.return_value = mock_model_instance - + # Create an agent with tool_choice="required" to force tool usage agent = Agent( name="Test agent", @@ -280,24 +263,24 @@ class TestToolChoiceReset: # This would cause infinite loops without the tool_choice reset tool_use_behavior="run_llm_again", ) - + # Set a timeout to catch potential infinite loops that our fix doesn't address try: # Run the agent with a timeout async def run_with_timeout(): return await Runner.run(agent, input="Test input") - + result = await asyncio.wait_for(run_with_timeout(), timeout=2.0) - + # Verify the agent ran successfully assert result is not None - + # Verify the tool was called at least once but not too many times # (indicating no infinite loop) assert tool_call_counter["count"] >= 1 assert tool_call_counter["count"] < 5 assert not tool_call_counter["potential_infinite_loop"] - - except asyncio.TimeoutError: + + except asyncio.TimeoutError as err: # If we hit a timeout, the test failed - we likely have an infinite loop - assert False, "Timeout occurred, potential infinite loop detected" \ No newline at end of file + raise AssertionError("Timeout occurred, potential infinite loop detected") from err