236 lines
8.8 KiB
Python
236 lines
8.8 KiB
Python
from __future__ import annotations
|
|
|
|
import inspect
|
|
from collections.abc import Awaitable
|
|
from dataclasses import dataclass
|
|
from typing import TYPE_CHECKING, Any, Callable, Generic, cast, overload
|
|
|
|
from pydantic import TypeAdapter
|
|
from typing_extensions import TypeAlias, TypeVar
|
|
|
|
from .exceptions import ModelBehaviorError, UserError
|
|
from .items import RunItem, TResponseInputItem
|
|
from .run_context import RunContextWrapper, TContext
|
|
from .strict_schema import ensure_strict_json_schema
|
|
from .tracing.spans import SpanError
|
|
from .util import _error_tracing, _json, _transforms
|
|
|
|
if TYPE_CHECKING:
|
|
from .agent import Agent
|
|
|
|
|
|
# The handoff input type is the type of data passed when the agent is called via a handoff.
|
|
THandoffInput = TypeVar("THandoffInput", default=Any)
|
|
|
|
OnHandoffWithInput = Callable[[RunContextWrapper[Any], THandoffInput], Any]
|
|
OnHandoffWithoutInput = Callable[[RunContextWrapper[Any]], Any]
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class HandoffInputData:
|
|
input_history: str | tuple[TResponseInputItem, ...]
|
|
"""
|
|
The input history before `Runner.run()` was called.
|
|
"""
|
|
|
|
pre_handoff_items: tuple[RunItem, ...]
|
|
"""
|
|
The items generated before the agent turn where the handoff was invoked.
|
|
"""
|
|
|
|
new_items: tuple[RunItem, ...]
|
|
"""
|
|
The new items generated during the current agent turn, including the item that triggered the
|
|
handoff and the tool output message representing the response from the handoff output.
|
|
"""
|
|
|
|
|
|
HandoffInputFilter: TypeAlias = Callable[[HandoffInputData], HandoffInputData]
|
|
"""A function that filters the input data passed to the next agent."""
|
|
|
|
|
|
@dataclass
|
|
class Handoff(Generic[TContext]):
|
|
"""A handoff is when an agent delegates a task to another agent.
|
|
For example, in a customer support scenario you might have a "triage agent" that determines
|
|
which agent should handle the user's request, and sub-agents that specialize in different
|
|
areas like billing, account management, etc.
|
|
"""
|
|
|
|
tool_name: str
|
|
"""The name of the tool that represents the handoff."""
|
|
|
|
tool_description: str
|
|
"""The description of the tool that represents the handoff."""
|
|
|
|
input_json_schema: dict[str, Any]
|
|
"""The JSON schema for the handoff input. Can be empty if the handoff does not take an input.
|
|
"""
|
|
|
|
on_invoke_handoff: Callable[[RunContextWrapper[Any], str], Awaitable[Agent[TContext]]]
|
|
"""The function that invokes the handoff. The parameters passed are:
|
|
1. The handoff run context
|
|
2. The arguments from the LLM, as a JSON string. Empty string if input_json_schema is empty.
|
|
|
|
Must return an agent.
|
|
"""
|
|
|
|
agent_name: str
|
|
"""The name of the agent that is being handed off to."""
|
|
|
|
input_filter: HandoffInputFilter | None = None
|
|
"""A function that filters the inputs that are passed to the next agent. By default, the new
|
|
agent sees the entire conversation history. In some cases, you may want to filter inputs e.g.
|
|
to remove older inputs, or remove tools from existing inputs.
|
|
|
|
The function will receive the entire conversation history so far, including the input item
|
|
that triggered the handoff and a tool call output item representing the handoff tool's output.
|
|
|
|
You are free to modify the input history or new items as you see fit. The next agent that
|
|
runs will receive `handoff_input_data.all_items`.
|
|
|
|
IMPORTANT: in streaming mode, we will not stream anything as a result of this function. The
|
|
items generated before will already have been streamed.
|
|
"""
|
|
|
|
strict_json_schema: bool = True
|
|
"""Whether the input JSON schema is in strict mode. We **strongly** recommend setting this to
|
|
True, as it increases the likelihood of correct JSON input.
|
|
"""
|
|
|
|
def get_transfer_message(self, agent: Agent[Any]) -> str:
|
|
base = f"{{'assistant': '{agent.name}'}}"
|
|
return base
|
|
|
|
@classmethod
|
|
def default_tool_name(cls, agent: Agent[Any]) -> str:
|
|
return _transforms.transform_string_function_style(f"transfer_to_{agent.name}")
|
|
|
|
@classmethod
|
|
def default_tool_description(cls, agent: Agent[Any]) -> str:
|
|
return (
|
|
f"Handoff to the {agent.name} agent to handle the request. "
|
|
f"{agent.handoff_description or ''}"
|
|
)
|
|
|
|
|
|
@overload
|
|
def handoff(
|
|
agent: Agent[TContext],
|
|
*,
|
|
tool_name_override: str | None = None,
|
|
tool_description_override: str | None = None,
|
|
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
|
|
) -> Handoff[TContext]: ...
|
|
|
|
|
|
@overload
|
|
def handoff(
|
|
agent: Agent[TContext],
|
|
*,
|
|
on_handoff: OnHandoffWithInput[THandoffInput],
|
|
input_type: type[THandoffInput],
|
|
tool_description_override: str | None = None,
|
|
tool_name_override: str | None = None,
|
|
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
|
|
) -> Handoff[TContext]: ...
|
|
|
|
|
|
@overload
|
|
def handoff(
|
|
agent: Agent[TContext],
|
|
*,
|
|
on_handoff: OnHandoffWithoutInput,
|
|
tool_description_override: str | None = None,
|
|
tool_name_override: str | None = None,
|
|
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
|
|
) -> Handoff[TContext]: ...
|
|
|
|
|
|
def handoff(
|
|
agent: Agent[TContext],
|
|
tool_name_override: str | None = None,
|
|
tool_description_override: str | None = None,
|
|
on_handoff: OnHandoffWithInput[THandoffInput] | OnHandoffWithoutInput | None = None,
|
|
input_type: type[THandoffInput] | None = None,
|
|
input_filter: Callable[[HandoffInputData], HandoffInputData] | None = None,
|
|
) -> Handoff[TContext]:
|
|
"""Create a handoff from an agent.
|
|
|
|
Args:
|
|
agent: The agent to handoff to, or a function that returns an agent.
|
|
tool_name_override: Optional override for the name of the tool that represents the handoff.
|
|
tool_description_override: Optional override for the description of the tool that
|
|
represents the handoff.
|
|
on_handoff: A function that runs when the handoff is invoked.
|
|
input_type: the type of the input to the handoff. If provided, the input will be validated
|
|
against this type. Only relevant if you pass a function that takes an input.
|
|
input_filter: a function that filters the inputs that are passed to the next agent.
|
|
"""
|
|
assert (on_handoff and input_type) or not (on_handoff and input_type), (
|
|
"You must provide either both on_input and input_type, or neither"
|
|
)
|
|
type_adapter: TypeAdapter[Any] | None
|
|
if input_type is not None:
|
|
assert callable(on_handoff), "on_handoff must be callable"
|
|
sig = inspect.signature(on_handoff)
|
|
if len(sig.parameters) != 2:
|
|
raise UserError("on_handoff must take two arguments: context and input")
|
|
|
|
type_adapter = TypeAdapter(input_type)
|
|
input_json_schema = type_adapter.json_schema()
|
|
else:
|
|
type_adapter = None
|
|
input_json_schema = {}
|
|
if on_handoff is not None:
|
|
sig = inspect.signature(on_handoff)
|
|
if len(sig.parameters) != 1:
|
|
raise UserError("on_handoff must take one argument: context")
|
|
|
|
async def _invoke_handoff(
|
|
ctx: RunContextWrapper[Any], input_json: str | None = None
|
|
) -> Agent[Any]:
|
|
if input_type is not None and type_adapter is not None:
|
|
if input_json is None:
|
|
_error_tracing.attach_error_to_current_span(
|
|
SpanError(
|
|
message="Handoff function expected non-null input, but got None",
|
|
data={"details": "input_json is None"},
|
|
)
|
|
)
|
|
raise ModelBehaviorError("Handoff function expected non-null input, but got None")
|
|
|
|
validated_input = _json.validate_json(
|
|
json_str=input_json,
|
|
type_adapter=type_adapter,
|
|
partial=False,
|
|
)
|
|
input_func = cast(OnHandoffWithInput[THandoffInput], on_handoff)
|
|
if inspect.iscoroutinefunction(input_func):
|
|
await input_func(ctx, validated_input)
|
|
else:
|
|
input_func(ctx, validated_input)
|
|
elif on_handoff is not None:
|
|
no_input_func = cast(OnHandoffWithoutInput, on_handoff)
|
|
if inspect.iscoroutinefunction(no_input_func):
|
|
await no_input_func(ctx)
|
|
else:
|
|
no_input_func(ctx)
|
|
|
|
return agent
|
|
|
|
tool_name = tool_name_override or Handoff.default_tool_name(agent)
|
|
tool_description = tool_description_override or Handoff.default_tool_description(agent)
|
|
|
|
# Always ensure the input JSON schema is in strict mode
|
|
# If there is a need, we can make this configurable in the future
|
|
input_json_schema = ensure_strict_json_schema(input_json_schema)
|
|
|
|
return Handoff(
|
|
tool_name=tool_name,
|
|
tool_description=tool_description,
|
|
input_json_schema=input_json_schema,
|
|
on_invoke_handoff=_invoke_handoff,
|
|
input_filter=input_filter,
|
|
agent_name=agent.name,
|
|
)
|