Introduce tool_use_behavior on agents

This commit is contained in:
Rohan Mehta 2025-03-18 21:43:02 -04:00
parent 6f7e801da0
commit 10aa5555af
12 changed files with 594 additions and 26 deletions

View file

@ -130,3 +130,16 @@ robot_agent = pirate_agent.clone(
instructions="Write like a robot",
)
```
## Forcing tool use
Supplying a list of tools doesn't always mean the LLM will use a tool. You can force tool use by setting [`ModelSettings.tool_choice`][agents.model_settings.ModelSettings.tool_choice]. Valid values are:
1. `auto`, which allows the LLM to decide whether or not to use a tool.
2. `required`, which requires the LLM to use a tool (but it can intelligently decide which tool).
3. `none`, which requires the LLM to _not_ use a tool.
4. Setting a specific string e.g. `my_tool`, which requires the LLM to use that specific tool.
!!! note
If requiring tool use, you should consider setting [`Agent.tool_use_behavior`] to stop the Agent from running when a tool output is produced. Otherwise, the Agent might run in an infinite loop, where the LLM produces a tool call , and the tool result is sent to the LLM, and this infinite loops because the LLM is always forced to use a tool.

View file

@ -0,0 +1,99 @@
from __future__ import annotations
import asyncio
from typing import Any, Literal
from pydantic import BaseModel
from agents import (
Agent,
FunctionToolResult,
ModelSettings,
RunContextWrapper,
Runner,
ToolsToFinalOutputFunction,
ToolsToFinalOutputResult,
function_tool,
)
"""
This example shows how to force the agent to use a tool. It uses `ModelSettings(tool_choice="required")`
to force the agent to use any tool.
You can run it with 3 options:
1. `default`: The default behavior, which is to send the tool output to the LLM. In this case,
`tool_choice` is not set, because otherwise it would result in an infinite loop - the LLM would
call the tool, the tool would run and send the results to the LLM, and that would repeat
(because the model is forced to use a tool every time.)
2. `first_tool_result`: The first tool result is used as the final output.
3. `custom`: A custom tool use behavior function is used. The custom function receives all the tool
results, and chooses to use the first tool result to generate the final output.
Usage:
python examples/agent_patterns/forcing_tool_use.py -t default
python examples/agent_patterns/forcing_tool_use.py -t first_tool
python examples/agent_patterns/forcing_tool_use.py -t custom
"""
class Weather(BaseModel):
city: str
temperature_range: str
conditions: str
@function_tool
def get_weather(city: str) -> Weather:
print("[debug] get_weather called")
return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind")
async def custom_tool_use_behavior(
context: RunContextWrapper[Any], results: list[FunctionToolResult]
) -> ToolsToFinalOutputResult:
weather: Weather = results[0].output
return ToolsToFinalOutputResult(
is_final_output=True, final_output=f"{weather.city} is {weather.conditions}."
)
async def main(tool_use_behavior: Literal["default", "first_tool", "custom"] = "default"):
if tool_use_behavior == "default":
behavior: Literal["run_llm_again", "stop_on_first_tool"] | ToolsToFinalOutputFunction = (
"run_llm_again"
)
elif tool_use_behavior == "first_tool":
behavior = "stop_on_first_tool"
elif tool_use_behavior == "custom":
behavior = custom_tool_use_behavior
agent = Agent(
name="Weather agent",
instructions="You are a helpful agent.",
tools=[get_weather],
tool_use_behavior=behavior,
model_settings=ModelSettings(
tool_choice="required" if tool_use_behavior != "default" else None
),
)
result = await Runner.run(agent, input="What's the weather in Tokyo?")
print(result.final_output)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument(
"-t",
"--tool-use-behavior",
type=str,
required=True,
choices=["default", "first_tool", "custom"],
help="The behavior to use for tool use. Default will cause tool outputs to be sent to the model. "
"first_tool_result will cause the first tool result to be used as the final output. "
"custom will use a custom tool use behavior function.",
)
args = parser.parse_args()
asyncio.run(main(args.tool_use_behavior))

34
examples/basic/tools.py Normal file
View file

@ -0,0 +1,34 @@
import asyncio
from pydantic import BaseModel
from agents import Agent, Runner, function_tool
class Weather(BaseModel):
city: str
temperature_range: str
conditions: str
@function_tool
def get_weather(city: str) -> Weather:
print("[debug] get_weather called")
return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind.")
agent = Agent(
name="Hello world",
instructions="You are a helpful agent.",
tools=[get_weather],
)
async def main():
result = await Runner.run(agent, input="What's the weather in Tokyo?")
print(result.final_output)
# The weather in Tokyo is sunny.
if __name__ == "__main__":
asyncio.run(main())

View file

@ -5,7 +5,7 @@ from typing import Literal
from openai import AsyncOpenAI
from . import _config
from .agent import Agent
from .agent import Agent, ToolsToFinalOutputFunction, ToolsToFinalOutputResult
from .agent_output import AgentOutputSchema
from .computer import AsyncComputer, Button, Computer, Environment
from .exceptions import (
@ -57,6 +57,7 @@ from .tool import (
ComputerTool,
FileSearchTool,
FunctionTool,
FunctionToolResult,
Tool,
WebSearchTool,
default_tool_error_function,
@ -137,6 +138,8 @@ def enable_verbose_stdout_logging():
__all__ = [
"Agent",
"ToolsToFinalOutputFunction",
"ToolsToFinalOutputResult",
"Runner",
"Model",
"ModelProvider",
@ -190,6 +193,7 @@ __all__ = [
"AgentUpdatedStreamEvent",
"StreamEvent",
"FunctionTool",
"FunctionToolResult",
"ComputerTool",
"FileSearchTool",
"Tool",

View file

@ -1,8 +1,10 @@
from __future__ import annotations
import asyncio
import inspect
from collections.abc import Awaitable
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, cast
from openai.types.responses import (
ResponseComputerToolCall,
@ -25,7 +27,7 @@ from openai.types.responses.response_computer_tool_call import (
from openai.types.responses.response_input_param import ComputerCallOutput
from openai.types.responses.response_reasoning_item import ResponseReasoningItem
from .agent import Agent
from .agent import Agent, ToolsToFinalOutputResult
from .agent_output import AgentOutputSchema
from .computer import AsyncComputer, Computer
from .exceptions import AgentsException, ModelBehaviorError, UserError
@ -48,7 +50,7 @@ from .logger import logger
from .models.interface import ModelTracing
from .run_context import RunContextWrapper, TContext
from .stream_events import RunItemStreamEvent, StreamEvent
from .tool import ComputerTool, FunctionTool
from .tool import ComputerTool, FunctionTool, FunctionToolResult
from .tracing import (
SpanError,
Trace,
@ -70,6 +72,8 @@ class QueueCompleteSentinel:
QUEUE_COMPLETE_SENTINEL = QueueCompleteSentinel()
_NOT_FINAL_OUTPUT = ToolsToFinalOutputResult(is_final_output=False, final_output=None)
@dataclass
class ToolRunHandoff:
@ -199,7 +203,7 @@ class RunImpl:
config=run_config,
),
)
new_step_items.extend(function_results)
new_step_items.extend([result.run_item for result in function_results])
new_step_items.extend(computer_results)
# Second, check if there are any handoffs
@ -216,6 +220,36 @@ class RunImpl:
run_config=run_config,
)
# Third, we'll check if the tool use should result in a final output
check_tool_use = await cls._check_for_final_output_from_tools(
agent=agent,
tool_results=function_results,
context_wrapper=context_wrapper,
config=run_config,
)
if check_tool_use.is_final_output:
# If the output type is str, then let's just stringify it
if not agent.output_type or agent.output_type is str:
check_tool_use.final_output = str(check_tool_use.final_output)
if check_tool_use.final_output is None:
logger.error(
"Model returned a final output of None. Not raising an error because we assume"
"you know what you're doing."
)
return await cls.execute_final_output(
agent=agent,
original_input=original_input,
new_response=new_response,
pre_step_items=pre_step_items,
new_step_items=new_step_items,
final_output=check_tool_use.final_output,
hooks=hooks,
context_wrapper=context_wrapper,
)
# Now we can check if the model also produced a final output
message_items = [item for item in new_step_items if isinstance(item, MessageOutputItem)]
@ -355,10 +389,10 @@ class RunImpl:
hooks: RunHooks[TContext],
context_wrapper: RunContextWrapper[TContext],
config: RunConfig,
) -> list[RunItem]:
) -> list[FunctionToolResult]:
async def run_single_tool(
func_tool: FunctionTool, tool_call: ResponseFunctionToolCall
) -> str:
) -> Any:
with function_span(func_tool.name) as span_fn:
if config.trace_include_sensitive_data:
span_fn.span_data.input = tool_call.arguments
@ -404,10 +438,14 @@ class RunImpl:
results = await asyncio.gather(*tasks)
return [
ToolCallOutputItem(
output=str(result),
raw_item=ItemHelpers.tool_call_output_item(tool_run.tool_call, str(result)),
agent=agent,
FunctionToolResult(
tool=tool_run.function_tool,
output=result,
run_item=ToolCallOutputItem(
output=result,
raw_item=ItemHelpers.tool_call_output_item(tool_run.tool_call, str(result)),
agent=agent,
),
)
for tool_run, result in zip(tool_runs, results)
]
@ -646,6 +684,47 @@ class RunImpl:
if event:
queue.put_nowait(event)
@classmethod
async def _check_for_final_output_from_tools(
cls,
*,
agent: Agent[TContext],
tool_results: list[FunctionToolResult],
context_wrapper: RunContextWrapper[TContext],
config: RunConfig,
) -> ToolsToFinalOutputResult:
"""Returns (i, final_output)."""
if not tool_results:
return _NOT_FINAL_OUTPUT
if agent.tool_use_behavior == "run_llm_again":
return _NOT_FINAL_OUTPUT
elif agent.tool_use_behavior == "stop_on_first_tool":
return ToolsToFinalOutputResult(
is_final_output=True, final_output=tool_results[0].output
)
elif isinstance(agent.tool_use_behavior, dict):
names = agent.tool_use_behavior.get("stop_at_tool_names", [])
for tool_result in tool_results:
if tool_result.tool.name in names:
return ToolsToFinalOutputResult(
is_final_output=True, final_output=tool_result.output
)
return ToolsToFinalOutputResult(is_final_output=False, final_output=None)
elif callable(agent.tool_use_behavior):
if inspect.iscoroutinefunction(agent.tool_use_behavior):
return await cast(
Awaitable[ToolsToFinalOutputResult],
agent.tool_use_behavior(context_wrapper, tool_results),
)
else:
return cast(
ToolsToFinalOutputResult, agent.tool_use_behavior(context_wrapper, tool_results)
)
logger.error(f"Invalid tool_use_behavior: {agent.tool_use_behavior}")
raise UserError(f"Invalid tool_use_behavior: {agent.tool_use_behavior}")
class TraceCtxManager:
"""Creates a trace only if there is no current trace, and manages the trace lifecycle."""

View file

@ -4,7 +4,9 @@ import dataclasses
import inspect
from collections.abc import Awaitable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable, Generic, cast
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, cast
from typing_extensions import TypeAlias, TypedDict
from .guardrail import InputGuardrail, OutputGuardrail
from .handoffs import Handoff
@ -13,7 +15,7 @@ from .logger import logger
from .model_settings import ModelSettings
from .models.interface import Model
from .run_context import RunContextWrapper, TContext
from .tool import Tool, function_tool
from .tool import FunctionToolResult, Tool, function_tool
from .util import _transforms
from .util._types import MaybeAwaitable
@ -22,6 +24,33 @@ if TYPE_CHECKING:
from .result import RunResult
@dataclass
class ToolsToFinalOutputResult:
is_final_output: bool
"""Whether this is the final output. If False, the LLM will run again and receive the tool call
output.
"""
final_output: Any | None = None
"""The final output. Can be None if `is_final_output` is False, otherwise must match the
`output_type` of the agent.
"""
ToolsToFinalOutputFunction: TypeAlias = Callable[
[RunContextWrapper[TContext], list[FunctionToolResult]],
MaybeAwaitable[ToolsToFinalOutputResult],
]
"""A function that takes a run context and a list of tool results, and returns a
`ToolToFinalOutputResult`.
"""
class StopAtTools(TypedDict):
stop_at_tool_names: list[str]
"""A list of tool names, any of which will stop the agent from running further."""
@dataclass
class Agent(Generic[TContext]):
"""An agent is an AI model configured with instructions, tools, guardrails, handoffs and more.
@ -95,6 +124,25 @@ class Agent(Generic[TContext]):
"""A class that receives callbacks on various lifecycle events for this agent.
"""
tool_use_behavior: (
Literal["run_llm_again", "stop_on_first_tool"] | StopAtTools | ToolsToFinalOutputFunction
) = "run_llm_again"
"""This lets you configure how tool use is handled.
- "run_llm_again": The default behavior. Tools are run, and then the LLM receives the results
and gets to respond.
- "stop_on_first_tool": The output of the first tool call is used as the final output. This
means that the LLM does not process the result of the tool call.
- A list of tool names: The agent will stop running if any of the tools in the list are called.
The final output will be the output of the first matching tool call. The LLM does not
process the result of the tool call.
- A function: If you pass a function, it will be called with the run context and the list of
tool results. It must return a `ToolToFinalOutputResult`, which determines whether the tool
calls result in a final output.
NOTE: This configuration is specific to FunctionTools. Hosted tools, such as file search,
web search, etc are always processed by the LLM.
"""
def clone(self, **kwargs: Any) -> Agent[TContext]:
"""Make a copy of the agent, with the given arguments changed. For example, you could do:
```

