Start and finish streaming trace in impl metod (#540)

Closes #435 and closes #538.

Unit tests.
This commit is contained in:
Rohan Mehta 2025-04-21 13:08:38 -04:00 committed by GitHub
parent e3698f32b1
commit 616d8e7f4b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 87 additions and 11 deletions

3
.gitignore vendored
View file

@ -141,4 +141,5 @@ cython_debug/
.ruff_cache/ .ruff_cache/
# PyPI configuration file # PyPI configuration file
.pypirc .pypirc
.aider*

View file

@ -61,6 +61,7 @@ dev = [
"graphviz", "graphviz",
"mkdocs-static-i18n>=1.3.0", "mkdocs-static-i18n>=1.3.0",
"eval-type-backport>=0.2.2", "eval-type-backport>=0.2.2",
"fastapi >= 0.110.0, <1",
] ]
[tool.uv.workspace] [tool.uv.workspace]

View file

@ -126,7 +126,7 @@ class RunResultStreaming(RunResultBase):
_current_agent_output_schema: AgentOutputSchemaBase | None = field(repr=False) _current_agent_output_schema: AgentOutputSchemaBase | None = field(repr=False)
_trace: Trace | None = field(repr=False) trace: Trace | None = field(repr=False)
is_complete: bool = False is_complete: bool = False
"""Whether the agent has finished running.""" """Whether the agent has finished running."""
@ -185,9 +185,6 @@ class RunResultStreaming(RunResultBase):
yield item yield item
self._event_queue.task_done() self._event_queue.task_done()
if self._trace:
self._trace.finish(reset_current=True)
self._cleanup_tasks() self._cleanup_tasks()
if self._stored_exception: if self._stored_exception:

View file

