Introduce tool_use_behavior on agents
This commit is contained in:
parent
6f7e801da0
commit
10aa5555af
12 changed files with 594 additions and 26 deletions
|
|
@ -130,3 +130,16 @@ robot_agent = pirate_agent.clone(
|
|||
instructions="Write like a robot",
|
||||
)
|
||||
```
|
||||
|
||||
## Forcing tool use
|
||||
|
||||
Supplying a list of tools doesn't always mean the LLM will use a tool. You can force tool use by setting [`ModelSettings.tool_choice`][agents.model_settings.ModelSettings.tool_choice]. Valid values are:
|
||||
|
||||
1. `auto`, which allows the LLM to decide whether or not to use a tool.
|
||||
2. `required`, which requires the LLM to use a tool (but it can intelligently decide which tool).
|
||||
3. `none`, which requires the LLM to _not_ use a tool.
|
||||
4. Setting a specific string e.g. `my_tool`, which requires the LLM to use that specific tool.
|
||||
|
||||
!!! note
|
||||
|
||||
If requiring tool use, you should consider setting [`Agent.tool_use_behavior`] to stop the Agent from running when a tool output is produced. Otherwise, the Agent might run in an infinite loop, where the LLM produces a tool call , and the tool result is sent to the LLM, and this infinite loops because the LLM is always forced to use a tool.
|
||||
|
|
|
|||
99
examples/agent_patterns/forcing_tool_use.py
Normal file
99
examples/agent_patterns/forcing_tool_use.py
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from agents import (
|
||||
Agent,
|
||||
FunctionToolResult,
|
||||
ModelSettings,
|
||||
RunContextWrapper,
|
||||
Runner,
|
||||
ToolsToFinalOutputFunction,
|
||||
ToolsToFinalOutputResult,
|
||||
function_tool,
|
||||
)
|
||||
|
||||
"""
|
||||
This example shows how to force the agent to use a tool. It uses `ModelSettings(tool_choice="required")`
|
||||
to force the agent to use any tool.
|
||||
|
||||
You can run it with 3 options:
|
||||
1. `default`: The default behavior, which is to send the tool output to the LLM. In this case,
|
||||
`tool_choice` is not set, because otherwise it would result in an infinite loop - the LLM would
|
||||
call the tool, the tool would run and send the results to the LLM, and that would repeat
|
||||
(because the model is forced to use a tool every time.)
|
||||
2. `first_tool_result`: The first tool result is used as the final output.
|
||||
3. `custom`: A custom tool use behavior function is used. The custom function receives all the tool
|
||||
results, and chooses to use the first tool result to generate the final output.
|
||||
|
||||
Usage:
|
||||
python examples/agent_patterns/forcing_tool_use.py -t default
|
||||
python examples/agent_patterns/forcing_tool_use.py -t first_tool
|
||||
python examples/agent_patterns/forcing_tool_use.py -t custom
|
||||
"""
|
||||
|
||||
|
||||
class Weather(BaseModel):
|
||||
city: str
|
||||
temperature_range: str
|
||||
conditions: str
|
||||
|
||||
|
||||
@function_tool
|
||||
def get_weather(city: str) -> Weather:
|
||||
print("[debug] get_weather called")
|
||||
return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind")
|
||||
|
||||
|
||||
async def custom_tool_use_behavior(
|
||||
context: RunContextWrapper[Any], results: list[FunctionToolResult]
|
||||
) -> ToolsToFinalOutputResult:
|
||||
weather: Weather = results[0].output
|
||||
return ToolsToFinalOutputResult(
|
||||
is_final_output=True, final_output=f"{weather.city} is {weather.conditions}."
|
||||
)
|
||||
|
||||
|
||||
async def main(tool_use_behavior: Literal["default", "first_tool", "custom"] = "default"):
|
||||
if tool_use_behavior == "default":
|
||||
behavior: Literal["run_llm_again", "stop_on_first_tool"] | ToolsToFinalOutputFunction = (
|
||||
"run_llm_again"
|
||||
)
|
||||
elif tool_use_behavior == "first_tool":
|
||||
behavior = "stop_on_first_tool"
|
||||
elif tool_use_behavior == "custom":
|
||||
behavior = custom_tool_use_behavior
|
||||
|
||||
agent = Agent(
|
||||
name="Weather agent",
|
||||
instructions="You are a helpful agent.",
|
||||
tools=[get_weather],
|
||||
tool_use_behavior=behavior,
|
||||
model_settings=ModelSettings(
|
||||
tool_choice="required" if tool_use_behavior != "default" else None
|
||||
),
|
||||
)
|
||||
|
||||
result = await Runner.run(agent, input="What's the weather in Tokyo?")
|
||||
print(result.final_output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--tool-use-behavior",
|
||||
type=str,
|
||||
required=True,
|
||||
choices=["default", "first_tool", "custom"],
|
||||
help="The behavior to use for tool use. Default will cause tool outputs to be sent to the model. "
|
||||
"first_tool_result will cause the first tool result to be used as the final output. "
|
||||
"custom will use a custom tool use behavior function.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
asyncio.run(main(args.tool_use_behavior))
|
||||
34
examples/basic/tools.py
Normal file
34
examples/basic/tools.py
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
import asyncio
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from agents import Agent, Runner, function_tool
|
||||
|
||||
|
||||
class Weather(BaseModel):
|
||||
city: str
|
||||
temperature_range: str
|
||||
conditions: str
|
||||
|
||||
|
||||
@function_tool
|
||||
def get_weather(city: str) -> Weather:
|
||||
print("[debug] get_weather called")
|
||||
return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind.")
|
||||
|
||||
|
||||
agent = Agent(
|
||||
name="Hello world",
|
||||
instructions="You are a helpful agent.",
|
||||
tools=[get_weather],
|
||||
)
|
||||
|
||||
|
||||
async def main():
|
||||
result = await Runner.run(agent, input="What's the weather in Tokyo?")
|
||||
print(result.final_output)
|
||||
# The weather in Tokyo is sunny.
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
@ -5,7 +5,7 @@ from typing import Literal
|
|||
from openai import AsyncOpenAI
|
||||
|
||||
from . import _config
|
||||
from .agent import Agent
|
||||
from .agent import Agent, ToolsToFinalOutputFunction, ToolsToFinalOutputResult
|
||||
from .agent_output import AgentOutputSchema
|
||||
from .computer import AsyncComputer, Button, Computer, Environment
|
||||
from .exceptions import (
|
||||
|
|
@ -57,6 +57,7 @@ from .tool import (
|
|||
ComputerTool,
|
||||
FileSearchTool,
|
||||
FunctionTool,
|
||||
FunctionToolResult,
|
||||
Tool,
|
||||
WebSearchTool,
|
||||
default_tool_error_function,
|
||||
|
|
@ -137,6 +138,8 @@ def enable_verbose_stdout_logging():
|
|||
|
||||
__all__ = [
|
||||
"Agent",
|
||||
"ToolsToFinalOutputFunction",
|
||||
"ToolsToFinalOutputResult",
|
||||
"Runner",
|
||||
"Model",
|
||||
"ModelProvider",
|
||||
|
|
@ -190,6 +193,7 @@ __all__ = [
|
|||
"AgentUpdatedStreamEvent",
|
||||
"StreamEvent",
|
||||
"FunctionTool",
|
||||
"FunctionToolResult",
|
||||
"ComputerTool",
|
||||
"FileSearchTool",
|
||||
"Tool",
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
from collections.abc import Awaitable
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from openai.types.responses import (
|
||||
ResponseComputerToolCall,
|
||||
|
|
@ -25,7 +27,7 @@ from openai.types.responses.response_computer_tool_call import (
|
|||
from openai.types.responses.response_input_param import ComputerCallOutput
|
||||
from openai.types.responses.response_reasoning_item import ResponseReasoningItem
|
||||
|
||||
from .agent import Agent
|
||||
from .agent import Agent, ToolsToFinalOutputResult
|
||||
from .agent_output import AgentOutputSchema
|
||||
from .computer import AsyncComputer, Computer
|
||||
from .exceptions import AgentsException, ModelBehaviorError, UserError
|
||||
|
|
@ -48,7 +50,7 @@ from .logger import logger
|
|||
from .models.interface import ModelTracing
|
||||
from .run_context import RunContextWrapper, TContext
|
||||
from .stream_events import RunItemStreamEvent, StreamEvent
|
||||
from .tool import ComputerTool, FunctionTool
|
||||
from .tool import ComputerTool, FunctionTool, FunctionToolResult
|
||||
from .tracing import (
|
||||
SpanError,
|
||||
Trace,
|
||||
|
|
@ -70,6 +72,8 @@ class QueueCompleteSentinel:
|
|||
|
||||
QUEUE_COMPLETE_SENTINEL = QueueCompleteSentinel()
|
||||
|
||||
_NOT_FINAL_OUTPUT = ToolsToFinalOutputResult(is_final_output=False, final_output=None)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolRunHandoff:
|
||||
|
|
@ -199,7 +203,7 @@ class RunImpl:
|
|||
config=run_config,
|
||||
),
|
||||
)
|
||||
new_step_items.extend(function_results)
|
||||
new_step_items.extend([result.run_item for result in function_results])
|
||||
new_step_items.extend(computer_results)
|
||||
|
||||
# Second, check if there are any handoffs
|
||||
|
|
@ -216,6 +220,36 @@ class RunImpl:
|
|||
run_config=run_config,
|
||||
)
|
||||
|
||||
# Third, we'll check if the tool use should result in a final output
|
||||
check_tool_use = await cls._check_for_final_output_from_tools(
|
||||
agent=agent,
|
||||
tool_results=function_results,
|
||||
context_wrapper=context_wrapper,
|
||||
config=run_config,
|
||||
)
|
||||
|
||||
if check_tool_use.is_final_output:
|
||||
# If the output type is str, then let's just stringify it
|
||||
if not agent.output_type or agent.output_type is str:
|
||||
check_tool_use.final_output = str(check_tool_use.final_output)
|
||||
|
||||
if check_tool_use.final_output is None:
|
||||
logger.error(
|
||||
"Model returned a final output of None. Not raising an error because we assume"
|
||||
"you know what you're doing."
|
||||
)
|
||||
|
||||
return await cls.execute_final_output(
|
||||
agent=agent,
|
||||
original_input=original_input,
|
||||
new_response=new_response,
|
||||
pre_step_items=pre_step_items,
|
||||
new_step_items=new_step_items,
|
||||
final_output=check_tool_use.final_output,
|
||||
hooks=hooks,
|
||||
context_wrapper=context_wrapper,
|
||||
)
|
||||
|
||||
# Now we can check if the model also produced a final output
|
||||
message_items = [item for item in new_step_items if isinstance(item, MessageOutputItem)]
|
||||
|
||||
|
|
@ -355,10 +389,10 @@ class RunImpl:
|
|||
hooks: RunHooks[TContext],
|
||||
context_wrapper: RunContextWrapper[TContext],
|
||||
config: RunConfig,
|
||||
) -> list[RunItem]:
|
||||
) -> list[FunctionToolResult]:
|
||||
async def run_single_tool(
|
||||
func_tool: FunctionTool, tool_call: ResponseFunctionToolCall
|
||||
) -> str:
|
||||
) -> Any:
|
||||
with function_span(func_tool.name) as span_fn:
|
||||
if config.trace_include_sensitive_data:
|
||||
span_fn.span_data.input = tool_call.arguments
|
||||
|
|
@ -404,10 +438,14 @@ class RunImpl:
|
|||
results = await asyncio.gather(*tasks)
|
||||
|
||||
return [
|
||||
ToolCallOutputItem(
|
||||
output=str(result),
|
||||
raw_item=ItemHelpers.tool_call_output_item(tool_run.tool_call, str(result)),
|
||||
agent=agent,
|
||||
FunctionToolResult(
|
||||
tool=tool_run.function_tool,
|
||||
output=result,
|
||||
run_item=ToolCallOutputItem(
|
||||
output=result,
|
||||
raw_item=ItemHelpers.tool_call_output_item(tool_run.tool_call, str(result)),
|
||||
agent=agent,
|
||||
),
|
||||
)
|
||||
for tool_run, result in zip(tool_runs, results)
|
||||
]
|
||||
|
|
@ -646,6 +684,47 @@ class RunImpl:
|
|||
if event:
|
||||
queue.put_nowait(event)
|
||||
|
||||
@classmethod
|
||||
async def _check_for_final_output_from_tools(
|
||||
cls,
|
||||
*,
|
||||
agent: Agent[TContext],
|
||||
tool_results: list[FunctionToolResult],
|
||||
context_wrapper: RunContextWrapper[TContext],
|
||||
config: RunConfig,
|
||||
) -> ToolsToFinalOutputResult:
|
||||
"""Returns (i, final_output)."""
|
||||
if not tool_results:
|
||||
return _NOT_FINAL_OUTPUT
|
||||
|
||||
if agent.tool_use_behavior == "run_llm_again":
|
||||
return _NOT_FINAL_OUTPUT
|
||||
elif agent.tool_use_behavior == "stop_on_first_tool":
|
||||
return ToolsToFinalOutputResult(
|
||||
is_final_output=True, final_output=tool_results[0].output
|
||||
)
|
||||
elif isinstance(agent.tool_use_behavior, dict):
|
||||
names = agent.tool_use_behavior.get("stop_at_tool_names", [])
|
||||
for tool_result in tool_results:
|
||||
if tool_result.tool.name in names:
|
||||
return ToolsToFinalOutputResult(
|
||||
is_final_output=True, final_output=tool_result.output
|
||||
)
|
||||
return ToolsToFinalOutputResult(is_final_output=False, final_output=None)
|
||||
elif callable(agent.tool_use_behavior):
|
||||
if inspect.iscoroutinefunction(agent.tool_use_behavior):
|
||||
return await cast(
|
||||
Awaitable[ToolsToFinalOutputResult],
|
||||
agent.tool_use_behavior(context_wrapper, tool_results),
|
||||
)
|
||||
else:
|
||||
return cast(
|
||||
ToolsToFinalOutputResult, agent.tool_use_behavior(context_wrapper, tool_results)
|
||||
)
|
||||
|
||||
logger.error(f"Invalid tool_use_behavior: {agent.tool_use_behavior}")
|
||||
raise UserError(f"Invalid tool_use_behavior: {agent.tool_use_behavior}")
|
||||
|
||||
|
||||
class TraceCtxManager:
|
||||
"""Creates a trace only if there is no current trace, and manages the trace lifecycle."""
|
||||
|
|
|
|||
|
|
@ -4,7 +4,9 @@ import dataclasses
|
|||
import inspect
|
||||
from collections.abc import Awaitable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Callable, Generic, cast
|
||||
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, cast
|
||||
|
||||
from typing_extensions import TypeAlias, TypedDict
|
||||
|
||||
from .guardrail import InputGuardrail, OutputGuardrail
|
||||
from .handoffs import Handoff
|
||||
|
|
@ -13,7 +15,7 @@ from .logger import logger
|
|||
from .model_settings import ModelSettings
|
||||
from .models.interface import Model
|
||||
from .run_context import RunContextWrapper, TContext
|
||||
from .tool import Tool, function_tool
|
||||
from .tool import FunctionToolResult, Tool, function_tool
|
||||
from .util import _transforms
|
||||
from .util._types import MaybeAwaitable
|
||||
|
||||
|
|
@ -22,6 +24,33 @@ if TYPE_CHECKING:
|
|||
from .result import RunResult
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolsToFinalOutputResult:
|
||||
is_final_output: bool
|
||||
"""Whether this is the final output. If False, the LLM will run again and receive the tool call
|
||||
output.
|
||||
"""
|
||||
|
||||
final_output: Any | None = None
|
||||
"""The final output. Can be None if `is_final_output` is False, otherwise must match the
|
||||
`output_type` of the agent.
|
||||
"""
|
||||
|
||||
|
||||
ToolsToFinalOutputFunction: TypeAlias = Callable[
|
||||
[RunContextWrapper[TContext], list[FunctionToolResult]],
|
||||
MaybeAwaitable[ToolsToFinalOutputResult],
|
||||
]
|
||||
"""A function that takes a run context and a list of tool results, and returns a
|
||||
`ToolToFinalOutputResult`.
|
||||
"""
|
||||
|
||||
|
||||
class StopAtTools(TypedDict):
|
||||
stop_at_tool_names: list[str]
|
||||
"""A list of tool names, any of which will stop the agent from running further."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class Agent(Generic[TContext]):
|
||||
"""An agent is an AI model configured with instructions, tools, guardrails, handoffs and more.
|
||||
|
|
@ -95,6 +124,25 @@ class Agent(Generic[TContext]):
|
|||
"""A class that receives callbacks on various lifecycle events for this agent.
|
||||
"""
|
||||
|
||||
tool_use_behavior: (
|
||||
Literal["run_llm_again", "stop_on_first_tool"] | StopAtTools | ToolsToFinalOutputFunction
|
||||
) = "run_llm_again"
|
||||
"""This lets you configure how tool use is handled.
|
||||
- "run_llm_again": The default behavior. Tools are run, and then the LLM receives the results
|
||||
and gets to respond.
|
||||
- "stop_on_first_tool": The output of the first tool call is used as the final output. This
|
||||
means that the LLM does not process the result of the tool call.
|
||||
- A list of tool names: The agent will stop running if any of the tools in the list are called.
|
||||
The final output will be the output of the first matching tool call. The LLM does not
|
||||
process the result of the tool call.
|
||||
- A function: If you pass a function, it will be called with the run context and the list of
|
||||
tool results. It must return a `ToolToFinalOutputResult`, which determines whether the tool
|
||||
calls result in a final output.
|
||||
|
||||
NOTE: This configuration is specific to FunctionTools. Hosted tools, such as file search,
|
||||
web search, etc are always processed by the LLM.
|
||||
"""
|
||||
|
||||
def clone(self, **kwargs: Any) -> Agent[TContext]:
|
||||
"""Make a copy of the agent, with the given arguments changed. For example, you could do:
|
||||
```
|
||||
|
|
|
|||
|
|
@ -129,8 +129,10 @@ class ToolCallOutputItem(RunItemBase[Union[FunctionCallOutput, ComputerCallOutpu
|
|||
raw_item: FunctionCallOutput | ComputerCallOutput
|
||||
"""The raw item from the model."""
|
||||
|
||||
output: str
|
||||
"""The output of the tool call."""
|
||||
output: Any
|
||||
"""The output of the tool call. This is whatever the tool call returned; the `raw_item`
|
||||
contains a string representation of the output.
|
||||
"""
|
||||
|
||||
type: Literal["tool_call_output_item"] = "tool_call_output_item"
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from . import _debug
|
|||
from .computer import AsyncComputer, Computer
|
||||
from .exceptions import ModelBehaviorError
|
||||
from .function_schema import DocstringStyle, function_schema
|
||||
from .items import RunItem
|
||||
from .logger import logger
|
||||
from .run_context import RunContextWrapper
|
||||
from .tracing import SpanError
|
||||
|
|
@ -29,6 +30,18 @@ ToolFunctionWithContext = Callable[Concatenate[RunContextWrapper[Any], ToolParam
|
|||
ToolFunction = Union[ToolFunctionWithoutContext[ToolParams], ToolFunctionWithContext[ToolParams]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionToolResult:
|
||||
tool: FunctionTool
|
||||
"""The tool that was run."""
|
||||
|
||||
output: Any
|
||||
"""The output of the tool."""
|
||||
|
||||
run_item: RunItem
|
||||
"""The run item that was produced as a result of the tool call."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionTool:
|
||||
"""A tool that wraps a function. In most cases, you should use the `function_tool` helpers to
|
||||
|
|
@ -44,15 +57,15 @@ class FunctionTool:
|
|||
params_json_schema: dict[str, Any]
|
||||
"""The JSON schema for the tool's parameters."""
|
||||
|
||||
on_invoke_tool: Callable[[RunContextWrapper[Any], str], Awaitable[str]]
|
||||
on_invoke_tool: Callable[[RunContextWrapper[Any], str], Awaitable[Any]]
|
||||
"""A function that invokes the tool with the given context and parameters. The params passed
|
||||
are:
|
||||
1. The tool run context.
|
||||
2. The arguments from the LLM, as a JSON string.
|
||||
|
||||
You must return a string representation of the tool output. In case of errors, you can either
|
||||
raise an Exception (which will cause the run to fail) or return a string error message (which
|
||||
will be sent back to the LLM).
|
||||
You must return a string representation of the tool output, or something we can call `str()` on.
|
||||
In case of errors, you can either raise an Exception (which will cause the run to fail) or
|
||||
return a string error message (which will be sent back to the LLM).
|
||||
"""
|
||||
|
||||
strict_json_schema: bool = True
|
||||
|
|
@ -207,7 +220,7 @@ def function_tool(
|
|||
strict_json_schema=strict_mode,
|
||||
)
|
||||
|
||||
async def _on_invoke_tool_impl(ctx: RunContextWrapper[Any], input: str) -> str:
|
||||
async def _on_invoke_tool_impl(ctx: RunContextWrapper[Any], input: str) -> Any:
|
||||
try:
|
||||
json_data: dict[str, Any] = json.loads(input) if input else {}
|
||||
except Exception as e:
|
||||
|
|
@ -254,9 +267,9 @@ def function_tool(
|
|||
else:
|
||||
logger.debug(f"Tool {schema.name} returned {result}")
|
||||
|
||||
return str(result)
|
||||
return result
|
||||
|
||||
async def _on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> str:
|
||||
async def _on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> Any:
|
||||
try:
|
||||
return await _on_invoke_tool_impl(ctx, input)
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ class AgentSpanData(SpanData):
|
|||
class FunctionSpanData(SpanData):
|
||||
__slots__ = ("name", "input", "output")
|
||||
|
||||
def __init__(self, name: str, input: str | None, output: str | None):
|
||||
def __init__(self, name: str, input: str | None, output: Any | None):
|
||||
self.name = name
|
||||
self.input = input
|
||||
self.output = output
|
||||
|
|
@ -65,7 +65,7 @@ class FunctionSpanData(SpanData):
|
|||
"type": self.type,
|
||||
"name": self.name,
|
||||
"input": self.input,
|
||||
"output": self.output,
|
||||
"output": str(self.output) if self.output else None,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -21,6 +21,8 @@ from agents import (
|
|||
UserError,
|
||||
handoff,
|
||||
)
|
||||
from agents.agent import ToolsToFinalOutputResult
|
||||
from agents.tool import FunctionToolResult, function_tool
|
||||
|
||||
from .fake_model import FakeModel
|
||||
from .test_responses import (
|
||||
|
|
@ -552,3 +554,83 @@ async def test_output_guardrail_tripwire_triggered_causes_exception():
|
|||
|
||||
with pytest.raises(OutputGuardrailTripwireTriggered):
|
||||
await Runner.run(agent, input="user_message")
|
||||
|
||||
|
||||
@function_tool
|
||||
def test_tool_one():
|
||||
return Foo(bar="tool_one_result")
|
||||
|
||||
|
||||
@function_tool
|
||||
def test_tool_two():
|
||||
return "tool_two_result"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_use_behavior_first_output():
|
||||
model = FakeModel()
|
||||
agent = Agent(
|
||||
name="test",
|
||||
model=model,
|
||||
tools=[get_function_tool("foo", "tool_result"), test_tool_one, test_tool_two],
|
||||
tool_use_behavior="stop_on_first_tool",
|
||||
output_type=Foo,
|
||||
)
|
||||
|
||||
model.add_multiple_turn_outputs(
|
||||
[
|
||||
# First turn: a message and tool call
|
||||
[
|
||||
get_text_message("a_message"),
|
||||
get_function_tool_call("test_tool_one", None),
|
||||
get_function_tool_call("test_tool_two", None),
|
||||
],
|
||||
]
|
||||
)
|
||||
|
||||
result = await Runner.run(agent, input="user_message")
|
||||
|
||||
assert result.final_output == Foo(bar="tool_one_result"), (
|
||||
"should have used the first tool result"
|
||||
)
|
||||
|
||||
|
||||
def custom_tool_use_behavior(
|
||||
context: RunContextWrapper[Any], results: list[FunctionToolResult]
|
||||
) -> ToolsToFinalOutputResult:
|
||||
if "test_tool_one" in [result.tool.name for result in results]:
|
||||
return ToolsToFinalOutputResult(is_final_output=True, final_output="the_final_output")
|
||||
else:
|
||||
return ToolsToFinalOutputResult(is_final_output=False, final_output=None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_use_behavior_custom_function():
|
||||
model = FakeModel()
|
||||
agent = Agent(
|
||||
name="test",
|
||||
model=model,
|
||||
tools=[get_function_tool("foo", "tool_result"), test_tool_one, test_tool_two],
|
||||
tool_use_behavior=custom_tool_use_behavior,
|
||||
)
|
||||
|
||||
model.add_multiple_turn_outputs(
|
||||
[
|
||||
# First turn: a message and tool call
|
||||
[
|
||||
get_text_message("a_message"),
|
||||
get_function_tool_call("test_tool_two", None),
|
||||
],
|
||||
# Second turn: a message and tool call
|
||||
[
|
||||
get_text_message("a_message"),
|
||||
get_function_tool_call("test_tool_one", None),
|
||||
get_function_tool_call("test_tool_two", None),
|
||||
],
|
||||
]
|
||||
)
|
||||
|
||||
result = await Runner.run(agent, input="user_message")
|
||||
|
||||
assert len(result.raw_responses) == 2, "should have two model responses"
|
||||
assert result.final_output == "the_final_output", "should have used the custom function"
|
||||
|
|
|
|||
|
|
@ -49,10 +49,10 @@ async def test_simple_function():
|
|||
assert tool.name == "simple_function"
|
||||
|
||||
result = await tool.on_invoke_tool(RunContextWrapper(None), '{"a": 1}')
|
||||
assert result == "6"
|
||||
assert result == 6
|
||||
|
||||
result = await tool.on_invoke_tool(RunContextWrapper(None), '{"a": 1, "b": 2}')
|
||||
assert result == "3"
|
||||
assert result == 3
|
||||
|
||||
# Missing required argument should raise an error
|
||||
with pytest.raises(ModelBehaviorError):
|
||||
|
|
|
|||
194
tests/test_tool_use_behavior.py
Normal file
194
tests/test_tool_use_behavior.py
Normal file
|
|
@ -0,0 +1,194 @@
|
|||
# Copyright
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from openai.types.responses.response_input_item_param import FunctionCallOutput
|
||||
|
||||
from agents import (
|
||||
Agent,
|
||||
FunctionToolResult,
|
||||
RunConfig,
|
||||
RunContextWrapper,
|
||||
ToolCallOutputItem,
|
||||
ToolsToFinalOutputResult,
|
||||
UserError,
|
||||
)
|
||||
from agents._run_impl import RunImpl
|
||||
|
||||
from .test_responses import get_function_tool
|
||||
|
||||
|
||||
def _make_function_tool_result(
|
||||
agent: Agent, output: str, tool_name: str | None = None
|
||||
) -> FunctionToolResult:
|
||||
# Construct a FunctionToolResult with the given output using a simple function tool.
|
||||
tool = get_function_tool(tool_name or "dummy", return_value=output)
|
||||
raw_item: FunctionCallOutput = cast(
|
||||
FunctionCallOutput,
|
||||
{
|
||||
"call_id": "1",
|
||||
"output": output,
|
||||
"type": "function_call_output",
|
||||
},
|
||||
)
|
||||
# For this test we don't care about the specific RunItem subclass, only the output field
|
||||
run_item = ToolCallOutputItem(agent=agent, raw_item=raw_item, output=output)
|
||||
return FunctionToolResult(tool=tool, output=output, run_item=run_item)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_tool_results_returns_not_final_output() -> None:
|
||||
# If there are no tool results at all, tool_use_behavior should not produce a final output.
|
||||
agent = Agent(name="test")
|
||||
result = await RunImpl._check_for_final_output_from_tools(
|
||||
agent=agent,
|
||||
tool_results=[],
|
||||
context_wrapper=RunContextWrapper(context=None),
|
||||
config=RunConfig(),
|
||||
)
|
||||
assert result.is_final_output is False
|
||||
assert result.final_output is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_llm_again_behavior() -> None:
|
||||
# With the default run_llm_again behavior, even with tools we still expect to keep running.
|
||||
agent = Agent(name="test", tool_use_behavior="run_llm_again")
|
||||
tool_results = [_make_function_tool_result(agent, "ignored")]
|
||||
result = await RunImpl._check_for_final_output_from_tools(
|
||||
agent=agent,
|
||||
tool_results=tool_results,
|
||||
context_wrapper=RunContextWrapper(context=None),
|
||||
config=RunConfig(),
|
||||
)
|
||||
assert result.is_final_output is False
|
||||
assert result.final_output is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_on_first_tool_behavior() -> None:
|
||||
# When tool_use_behavior is stop_on_first_tool, we should surface first tool output as final.
|
||||
agent = Agent(name="test", tool_use_behavior="stop_on_first_tool")
|
||||
tool_results = [
|
||||
_make_function_tool_result(agent, "first_tool_output"),
|
||||
_make_function_tool_result(agent, "ignored"),
|
||||
]
|
||||
result = await RunImpl._check_for_final_output_from_tools(
|
||||
agent=agent,
|
||||
tool_results=tool_results,
|
||||
context_wrapper=RunContextWrapper(context=None),
|
||||
config=RunConfig(),
|
||||
)
|
||||
assert result.is_final_output is True
|
||||
assert result.final_output == "first_tool_output"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_tool_use_behavior_sync() -> None:
|
||||
"""If tool_use_behavior is a sync function, we should call it and propagate its return."""
|
||||
|
||||
def behavior(
|
||||
context: RunContextWrapper, results: list[FunctionToolResult]
|
||||
) -> ToolsToFinalOutputResult:
|
||||
assert len(results) == 3
|
||||
return ToolsToFinalOutputResult(is_final_output=True, final_output="custom")
|
||||
|
||||
agent = Agent(name="test", tool_use_behavior=behavior)
|
||||
tool_results = [
|
||||
_make_function_tool_result(agent, "ignored1"),
|
||||
_make_function_tool_result(agent, "ignored2"),
|
||||
_make_function_tool_result(agent, "ignored3"),
|
||||
]
|
||||
result = await RunImpl._check_for_final_output_from_tools(
|
||||
agent=agent,
|
||||
tool_results=tool_results,
|
||||
context_wrapper=RunContextWrapper(context=None),
|
||||
config=RunConfig(),
|
||||
)
|
||||
assert result.is_final_output is True
|
||||
assert result.final_output == "custom"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_custom_tool_use_behavior_async() -> None:
|
||||
"""If tool_use_behavior is an async function, we should await it and propagate its return."""
|
||||
|
||||
async def behavior(
|
||||
context: RunContextWrapper, results: list[FunctionToolResult]
|
||||
) -> ToolsToFinalOutputResult:
|
||||
assert len(results) == 3
|
||||
return ToolsToFinalOutputResult(is_final_output=True, final_output="async_custom")
|
||||
|
||||
agent = Agent(name="test", tool_use_behavior=behavior)
|
||||
tool_results = [
|
||||
_make_function_tool_result(agent, "ignored1"),
|
||||
_make_function_tool_result(agent, "ignored2"),
|
||||
_make_function_tool_result(agent, "ignored3"),
|
||||
]
|
||||
result = await RunImpl._check_for_final_output_from_tools(
|
||||
agent=agent,
|
||||
tool_results=tool_results,
|
||||
context_wrapper=RunContextWrapper(context=None),
|
||||
config=RunConfig(),
|
||||
)
|
||||
assert result.is_final_output is True
|
||||
assert result.final_output == "async_custom"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_tool_use_behavior_raises() -> None:
|
||||
"""If tool_use_behavior is invalid, we should raise a UserError."""
|
||||
agent = Agent(name="test")
|
||||
# Force an invalid value; mypy will complain, so ignore the type here.
|
||||
agent.tool_use_behavior = "bad_value" # type: ignore[assignment]
|
||||
tool_results = [_make_function_tool_result(agent, "ignored")]
|
||||
with pytest.raises(UserError):
|
||||
await RunImpl._check_for_final_output_from_tools(
|
||||
agent=agent,
|
||||
tool_results=tool_results,
|
||||
context_wrapper=RunContextWrapper(context=None),
|
||||
config=RunConfig(),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_names_to_stop_at_behavior() -> None:
|
||||
agent = Agent(
|
||||
name="test",
|
||||
tools=[
|
||||
get_function_tool("tool1", return_value="tool1_output"),
|
||||
get_function_tool("tool2", return_value="tool2_output"),
|
||||
get_function_tool("tool3", return_value="tool3_output"),
|
||||
],
|
||||
tool_use_behavior={"stop_at_tool_names": ["tool1"]},
|
||||
)
|
||||
|
||||
tool_results = [
|
||||
_make_function_tool_result(agent, "ignored1", "tool2"),
|
||||
_make_function_tool_result(agent, "ignored3", "tool3"),
|
||||
]
|
||||
result = await RunImpl._check_for_final_output_from_tools(
|
||||
agent=agent,
|
||||
tool_results=tool_results,
|
||||
context_wrapper=RunContextWrapper(context=None),
|
||||
config=RunConfig(),
|
||||
)
|
||||
assert result.is_final_output is False, "We should not have stopped at tool1"
|
||||
|
||||
# Now test with a tool that matches the list
|
||||
tool_results = [
|
||||
_make_function_tool_result(agent, "output1", "tool1"),
|
||||
_make_function_tool_result(agent, "ignored2", "tool2"),
|
||||
_make_function_tool_result(agent, "ignored3", "tool3"),
|
||||
]
|
||||
result = await RunImpl._check_for_final_output_from_tools(
|
||||
agent=agent,
|
||||
tool_results=tool_results,
|
||||
context_wrapper=RunContextWrapper(context=None),
|
||||
config=RunConfig(),
|
||||
)
|
||||
assert result.is_final_output is True, "We should have stopped at tool1"
|
||||
assert result.final_output == "output1"
|
||||
Loading…
Reference in a new issue