View file

@ -129,8 +129,10 @@ class ToolCallOutputItem(RunItemBase[Union[FunctionCallOutput, ComputerCallOutpu
raw_item: FunctionCallOutput | ComputerCallOutput
"""The raw item from the model."""
output: str
"""The output of the tool call."""
output: Any
"""The output of the tool call. This is whatever the tool call returned; the `raw_item`
contains a string representation of the output.
"""
type: Literal["tool_call_output_item"] = "tool_call_output_item"

View file

@ -15,6 +15,7 @@ from . import _debug
from .computer import AsyncComputer, Computer
from .exceptions import ModelBehaviorError
from .function_schema import DocstringStyle, function_schema
from .items import RunItem
from .logger import logger
from .run_context import RunContextWrapper
from .tracing import SpanError
@ -29,6 +30,18 @@ ToolFunctionWithContext = Callable[Concatenate[RunContextWrapper[Any], ToolParam
ToolFunction = Union[ToolFunctionWithoutContext[ToolParams], ToolFunctionWithContext[ToolParams]]
@dataclass
class FunctionToolResult:
tool: FunctionTool
"""The tool that was run."""
output: Any
"""The output of the tool."""
run_item: RunItem
"""The run item that was produced as a result of the tool call."""
@dataclass
class FunctionTool:
"""A tool that wraps a function. In most cases, you should use the `function_tool` helpers to
@ -44,15 +57,15 @@ class FunctionTool:
params_json_schema: dict[str, Any]
"""The JSON schema for the tool's parameters."""
on_invoke_tool: Callable[[RunContextWrapper[Any], str], Awaitable[str]]
on_invoke_tool: Callable[[RunContextWrapper[Any], str], Awaitable[Any]]
"""A function that invokes the tool with the given context and parameters. The params passed
are:
1. The tool run context.
2. The arguments from the LLM, as a JSON string.
You must return a string representation of the tool output. In case of errors, you can either
raise an Exception (which will cause the run to fail) or return a string error message (which
will be sent back to the LLM).
You must return a string representation of the tool output, or something we can call `str()` on.
In case of errors, you can either raise an Exception (which will cause the run to fail) or
return a string error message (which will be sent back to the LLM).
"""
strict_json_schema: bool = True
@ -207,7 +220,7 @@ def function_tool(
strict_json_schema=strict_mode,
)
async def _on_invoke_tool_impl(ctx: RunContextWrapper[Any], input: str) -> str:
async def _on_invoke_tool_impl(ctx: RunContextWrapper[Any], input: str) -> Any:
try:
json_data: dict[str, Any] = json.loads(input) if input else {}
except Exception as e:
@ -254,9 +267,9 @@ def function_tool(
else:
logger.debug(f"Tool {schema.name} returned {result}")
return str(result)
return result
async def _on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> str:
async def _on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> Any:
try:
return await _on_invoke_tool_impl(ctx, input)
except Exception as e:

View file

@ -51,7 +51,7 @@ class AgentSpanData(SpanData):
class FunctionSpanData(SpanData):
__slots__ = ("name", "input", "output")
def __init__(self, name: str, input: str | None, output: str | None):
def __init__(self, name: str, input: str | None, output: Any | None):
self.name = name
self.input = input
self.output = output
@ -65,7 +65,7 @@ class FunctionSpanData(SpanData):
"type": self.type,
"name": self.name,
"input": self.input,
"output": self.output,
"output": str(self.output) if self.output else None,
}

View file

@ -21,6 +21,8 @@ from agents import (
UserError,
handoff,
)
from agents.agent import ToolsToFinalOutputResult
from agents.tool import FunctionToolResult, function_tool
from .fake_model import FakeModel
from .test_responses import (
@ -552,3 +554,83 @@ async def test_output_guardrail_tripwire_triggered_causes_exception():
with pytest.raises(OutputGuardrailTripwireTriggered):
await Runner.run(agent, input="user_message")
@function_tool
def test_tool_one():
return Foo(bar="tool_one_result")
@function_tool
def test_tool_two():
return "tool_two_result"
@pytest.mark.asyncio
async def test_tool_use_behavior_first_output():
model = FakeModel()
agent = Agent(
name="test",
model=model,
tools=[get_function_tool("foo", "tool_result"), test_tool_one, test_tool_two],
tool_use_behavior="stop_on_first_tool",
output_type=Foo,
)
model.add_multiple_turn_outputs(
[
# First turn: a message and tool call
[
get_text_message("a_message"),
get_function_tool_call("test_tool_one", None),
get_function_tool_call("test_tool_two", None),
],
]
)
result = await Runner.run(agent, input="user_message")
assert result.final_output == Foo(bar="tool_one_result"), (
"should have used the first tool result"
)
def custom_tool_use_behavior(
context: RunContextWrapper[Any], results: list[FunctionToolResult]
) -> ToolsToFinalOutputResult:
if "test_tool_one" in [result.tool.name for result in results]:
return ToolsToFinalOutputResult(is_final_output=True, final_output="the_final_output")
else:
return ToolsToFinalOutputResult(is_final_output=False, final_output=None)
@pytest.mark.asyncio
async def test_tool_use_behavior_custom_function():
model = FakeModel()
agent = Agent(
name="test",
model=model,
tools=[get_function_tool("foo", "tool_result"), test_tool_one, test_tool_two],
tool_use_behavior=custom_tool_use_behavior,
)
model.add_multiple_turn_outputs(
[
# First turn: a message and tool call
[
get_text_message("a_message"),
get_function_tool_call("test_tool_two", None),
],
# Second turn: a message and tool call
[
get_text_message("a_message"),
get_function_tool_call("test_tool_one", None),
get_function_tool_call("test_tool_two", None),
],
]
)
result = await Runner.run(agent, input="user_message")
assert len(result.raw_responses) == 2, "should have two model responses"
assert result.final_output == "the_final_output", "should have used the custom function"

View file

@ -49,10 +49,10 @@ async def test_simple_function():
assert tool.name == "simple_function"
result = await tool.on_invoke_tool(RunContextWrapper(None), '{"a": 1}')
assert result == "6"
assert result == 6
result = await tool.on_invoke_tool(RunContextWrapper(None), '{"a": 1, "b": 2}')
assert result == "3"
assert result == 3
# Missing required argument should raise an error
with pytest.raises(ModelBehaviorError):

View file

@ -0,0 +1,194 @@
# Copyright
from __future__ import annotations
from typing import cast
import pytest
from openai.types.responses.response_input_item_param import FunctionCallOutput
from agents import (
Agent,
FunctionToolResult,
RunConfig,
RunContextWrapper,
ToolCallOutputItem,
ToolsToFinalOutputResult,
UserError,
)
from agents._run_impl import RunImpl
from .test_responses import get_function_tool
def _make_function_tool_result(
agent: Agent, output: str, tool_name: str | None = None
) -> FunctionToolResult:
# Construct a FunctionToolResult with the given output using a simple function tool.
tool = get_function_tool(tool_name or "dummy", return_value=output)
raw_item: FunctionCallOutput = cast(
FunctionCallOutput,
{
"call_id": "1",
"output": output,
"type": "function_call_output",
},
)
# For this test we don't care about the specific RunItem subclass, only the output field
run_item = ToolCallOutputItem(agent=agent, raw_item=raw_item, output=output)
return FunctionToolResult(tool=tool, output=output, run_item=run_item)
@pytest.mark.asyncio
async def test_no_tool_results_returns_not_final_output() -> None:
# If there are no tool results at all, tool_use_behavior should not produce a final output.
agent = Agent(name="test")
result = await RunImpl._check_for_final_output_from_tools(
agent=agent,
tool_results=[],
context_wrapper=RunContextWrapper(context=None),
config=RunConfig(),
)
assert result.is_final_output is False
assert result.final_output is None
@pytest.mark.asyncio
async def test_run_llm_again_behavior() -> None:
# With the default run_llm_again behavior, even with tools we still expect to keep running.
agent = Agent(name="test", tool_use_behavior="run_llm_again")
tool_results = [_make_function_tool_result(agent, "ignored")]
result = await RunImpl._check_for_final_output_from_tools(
agent=agent,
tool_results=tool_results,
context_wrapper=RunContextWrapper(context=None),
config=RunConfig(),
)
assert result.is_final_output is False
assert result.final_output is None
@pytest.mark.asyncio
async def test_stop_on_first_tool_behavior() -> None:
# When tool_use_behavior is stop_on_first_tool, we should surface first tool output as final.
agent = Agent(name="test", tool_use_behavior="stop_on_first_tool")
tool_results = [
_make_function_tool_result(agent, "first_tool_output"),
_make_function_tool_result(agent, "ignored"),
]
result = await RunImpl._check_for_final_output_from_tools(
agent=agent,
tool_results=tool_results,
context_wrapper=RunContextWrapper(context=None),
config=RunConfig(),
)
assert result.is_final_output is True
assert result.final_output == "first_tool_output"
@pytest.mark.asyncio
async def test_custom_tool_use_behavior_sync() -> None:
"""If tool_use_behavior is a sync function, we should call it and propagate its return."""
def behavior(
context: RunContextWrapper, results: list[FunctionToolResult]
) -> ToolsToFinalOutputResult:
assert len(results) == 3
return ToolsToFinalOutputResult(is_final_output=True, final_output="custom")
agent = Agent(name="test", tool_use_behavior=behavior)
tool_results = [
_make_function_tool_result(agent, "ignored1"),
_make_function_tool_result(agent, "ignored2"),
_make_function_tool_result(agent, "ignored3"),
]
result = await RunImpl._check_for_final_output_from_tools(
agent=agent,
tool_results=tool_results,
context_wrapper=RunContextWrapper(context=None),
config=RunConfig(),
)
assert result.is_final_output is True
assert result.final_output == "custom"
@pytest.mark.asyncio
async def test_custom_tool_use_behavior_async() -> None:
"""If tool_use_behavior is an async function, we should await it and propagate its return."""
async def behavior(
context: RunContextWrapper, results: list[FunctionToolResult]
) -> ToolsToFinalOutputResult:
assert len(results) == 3
return ToolsToFinalOutputResult(is_final_output=True, final_output="async_custom")
agent = Agent(name="test", tool_use_behavior=behavior)
tool_results = [
_make_function_tool_result(agent, "ignored1"),
_make_function_tool_result(agent, "ignored2"),
_make_function_tool_result(agent, "ignored3"),
]
result = await RunImpl._check_for_final_output_from_tools(
agent=agent,
tool_results=tool_results,
context_wrapper=RunContextWrapper(context=None),
config=RunConfig(),
)
assert result.is_final_output is True
assert result.final_output == "async_custom"
@pytest.mark.asyncio
async def test_invalid_tool_use_behavior_raises() -> None:
"""If tool_use_behavior is invalid, we should raise a UserError."""
agent = Agent(name="test")
# Force an invalid value; mypy will complain, so ignore the type here.
agent.tool_use_behavior = "bad_value" # type: ignore[assignment]
tool_results = [_make_function_tool_result(agent, "ignored")]
with pytest.raises(UserError):
await RunImpl._check_for_final_output_from_tools(
agent=agent,
tool_results=tool_results,
context_wrapper=RunContextWrapper(context=None),
config=RunConfig(),
)
@pytest.mark.asyncio
async def test_tool_names_to_stop_at_behavior() -> None:
agent = Agent(
name="test",
tools=[
get_function_tool("tool1", return_value="tool1_output"),
get_function_tool("tool2", return_value="tool2_output"),
get_function_tool("tool3", return_value="tool3_output"),
],
tool_use_behavior={"stop_at_tool_names": ["tool1"]},
)
tool_results = [
_make_function_tool_result(agent, "ignored1", "tool2"),
_make_function_tool_result(agent, "ignored3", "tool3"),
]
result = await RunImpl._check_for_final_output_from_tools(
agent=agent,
tool_results=tool_results,
context_wrapper=RunContextWrapper(context=None),
config=RunConfig(),
)
assert result.is_final_output is False, "We should not have stopped at tool1"
# Now test with a tool that matches the list
tool_results = [
_make_function_tool_result(agent, "output1", "tool1"),
_make_function_tool_result(agent, "ignored2", "tool2"),
_make_function_tool_result(agent, "ignored3", "tool3"),
]
result = await RunImpl._check_for_final_output_from_tools(
agent=agent,
tool_results=tool_results,
context_wrapper=RunContextWrapper(context=None),
config=RunConfig(),
)
assert result.is_final_output is True, "We should have stopped at tool1"
assert result.final_output == "output1"