Added support for passing tool_call_id via the RunContextWrapper (#766)
This PR fixes issue: https://github.com/openai/openai-agents-python/issues/559 By adding the tool_call_id to the RunContextWrapper prior to calling tools. This gives the ability to access the tool_call_id in the implementation of the tool.
This commit is contained in:
parent
dcb88e69cd
commit
8dfd6ff35c
8 changed files with 115 additions and 35 deletions
|
|
@ -75,6 +75,7 @@ from .tool import (
|
|||
MCPToolApprovalRequest,
|
||||
Tool,
|
||||
)
|
||||
from .tool_context import ToolContext
|
||||
from .tracing import (
|
||||
SpanError,
|
||||
Trace,
|
||||
|
|
@ -543,23 +544,24 @@ class RunImpl:
|
|||
func_tool: FunctionTool, tool_call: ResponseFunctionToolCall
|
||||
) -> Any:
|
||||
with function_span(func_tool.name) as span_fn:
|
||||
tool_context = ToolContext.from_agent_context(context_wrapper, tool_call.call_id)
|
||||
if config.trace_include_sensitive_data:
|
||||
span_fn.span_data.input = tool_call.arguments
|
||||
try:
|
||||
_, _, result = await asyncio.gather(
|
||||
hooks.on_tool_start(context_wrapper, agent, func_tool),
|
||||
hooks.on_tool_start(tool_context, agent, func_tool),
|
||||
(
|
||||
agent.hooks.on_tool_start(context_wrapper, agent, func_tool)
|
||||
agent.hooks.on_tool_start(tool_context, agent, func_tool)
|
||||
if agent.hooks
|
||||
else _coro.noop_coroutine()
|
||||
),
|
||||
func_tool.on_invoke_tool(context_wrapper, tool_call.arguments),
|
||||
func_tool.on_invoke_tool(tool_context, tool_call.arguments),
|
||||
)
|
||||
|
||||
await asyncio.gather(
|
||||
hooks.on_tool_end(context_wrapper, agent, func_tool, result),
|
||||
hooks.on_tool_end(tool_context, agent, func_tool, result),
|
||||
(
|
||||
agent.hooks.on_tool_end(context_wrapper, agent, func_tool, result)
|
||||
agent.hooks.on_tool_end(tool_context, agent, func_tool, result)
|
||||
if agent.hooks
|
||||
else _coro.noop_coroutine()
|
||||
),
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from pydantic import BaseModel, Field, create_model
|
|||
from .exceptions import UserError
|
||||
from .run_context import RunContextWrapper
|
||||
from .strict_schema import ensure_strict_json_schema
|
||||
from .tool_context import ToolContext
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -237,21 +238,21 @@ def function_schema(
|
|||
ann = type_hints.get(first_name, first_param.annotation)
|
||||
if ann != inspect._empty:
|
||||
origin = get_origin(ann) or ann
|
||||
if origin is RunContextWrapper:
|
||||
if origin is RunContextWrapper or origin is ToolContext:
|
||||
takes_context = True # Mark that the function takes context
|
||||
else:
|
||||
filtered_params.append((first_name, first_param))
|
||||
else:
|
||||
filtered_params.append((first_name, first_param))
|
||||
|
||||
# For parameters other than the first, raise error if any use RunContextWrapper.
|
||||
# For parameters other than the first, raise error if any use RunContextWrapper or ToolContext.
|
||||
for name, param in params[1:]:
|
||||
ann = type_hints.get(name, param.annotation)
|
||||
if ann != inspect._empty:
|
||||
origin = get_origin(ann) or ann
|
||||
if origin is RunContextWrapper:
|
||||
if origin is RunContextWrapper or origin is ToolContext:
|
||||
raise UserError(
|
||||
f"RunContextWrapper param found at non-first position in function"
|
||||
f"RunContextWrapper/ToolContext param found at non-first position in function"
|
||||
f" {func.__name__}"
|
||||
)
|
||||
filtered_params.append((name, param))
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ from .function_schema import DocstringStyle, function_schema
|
|||
from .items import RunItem
|
||||
from .logger import logger
|
||||
from .run_context import RunContextWrapper
|
||||
from .tool_context import ToolContext
|
||||
from .tracing import SpanError
|
||||
from .util import _error_tracing
|
||||
from .util._types import MaybeAwaitable
|
||||
|
|
@ -31,8 +32,13 @@ ToolParams = ParamSpec("ToolParams")
|
|||
|
||||
ToolFunctionWithoutContext = Callable[ToolParams, Any]
|
||||
ToolFunctionWithContext = Callable[Concatenate[RunContextWrapper[Any], ToolParams], Any]
|
||||
ToolFunctionWithToolContext = Callable[Concatenate[ToolContext, ToolParams], Any]
|
||||
|
||||
ToolFunction = Union[ToolFunctionWithoutContext[ToolParams], ToolFunctionWithContext[ToolParams]]
|
||||
ToolFunction = Union[
|
||||
ToolFunctionWithoutContext[ToolParams],
|
||||
ToolFunctionWithContext[ToolParams],
|
||||
ToolFunctionWithToolContext[ToolParams],
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -62,7 +68,7 @@ class FunctionTool:
|
|||
params_json_schema: dict[str, Any]
|
||||
"""The JSON schema for the tool's parameters."""
|
||||
|
||||
on_invoke_tool: Callable[[RunContextWrapper[Any], str], Awaitable[Any]]
|
||||
on_invoke_tool: Callable[[ToolContext[Any], str], Awaitable[Any]]
|
||||
"""A function that invokes the tool with the given context and parameters. The params passed
|
||||
are:
|
||||
1. The tool run context.
|
||||
|
|
@ -344,7 +350,7 @@ def function_tool(
|
|||
strict_json_schema=strict_mode,
|
||||
)
|
||||
|
||||
async def _on_invoke_tool_impl(ctx: RunContextWrapper[Any], input: str) -> Any:
|
||||
async def _on_invoke_tool_impl(ctx: ToolContext[Any], input: str) -> Any:
|
||||
try:
|
||||
json_data: dict[str, Any] = json.loads(input) if input else {}
|
||||
except Exception as e:
|
||||
|
|
@ -393,7 +399,7 @@ def function_tool(
|
|||
|
||||
return result
|
||||
|
||||
async def _on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> Any:
|
||||
async def _on_invoke_tool(ctx: ToolContext[Any], input: str) -> Any:
|
||||
try:
|
||||
return await _on_invoke_tool_impl(ctx, input)
|
||||
except Exception as e:
|
||||
|
|
|
|||
28
src/agents/tool_context.py
Normal file
28
src/agents/tool_context.py
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
from dataclasses import dataclass, field, fields
|
||||
from typing import Any
|
||||
|
||||
from .run_context import RunContextWrapper, TContext
|
||||
|
||||
|
||||
def _assert_must_pass_tool_call_id() -> str:
|
||||
raise ValueError("tool_call_id must be passed to ToolContext")
|
||||
|
||||
@dataclass
|
||||
class ToolContext(RunContextWrapper[TContext]):
|
||||
"""The context of a tool call."""
|
||||
|
||||
tool_call_id: str = field(default_factory=_assert_must_pass_tool_call_id)
|
||||
"""The ID of the tool call."""
|
||||
|
||||
@classmethod
|
||||
def from_agent_context(
|
||||
cls, context: RunContextWrapper[TContext], tool_call_id: str
|
||||
) -> "ToolContext":
|
||||
"""
|
||||
Create a ToolContext from a RunContextWrapper.
|
||||
"""
|
||||
# Grab the names of the RunContextWrapper's init=True fields
|
||||
base_values: dict[str, Any] = {
|
||||
f.name: getattr(context, f.name) for f in fields(RunContextWrapper) if f.init
|
||||
}
|
||||
return cls(tool_call_id=tool_call_id, **base_values)
|
||||
|
|
@ -7,6 +7,7 @@ from typing_extensions import TypedDict
|
|||
|
||||
from agents import Agent, FunctionTool, ModelBehaviorError, RunContextWrapper, function_tool
|
||||
from agents.tool import default_tool_error_function
|
||||
from agents.tool_context import ToolContext
|
||||
|
||||
|
||||
def argless_function() -> str:
|
||||
|
|
@ -18,11 +19,11 @@ async def test_argless_function():
|
|||
tool = function_tool(argless_function)
|
||||
assert tool.name == "argless_function"
|
||||
|
||||
result = await tool.on_invoke_tool(RunContextWrapper(None), "")
|
||||
result = await tool.on_invoke_tool(ToolContext(context=None, tool_call_id="1"), "")
|
||||
assert result == "ok"
|
||||
|
||||
|
||||
def argless_with_context(ctx: RunContextWrapper[str]) -> str:
|
||||
def argless_with_context(ctx: ToolContext[str]) -> str:
|
||||
return "ok"
|
||||
|
||||
|
||||
|
|
@ -31,11 +32,11 @@ async def test_argless_with_context():
|
|||
tool = function_tool(argless_with_context)
|
||||
assert tool.name == "argless_with_context"
|
||||
|
||||
result = await tool.on_invoke_tool(RunContextWrapper(None), "")
|
||||
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), "")
|
||||
assert result == "ok"
|
||||
|
||||
# Extra JSON should not raise an error
|
||||
result = await tool.on_invoke_tool(RunContextWrapper(None), '{"a": 1}')
|
||||
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"a": 1}')
|
||||
assert result == "ok"
|
||||
|
||||
|
||||
|
|
@ -48,15 +49,15 @@ async def test_simple_function():
|
|||
tool = function_tool(simple_function, failure_error_function=None)
|
||||
assert tool.name == "simple_function"
|
||||
|
||||
result = await tool.on_invoke_tool(RunContextWrapper(None), '{"a": 1}')
|
||||
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"a": 1}')
|
||||
assert result == 6
|
||||
|
||||
result = await tool.on_invoke_tool(RunContextWrapper(None), '{"a": 1, "b": 2}')
|
||||
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"a": 1, "b": 2}')
|
||||
assert result == 3
|
||||
|
||||
# Missing required argument should raise an error
|
||||
with pytest.raises(ModelBehaviorError):
|
||||
await tool.on_invoke_tool(RunContextWrapper(None), "")
|
||||
await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), "")
|
||||
|
||||
|
||||
class Foo(BaseModel):
|
||||
|
|
@ -84,7 +85,7 @@ async def test_complex_args_function():
|
|||
"bar": Bar(x="hello", y=10),
|
||||
}
|
||||
)
|
||||
result = await tool.on_invoke_tool(RunContextWrapper(None), valid_json)
|
||||
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), valid_json)
|
||||
assert result == "6 hello10 hello"
|
||||
|
||||
valid_json = json.dumps(
|
||||
|
|
@ -93,7 +94,7 @@ async def test_complex_args_function():
|
|||
"bar": Bar(x="hello", y=10),
|
||||
}
|
||||
)
|
||||
result = await tool.on_invoke_tool(RunContextWrapper(None), valid_json)
|
||||
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), valid_json)
|
||||
assert result == "3 hello10 hello"
|
||||
|
||||
valid_json = json.dumps(
|
||||
|
|
@ -103,12 +104,12 @@ async def test_complex_args_function():
|
|||
"baz": "world",
|
||||
}
|
||||
)
|
||||
result = await tool.on_invoke_tool(RunContextWrapper(None), valid_json)
|
||||
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), valid_json)
|
||||
assert result == "3 hello10 world"
|
||||
|
||||
# Missing required argument should raise an error
|
||||
with pytest.raises(ModelBehaviorError):
|
||||
await tool.on_invoke_tool(RunContextWrapper(None), '{"foo": {"a": 1}}')
|
||||
await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"foo": {"a": 1}}')
|
||||
|
||||
|
||||
def test_function_config_overrides():
|
||||
|
|
@ -168,7 +169,7 @@ async def test_manual_function_tool_creation_works():
|
|||
assert tool.params_json_schema[key] == value
|
||||
assert tool.strict_json_schema
|
||||
|
||||
result = await tool.on_invoke_tool(RunContextWrapper(None), '{"data": "hello"}')
|
||||
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"data": "hello"}')
|
||||
assert result == "hello_done"
|
||||
|
||||
tool_not_strict = FunctionTool(
|
||||
|
|
@ -183,7 +184,7 @@ async def test_manual_function_tool_creation_works():
|
|||
assert "additionalProperties" not in tool_not_strict.params_json_schema
|
||||
|
||||
result = await tool_not_strict.on_invoke_tool(
|
||||
RunContextWrapper(None), '{"data": "hello", "bar": "baz"}'
|
||||
ToolContext(None, tool_call_id="1"), '{"data": "hello", "bar": "baz"}'
|
||||
)
|
||||
assert result == "hello_done"
|
||||
|
||||
|
|
@ -194,7 +195,7 @@ async def test_function_tool_default_error_works():
|
|||
raise ValueError("test")
|
||||
|
||||
tool = function_tool(my_func)
|
||||
ctx = RunContextWrapper(None)
|
||||
ctx = ToolContext(None, tool_call_id="1")
|
||||
|
||||
result = await tool.on_invoke_tool(ctx, "")
|
||||
assert "Invalid JSON" in str(result)
|
||||
|
|
@ -218,7 +219,7 @@ async def test_sync_custom_error_function_works():
|
|||
return f"error_{error.__class__.__name__}"
|
||||
|
||||
tool = function_tool(my_func, failure_error_function=custom_sync_error_function)
|
||||
ctx = RunContextWrapper(None)
|
||||
ctx = ToolContext(None, tool_call_id="1")
|
||||
|
||||
result = await tool.on_invoke_tool(ctx, "")
|
||||
assert result == "error_ModelBehaviorError"
|
||||
|
|
@ -242,7 +243,7 @@ async def test_async_custom_error_function_works():
|
|||
return f"error_{error.__class__.__name__}"
|
||||
|
||||
tool = function_tool(my_func, failure_error_function=custom_sync_error_function)
|
||||
ctx = RunContextWrapper(None)
|
||||
ctx = ToolContext(None, tool_call_id="1")
|
||||
|
||||
result = await tool.on_invoke_tool(ctx, "")
|
||||
assert result == "error_ModelBehaviorError"
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from inline_snapshot import snapshot
|
|||
|
||||
from agents import function_tool
|
||||
from agents.run_context import RunContextWrapper
|
||||
from agents.tool_context import ToolContext
|
||||
|
||||
|
||||
class DummyContext:
|
||||
|
|
@ -14,8 +15,8 @@ class DummyContext:
|
|||
self.data = "something"
|
||||
|
||||
|
||||
def ctx_wrapper() -> RunContextWrapper[DummyContext]:
|
||||
return RunContextWrapper(DummyContext())
|
||||
def ctx_wrapper() -> ToolContext[DummyContext]:
|
||||
return ToolContext(context=DummyContext(), tool_call_id="1")
|
||||
|
||||
|
||||
@function_tool
|
||||
|
|
@ -44,7 +45,7 @@ async def test_sync_no_context_with_args_invocation():
|
|||
|
||||
|
||||
@function_tool
|
||||
def sync_with_context(ctx: RunContextWrapper[DummyContext], name: str) -> str:
|
||||
def sync_with_context(ctx: ToolContext[DummyContext], name: str) -> str:
|
||||
return f"{name}_{ctx.context.data}"
|
||||
|
||||
|
||||
|
|
@ -71,7 +72,7 @@ async def test_async_no_context_invocation():
|
|||
|
||||
|
||||
@function_tool
|
||||
async def async_with_context(ctx: RunContextWrapper[DummyContext], prefix: str, num: int) -> str:
|
||||
async def async_with_context(ctx: ToolContext[DummyContext], prefix: str, num: int) -> str:
|
||||
await asyncio.sleep(0)
|
||||
return f"{prefix}-{num}-{ctx.context.data}"
|
||||
|
||||
|
|
|
|||
|
|
@ -49,10 +49,12 @@ def get_function_tool(
|
|||
)
|
||||
|
||||
|
||||
def get_function_tool_call(name: str, arguments: str | None = None) -> ResponseOutputItem:
|
||||
def get_function_tool_call(
|
||||
name: str, arguments: str | None = None, call_id: str | None = None
|
||||
) -> ResponseOutputItem:
|
||||
return ResponseFunctionToolCall(
|
||||
id="1",
|
||||
call_id="2",
|
||||
call_id=call_id or "2",
|
||||
type="function_call",
|
||||
name=name,
|
||||
arguments=arguments or "",
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
|
@ -26,6 +27,8 @@ from agents._run_impl import (
|
|||
RunImpl,
|
||||
SingleStepResult,
|
||||
)
|
||||
from agents.tool import function_tool
|
||||
from agents.tool_context import ToolContext
|
||||
|
||||
from .test_responses import (
|
||||
get_final_output_message,
|
||||
|
|
@ -158,6 +161,42 @@ async def test_multiple_tool_calls():
|
|||
assert isinstance(result.next_step, NextStepRunAgain)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_tool_calls_with_tool_context():
|
||||
async def _fake_tool(context: ToolContext[str], value: str) -> str:
|
||||
return f"{value}-{context.tool_call_id}"
|
||||
|
||||
tool = function_tool(_fake_tool, name_override="fake_tool", failure_error_function=None)
|
||||
|
||||
agent = Agent(
|
||||
name="test",
|
||||
tools=[tool],
|
||||
)
|
||||
response = ModelResponse(
|
||||
output=[
|
||||
get_function_tool_call("fake_tool", json.dumps({"value": "123"}), call_id="1"),
|
||||
get_function_tool_call("fake_tool", json.dumps({"value": "456"}), call_id="2"),
|
||||
],
|
||||
usage=Usage(),
|
||||
response_id=None,
|
||||
)
|
||||
|
||||
result = await get_execute_result(agent, response)
|
||||
assert result.original_input == "hello"
|
||||
|
||||
# 4 items: new message, 2 tool calls, 2 tool call outputs
|
||||
assert len(result.generated_items) == 4
|
||||
assert isinstance(result.next_step, NextStepRunAgain)
|
||||
|
||||
items = result.generated_items
|
||||
assert_item_is_function_tool_call(items[0], "fake_tool", json.dumps({"value": "123"}))
|
||||
assert_item_is_function_tool_call(items[1], "fake_tool", json.dumps({"value": "456"}))
|
||||
assert_item_is_function_tool_call_output(items[2], "123-1")
|
||||
assert_item_is_function_tool_call_output(items[3], "456-2")
|
||||
|
||||
assert isinstance(result.next_step, NextStepRunAgain)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handoff_output_leads_to_handoff_next_step():
|
||||
agent_1 = Agent(name="test_1")
|
||||
|
|
|
|||
Loading…
Reference in a new issue