Introduce tool_use_behavior on agents
This commit is contained in:
parent
6f7e801da0
commit
10aa5555af
12 changed files with 594 additions and 26 deletions
|
|
@ -130,3 +130,16 @@ robot_agent = pirate_agent.clone(
|
||||||
instructions="Write like a robot",
|
instructions="Write like a robot",
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Forcing tool use
|
||||||
|
|
||||||
|
Supplying a list of tools doesn't always mean the LLM will use a tool. You can force tool use by setting [`ModelSettings.tool_choice`][agents.model_settings.ModelSettings.tool_choice]. Valid values are:
|
||||||
|
|
||||||
|
1. `auto`, which allows the LLM to decide whether or not to use a tool.
|
||||||
|
2. `required`, which requires the LLM to use a tool (but it can intelligently decide which tool).
|
||||||
|
3. `none`, which requires the LLM to _not_ use a tool.
|
||||||
|
4. Setting a specific string e.g. `my_tool`, which requires the LLM to use that specific tool.
|
||||||
|
|
||||||
|
!!! note
|
||||||
|
|
||||||
|
If requiring tool use, you should consider setting [`Agent.tool_use_behavior`] to stop the Agent from running when a tool output is produced. Otherwise, the Agent might run in an infinite loop, where the LLM produces a tool call , and the tool result is sent to the LLM, and this infinite loops because the LLM is always forced to use a tool.
|
||||||
|
|
|
||||||
99
examples/agent_patterns/forcing_tool_use.py
Normal file
99
examples/agent_patterns/forcing_tool_use.py
Normal file
|
|
@ -0,0 +1,99 @@
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from agents import (
|
||||||
|
Agent,
|
||||||
|
FunctionToolResult,
|
||||||
|
ModelSettings,
|
||||||
|
RunContextWrapper,
|
||||||
|
Runner,
|
||||||
|
ToolsToFinalOutputFunction,
|
||||||
|
ToolsToFinalOutputResult,
|
||||||
|
function_tool,
|
||||||
|
)
|
||||||
|
|
||||||
|
"""
|
||||||
|
This example shows how to force the agent to use a tool. It uses `ModelSettings(tool_choice="required")`
|
||||||
|
to force the agent to use any tool.
|
||||||
|
|
||||||
|
You can run it with 3 options:
|
||||||
|
1. `default`: The default behavior, which is to send the tool output to the LLM. In this case,
|
||||||
|
`tool_choice` is not set, because otherwise it would result in an infinite loop - the LLM would
|
||||||
|
call the tool, the tool would run and send the results to the LLM, and that would repeat
|
||||||
|
(because the model is forced to use a tool every time.)
|
||||||
|
2. `first_tool_result`: The first tool result is used as the final output.
|
||||||
|
3. `custom`: A custom tool use behavior function is used. The custom function receives all the tool
|
||||||
|
results, and chooses to use the first tool result to generate the final output.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
python examples/agent_patterns/forcing_tool_use.py -t default
|
||||||
|
python examples/agent_patterns/forcing_tool_use.py -t first_tool
|
||||||
|
python examples/agent_patterns/forcing_tool_use.py -t custom
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class Weather(BaseModel):
|
||||||
|
city: str
|
||||||
|
temperature_range: str
|
||||||
|
conditions: str
|
||||||
|
|
||||||
|
|
||||||
|
@function_tool
|
||||||
|
def get_weather(city: str) -> Weather:
|
||||||
|
print("[debug] get_weather called")
|
||||||
|
return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind")
|
||||||
|
|
||||||
|
|
||||||
|
async def custom_tool_use_behavior(
|
||||||
|
context: RunContextWrapper[Any], results: list[FunctionToolResult]
|
||||||
|
) -> ToolsToFinalOutputResult:
|
||||||
|
weather: Weather = results[0].output
|
||||||
|
return ToolsToFinalOutputResult(
|
||||||
|
is_final_output=True, final_output=f"{weather.city} is {weather.conditions}."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def main(tool_use_behavior: Literal["default", "first_tool", "custom"] = "default"):
|
||||||
|
if tool_use_behavior == "default":
|
||||||
|
behavior: Literal["run_llm_again", "stop_on_first_tool"] | ToolsToFinalOutputFunction = (
|
||||||
|
"run_llm_again"
|
||||||
|
)
|
||||||
|
elif tool_use_behavior == "first_tool":
|
||||||
|
behavior = "stop_on_first_tool"
|
||||||
|
elif tool_use_behavior == "custom":
|
||||||
|
behavior = custom_tool_use_behavior
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
name="Weather agent",
|
||||||
|
instructions="You are a helpful agent.",
|
||||||
|
tools=[get_weather],
|
||||||
|
tool_use_behavior=behavior,
|
||||||
|
model_settings=ModelSettings(
|
||||||
|
tool_choice="required" if tool_use_behavior != "default" else None
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await Runner.run(agent, input="What's the weather in Tokyo?")
|
||||||
|
print(result.final_output)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
"-t",
|
||||||
|
"--tool-use-behavior",
|
||||||
|
type=str,
|
||||||
|
required=True,
|
||||||
|
choices=["default", "first_tool", "custom"],
|
||||||
|
help="The behavior to use for tool use. Default will cause tool outputs to be sent to the model. "
|
||||||
|
"first_tool_result will cause the first tool result to be used as the final output. "
|
||||||
|
"custom will use a custom tool use behavior function.",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
asyncio.run(main(args.tool_use_behavior))
|
||||||
34
examples/basic/tools.py
Normal file
34
examples/basic/tools.py
Normal file
|
|
@ -0,0 +1,34 @@
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from agents import Agent, Runner, function_tool
|
||||||
|
|
||||||
|
|
||||||
|
class Weather(BaseModel):
|
||||||
|
city: str
|
||||||
|
temperature_range: str
|
||||||
|
conditions: str
|
||||||
|
|
||||||
|
|
||||||
|
@function_tool
|
||||||
|
def get_weather(city: str) -> Weather:
|
||||||
|
print("[debug] get_weather called")
|
||||||
|
return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind.")
|
||||||
|
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
name="Hello world",
|
||||||
|
instructions="You are a helpful agent.",
|
||||||
|
tools=[get_weather],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
result = await Runner.run(agent, input="What's the weather in Tokyo?")
|
||||||
|
print(result.final_output)
|
||||||
|
# The weather in Tokyo is sunny.
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
asyncio.run(main())
|
||||||
|
|
@ -5,7 +5,7 @@ from typing import Literal
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
from . import _config
|
from . import _config
|
||||||
from .agent import Agent
|
from .agent import Agent, ToolsToFinalOutputFunction, ToolsToFinalOutputResult
|
||||||
from .agent_output import AgentOutputSchema
|
from .agent_output import AgentOutputSchema
|
||||||
from .computer import AsyncComputer, Button, Computer, Environment
|
from .computer import AsyncComputer, Button, Computer, Environment
|
||||||
from .exceptions import (
|
from .exceptions import (
|
||||||
|
|
@ -57,6 +57,7 @@ from .tool import (
|
||||||
ComputerTool,
|
ComputerTool,
|
||||||
FileSearchTool,
|
FileSearchTool,
|
||||||
FunctionTool,
|
FunctionTool,
|
||||||
|
FunctionToolResult,
|
||||||
Tool,
|
Tool,
|
||||||
WebSearchTool,
|
WebSearchTool,
|
||||||
default_tool_error_function,
|
default_tool_error_function,
|
||||||
|
|
@ -137,6 +138,8 @@ def enable_verbose_stdout_logging():
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Agent",
|
"Agent",
|
||||||
|
"ToolsToFinalOutputFunction",
|
||||||
|
"ToolsToFinalOutputResult",
|
||||||
"Runner",
|
"Runner",
|
||||||
"Model",
|
"Model",
|
||||||
"ModelProvider",
|
"ModelProvider",
|
||||||
|
|
@ -190,6 +193,7 @@ __all__ = [
|
||||||
"AgentUpdatedStreamEvent",
|
"AgentUpdatedStreamEvent",
|
||||||
"StreamEvent",
|
"StreamEvent",
|
||||||
"FunctionTool",
|
"FunctionTool",
|
||||||
|
"FunctionToolResult",
|
||||||
"ComputerTool",
|
"ComputerTool",
|
||||||
"FileSearchTool",
|
"FileSearchTool",
|
||||||
"Tool",
|
"Tool",
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,10 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import inspect
|
||||||
|
from collections.abc import Awaitable
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any, cast
|
||||||
|
|
||||||
from openai.types.responses import (
|
from openai.types.responses import (
|
||||||
ResponseComputerToolCall,
|
ResponseComputerToolCall,
|
||||||
|
|
@ -25,7 +27,7 @@ from openai.types.responses.response_computer_tool_call import (
|
||||||
from openai.types.responses.response_input_param import ComputerCallOutput
|
from openai.types.responses.response_input_param import ComputerCallOutput
|
||||||
from openai.types.responses.response_reasoning_item import ResponseReasoningItem
|
from openai.types.responses.response_reasoning_item import ResponseReasoningItem
|
||||||
|
|
||||||
from .agent import Agent
|
from .agent import Agent, ToolsToFinalOutputResult
|
||||||
from .agent_output import AgentOutputSchema
|
from .agent_output import AgentOutputSchema
|
||||||
from .computer import AsyncComputer, Computer
|
from .computer import AsyncComputer, Computer
|
||||||
from .exceptions import AgentsException, ModelBehaviorError, UserError
|
from .exceptions import AgentsException, ModelBehaviorError, UserError
|
||||||
|
|
@ -48,7 +50,7 @@ from .logger import logger
|
||||||
from .models.interface import ModelTracing
|
from .models.interface import ModelTracing
|
||||||
from .run_context import RunContextWrapper, TContext
|
from .run_context import RunContextWrapper, TContext
|
||||||
from .stream_events import RunItemStreamEvent, StreamEvent
|
from .stream_events import RunItemStreamEvent, StreamEvent
|
||||||
from .tool import ComputerTool, FunctionTool
|
from .tool import ComputerTool, FunctionTool, FunctionToolResult
|
||||||
from .tracing import (
|
from .tracing import (
|
||||||
SpanError,
|
SpanError,
|
||||||
Trace,
|
Trace,
|
||||||
|
|
@ -70,6 +72,8 @@ class QueueCompleteSentinel:
|
||||||
|
|
||||||
QUEUE_COMPLETE_SENTINEL = QueueCompleteSentinel()
|
QUEUE_COMPLETE_SENTINEL = QueueCompleteSentinel()
|
||||||
|
|
||||||
|
_NOT_FINAL_OUTPUT = ToolsToFinalOutputResult(is_final_output=False, final_output=None)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ToolRunHandoff:
|
class ToolRunHandoff:
|
||||||
|
|
@ -199,7 +203,7 @@ class RunImpl:
|
||||||
config=run_config,
|
config=run_config,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
new_step_items.extend(function_results)
|
new_step_items.extend([result.run_item for result in function_results])
|
||||||
new_step_items.extend(computer_results)
|
new_step_items.extend(computer_results)
|
||||||
|
|
||||||
# Second, check if there are any handoffs
|
# Second, check if there are any handoffs
|
||||||
|
|
@ -216,6 +220,36 @@ class RunImpl:
|
||||||
run_config=run_config,
|
run_config=run_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Third, we'll check if the tool use should result in a final output
|
||||||
|
check_tool_use = await cls._check_for_final_output_from_tools(
|
||||||
|
agent=agent,
|
||||||
|
tool_results=function_results,
|
||||||
|
context_wrapper=context_wrapper,
|
||||||
|
config=run_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
if check_tool_use.is_final_output:
|
||||||
|
# If the output type is str, then let's just stringify it
|
||||||
|
if not agent.output_type or agent.output_type is str:
|
||||||
|
check_tool_use.final_output = str(check_tool_use.final_output)
|
||||||
|
|
||||||
|
if check_tool_use.final_output is None:
|
||||||
|
logger.error(
|
||||||
|
"Model returned a final output of None. Not raising an error because we assume"
|
||||||
|
"you know what you're doing."
|
||||||
|
)
|
||||||
|
|
||||||
|
return await cls.execute_final_output(
|
||||||
|
agent=agent,
|
||||||
|
original_input=original_input,
|
||||||
|
new_response=new_response,
|
||||||
|
pre_step_items=pre_step_items,
|
||||||
|
new_step_items=new_step_items,
|
||||||
|
final_output=check_tool_use.final_output,
|
||||||
|
hooks=hooks,
|
||||||
|
context_wrapper=context_wrapper,
|
||||||
|
)
|
||||||
|
|
||||||
# Now we can check if the model also produced a final output
|
# Now we can check if the model also produced a final output
|
||||||
message_items = [item for item in new_step_items if isinstance(item, MessageOutputItem)]
|
message_items = [item for item in new_step_items if isinstance(item, MessageOutputItem)]
|
||||||
|
|
||||||
|
|
@ -355,10 +389,10 @@ class RunImpl:
|
||||||
hooks: RunHooks[TContext],
|
hooks: RunHooks[TContext],
|
||||||
context_wrapper: RunContextWrapper[TContext],
|
context_wrapper: RunContextWrapper[TContext],
|
||||||
config: RunConfig,
|
config: RunConfig,
|
||||||
) -> list[RunItem]:
|
) -> list[FunctionToolResult]:
|
||||||
async def run_single_tool(
|
async def run_single_tool(
|
||||||
func_tool: FunctionTool, tool_call: ResponseFunctionToolCall
|
func_tool: FunctionTool, tool_call: ResponseFunctionToolCall
|
||||||
) -> str:
|
) -> Any:
|
||||||
with function_span(func_tool.name) as span_fn:
|
with function_span(func_tool.name) as span_fn:
|
||||||
if config.trace_include_sensitive_data:
|
if config.trace_include_sensitive_data:
|
||||||
span_fn.span_data.input = tool_call.arguments
|
span_fn.span_data.input = tool_call.arguments
|
||||||
|
|
@ -404,10 +438,14 @@ class RunImpl:
|
||||||
results = await asyncio.gather(*tasks)
|
results = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
return [
|
return [
|
||||||
ToolCallOutputItem(
|
FunctionToolResult(
|
||||||
output=str(result),
|
tool=tool_run.function_tool,
|
||||||
raw_item=ItemHelpers.tool_call_output_item(tool_run.tool_call, str(result)),
|
output=result,
|
||||||
agent=agent,
|
run_item=ToolCallOutputItem(
|
||||||
|
output=result,
|
||||||
|
raw_item=ItemHelpers.tool_call_output_item(tool_run.tool_call, str(result)),
|
||||||
|
agent=agent,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
for tool_run, result in zip(tool_runs, results)
|
for tool_run, result in zip(tool_runs, results)
|
||||||
]
|
]
|
||||||
|
|
@ -646,6 +684,47 @@ class RunImpl:
|
||||||
if event:
|
if event:
|
||||||
queue.put_nowait(event)
|
queue.put_nowait(event)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def _check_for_final_output_from_tools(
|
||||||
|
cls,
|
||||||
|
*,
|
||||||
|
agent: Agent[TContext],
|
||||||
|
tool_results: list[FunctionToolResult],
|
||||||
|
context_wrapper: RunContextWrapper[TContext],
|
||||||
|
config: RunConfig,
|
||||||
|
) -> ToolsToFinalOutputResult:
|
||||||
|
"""Returns (i, final_output)."""
|
||||||
|
if not tool_results:
|
||||||
|
return _NOT_FINAL_OUTPUT
|
||||||
|
|
||||||
|
if agent.tool_use_behavior == "run_llm_again":
|
||||||
|
return _NOT_FINAL_OUTPUT
|
||||||
|
elif agent.tool_use_behavior == "stop_on_first_tool":
|
||||||
|
return ToolsToFinalOutputResult(
|
||||||
|
is_final_output=True, final_output=tool_results[0].output
|
||||||
|
)
|
||||||
|
elif isinstance(agent.tool_use_behavior, dict):
|
||||||
|
names = agent.tool_use_behavior.get("stop_at_tool_names", [])
|
||||||
|
for tool_result in tool_results:
|
||||||
|
if tool_result.tool.name in names:
|
||||||
|
return ToolsToFinalOutputResult(
|
||||||
|
is_final_output=True, final_output=tool_result.output
|
||||||
|
)
|
||||||
|
return ToolsToFinalOutputResult(is_final_output=False, final_output=None)
|
||||||
|
elif callable(agent.tool_use_behavior):
|
||||||
|
if inspect.iscoroutinefunction(agent.tool_use_behavior):
|
||||||
|
return await cast(
|
||||||
|
Awaitable[ToolsToFinalOutputResult],
|
||||||
|
agent.tool_use_behavior(context_wrapper, tool_results),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return cast(
|
||||||
|
ToolsToFinalOutputResult, agent.tool_use_behavior(context_wrapper, tool_results)
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.error(f"Invalid tool_use_behavior: {agent.tool_use_behavior}")
|
||||||
|
raise UserError(f"Invalid tool_use_behavior: {agent.tool_use_behavior}")
|
||||||
|
|
||||||
|
|
||||||
class TraceCtxManager:
|
class TraceCtxManager:
|
||||||
"""Creates a trace only if there is no current trace, and manages the trace lifecycle."""
|
"""Creates a trace only if there is no current trace, and manages the trace lifecycle."""
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,9 @@ import dataclasses
|
||||||
import inspect
|
import inspect
|
||||||
from collections.abc import Awaitable
|
from collections.abc import Awaitable
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Generic, cast
|
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, cast
|
||||||
|
|
||||||
|
from typing_extensions import TypeAlias, TypedDict
|
||||||
|
|
||||||
from .guardrail import InputGuardrail, OutputGuardrail
|
from .guardrail import InputGuardrail, OutputGuardrail
|
||||||
from .handoffs import Handoff
|
from .handoffs import Handoff
|
||||||
|
|
@ -13,7 +15,7 @@ from .logger import logger
|
||||||
from .model_settings import ModelSettings
|
from .model_settings import ModelSettings
|
||||||
from .models.interface import Model
|
from .models.interface import Model
|
||||||
from .run_context import RunContextWrapper, TContext
|
from .run_context import RunContextWrapper, TContext
|
||||||
from .tool import Tool, function_tool
|
from .tool import FunctionToolResult, Tool, function_tool
|
||||||
from .util import _transforms
|
from .util import _transforms
|
||||||
from .util._types import MaybeAwaitable
|
from .util._types import MaybeAwaitable
|
||||||
|
|
||||||
|
|
@ -22,6 +24,33 @@ if TYPE_CHECKING:
|
||||||
from .result import RunResult
|
from .result import RunResult
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolsToFinalOutputResult:
|
||||||
|
is_final_output: bool
|
||||||
|
"""Whether this is the final output. If False, the LLM will run again and receive the tool call
|
||||||
|
output.
|
||||||
|
"""
|
||||||
|
|
||||||
|
final_output: Any | None = None
|
||||||
|
"""The final output. Can be None if `is_final_output` is False, otherwise must match the
|
||||||
|
`output_type` of the agent.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
ToolsToFinalOutputFunction: TypeAlias = Callable[
|
||||||
|
[RunContextWrapper[TContext], list[FunctionToolResult]],
|
||||||
|
MaybeAwaitable[ToolsToFinalOutputResult],
|
||||||
|
]
|
||||||
|
"""A function that takes a run context and a list of tool results, and returns a
|
||||||
|
`ToolToFinalOutputResult`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class StopAtTools(TypedDict):
|
||||||
|
stop_at_tool_names: list[str]
|
||||||
|
"""A list of tool names, any of which will stop the agent from running further."""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Agent(Generic[TContext]):
|
class Agent(Generic[TContext]):
|
||||||
"""An agent is an AI model configured with instructions, tools, guardrails, handoffs and more.
|
"""An agent is an AI model configured with instructions, tools, guardrails, handoffs and more.
|
||||||
|
|
@ -95,6 +124,25 @@ class Agent(Generic[TContext]):
|
||||||
"""A class that receives callbacks on various lifecycle events for this agent.
|
"""A class that receives callbacks on various lifecycle events for this agent.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
tool_use_behavior: (
|
||||||
|
Literal["run_llm_again", "stop_on_first_tool"] | StopAtTools | ToolsToFinalOutputFunction
|
||||||
|
) = "run_llm_again"
|
||||||
|
"""This lets you configure how tool use is handled.
|
||||||
|
- "run_llm_again": The default behavior. Tools are run, and then the LLM receives the results
|
||||||
|
and gets to respond.
|
||||||
|
- "stop_on_first_tool": The output of the first tool call is used as the final output. This
|
||||||
|
means that the LLM does not process the result of the tool call.
|
||||||
|
- A list of tool names: The agent will stop running if any of the tools in the list are called.
|
||||||
|
The final output will be the output of the first matching tool call. The LLM does not
|
||||||
|
process the result of the tool call.
|
||||||
|
- A function: If you pass a function, it will be called with the run context and the list of
|
||||||
|
tool results. It must return a `ToolToFinalOutputResult`, which determines whether the tool
|
||||||
|
calls result in a final output.
|
||||||
|
|
||||||
|
NOTE: This configuration is specific to FunctionTools. Hosted tools, such as file search,
|
||||||
|
web search, etc are always processed by the LLM.
|
||||||
|
"""
|
||||||
|
|
||||||
def clone(self, **kwargs: Any) -> Agent[TContext]:
|
def clone(self, **kwargs: Any) -> Agent[TContext]:
|
||||||
"""Make a copy of the agent, with the given arguments changed. For example, you could do:
|
"""Make a copy of the agent, with the given arguments changed. For example, you could do:
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -129,8 +129,10 @@ class ToolCallOutputItem(RunItemBase[Union[FunctionCallOutput, ComputerCallOutpu
|
||||||
raw_item: FunctionCallOutput | ComputerCallOutput
|
raw_item: FunctionCallOutput | ComputerCallOutput
|
||||||
"""The raw item from the model."""
|
"""The raw item from the model."""
|
||||||
|
|
||||||
output: str
|
output: Any
|
||||||
"""The output of the tool call."""
|
"""The output of the tool call. This is whatever the tool call returned; the `raw_item`
|
||||||
|
contains a string representation of the output.
|
||||||
|
"""
|
||||||
|
|
||||||
type: Literal["tool_call_output_item"] = "tool_call_output_item"
|
type: Literal["tool_call_output_item"] = "tool_call_output_item"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ from . import _debug
|
||||||
from .computer import AsyncComputer, Computer
|
from .computer import AsyncComputer, Computer
|
||||||
from .exceptions import ModelBehaviorError
|
from .exceptions import ModelBehaviorError
|
||||||
from .function_schema import DocstringStyle, function_schema
|
from .function_schema import DocstringStyle, function_schema
|
||||||
|
from .items import RunItem
|
||||||
from .logger import logger
|
from .logger import logger
|
||||||
from .run_context import RunContextWrapper
|
from .run_context import RunContextWrapper
|
||||||
from .tracing import SpanError
|
from .tracing import SpanError
|
||||||
|
|
@ -29,6 +30,18 @@ ToolFunctionWithContext = Callable[Concatenate[RunContextWrapper[Any], ToolParam
|
||||||
ToolFunction = Union[ToolFunctionWithoutContext[ToolParams], ToolFunctionWithContext[ToolParams]]
|
ToolFunction = Union[ToolFunctionWithoutContext[ToolParams], ToolFunctionWithContext[ToolParams]]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FunctionToolResult:
|
||||||
|
tool: FunctionTool
|
||||||
|
"""The tool that was run."""
|
||||||
|
|
||||||
|
output: Any
|
||||||
|
"""The output of the tool."""
|
||||||
|
|
||||||
|
run_item: RunItem
|
||||||
|
"""The run item that was produced as a result of the tool call."""
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class FunctionTool:
|
class FunctionTool:
|
||||||
"""A tool that wraps a function. In most cases, you should use the `function_tool` helpers to
|
"""A tool that wraps a function. In most cases, you should use the `function_tool` helpers to
|
||||||
|
|
@ -44,15 +57,15 @@ class FunctionTool:
|
||||||
params_json_schema: dict[str, Any]
|
params_json_schema: dict[str, Any]
|
||||||
"""The JSON schema for the tool's parameters."""
|
"""The JSON schema for the tool's parameters."""
|
||||||
|
|
||||||
on_invoke_tool: Callable[[RunContextWrapper[Any], str], Awaitable[str]]
|
on_invoke_tool: Callable[[RunContextWrapper[Any], str], Awaitable[Any]]
|
||||||
"""A function that invokes the tool with the given context and parameters. The params passed
|
"""A function that invokes the tool with the given context and parameters. The params passed
|
||||||
are:
|
are:
|
||||||
1. The tool run context.
|
1. The tool run context.
|
||||||
2. The arguments from the LLM, as a JSON string.
|
2. The arguments from the LLM, as a JSON string.
|
||||||
|
|
||||||
You must return a string representation of the tool output. In case of errors, you can either
|
You must return a string representation of the tool output, or something we can call `str()` on.
|
||||||
raise an Exception (which will cause the run to fail) or return a string error message (which
|
In case of errors, you can either raise an Exception (which will cause the run to fail) or
|
||||||
will be sent back to the LLM).
|
return a string error message (which will be sent back to the LLM).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
strict_json_schema: bool = True
|
strict_json_schema: bool = True
|
||||||
|
|
@ -207,7 +220,7 @@ def function_tool(
|
||||||
strict_json_schema=strict_mode,
|
strict_json_schema=strict_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _on_invoke_tool_impl(ctx: RunContextWrapper[Any], input: str) -> str:
|
async def _on_invoke_tool_impl(ctx: RunContextWrapper[Any], input: str) -> Any:
|
||||||
try:
|
try:
|
||||||
json_data: dict[str, Any] = json.loads(input) if input else {}
|
json_data: dict[str, Any] = json.loads(input) if input else {}
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -254,9 +267,9 @@ def function_tool(
|
||||||
else:
|
else:
|
||||||
logger.debug(f"Tool {schema.name} returned {result}")
|
logger.debug(f"Tool {schema.name} returned {result}")
|
||||||
|
|
||||||
return str(result)
|
return result
|
||||||
|
|
||||||
async def _on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> str:
|
async def _on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> Any:
|
||||||
try:
|
try:
|
||||||
return await _on_invoke_tool_impl(ctx, input)
|
return await _on_invoke_tool_impl(ctx, input)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
|
|
@ -51,7 +51,7 @@ class AgentSpanData(SpanData):
|
||||||
class FunctionSpanData(SpanData):
|
class FunctionSpanData(SpanData):
|
||||||
__slots__ = ("name", "input", "output")
|
__slots__ = ("name", "input", "output")
|
||||||
|
|
||||||
def __init__(self, name: str, input: str | None, output: str | None):
|
def __init__(self, name: str, input: str | None, output: Any | None):
|
||||||
self.name = name
|
self.name = name
|
||||||
self.input = input
|
self.input = input
|
||||||
self.output = output
|
self.output = output
|
||||||
|
|
@ -65,7 +65,7 @@ class FunctionSpanData(SpanData):
|
||||||
"type": self.type,
|
"type": self.type,
|
||||||
"name": self.name,
|
"name": self.name,
|
||||||
"input": self.input,
|
"input": self.input,
|
||||||
"output": self.output,
|
"output": str(self.output) if self.output else None,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,8 @@ from agents import (
|
||||||
UserError,
|
UserError,
|
||||||
handoff,
|
handoff,
|
||||||
)
|
)
|
||||||
|
from agents.agent import ToolsToFinalOutputResult
|
||||||
|
from agents.tool import FunctionToolResult, function_tool
|
||||||
|
|
||||||
from .fake_model import FakeModel
|
from .fake_model import FakeModel
|
||||||
from .test_responses import (
|
from .test_responses import (
|
||||||
|
|
@ -552,3 +554,83 @@ async def test_output_guardrail_tripwire_triggered_causes_exception():
|
||||||
|
|
||||||
with pytest.raises(OutputGuardrailTripwireTriggered):
|
with pytest.raises(OutputGuardrailTripwireTriggered):
|
||||||
await Runner.run(agent, input="user_message")
|
await Runner.run(agent, input="user_message")
|
||||||
|
|
||||||
|
|
||||||
|
@function_tool
|
||||||
|
def test_tool_one():
|
||||||
|
return Foo(bar="tool_one_result")
|
||||||
|
|
||||||
|
|
||||||
|
@function_tool
|
||||||
|
def test_tool_two():
|
||||||
|
return "tool_two_result"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_use_behavior_first_output():
|
||||||
|
model = FakeModel()
|
||||||
|
agent = Agent(
|
||||||
|
name="test",
|
||||||
|
model=model,
|
||||||
|
tools=[get_function_tool("foo", "tool_result"), test_tool_one, test_tool_two],
|
||||||
|
tool_use_behavior="stop_on_first_tool",
|
||||||
|
output_type=Foo,
|
||||||
|
)
|
||||||
|
|
||||||
|
model.add_multiple_turn_outputs(
|
||||||
|
[
|
||||||
|
# First turn: a message and tool call
|
||||||
|
[
|
||||||
|
get_text_message("a_message"),
|
||||||
|
get_function_tool_call("test_tool_one", None),
|
||||||
|
get_function_tool_call("test_tool_two", None),
|
||||||
|
],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await Runner.run(agent, input="user_message")
|
||||||
|
|
||||||
|
assert result.final_output == Foo(bar="tool_one_result"), (
|
||||||
|
"should have used the first tool result"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def custom_tool_use_behavior(
|
||||||
|
context: RunContextWrapper[Any], results: list[FunctionToolResult]
|
||||||
|
) -> ToolsToFinalOutputResult:
|
||||||
|
if "test_tool_one" in [result.tool.name for result in results]:
|
||||||
|
return ToolsToFinalOutputResult(is_final_output=True, final_output="the_final_output")
|
||||||
|
else:
|
||||||
|
return ToolsToFinalOutputResult(is_final_output=False, final_output=None)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_use_behavior_custom_function():
|
||||||
|
model = FakeModel()
|
||||||
|
agent = Agent(
|
||||||
|
name="test",
|
||||||
|
model=model,
|
||||||
|
tools=[get_function_tool("foo", "tool_result"), test_tool_one, test_tool_two],
|
||||||
|
tool_use_behavior=custom_tool_use_behavior,
|
||||||
|
)
|
||||||
|
|
||||||
|
model.add_multiple_turn_outputs(
|
||||||
|
[
|
||||||
|
# First turn: a message and tool call
|
||||||
|
[
|
||||||
|
get_text_message("a_message"),
|
||||||
|
get_function_tool_call("test_tool_two", None),
|
||||||
|
],
|
||||||
|
# Second turn: a message and tool call
|
||||||
|
[
|
||||||
|
get_text_message("a_message"),
|
||||||
|
get_function_tool_call("test_tool_one", None),
|
||||||
|
get_function_tool_call("test_tool_two", None),
|
||||||
|
],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await Runner.run(agent, input="user_message")
|
||||||
|
|
||||||
|
assert len(result.raw_responses) == 2, "should have two model responses"
|
||||||
|
assert result.final_output == "the_final_output", "should have used the custom function"
|
||||||
|
|
|
||||||
|
|
@ -49,10 +49,10 @@ async def test_simple_function():
|
||||||
assert tool.name == "simple_function"
|
assert tool.name == "simple_function"
|
||||||
|
|
||||||
result = await tool.on_invoke_tool(RunContextWrapper(None), '{"a": 1}')
|
result = await tool.on_invoke_tool(RunContextWrapper(None), '{"a": 1}')
|
||||||
assert result == "6"
|
assert result == 6
|
||||||
|
|
||||||
result = await tool.on_invoke_tool(RunContextWrapper(None), '{"a": 1, "b": 2}')
|
result = await tool.on_invoke_tool(RunContextWrapper(None), '{"a": 1, "b": 2}')
|
||||||
assert result == "3"
|
assert result == 3
|
||||||
|
|
||||||
# Missing required argument should raise an error
|
# Missing required argument should raise an error
|
||||||
with pytest.raises(ModelBehaviorError):
|
with pytest.raises(ModelBehaviorError):
|
||||||
|
|
|
||||||
194
tests/test_tool_use_behavior.py
Normal file
194
tests/test_tool_use_behavior.py
Normal file
|
|
@ -0,0 +1,194 @@
|
||||||
|
# Copyright
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from openai.types.responses.response_input_item_param import FunctionCallOutput
|
||||||
|
|
||||||
|
from agents import (
|
||||||
|
Agent,
|
||||||
|
FunctionToolResult,
|
||||||
|
RunConfig,
|
||||||
|
RunContextWrapper,
|
||||||
|
ToolCallOutputItem,
|
||||||
|
ToolsToFinalOutputResult,
|
||||||
|
UserError,
|
||||||
|
)
|
||||||
|
from agents._run_impl import RunImpl
|
||||||
|
|
||||||
|
from .test_responses import get_function_tool
|
||||||
|
|
||||||
|
|
||||||
|
def _make_function_tool_result(
|
||||||
|
agent: Agent, output: str, tool_name: str | None = None
|
||||||
|
) -> FunctionToolResult:
|
||||||
|
# Construct a FunctionToolResult with the given output using a simple function tool.
|
||||||
|
tool = get_function_tool(tool_name or "dummy", return_value=output)
|
||||||
|
raw_item: FunctionCallOutput = cast(
|
||||||
|
FunctionCallOutput,
|
||||||
|
{
|
||||||
|
"call_id": "1",
|
||||||
|
"output": output,
|
||||||
|
"type": "function_call_output",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
# For this test we don't care about the specific RunItem subclass, only the output field
|
||||||
|
run_item = ToolCallOutputItem(agent=agent, raw_item=raw_item, output=output)
|
||||||
|
return FunctionToolResult(tool=tool, output=output, run_item=run_item)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_tool_results_returns_not_final_output() -> None:
|
||||||
|
# If there are no tool results at all, tool_use_behavior should not produce a final output.
|
||||||
|
agent = Agent(name="test")
|
||||||
|
result = await RunImpl._check_for_final_output_from_tools(
|
||||||
|
agent=agent,
|
||||||
|
tool_results=[],
|
||||||
|
context_wrapper=RunContextWrapper(context=None),
|
||||||
|
config=RunConfig(),
|
||||||
|
)
|
||||||
|
assert result.is_final_output is False
|
||||||
|
assert result.final_output is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_llm_again_behavior() -> None:
|
||||||
|
# With the default run_llm_again behavior, even with tools we still expect to keep running.
|
||||||
|
agent = Agent(name="test", tool_use_behavior="run_llm_again")
|
||||||
|
tool_results = [_make_function_tool_result(agent, "ignored")]
|
||||||
|
result = await RunImpl._check_for_final_output_from_tools(
|
||||||
|
agent=agent,
|
||||||
|
tool_results=tool_results,
|
||||||
|
context_wrapper=RunContextWrapper(context=None),
|
||||||
|
config=RunConfig(),
|
||||||
|
)
|
||||||
|
assert result.is_final_output is False
|
||||||
|
assert result.final_output is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stop_on_first_tool_behavior() -> None:
|
||||||
|
# When tool_use_behavior is stop_on_first_tool, we should surface first tool output as final.
|
||||||
|
agent = Agent(name="test", tool_use_behavior="stop_on_first_tool")
|
||||||
|
tool_results = [
|
||||||
|
_make_function_tool_result(agent, "first_tool_output"),
|
||||||
|
_make_function_tool_result(agent, "ignored"),
|
||||||
|
]
|
||||||
|
result = await RunImpl._check_for_final_output_from_tools(
|
||||||
|
agent=agent,
|
||||||
|
tool_results=tool_results,
|
||||||
|
context_wrapper=RunContextWrapper(context=None),
|
||||||
|
config=RunConfig(),
|
||||||
|
)
|
||||||
|
assert result.is_final_output is True
|
||||||
|
assert result.final_output == "first_tool_output"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_custom_tool_use_behavior_sync() -> None:
|
||||||
|
"""If tool_use_behavior is a sync function, we should call it and propagate its return."""
|
||||||
|
|
||||||
|
def behavior(
|
||||||
|
context: RunContextWrapper, results: list[FunctionToolResult]
|
||||||
|
) -> ToolsToFinalOutputResult:
|
||||||
|
assert len(results) == 3
|
||||||
|
return ToolsToFinalOutputResult(is_final_output=True, final_output="custom")
|
||||||
|
|
||||||
|
agent = Agent(name="test", tool_use_behavior=behavior)
|
||||||
|
tool_results = [
|
||||||
|
_make_function_tool_result(agent, "ignored1"),
|
||||||
|
_make_function_tool_result(agent, "ignored2"),
|
||||||
|
_make_function_tool_result(agent, "ignored3"),
|
||||||
|
]
|
||||||
|
result = await RunImpl._check_for_final_output_from_tools(
|
||||||
|
agent=agent,
|
||||||
|
tool_results=tool_results,
|
||||||
|
context_wrapper=RunContextWrapper(context=None),
|
||||||
|
config=RunConfig(),
|
||||||
|
)
|
||||||
|
assert result.is_final_output is True
|
||||||
|
assert result.final_output == "custom"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_custom_tool_use_behavior_async() -> None:
|
||||||
|
"""If tool_use_behavior is an async function, we should await it and propagate its return."""
|
||||||
|
|
||||||
|
async def behavior(
|
||||||
|
context: RunContextWrapper, results: list[FunctionToolResult]
|
||||||
|
) -> ToolsToFinalOutputResult:
|
||||||
|
assert len(results) == 3
|
||||||
|
return ToolsToFinalOutputResult(is_final_output=True, final_output="async_custom")
|
||||||
|
|
||||||
|
agent = Agent(name="test", tool_use_behavior=behavior)
|
||||||
|
tool_results = [
|
||||||
|
_make_function_tool_result(agent, "ignored1"),
|
||||||
|
_make_function_tool_result(agent, "ignored2"),
|
||||||
|
_make_function_tool_result(agent, "ignored3"),
|
||||||
|
]
|
||||||
|
result = await RunImpl._check_for_final_output_from_tools(
|
||||||
|
agent=agent,
|
||||||
|
tool_results=tool_results,
|
||||||
|
context_wrapper=RunContextWrapper(context=None),
|
||||||
|
config=RunConfig(),
|
||||||
|
)
|
||||||
|
assert result.is_final_output is True
|
||||||
|
assert result.final_output == "async_custom"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_invalid_tool_use_behavior_raises() -> None:
|
||||||
|
"""If tool_use_behavior is invalid, we should raise a UserError."""
|
||||||
|
agent = Agent(name="test")
|
||||||
|
# Force an invalid value; mypy will complain, so ignore the type here.
|
||||||
|
agent.tool_use_behavior = "bad_value" # type: ignore[assignment]
|
||||||
|
tool_results = [_make_function_tool_result(agent, "ignored")]
|
||||||
|
with pytest.raises(UserError):
|
||||||
|
await RunImpl._check_for_final_output_from_tools(
|
||||||
|
agent=agent,
|
||||||
|
tool_results=tool_results,
|
||||||
|
context_wrapper=RunContextWrapper(context=None),
|
||||||
|
config=RunConfig(),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_names_to_stop_at_behavior() -> None:
|
||||||
|
agent = Agent(
|
||||||
|
name="test",
|
||||||
|
tools=[
|
||||||
|
get_function_tool("tool1", return_value="tool1_output"),
|
||||||
|
get_function_tool("tool2", return_value="tool2_output"),
|
||||||
|
get_function_tool("tool3", return_value="tool3_output"),
|
||||||
|
],
|
||||||
|
tool_use_behavior={"stop_at_tool_names": ["tool1"]},
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_results = [
|
||||||
|
_make_function_tool_result(agent, "ignored1", "tool2"),
|
||||||
|
_make_function_tool_result(agent, "ignored3", "tool3"),
|
||||||
|
]
|
||||||
|
result = await RunImpl._check_for_final_output_from_tools(
|
||||||
|
agent=agent,
|
||||||
|
tool_results=tool_results,
|
||||||
|
context_wrapper=RunContextWrapper(context=None),
|
||||||
|
config=RunConfig(),
|
||||||
|
)
|
||||||
|
assert result.is_final_output is False, "We should not have stopped at tool1"
|
||||||
|
|
||||||
|
# Now test with a tool that matches the list
|
||||||
|
tool_results = [
|
||||||
|
_make_function_tool_result(agent, "output1", "tool1"),
|
||||||
|
_make_function_tool_result(agent, "ignored2", "tool2"),
|
||||||
|
_make_function_tool_result(agent, "ignored3", "tool3"),
|
||||||
|
]
|
||||||
|
result = await RunImpl._check_for_final_output_from_tools(
|
||||||
|
agent=agent,
|
||||||
|
tool_results=tool_results,
|
||||||
|
context_wrapper=RunContextWrapper(context=None),
|
||||||
|
config=RunConfig(),
|
||||||
|
)
|
||||||
|
assert result.is_final_output is True, "We should have stopped at tool1"
|
||||||
|
assert result.final_output == "output1"
|
||||||
Loading…
Reference in a new issue