@ -404,10 +404,6 @@ class Runner:
disabled=run_config.tracing_disabled, disabled=run_config.tracing_disabled,
) )
) )
# Need to start the trace here, because the current trace contextvar is captured at
# asyncio.create_task time
if new_trace:
new_trace.start(mark_as_current=True)
output_schema = cls._get_output_schema(starting_agent) output_schema = cls._get_output_schema(starting_agent)
context_wrapper: RunContextWrapper[TContext] = RunContextWrapper( context_wrapper: RunContextWrapper[TContext] = RunContextWrapper(
@ -426,7 +422,7 @@ class Runner:
input_guardrail_results=[], input_guardrail_results=[],
output_guardrail_results=[], output_guardrail_results=[],
_current_agent_output_schema=output_schema, _current_agent_output_schema=output_schema,
_trace=new_trace, trace=new_trace,
) )
# Kick off the actual agent loop in the background and return the streamed result object. # Kick off the actual agent loop in the background and return the streamed result object.
@ -499,6 +495,9 @@ class Runner:
run_config: RunConfig, run_config: RunConfig,
previous_response_id: str | None, previous_response_id: str | None,
): ):
if streamed_result.trace:
streamed_result.trace.start(mark_as_current=True)
current_span: Span[AgentSpanData] | None = None current_span: Span[AgentSpanData] | None = None
current_agent = starting_agent current_agent = starting_agent
current_turn = 0 current_turn = 0
@ -625,6 +624,8 @@ class Runner:
finally: finally:
if current_span: if current_span:
current_span.finish(reset_current=True) current_span.finish(reset_current=True)
if streamed_result.trace:
streamed_result.trace.finish(reset_current=True)
@classmethod @classmethod
async def _run_single_turn_streamed( async def _run_single_turn_streamed(

View file

View file

@ -0,0 +1,30 @@
from collections.abc import AsyncIterator
from fastapi import FastAPI
from starlette.responses import StreamingResponse
from agents import Agent, Runner, RunResultStreaming
agent = Agent(
name="Assistant",
instructions="You are a helpful assistant.",
)
app = FastAPI()
@app.post("/stream")
async def stream():
result = Runner.run_streamed(agent, input="Tell me a joke")
stream_handler = StreamHandler(result)
return StreamingResponse(stream_handler.stream_events(), media_type="application/x-ndjson")
class StreamHandler:
def __init__(self, result: RunResultStreaming):
self.result = result
async def stream_events(self) -> AsyncIterator[str]:
async for event in self.result.stream_events():
yield f"{event.type}\n\n"

View file

@ -0,0 +1,29 @@
import pytest
from httpx import ASGITransport, AsyncClient
from inline_snapshot import snapshot
from ..fake_model import FakeModel
from ..test_responses import get_text_message
from .streaming_app import agent, app
@pytest.mark.asyncio
async def test_streaming_context():
"""This ensures that FastAPI streaming works. The context for this test is that the Runner
method was called in one async context, and the streaming was ended in another context,
leading to a tracing error because the context was closed in the wrong context. This test
ensures that this actually works.
"""
model = FakeModel()
agent.model = model
model.set_next_output([get_text_message("done")])
transport = ASGITransport(app)
async with AsyncClient(transport=transport, base_url="http://test") as ac:
async with ac.stream("POST", "/stream") as r:
assert r.status_code == 200
body = (await r.aread()).decode("utf-8")
lines = [line for line in body.splitlines() if line]
assert lines == snapshot(
["agent_updated_stream_event", "raw_response_event", "run_item_stream_event"]
)

19
uv.lock
View file

@ -483,6 +483,20 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/7b/8f/c4d9bafc34ad7ad5d8dc16dd1347ee0e507a52c3adb6bfa8887e1c6a26ba/executing-2.2.0-py2.py3-none-any.whl", hash = "sha256:11387150cad388d62750327a53d3339fad4888b39a6fe233c3afbb54ecffd3aa", size = 26702 }, { url = "https://files.pythonhosted.org/packages/7b/8f/c4d9bafc34ad7ad5d8dc16dd1347ee0e507a52c3adb6bfa8887e1c6a26ba/executing-2.2.0-py2.py3-none-any.whl", hash = "sha256:11387150cad388d62750327a53d3339fad4888b39a6fe233c3afbb54ecffd3aa", size = 26702 },
] ]
[[package]]
name = "fastapi"
version = "0.115.12"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pydantic" },
{ name = "starlette" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/f4/55/ae499352d82338331ca1e28c7f4a63bfd09479b16395dce38cf50a39e2c2/fastapi-0.115.12.tar.gz", hash = "sha256:1e2c2a2646905f9e83d32f04a3f86aff4a286669c6c950ca95b5fd68c2602681", size = 295236 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/50/b3/b51f09c2ba432a576fe63758bddc81f78f0c6309d9e5c10d194313bf021e/fastapi-0.115.12-py3-none-any.whl", hash = "sha256:e94613d6c05e27be7ffebdd6ea5f388112e5e430c8f7d6494a9d1d88d43e814d", size = 95164 },
]
[[package]] [[package]]
name = "filelock" name = "filelock"
version = "3.18.0" version = "3.18.0"
@ -1496,6 +1510,7 @@ voice = [
dev = [ dev = [
{ name = "coverage" }, { name = "coverage" },
{ name = "eval-type-backport" }, { name = "eval-type-backport" },
{ name = "fastapi" },
{ name = "graphviz" }, { name = "graphviz" },
{ name = "inline-snapshot" }, { name = "inline-snapshot" },
{ name = "mkdocs" }, { name = "mkdocs" },
@ -1536,6 +1551,7 @@ provides-extras = ["voice", "viz", "litellm"]
dev = [ dev = [
{ name = "coverage", specifier = ">=7.6.12" }, { name = "coverage", specifier = ">=7.6.12" },
{ name = "eval-type-backport", specifier = ">=0.2.2" }, { name = "eval-type-backport", specifier = ">=0.2.2" },
{ name = "fastapi", specifier = ">=0.110.0,<1" },
{ name = "graphviz" }, { name = "graphviz" },
{ name = "inline-snapshot", specifier = ">=0.20.7" }, { name = "inline-snapshot", specifier = ">=0.20.7" },
{ name = "mkdocs", specifier = ">=1.6.0" }, { name = "mkdocs", specifier = ">=1.6.0" },
@ -2474,7 +2490,8 @@ name = "starlette"
version = "0.46.2" version = "0.46.2"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "anyio", marker = "python_full_version >= '3.10'" }, { name = "anyio" },
{ name = "typing-extensions", marker = "python_full_version < '3.10'" },
] ]
sdist = { url = "https://files.pythonhosted.org/packages/ce/20/08dfcd9c983f6a6f4a1000d934b9e6d626cff8d2eeb77a89a68eef20a2b7/starlette-0.46.2.tar.gz", hash = "sha256:7f7361f34eed179294600af672f565727419830b54b7b084efe44bb82d2fccd5", size = 2580846 } sdist = { url = "https://files.pythonhosted.org/packages/ce/20/08dfcd9c983f6a6f4a1000d934b9e6d626cff8d2eeb77a89a68eef20a2b7/starlette-0.46.2.tar.gz", hash = "sha256:7f7361f34eed179294600af672f565727419830b54b7b084efe44bb82d2fccd5", size = 2580846 }
wheels = [ wheels = [