fix: optimize tool_choice reset logic and fix lint errors
- Refactor tool_choice reset to target only problematic edge cases - Replace manual ModelSettings recreation with dataclasses.replace - Fix line length and error handling lint issues in tests
This commit is contained in:
parent
d169d79288
commit
bbcda753df
2 changed files with 137 additions and 147 deletions
|
|
@ -1,6 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import inspect
|
||||
from collections.abc import Awaitable
|
||||
from dataclasses import dataclass
|
||||
|
|
@ -51,7 +52,7 @@ from .model_settings import ModelSettings
|
|||
from .models.interface import ModelTracing
|
||||
from .run_context import RunContextWrapper, TContext
|
||||
from .stream_events import RunItemStreamEvent, StreamEvent
|
||||
from .tool import ComputerTool, FunctionTool, FunctionToolResult
|
||||
from .tool import ComputerTool, FunctionTool, FunctionToolResult, Tool
|
||||
from .tracing import (
|
||||
SpanError,
|
||||
Trace,
|
||||
|
|
@ -208,34 +209,22 @@ class RunImpl:
|
|||
new_step_items.extend(computer_results)
|
||||
|
||||
# Reset tool_choice to "auto" after tool execution to prevent infinite loops
|
||||
if (processed_response.functions or processed_response.computer_actions):
|
||||
# Reset agent's model_settings
|
||||
if agent.model_settings.tool_choice == "required" or isinstance(agent.model_settings.tool_choice, str):
|
||||
# Create a new model_settings to avoid modifying the original shared instance
|
||||
agent.model_settings = ModelSettings(
|
||||
temperature=agent.model_settings.temperature,
|
||||
top_p=agent.model_settings.top_p,
|
||||
frequency_penalty=agent.model_settings.frequency_penalty,
|
||||
presence_penalty=agent.model_settings.presence_penalty,
|
||||
tool_choice="auto", # Reset to auto
|
||||
parallel_tool_calls=agent.model_settings.parallel_tool_calls,
|
||||
truncation=agent.model_settings.truncation,
|
||||
max_tokens=agent.model_settings.max_tokens,
|
||||
if processed_response.functions or processed_response.computer_actions:
|
||||
tools = agent.tools
|
||||
# Only reset in the problematic scenarios where loops are likely unintentional
|
||||
if cls._should_reset_tool_choice(agent.model_settings, tools):
|
||||
agent.model_settings = dataclasses.replace(
|
||||
agent.model_settings,
|
||||
tool_choice="auto"
|
||||
)
|
||||
|
||||
# Also reset run_config's model_settings if it exists
|
||||
if run_config.model_settings and (run_config.model_settings.tool_choice == "required" or
|
||||
isinstance(run_config.model_settings.tool_choice, str)):
|
||||
# Create a new model_settings for run_config
|
||||
run_config.model_settings = ModelSettings(
|
||||
temperature=run_config.model_settings.temperature,
|
||||
top_p=run_config.model_settings.top_p,
|
||||
frequency_penalty=run_config.model_settings.frequency_penalty,
|
||||
presence_penalty=run_config.model_settings.presence_penalty,
|
||||
tool_choice="auto", # Reset to auto
|
||||
parallel_tool_calls=run_config.model_settings.parallel_tool_calls,
|
||||
truncation=run_config.model_settings.truncation,
|
||||
max_tokens=run_config.model_settings.max_tokens,
|
||||
if (
|
||||
run_config.model_settings and
|
||||
cls._should_reset_tool_choice(run_config.model_settings, tools)
|
||||
):
|
||||
run_config.model_settings = dataclasses.replace(
|
||||
run_config.model_settings,
|
||||
tool_choice="auto"
|
||||
)
|
||||
|
||||
# Second, check if there are any handoffs
|
||||
|
|
@ -328,6 +317,24 @@ class RunImpl:
|
|||
next_step=NextStepRunAgain(),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _should_reset_tool_choice(cls, model_settings: ModelSettings, tools: list[Tool]) -> bool:
|
||||
if model_settings is None or model_settings.tool_choice is None:
|
||||
return False
|
||||
|
||||
# for specific tool choices
|
||||
if (
|
||||
isinstance(model_settings.tool_choice, str) and
|
||||
model_settings.tool_choice not in ["auto", "required", "none"]
|
||||
):
|
||||
return True
|
||||
|
||||
# for one tool and required tool choice
|
||||
if model_settings.tool_choice == "required":
|
||||
return len(tools) == 1
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def process_model_response(
|
||||
cls,
|
||||
|
|
|
|||
|
|
@ -1,13 +1,15 @@
|
|||
from unittest import mock
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import json
|
||||
from typing import List
|
||||
from unittest import mock
|
||||
|
||||
from agents import Agent, ModelSettings, RunConfig, function_tool, Runner
|
||||
from agents.models.interface import ModelResponse
|
||||
from agents.items import Usage
|
||||
from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall
|
||||
|
||||
from agents import Agent, ModelSettings, RunConfig, Runner, function_tool
|
||||
from agents.items import Usage
|
||||
from agents.models.interface import ModelResponse
|
||||
from agents.tool import Tool
|
||||
|
||||
|
||||
@function_tool
|
||||
def echo(text: str) -> str:
|
||||
|
|
@ -15,6 +17,23 @@ def echo(text: str) -> str:
|
|||
return text
|
||||
|
||||
|
||||
def should_reset_tool_choice(model_settings: ModelSettings, tools: list[Tool]) -> bool:
|
||||
if model_settings is None or model_settings.tool_choice is None:
|
||||
return False
|
||||
|
||||
# for specific tool choices
|
||||
if (
|
||||
isinstance(model_settings.tool_choice, str) and
|
||||
model_settings.tool_choice not in ["auto", "required", "none"]
|
||||
):
|
||||
return True
|
||||
|
||||
# for one tool and required tool choice
|
||||
if model_settings.tool_choice == "required":
|
||||
return len(tools) == 1
|
||||
|
||||
return False
|
||||
|
||||
# Mock model implementation that always calls tools when tool_choice is set
|
||||
class MockModel:
|
||||
def __init__(self, tool_call_counter):
|
||||
|
|
@ -60,7 +79,7 @@ class TestToolChoiceReset:
|
|||
# Create an agent with tool_choice="required"
|
||||
agent = Agent(
|
||||
name="Test agent",
|
||||
tools=[echo],
|
||||
tools=[echo], # Only one tool
|
||||
model_settings=ModelSettings(tool_choice="required"),
|
||||
)
|
||||
|
||||
|
|
@ -77,31 +96,22 @@ class TestToolChoiceReset:
|
|||
|
||||
# Execute our code under test
|
||||
if processed_response.functions:
|
||||
# Reset agent's model_settings
|
||||
if agent.model_settings.tool_choice == "required" or isinstance(agent.model_settings.tool_choice, str):
|
||||
agent.model_settings = ModelSettings(
|
||||
temperature=agent.model_settings.temperature,
|
||||
top_p=agent.model_settings.top_p,
|
||||
frequency_penalty=agent.model_settings.frequency_penalty,
|
||||
presence_penalty=agent.model_settings.presence_penalty,
|
||||
tool_choice="auto", # Reset to auto
|
||||
parallel_tool_calls=agent.model_settings.parallel_tool_calls,
|
||||
truncation=agent.model_settings.truncation,
|
||||
max_tokens=agent.model_settings.max_tokens,
|
||||
# Apply the targeted reset logic
|
||||
tools = agent.tools
|
||||
if should_reset_tool_choice(agent.model_settings, tools):
|
||||
agent.model_settings = dataclasses.replace(
|
||||
agent.model_settings,
|
||||
tool_choice="auto" # Reset to auto
|
||||
)
|
||||
|
||||
# Also reset run_config's model_settings if it exists
|
||||
if run_config.model_settings and (run_config.model_settings.tool_choice == "required" or
|
||||
isinstance(run_config.model_settings.tool_choice, str)):
|
||||
run_config.model_settings = ModelSettings(
|
||||
temperature=run_config.model_settings.temperature,
|
||||
top_p=run_config.model_settings.top_p,
|
||||
frequency_penalty=run_config.model_settings.frequency_penalty,
|
||||
presence_penalty=run_config.model_settings.presence_penalty,
|
||||
tool_choice="auto", # Reset to auto
|
||||
parallel_tool_calls=run_config.model_settings.parallel_tool_calls,
|
||||
truncation=run_config.model_settings.truncation,
|
||||
max_tokens=run_config.model_settings.max_tokens,
|
||||
if (
|
||||
run_config.model_settings and
|
||||
should_reset_tool_choice(run_config.model_settings, tools)
|
||||
):
|
||||
run_config.model_settings = dataclasses.replace(
|
||||
run_config.model_settings,
|
||||
tool_choice="auto" # Reset to auto
|
||||
)
|
||||
|
||||
# Check that tool_choice was reset to "auto"
|
||||
|
|
@ -115,7 +125,7 @@ class TestToolChoiceReset:
|
|||
instructions="You are a test agent",
|
||||
tools=[echo],
|
||||
model="gpt-4-0125-preview",
|
||||
model_settings=ModelSettings(tool_choice="echo"),
|
||||
model_settings=ModelSettings(tool_choice="echo"), # Specific function name
|
||||
)
|
||||
|
||||
# Execute our code under test
|
||||
|
|
@ -129,31 +139,22 @@ class TestToolChoiceReset:
|
|||
|
||||
# Execute our code under test
|
||||
if processed_response.functions:
|
||||
# Reset agent's model_settings
|
||||
if agent.model_settings.tool_choice == "required" or isinstance(agent.model_settings.tool_choice, str):
|
||||
agent.model_settings = ModelSettings(
|
||||
temperature=agent.model_settings.temperature,
|
||||
top_p=agent.model_settings.top_p,
|
||||
frequency_penalty=agent.model_settings.frequency_penalty,
|
||||
presence_penalty=agent.model_settings.presence_penalty,
|
||||
tool_choice="auto", # Reset to auto
|
||||
parallel_tool_calls=agent.model_settings.parallel_tool_calls,
|
||||
truncation=agent.model_settings.truncation,
|
||||
max_tokens=agent.model_settings.max_tokens,
|
||||
# Apply the targeted reset logic
|
||||
tools = agent.tools
|
||||
if should_reset_tool_choice(agent.model_settings, tools):
|
||||
agent.model_settings = dataclasses.replace(
|
||||
agent.model_settings,
|
||||
tool_choice="auto" # Reset to auto
|
||||
)
|
||||
|
||||
# Also reset run_config's model_settings if it exists
|
||||
if run_config.model_settings and (run_config.model_settings.tool_choice == "required" or
|
||||
isinstance(run_config.model_settings.tool_choice, str)):
|
||||
run_config.model_settings = ModelSettings(
|
||||
temperature=run_config.model_settings.temperature,
|
||||
top_p=run_config.model_settings.top_p,
|
||||
frequency_penalty=run_config.model_settings.frequency_penalty,
|
||||
presence_penalty=run_config.model_settings.presence_penalty,
|
||||
tool_choice="auto", # Reset to auto
|
||||
parallel_tool_calls=run_config.model_settings.parallel_tool_calls,
|
||||
truncation=run_config.model_settings.truncation,
|
||||
max_tokens=run_config.model_settings.max_tokens,
|
||||
if (
|
||||
run_config.model_settings and
|
||||
should_reset_tool_choice(run_config.model_settings, tools)
|
||||
):
|
||||
run_config.model_settings = dataclasses.replace(
|
||||
run_config.model_settings,
|
||||
tool_choice="auto" # Reset to auto
|
||||
)
|
||||
|
||||
# Check that tool_choice was reset to "auto"
|
||||
|
|
@ -179,31 +180,22 @@ class TestToolChoiceReset:
|
|||
|
||||
# Execute our code under test
|
||||
if processed_response.functions:
|
||||
# Reset agent's model_settings
|
||||
if agent.model_settings.tool_choice == "required" or isinstance(agent.model_settings.tool_choice, str):
|
||||
agent.model_settings = ModelSettings(
|
||||
temperature=agent.model_settings.temperature,
|
||||
top_p=agent.model_settings.top_p,
|
||||
frequency_penalty=agent.model_settings.frequency_penalty,
|
||||
presence_penalty=agent.model_settings.presence_penalty,
|
||||
tool_choice="auto", # Reset to auto
|
||||
parallel_tool_calls=agent.model_settings.parallel_tool_calls,
|
||||
truncation=agent.model_settings.truncation,
|
||||
max_tokens=agent.model_settings.max_tokens,
|
||||
# Apply the targeted reset logic
|
||||
tools = agent.tools
|
||||
if should_reset_tool_choice(agent.model_settings, tools):
|
||||
agent.model_settings = dataclasses.replace(
|
||||
agent.model_settings,
|
||||
tool_choice="auto" # Reset to auto
|
||||
)
|
||||
|
||||
# Also reset run_config's model_settings if it exists
|
||||
if run_config.model_settings and (run_config.model_settings.tool_choice == "required" or
|
||||
isinstance(run_config.model_settings.tool_choice, str)):
|
||||
run_config.model_settings = ModelSettings(
|
||||
temperature=run_config.model_settings.temperature,
|
||||
top_p=run_config.model_settings.top_p,
|
||||
frequency_penalty=run_config.model_settings.frequency_penalty,
|
||||
presence_penalty=run_config.model_settings.presence_penalty,
|
||||
tool_choice="auto", # Reset to auto
|
||||
parallel_tool_calls=run_config.model_settings.parallel_tool_calls,
|
||||
truncation=run_config.model_settings.truncation,
|
||||
max_tokens=run_config.model_settings.max_tokens,
|
||||
if (
|
||||
run_config.model_settings and
|
||||
should_reset_tool_choice(run_config.model_settings, tools)
|
||||
):
|
||||
run_config.model_settings = dataclasses.replace(
|
||||
run_config.model_settings,
|
||||
tool_choice="auto" # Reset to auto
|
||||
)
|
||||
|
||||
# Check that tool_choice remains "auto"
|
||||
|
|
@ -214,7 +206,7 @@ class TestToolChoiceReset:
|
|||
# Create an agent with default model_settings
|
||||
agent = Agent(
|
||||
name="Test agent",
|
||||
tools=[echo],
|
||||
tools=[echo], # Only one tool
|
||||
model_settings=ModelSettings(tool_choice=None),
|
||||
)
|
||||
|
||||
|
|
@ -229,31 +221,22 @@ class TestToolChoiceReset:
|
|||
|
||||
# Execute our code under test
|
||||
if processed_response.functions:
|
||||
# Reset agent's model_settings
|
||||
if agent.model_settings.tool_choice == "required" or isinstance(agent.model_settings.tool_choice, str):
|
||||
agent.model_settings = ModelSettings(
|
||||
temperature=agent.model_settings.temperature,
|
||||
top_p=agent.model_settings.top_p,
|
||||
frequency_penalty=agent.model_settings.frequency_penalty,
|
||||
presence_penalty=agent.model_settings.presence_penalty,
|
||||
tool_choice="auto", # Reset to auto
|
||||
parallel_tool_calls=agent.model_settings.parallel_tool_calls,
|
||||
truncation=agent.model_settings.truncation,
|
||||
max_tokens=agent.model_settings.max_tokens,
|
||||
# Apply the targeted reset logic
|
||||
tools = agent.tools
|
||||
if should_reset_tool_choice(agent.model_settings, tools):
|
||||
agent.model_settings = dataclasses.replace(
|
||||
agent.model_settings,
|
||||
tool_choice="auto" # Reset to auto
|
||||
)
|
||||
|
||||
# Also reset run_config's model_settings if it exists
|
||||
if run_config.model_settings and (run_config.model_settings.tool_choice == "required" or
|
||||
isinstance(run_config.model_settings.tool_choice, str)):
|
||||
run_config.model_settings = ModelSettings(
|
||||
temperature=run_config.model_settings.temperature,
|
||||
top_p=run_config.model_settings.top_p,
|
||||
frequency_penalty=run_config.model_settings.frequency_penalty,
|
||||
presence_penalty=run_config.model_settings.presence_penalty,
|
||||
tool_choice="auto", # Reset to auto
|
||||
parallel_tool_calls=run_config.model_settings.parallel_tool_calls,
|
||||
truncation=run_config.model_settings.truncation,
|
||||
max_tokens=run_config.model_settings.max_tokens,
|
||||
if (
|
||||
run_config.model_settings and
|
||||
should_reset_tool_choice(run_config.model_settings, tools)
|
||||
):
|
||||
run_config.model_settings = dataclasses.replace(
|
||||
run_config.model_settings,
|
||||
tool_choice="auto" # Reset to auto
|
||||
)
|
||||
|
||||
# Check that run_config's tool_choice was reset to "auto"
|
||||
|
|
@ -298,6 +281,6 @@ class TestToolChoiceReset:
|
|||
assert tool_call_counter["count"] < 5
|
||||
assert not tool_call_counter["potential_infinite_loop"]
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
except asyncio.TimeoutError as err:
|
||||
# If we hit a timeout, the test failed - we likely have an infinite loop
|
||||
assert False, "Timeout occurred, potential infinite loop detected"
|
||||
raise AssertionError("Timeout occurred, potential infinite loop detected") from err
|
||||
|
|
|
|||
Loading…
Reference in a new issue