Added support for passing tool_call_id via the RunContextWrapper (#766)

This PR fixes issue:
https://github.com/openai/openai-agents-python/issues/559

By adding the tool_call_id to the RunContextWrapper prior to calling
tools. This gives the ability to access the tool_call_id in the
implementation of the tool.
This commit is contained in:
Niv Hertz 2025-06-09 18:08:50 +03:00 committed by GitHub
parent dcb88e69cd
commit 8dfd6ff35c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 115 additions and 35 deletions

View file

@ -75,6 +75,7 @@ from .tool import (
MCPToolApprovalRequest,
Tool,
)
from .tool_context import ToolContext
from .tracing import (
SpanError,
Trace,
@ -543,23 +544,24 @@ class RunImpl:
func_tool: FunctionTool, tool_call: ResponseFunctionToolCall
) -> Any:
with function_span(func_tool.name) as span_fn:
tool_context = ToolContext.from_agent_context(context_wrapper, tool_call.call_id)
if config.trace_include_sensitive_data:
span_fn.span_data.input = tool_call.arguments
try:
_, _, result = await asyncio.gather(
hooks.on_tool_start(context_wrapper, agent, func_tool),
hooks.on_tool_start(tool_context, agent, func_tool),
(
agent.hooks.on_tool_start(context_wrapper, agent, func_tool)
agent.hooks.on_tool_start(tool_context, agent, func_tool)
if agent.hooks
else _coro.noop_coroutine()
),
func_tool.on_invoke_tool(context_wrapper, tool_call.arguments),
func_tool.on_invoke_tool(tool_context, tool_call.arguments),
)
await asyncio.gather(
hooks.on_tool_end(context_wrapper, agent, func_tool, result),
hooks.on_tool_end(tool_context, agent, func_tool, result),
(
agent.hooks.on_tool_end(context_wrapper, agent, func_tool, result)
agent.hooks.on_tool_end(tool_context, agent, func_tool, result)
if agent.hooks
else _coro.noop_coroutine()
),

View file

@ -13,6 +13,7 @@ from pydantic import BaseModel, Field, create_model
from .exceptions import UserError
from .run_context import RunContextWrapper
from .strict_schema import ensure_strict_json_schema
from .tool_context import ToolContext
@dataclass
@ -237,21 +238,21 @@ def function_schema(
ann = type_hints.get(first_name, first_param.annotation)
if ann != inspect._empty:
origin = get_origin(ann) or ann
if origin is RunContextWrapper:
if origin is RunContextWrapper or origin is ToolContext:
takes_context = True # Mark that the function takes context
else:
filtered_params.append((first_name, first_param))
else:
filtered_params.append((first_name, first_param))
# For parameters other than the first, raise error if any use RunContextWrapper.
# For parameters other than the first, raise error if any use RunContextWrapper or ToolContext.
for name, param in params[1:]:
ann = type_hints.get(name, param.annotation)
if ann != inspect._empty:
origin = get_origin(ann) or ann
if origin is RunContextWrapper:
if origin is RunContextWrapper or origin is ToolContext:
raise UserError(
f"RunContextWrapper param found at non-first position in function"
f"RunContextWrapper/ToolContext param found at non-first position in function"
f" {func.__name__}"
)
filtered_params.append((name, param))

View file

@ -20,6 +20,7 @@ from .function_schema import DocstringStyle, function_schema
from .items import RunItem
from .logger import logger
from .run_context import RunContextWrapper
from .tool_context import ToolContext
from .tracing import SpanError
from .util import _error_tracing
from .util._types import MaybeAwaitable
@ -31,8 +32,13 @@ ToolParams = ParamSpec("ToolParams")
ToolFunctionWithoutContext = Callable[ToolParams, Any]
ToolFunctionWithContext = Callable[Concatenate[RunContextWrapper[Any], ToolParams], Any]
ToolFunctionWithToolContext = Callable[Concatenate[ToolContext, ToolParams], Any]
ToolFunction = Union[ToolFunctionWithoutContext[ToolParams], ToolFunctionWithContext[ToolParams]]
ToolFunction = Union[
ToolFunctionWithoutContext[ToolParams],
ToolFunctionWithContext[ToolParams],
ToolFunctionWithToolContext[ToolParams],
]
@dataclass
@ -62,7 +68,7 @@ class FunctionTool:
params_json_schema: dict[str, Any]
"""The JSON schema for the tool's parameters."""
on_invoke_tool: Callable[[RunContextWrapper[Any], str], Awaitable[Any]]
on_invoke_tool: Callable[[ToolContext[Any], str], Awaitable[Any]]
"""A function that invokes the tool with the given context and parameters. The params passed
are:
1. The tool run context.
@ -344,7 +350,7 @@ def function_tool(
strict_json_schema=strict_mode,
)
async def _on_invoke_tool_impl(ctx: RunContextWrapper[Any], input: str) -> Any:
async def _on_invoke_tool_impl(ctx: ToolContext[Any], input: str) -> Any:
try:
json_data: dict[str, Any] = json.loads(input) if input else {}
except Exception as e:
@ -393,7 +399,7 @@ def function_tool(
return result
async def _on_invoke_tool(ctx: RunContextWrapper[Any], input: str) -> Any:
async def _on_invoke_tool(ctx: ToolContext[Any], input: str) -> Any:
try:
return await _on_invoke_tool_impl(ctx, input)
except Exception as e:

View file

@ -0,0 +1,28 @@
from dataclasses import dataclass, field, fields
from typing import Any
from .run_context import RunContextWrapper, TContext
def _assert_must_pass_tool_call_id() -> str:
raise ValueError("tool_call_id must be passed to ToolContext")
@dataclass
class ToolContext(RunContextWrapper[TContext]):
"""The context of a tool call."""
tool_call_id: str = field(default_factory=_assert_must_pass_tool_call_id)
"""The ID of the tool call."""
@classmethod
def from_agent_context(
cls, context: RunContextWrapper[TContext], tool_call_id: str
) -> "ToolContext":
"""
Create a ToolContext from a RunContextWrapper.
"""
# Grab the names of the RunContextWrapper's init=True fields
base_values: dict[str, Any] = {
f.name: getattr(context, f.name) for f in fields(RunContextWrapper) if f.init
}
return cls(tool_call_id=tool_call_id, **base_values)

View file

@ -7,6 +7,7 @@ from typing_extensions import TypedDict
from agents import Agent, FunctionTool, ModelBehaviorError, RunContextWrapper, function_tool
from agents.tool import default_tool_error_function
from agents.tool_context import ToolContext
def argless_function() -> str:
@ -18,11 +19,11 @@ async def test_argless_function():
tool = function_tool(argless_function)
assert tool.name == "argless_function"
result = await tool.on_invoke_tool(RunContextWrapper(None), "")
result = await tool.on_invoke_tool(ToolContext(context=None, tool_call_id="1"), "")
assert result == "ok"
def argless_with_context(ctx: RunContextWrapper[str]) -> str:
def argless_with_context(ctx: ToolContext[str]) -> str:
return "ok"
@ -31,11 +32,11 @@ async def test_argless_with_context():
tool = function_tool(argless_with_context)
assert tool.name == "argless_with_context"
result = await tool.on_invoke_tool(RunContextWrapper(None), "")
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), "")
assert result == "ok"
# Extra JSON should not raise an error
result = await tool.on_invoke_tool(RunContextWrapper(None), '{"a": 1}')
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"a": 1}')
assert result == "ok"
@ -48,15 +49,15 @@ async def test_simple_function():
tool = function_tool(simple_function, failure_error_function=None)
assert tool.name == "simple_function"
result = await tool.on_invoke_tool(RunContextWrapper(None), '{"a": 1}')
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"a": 1}')
assert result == 6
result = await tool.on_invoke_tool(RunContextWrapper(None), '{"a": 1, "b": 2}')
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"a": 1, "b": 2}')
assert result == 3
# Missing required argument should raise an error
with pytest.raises(ModelBehaviorError):
await tool.on_invoke_tool(RunContextWrapper(None), "")
await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), "")
class Foo(BaseModel):
@ -84,7 +85,7 @@ async def test_complex_args_function():
"bar": Bar(x="hello", y=10),
}
)
result = await tool.on_invoke_tool(RunContextWrapper(None), valid_json)
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), valid_json)
assert result == "6 hello10 hello"
valid_json = json.dumps(
@ -93,7 +94,7 @@ async def test_complex_args_function():
"bar": Bar(x="hello", y=10),
}
)
result = await tool.on_invoke_tool(RunContextWrapper(None), valid_json)
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), valid_json)
assert result == "3 hello10 hello"
valid_json = json.dumps(
@ -103,12 +104,12 @@ async def test_complex_args_function():
"baz": "world",
}
)
result = await tool.on_invoke_tool(RunContextWrapper(None), valid_json)
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), valid_json)
assert result == "3 hello10 world"
# Missing required argument should raise an error
with pytest.raises(ModelBehaviorError):
await tool.on_invoke_tool(RunContextWrapper(None), '{"foo": {"a": 1}}')
await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"foo": {"a": 1}}')
def test_function_config_overrides():
@ -168,7 +169,7 @@ async def test_manual_function_tool_creation_works():
assert tool.params_json_schema[key] == value
assert tool.strict_json_schema
result = await tool.on_invoke_tool(RunContextWrapper(None), '{"data": "hello"}')
result = await tool.on_invoke_tool(ToolContext(None, tool_call_id="1"), '{"data": "hello"}')
assert result == "hello_done"
tool_not_strict = FunctionTool(
@ -183,7 +184,7 @@ async def test_manual_function_tool_creation_works():
assert "additionalProperties" not in tool_not_strict.params_json_schema
result = await tool_not_strict.on_invoke_tool(
RunContextWrapper(None), '{"data": "hello", "bar": "baz"}'
ToolContext(None, tool_call_id="1"), '{"data": "hello", "bar": "baz"}'
)
assert result == "hello_done"
@ -194,7 +195,7 @@ async def test_function_tool_default_error_works():
raise ValueError("test")
tool = function_tool(my_func)
ctx = RunContextWrapper(None)
ctx = ToolContext(None, tool_call_id="1")
result = await tool.on_invoke_tool(ctx, "")
assert "Invalid JSON" in str(result)
@ -218,7 +219,7 @@ async def test_sync_custom_error_function_works():
return f"error_{error.__class__.__name__}"
tool = function_tool(my_func, failure_error_function=custom_sync_error_function)
ctx = RunContextWrapper(None)
ctx = ToolContext(None, tool_call_id="1")
result = await tool.on_invoke_tool(ctx, "")
assert result == "error_ModelBehaviorError"
@ -242,7 +243,7 @@ async def test_async_custom_error_function_works():
return f"error_{error.__class__.__name__}"
tool = function_tool(my_func, failure_error_function=custom_sync_error_function)
ctx = RunContextWrapper(None)
ctx = ToolContext(None, tool_call_id="1")
result = await tool.on_invoke_tool(ctx, "")
assert result == "error_ModelBehaviorError"

View file

@ -7,6 +7,7 @@ from inline_snapshot import snapshot
from agents import function_tool
from agents.run_context import RunContextWrapper
from agents.tool_context import ToolContext
class DummyContext:
@ -14,8 +15,8 @@ class DummyContext:
self.data = "something"
def ctx_wrapper() -> RunContextWrapper[DummyContext]:
return RunContextWrapper(DummyContext())
def ctx_wrapper() -> ToolContext[DummyContext]:
return ToolContext(context=DummyContext(), tool_call_id="1")
@function_tool
@ -44,7 +45,7 @@ async def test_sync_no_context_with_args_invocation():
@function_tool
def sync_with_context(ctx: RunContextWrapper[DummyContext], name: str) -> str:
def sync_with_context(ctx: ToolContext[DummyContext], name: str) -> str:
return f"{name}_{ctx.context.data}"
@ -71,7 +72,7 @@ async def test_async_no_context_invocation():
@function_tool
async def async_with_context(ctx: RunContextWrapper[DummyContext], prefix: str, num: int) -> str:
async def async_with_context(ctx: ToolContext[DummyContext], prefix: str, num: int) -> str:
await asyncio.sleep(0)
return f"{prefix}-{num}-{ctx.context.data}"

View file

@ -49,10 +49,12 @@ def get_function_tool(
)
def get_function_tool_call(name: str, arguments: str | None = None) -> ResponseOutputItem:
def get_function_tool_call(
name: str, arguments: str | None = None, call_id: str | None = None
) -> ResponseOutputItem:
return ResponseFunctionToolCall(
id="1",
call_id="2",
call_id=call_id or "2",
type="function_call",
name=name,
arguments=arguments or "",

View file

@ -1,5 +1,6 @@
from __future__ import annotations
import json
from typing import Any
import pytest
@ -26,6 +27,8 @@ from agents._run_impl import (
RunImpl,
SingleStepResult,
)
from agents.tool import function_tool
from agents.tool_context import ToolContext
from .test_responses import (
get_final_output_message,
@ -158,6 +161,42 @@ async def test_multiple_tool_calls():
assert isinstance(result.next_step, NextStepRunAgain)
@pytest.mark.asyncio
async def test_multiple_tool_calls_with_tool_context():
async def _fake_tool(context: ToolContext[str], value: str) -> str:
return f"{value}-{context.tool_call_id}"
tool = function_tool(_fake_tool, name_override="fake_tool", failure_error_function=None)
agent = Agent(
name="test",
tools=[tool],
)
response = ModelResponse(
output=[
get_function_tool_call("fake_tool", json.dumps({"value": "123"}), call_id="1"),
get_function_tool_call("fake_tool", json.dumps({"value": "456"}), call_id="2"),
],
usage=Usage(),
response_id=None,
)
result = await get_execute_result(agent, response)
assert result.original_input == "hello"
# 4 items: new message, 2 tool calls, 2 tool call outputs
assert len(result.generated_items) == 4
assert isinstance(result.next_step, NextStepRunAgain)
items = result.generated_items
assert_item_is_function_tool_call(items[0], "fake_tool", json.dumps({"value": "123"}))
assert_item_is_function_tool_call(items[1], "fake_tool", json.dumps({"value": "456"}))
assert_item_is_function_tool_call_output(items[2], "123-1")
assert_item_is_function_tool_call_output(items[3], "456-2")
assert isinstance(result.next_step, NextStepRunAgain)
@pytest.mark.asyncio
async def test_handoff_output_leads_to_handoff_next_step():
agent_1 = Agent(name="test_1")