### Summary: Allows a user to do `function_tool(is_enabled=<some_callable>)`; the callable is called when the agent runs. This allows you to dynamically enable/disable a tool based on the context/env. The meta-goal is to allow `Agent` to be effectively immutable. That enables some nice things down the line, and this allows you to dynamically modify the tools list without mutating the agent. ### Test Plan: Unit tests
435 lines
15 KiB
Python
435 lines
15 KiB
Python
from __future__ import annotations
|
|
|
|
import inspect
|
|
import json
|
|
from collections.abc import Awaitable
|
|
from dataclasses import dataclass
|
|
from typing import TYPE_CHECKING, Any, Callable, Literal, Union, overload
|
|
|
|
from openai.types.responses.file_search_tool_param import Filters, RankingOptions
|
|
from openai.types.responses.response_output_item import LocalShellCall, McpApprovalRequest
|
|
from openai.types.responses.tool_param import CodeInterpreter, ImageGeneration, Mcp
|
|
from openai.types.responses.web_search_tool_param import UserLocation
|
|
from pydantic import ValidationError
|
|
from typing_extensions import Concatenate, NotRequired, ParamSpec, TypedDict
|
|
|
|
from . import _debug
|
|
from .computer import AsyncComputer, Computer
|
|
from .exceptions import ModelBehaviorError
|
|
from .function_schema import DocstringStyle, function_schema
|
|
from .items import RunItem
|
|
from .logger import logger
|
|
from .run_context import RunContextWrapper
|
|
from .tracing import SpanError
|
|
from .util import _error_tracing
|
|
from .util._types import MaybeAwaitable
|
|
|
|
if TYPE_CHECKING:
|
|
from .agent import Agent
|
|
|
|
ToolParams = ParamSpec("ToolParams")
|
|
|
|
ToolFunctionWithoutContext = Callable[ToolParams, Any]
|
|
ToolFunctionWithContext = Callable[Concatenate[RunContextWrapper[Any], ToolParams], Any]
|
|
|
|
ToolFunction = Union[ToolFunctionWithoutContext[ToolParams], ToolFunctionWithContext[ToolParams]]
|
|
|
|
|
|
@dataclass
|
|
class FunctionToolResult:
|
|
tool: FunctionTool
|
|
"""The tool that was run."""
|
|
|
|
output: Any
|
|
"""The output of the tool."""
|
|
|
|
run_item: RunItem
|
|
"""The run item that was produced as a result of the tool call."""
|
|
|
|
|
|
@dataclass
|
|
class FunctionTool:
|
|
"""A tool that wraps a function. In most cases, you should use the `function_tool` helpers to
|
|
create a FunctionTool, as they let you easily wrap a Python function.
|
|
"""
|
|
|
|
name: str
|
|
"""The name of the tool, as shown to the LLM. Generally the name of the function."""
|
|
|
|
description: str
|
|
"""A description of the tool, as shown to the LLM."""
|
|
|
|
params_json_schema: dict[str, Any]
|
|
"""The JSON schema for the tool's parameters."""
|
|
|
|
on_invoke_tool: Callable[[RunContextWrapper[Any], str], Awaitable[Any]]
|
|
"""A function that invokes the tool with the given context and parameters. The params passed
|
|
are:
|
|
1. The tool run context.
|
|
2. The arguments from the LLM, as a JSON string.
|
|
|
|
You must return a string representation of the tool output, or something we can call `str()` on.
|
|
In case of errors, you can either raise an Exception (which will cause the run to fail) or
|
|
return a string error message (which will be sent back to the LLM).
|
|
"""
|
|
|
|
strict_json_schema: bool = True
|
|
"""Whether the JSON schema is in strict mode. We **strongly** recommend setting this to True,
|
|
as it increases the likelihood of correct JSON input."""
|
|
|
|
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True
|
|
"""Whether the tool is enabled. Either a bool or a Callable that takes the run context and agent
|
|
and returns whether the tool is enabled. You can use this to dynamically enable/disable a tool
|
|
based on your context/state."""
|
|
|
|
|
|
@dataclass
|
|
class FileSearchTool:
|
|
"""A hosted tool that lets the LLM search through a vector store. Currently only supported with
|
|
OpenAI models, using the Responses API.
|
|
"""
|
|
|
|
vector_store_ids: list[str]
|
|
"""The IDs of the vector stores to search."""
|
|
|
|
max_num_results: int | None = None
|
|
"""The maximum number of results to return."""
|
|
|
|
include_search_results: bool = False
|
|
"""Whether to include the search results in the output produced by the LLM."""
|
|
|
|
ranking_options: RankingOptions | None = None
|
|
"""Ranking options for search."""
|
|
|
|
filters: Filters | None = None
|
|
"""A filter to apply based on file attributes."""
|
|
|
|
@property
|
|
def name(self):
|
|
return "file_search"
|
|
|
|
|
|
@dataclass
|
|
class WebSearchTool:
|
|
"""A hosted tool that lets the LLM search the web. Currently only supported with OpenAI models,
|
|
using the Responses API.
|
|
"""
|
|
|
|
user_location: UserLocation | None = None
|
|
"""Optional location for the search. Lets you customize results to be relevant to a location."""
|
|
|
|
search_context_size: Literal["low", "medium", "high"] = "medium"
|
|
"""The amount of context to use for the search."""
|
|
|
|
@property
|
|
def name(self):
|
|
return "web_search_preview"
|
|
|
|
|
|
@dataclass
|
|
class ComputerTool:
|
|
"""A hosted tool that lets the LLM control a computer."""
|
|
|
|
computer: Computer | AsyncComputer
|
|
"""The computer implementation, which describes the environment and dimensions of the computer,
|
|
as well as implements the computer actions like click, screenshot, etc.
|
|
"""
|
|
|
|
@property
|
|
def name(self):
|
|
return "computer_use_preview"
|
|
|
|
|
|
@dataclass
|
|
class MCPToolApprovalRequest:
|
|
"""A request to approve a tool call."""
|
|
|
|
ctx_wrapper: RunContextWrapper[Any]
|
|
"""The run context."""
|
|
|
|
data: McpApprovalRequest
|
|
"""The data from the MCP tool approval request."""
|
|
|
|
|
|
class MCPToolApprovalFunctionResult(TypedDict):
|
|
"""The result of an MCP tool approval function."""
|
|
|
|
approve: bool
|
|
"""Whether to approve the tool call."""
|
|
|
|
reason: NotRequired[str]
|
|
"""An optional reason, if rejected."""
|
|
|
|
|
|
MCPToolApprovalFunction = Callable[
|
|
[MCPToolApprovalRequest], MaybeAwaitable[MCPToolApprovalFunctionResult]
|
|
]
|
|
"""A function that approves or rejects a tool call."""
|
|
|
|
|
|
@dataclass
|
|
class HostedMCPTool:
|
|
"""A tool that allows the LLM to use a remote MCP server. The LLM will automatically list and
|
|
call tools, without requiring a a round trip back to your code.
|
|
If you want to run MCP servers locally via stdio, in a VPC or other non-publicly-accessible
|
|
environment, or you just prefer to run tool calls locally, then you can instead use the servers
|
|
in `agents.mcp` and pass `Agent(mcp_servers=[...])` to the agent."""
|
|
|
|
tool_config: Mcp
|
|
"""The MCP tool config, which includes the server URL and other settings."""
|
|
|
|
on_approval_request: MCPToolApprovalFunction | None = None
|
|
"""An optional function that will be called if approval is requested for an MCP tool. If not
|
|
provided, you will need to manually add approvals/rejections to the input and call
|
|
`Runner.run(...)` again."""
|
|
|
|
@property
|
|
def name(self):
|
|
return "hosted_mcp"
|
|
|
|
|
|
@dataclass
|
|
class CodeInterpreterTool:
|
|
"""A tool that allows the LLM to execute code in a sandboxed environment."""
|
|
|
|
tool_config: CodeInterpreter
|
|
"""The tool config, which includes the container and other settings."""
|
|
|
|
@property
|
|
def name(self):
|
|
return "code_interpreter"
|
|
|
|
|
|
@dataclass
|
|
class ImageGenerationTool:
|
|
"""A tool that allows the LLM to generate images."""
|
|
|
|
tool_config: ImageGeneration
|
|
"""The tool config, which image generation settings."""
|
|
|
|
@property
|
|
def name(self):
|
|
return "image_generation"
|
|
|
|
|
|
@dataclass
|
|
class LocalShellCommandRequest:
|
|
"""A request to execute a command on a shell."""
|
|
|
|
ctx_wrapper: RunContextWrapper[Any]
|
|
"""The run context."""
|
|
|
|
data: LocalShellCall
|
|
"""The data from the local shell tool call."""
|
|
|
|
|
|
LocalShellExecutor = Callable[[LocalShellCommandRequest], MaybeAwaitable[str]]
|
|
"""A function that executes a command on a shell."""
|
|
|
|
|
|
@dataclass
|
|
class LocalShellTool:
|
|
"""A tool that allows the LLM to execute commands on a shell."""
|
|
|
|
executor: LocalShellExecutor
|
|
"""A function that executes a command on a shell."""
|
|
|
|
@property
|
|
def name(self):
|
|
return "local_shell"
|
|
|
|
|
|
Tool = Union[
|
|
FunctionTool,
|
|
FileSearchTool,
|
|
WebSearchTool,
|
|
ComputerTool,
|
|
HostedMCPTool,
|
|
LocalShellTool,
|
|
ImageGenerationTool,
|
|
CodeInterpreterTool,
|
|
]
|
|
"""A tool that can be used in an agent."""
|
|
|
|
|
|
def default_tool_error_function(ctx: RunContextWrapper[Any], error: Exception) -> str:
|
|
"""The default tool error function, which just returns a generic error message."""
|
|
return f"An error occurred while running the tool. Please try again. Error: {str(error)}"
|
|
|
|
|
|
ToolErrorFunction = Callable[[RunContextWrapper[Any], Exception], MaybeAwaitable[str]]
|
|
|
|
|
|
@overload
|
|
def function_tool(
|
|
func: ToolFunction[...],
|
|
*,
|
|
name_override: str | None = None,
|
|
description_override: str | None = None,
|
|
docstring_style: DocstringStyle | None = None,
|
|
use_docstring_info: bool = True,
|
|
failure_error_function: ToolErrorFunction | None = None,
|
|
strict_mode: bool = True,
|
|
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
|
|
) -> FunctionTool:
|
|
"""Overload for usage as @function_tool (no parentheses)."""
|
|
...
|
|
|
|
|
|
@overload
|
|
def function_tool(
|
|
*,
|
|
name_override: str | None = None,
|
|
description_override: str | None = None,
|
|
docstring_style: DocstringStyle | None = None,
|
|
use_docstring_info: bool = True,
|
|
failure_error_function: ToolErrorFunction | None = None,
|
|
strict_mode: bool = True,
|
|
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
|
|
) -> Callable[[ToolFunction[...]], FunctionTool]:
|
|
"""Overload for usage as @function_tool(...)."""
|
|
...
|
|
|
|
|
|
def function_tool(
|
|
func: ToolFunction[...] | None = None,
|
|
*,
|
|
name_override: str | None = None,
|
|
description_override: str | None = None,
|
|
docstring_style: DocstringStyle | None = None,
|
|
use_docstring_info: bool = True,
|
|
failure_error_function: ToolErrorFunction | None = default_tool_error_function,
|
|
strict_mode: bool = True,
|
|
is_enabled: bool | Callable[[RunContextWrapper[Any], Agent[Any]], MaybeAwaitable[bool]] = True,
|
|
) -> FunctionTool | Callable[[ToolFunction[...]], FunctionTool]:
|
|
"""
|
|
Decorator to create a FunctionTool from a function. By default, we will:
|
|
1. Parse the function signature to create a JSON schema for the tool's parameters.
|
|
2. Use the function's docstring to populate the tool's description.
|
|
3. Use the function's docstring to populate argument descriptions.
|
|
The docstring style is detected automatically, but you can override it.
|
|
|
|
If the function takes a `RunContextWrapper` as the first argument, it *must* match the
|
|
context type of the agent that uses the tool.
|
|
|
|
Args:
|
|
func: The function to wrap.
|
|
name_override: If provided, use this name for the tool instead of the function's name.
|
|
description_override: If provided, use this description for the tool instead of the
|
|
function's docstring.
|
|
docstring_style: If provided, use this style for the tool's docstring. If not provided,
|
|
we will attempt to auto-detect the style.
|
|
use_docstring_info: If True, use the function's docstring to populate the tool's
|
|
description and argument descriptions.
|
|
failure_error_function: If provided, use this function to generate an error message when
|
|
the tool call fails. The error message is sent to the LLM. If you pass None, then no
|
|
error message will be sent and instead an Exception will be raised.
|
|
strict_mode: Whether to enable strict mode for the tool's JSON schema. We *strongly*
|
|
recommend setting this to True, as it increases the likelihood of correct JSON input.
|
|
If False, it allows non-strict JSON schemas. For example, if a parameter has a default
|
|
value, it will be optional, additional properties are allowed, etc. See here for more:
|
|
https://platform.openai.com/docs/guides/structured-outputs?api-mode=responses#supported-schemas
|
|
is_enabled: Whether the tool is enabled. Can be a bool or a callable that takes the run
|
|
context and agent and returns whether the tool is enabled. Disabled tools are hidden
|
|
from the LLM at runtime.
|
|
"""
|
|
|
|
def _create_function_tool(the_func: ToolFunction[...]) -> FunctionTool:
|
|
schema = function_schema(
|
|
func=the_func,
|
|
name_override=name_override,
|
|
description_override=description_override,
|
|
docstring_style=docstring_style,
|
|
use_docstring_info=use_docstring_info,
|
|
strict_json_schema=strict_mode,
|
|
)
|
|
|
|
async def _on_invoke_tool_impl(ctx: RunContextWrapper[Any], input: str) -> Any:
|
|
try:
|
|
json_data: dict[str, Any] = json.loads(input) if input else {}
|
|
except Exception as e:
|
|
if _debug.DONT_LOG_TOOL_DATA:
|
|
logger.debug(f"Invalid JSON input for tool {schema.name}")
|
|
else:
|
|
logger.debug(f"Invalid JSON input for tool {schema.name}: {input}")
|
|
raise ModelBehaviorError(
|
|
f"Invalid JSON input for tool {schema.name}: {input}"
|
|
) from e
|
|
|
|
if _debug.DONT_LOG_TOOL_DATA:
|
|
logger.debug(f"Invoking tool {schema.name}")
|
|
else:
|
|
logger.debug(f"Invoking tool {schema.name} with input {input}")
|
|
|
|
try:
|
|
parsed = (
|
|
schema.params_pydantic_model(**json_data)
|
|
if json_data
|
|
else schema.params_pydantic_model()
|
|
)
|
|
except ValidationError as e:
|
|
raise ModelBehaviorError(f"Invalid JSON input for tool {schema.name}: {e}") from e
|
|
|
|
args, kwargs_dict = schema.to_call_args(parsed)
|
|
|
|
if not _debug.DONT_LOG_TOOL_DATA:
|
|
logger.debug(f"Tool call args: {args}, kwargs: {kwargs_dict}")
|
|
|
|
if inspect.iscoroutinefunction(the_func):
|
|
if schema.takes_context:
|
|
result = await the_func(ctx, *args, **kwargs_dict)
|
|
else:
|
|
result = await the_func(*args, **kwargs_dict)
|
|
else:
|
|
if schema.takes_context:
|
|
result = the_func(ctx, *args, **kwargs_dict)
|
|
else:
|
|
result = the_func(*args, **kwargs_dict)
|
|
|
|
if _debug.DONT_LOG_TOOL_DATA:
|
|
logger.debug(f"Tool {schema.name} completed.")
|
|
else:
|
|
logger.debug(f"Tool {schema.name} returned {result}")
|
|
|
|
return result
|
|
|
|
async def _on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> Any:
|
|
try:
|
|
return await _on_invoke_tool_impl(ctx, input)
|
|
except Exception as e:
|
|
if failure_error_function is None:
|
|
raise
|
|
|
|
result = failure_error_function(ctx, e)
|
|
if inspect.isawaitable(result):
|
|
return await result
|
|
|
|
_error_tracing.attach_error_to_current_span(
|
|
SpanError(
|
|
message="Error running tool (non-fatal)",
|
|
data={
|
|
"tool_name": schema.name,
|
|
"error": str(e),
|
|
},
|
|
)
|
|
)
|
|
return result
|
|
|
|
return FunctionTool(
|
|
name=schema.name,
|
|
description=schema.description or "",
|
|
params_json_schema=schema.params_json_schema,
|
|
on_invoke_tool=_on_invoke_tool,
|
|
strict_json_schema=strict_mode,
|
|
is_enabled=is_enabled,
|
|
)
|
|
|
|
# If func is actually a callable, we were used as @function_tool with no parentheses
|
|
if callable(func):
|
|
return _create_function_tool(func)
|
|
|
|
# Otherwise, we were used as @function_tool(...), so return a decorator
|
|
def decorator(real_func: ToolFunction[...]) -> FunctionTool:
|
|
return _create_function_tool(real_func)
|
|
|
|
return decorator
|