29 lines
941 B
Python
29 lines
941 B
Python
from dataclasses import dataclass, field, fields
|
|
from typing import Any
|
|
|
|
from .run_context import RunContextWrapper, TContext
|
|
|
|
|
|
def _assert_must_pass_tool_call_id() -> str:
|
|
raise ValueError("tool_call_id must be passed to ToolContext")
|
|
|
|
|
|
@dataclass
|
|
class ToolContext(RunContextWrapper[TContext]):
|
|
"""The context of a tool call."""
|
|
|
|
tool_call_id: str = field(default_factory=_assert_must_pass_tool_call_id)
|
|
"""The ID of the tool call."""
|
|
|
|
@classmethod
|
|
def from_agent_context(
|
|
cls, context: RunContextWrapper[TContext], tool_call_id: str
|
|
) -> "ToolContext":
|
|
"""
|
|
Create a ToolContext from a RunContextWrapper.
|
|
"""
|
|
# Grab the names of the RunContextWrapper's init=True fields
|
|
base_values: dict[str, Any] = {
|
|
f.name: getattr(context, f.name) for f in fields(RunContextWrapper) if f.init
|
|
}
|
|
return cls(tool_call_id=tool_call_id, **base_values)
|