Start and finish streaming trace in impl metod (#540)
Closes #435 and closes #538. Unit tests.
This commit is contained in:
parent
e3698f32b1
commit
616d8e7f4b
8 changed files with 87 additions and 11 deletions
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -141,4 +141,5 @@ cython_debug/
|
|||
.ruff_cache/
|
||||
|
||||
# PyPI configuration file
|
||||
.pypirc
|
||||
.pypirc
|
||||
.aider*
|
||||
|
|
|
|||
|
|
@ -61,6 +61,7 @@ dev = [
|
|||
"graphviz",
|
||||
"mkdocs-static-i18n>=1.3.0",
|
||||
"eval-type-backport>=0.2.2",
|
||||
"fastapi >= 0.110.0, <1",
|
||||
]
|
||||
|
||||
[tool.uv.workspace]
|
||||
|
|
|
|||
|
|
@ -126,7 +126,7 @@ class RunResultStreaming(RunResultBase):
|
|||
|
||||
_current_agent_output_schema: AgentOutputSchemaBase | None = field(repr=False)
|
||||
|
||||
_trace: Trace | None = field(repr=False)
|
||||
trace: Trace | None = field(repr=False)
|
||||
|
||||
is_complete: bool = False
|
||||
"""Whether the agent has finished running."""
|
||||
|
|
@ -185,9 +185,6 @@ class RunResultStreaming(RunResultBase):
|
|||
yield item
|
||||
self._event_queue.task_done()
|
||||
|
||||
if self._trace:
|
||||
self._trace.finish(reset_current=True)
|
||||
|
||||
self._cleanup_tasks()
|
||||
|
||||
if self._stored_exception:
|
||||
|
|
|
|||
|
|
@ -404,10 +404,6 @@ class Runner:
|
|||
disabled=run_config.tracing_disabled,
|
||||
)
|
||||
)
|
||||
# Need to start the trace here, because the current trace contextvar is captured at
|
||||
# asyncio.create_task time
|
||||
if new_trace:
|
||||
new_trace.start(mark_as_current=True)
|
||||
|
||||
output_schema = cls._get_output_schema(starting_agent)
|
||||
context_wrapper: RunContextWrapper[TContext] = RunContextWrapper(
|
||||
|
|
@ -426,7 +422,7 @@ class Runner:
|
|||
input_guardrail_results=[],
|
||||
output_guardrail_results=[],
|
||||
_current_agent_output_schema=output_schema,
|
||||
_trace=new_trace,
|
||||
trace=new_trace,
|
||||
)
|
||||
|
||||
# Kick off the actual agent loop in the background and return the streamed result object.
|
||||
|
|
@ -499,6 +495,9 @@ class Runner:
|
|||
run_config: RunConfig,
|
||||
previous_response_id: str | None,
|
||||
):
|
||||
if streamed_result.trace:
|
||||
streamed_result.trace.start(mark_as_current=True)
|
||||
|
||||
current_span: Span[AgentSpanData] | None = None
|
||||
current_agent = starting_agent
|
||||
current_turn = 0
|
||||
|
|
@ -625,6 +624,8 @@ class Runner:
|
|||
finally:
|
||||
if current_span:
|
||||
current_span.finish(reset_current=True)
|
||||
if streamed_result.trace:
|
||||
streamed_result.trace.finish(reset_current=True)
|
||||
|
||||
@classmethod
|
||||
async def _run_single_turn_streamed(
|
||||
|
|
|
|||
0
tests/fastapi/__init__.py
Normal file
0
tests/fastapi/__init__.py
Normal file
30
tests/fastapi/streaming_app.py
Normal file
30
tests/fastapi/streaming_app.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
from collections.abc import AsyncIterator
|
||||
|
||||
from fastapi import FastAPI
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from agents import Agent, Runner, RunResultStreaming
|
||||
|
||||
agent = Agent(
|
||||
name="Assistant",
|
||||
instructions="You are a helpful assistant.",
|
||||
)
|
||||
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
@app.post("/stream")
|
||||
async def stream():
|
||||
result = Runner.run_streamed(agent, input="Tell me a joke")
|
||||
stream_handler = StreamHandler(result)
|
||||
return StreamingResponse(stream_handler.stream_events(), media_type="application/x-ndjson")
|
||||
|
||||
|
||||
class StreamHandler:
|
||||
def __init__(self, result: RunResultStreaming):
|
||||
self.result = result
|
||||
|
||||
async def stream_events(self) -> AsyncIterator[str]:
|
||||
async for event in self.result.stream_events():
|
||||
yield f"{event.type}\n\n"
|
||||
29
tests/fastapi/test_streaming_context.py
Normal file
29
tests/fastapi/test_streaming_context.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
from inline_snapshot import snapshot
|
||||
|
||||
from ..fake_model import FakeModel
|
||||
from ..test_responses import get_text_message
|
||||
from .streaming_app import agent, app
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_context():
|
||||
"""This ensures that FastAPI streaming works. The context for this test is that the Runner
|
||||
method was called in one async context, and the streaming was ended in another context,
|
||||
leading to a tracing error because the context was closed in the wrong context. This test
|
||||
ensures that this actually works.
|
||||
"""
|
||||
model = FakeModel()
|
||||
agent.model = model
|
||||
model.set_next_output([get_text_message("done")])
|
||||
|
||||
transport = ASGITransport(app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
async with ac.stream("POST", "/stream") as r:
|
||||
assert r.status_code == 200
|
||||
body = (await r.aread()).decode("utf-8")
|
||||
lines = [line for line in body.splitlines() if line]
|
||||
assert lines == snapshot(
|
||||
["agent_updated_stream_event", "raw_response_event", "run_item_stream_event"]
|
||||
)
|
||||
19
uv.lock
19
uv.lock
|
|
@ -483,6 +483,20 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/7b/8f/c4d9bafc34ad7ad5d8dc16dd1347ee0e507a52c3adb6bfa8887e1c6a26ba/executing-2.2.0-py2.py3-none-any.whl", hash = "sha256:11387150cad388d62750327a53d3339fad4888b39a6fe233c3afbb54ecffd3aa", size = 26702 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "fastapi"
|
||||
version = "0.115.12"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "pydantic" },
|
||||
{ name = "starlette" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f4/55/ae499352d82338331ca1e28c7f4a63bfd09479b16395dce38cf50a39e2c2/fastapi-0.115.12.tar.gz", hash = "sha256:1e2c2a2646905f9e83d32f04a3f86aff4a286669c6c950ca95b5fd68c2602681", size = 295236 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/50/b3/b51f09c2ba432a576fe63758bddc81f78f0c6309d9e5c10d194313bf021e/fastapi-0.115.12-py3-none-any.whl", hash = "sha256:e94613d6c05e27be7ffebdd6ea5f388112e5e430c8f7d6494a9d1d88d43e814d", size = 95164 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "filelock"
|
||||
version = "3.18.0"
|
||||
|
|
@ -1496,6 +1510,7 @@ voice = [
|
|||
dev = [
|
||||
{ name = "coverage" },
|
||||
{ name = "eval-type-backport" },
|
||||
{ name = "fastapi" },
|
||||
{ name = "graphviz" },
|
||||
{ name = "inline-snapshot" },
|
||||
{ name = "mkdocs" },
|
||||
|
|
@ -1536,6 +1551,7 @@ provides-extras = ["voice", "viz", "litellm"]
|
|||
dev = [
|
||||
{ name = "coverage", specifier = ">=7.6.12" },
|
||||
{ name = "eval-type-backport", specifier = ">=0.2.2" },
|
||||
{ name = "fastapi", specifier = ">=0.110.0,<1" },
|
||||
{ name = "graphviz" },
|
||||
{ name = "inline-snapshot", specifier = ">=0.20.7" },
|
||||
{ name = "mkdocs", specifier = ">=1.6.0" },
|
||||
|
|
@ -2474,7 +2490,8 @@ name = "starlette"
|
|||
version = "0.46.2"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "anyio", marker = "python_full_version >= '3.10'" },
|
||||
{ name = "anyio" },
|
||||
{ name = "typing-extensions", marker = "python_full_version < '3.10'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/ce/20/08dfcd9c983f6a6f4a1000d934b9e6d626cff8d2eeb77a89a68eef20a2b7/starlette-0.46.2.tar.gz", hash = "sha256:7f7361f34eed179294600af672f565727419830b54b7b084efe44bb82d2fccd5", size = 2580846 }
|
||||
wheels = [
|
||||
|
|
|
|||
Loading…
Reference in a new issue