More tests for cancelling streamed run (#590)
This commit is contained in:
parent
e11b822d5f
commit
45eb41f1e6
2 changed files with 100 additions and 3 deletions
|
|
@ -127,7 +127,10 @@ class FakeModel(Model):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_response_obj(output: list[TResponseOutputItem], response_id: str | None = None) -> Response:
|
def get_response_obj(
|
||||||
|
output: list[TResponseOutputItem],
|
||||||
|
response_id: str | None = None,
|
||||||
|
) -> Response:
|
||||||
return Response(
|
return Response(
|
||||||
id=response_id or "123",
|
id=response_id or "123",
|
||||||
created_at=123,
|
created_at=123,
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,15 @@
|
||||||
|
import json
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from agents import Agent, Runner
|
from agents import Agent, Runner
|
||||||
|
|
||||||
from .fake_model import FakeModel
|
from .fake_model import FakeModel
|
||||||
|
from .test_responses import get_function_tool, get_function_tool_call, get_text_message
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_joker_streamed_jokes_with_cancel():
|
async def test_simple_streaming_with_cancel():
|
||||||
model = FakeModel()
|
model = FakeModel()
|
||||||
agent = Agent(name="Joker", model=model)
|
agent = Agent(name="Joker", model=model)
|
||||||
|
|
||||||
|
|
@ -16,7 +19,98 @@ async def test_joker_streamed_jokes_with_cancel():
|
||||||
|
|
||||||
async for _event in result.stream_events():
|
async for _event in result.stream_events():
|
||||||
num_events += 1
|
num_events += 1
|
||||||
if num_events == 1:
|
if num_events == stop_after:
|
||||||
result.cancel()
|
result.cancel()
|
||||||
|
|
||||||
assert num_events == 1, f"Expected {stop_after} visible events, but got {num_events}"
|
assert num_events == 1, f"Expected {stop_after} visible events, but got {num_events}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_multiple_events_streaming_with_cancel():
|
||||||
|
model = FakeModel()
|
||||||
|
agent = Agent(
|
||||||
|
name="Joker",
|
||||||
|
model=model,
|
||||||
|
tools=[get_function_tool("foo", "tool_result")],
|
||||||
|
)
|
||||||
|
|
||||||
|
model.add_multiple_turn_outputs(
|
||||||
|
[
|
||||||
|
# First turn: a message and tool call
|
||||||
|
[
|
||||||
|
get_text_message("a_message"),
|
||||||
|
get_function_tool_call("foo", json.dumps({"a": "b"})),
|
||||||
|
],
|
||||||
|
# Second turn: text message
|
||||||
|
[get_text_message("done")],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
result = Runner.run_streamed(agent, input="Please tell me 5 jokes.")
|
||||||
|
num_events = 0
|
||||||
|
stop_after = 2
|
||||||
|
|
||||||
|
async for _ in result.stream_events():
|
||||||
|
num_events += 1
|
||||||
|
if num_events == stop_after:
|
||||||
|
result.cancel()
|
||||||
|
|
||||||
|
assert num_events == stop_after, f"Expected {stop_after} visible events, but got {num_events}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancel_prevents_further_events():
|
||||||
|
model = FakeModel()
|
||||||
|
agent = Agent(name="Joker", model=model)
|
||||||
|
result = Runner.run_streamed(agent, input="Please tell me 5 jokes.")
|
||||||
|
events = []
|
||||||
|
async for event in result.stream_events():
|
||||||
|
events.append(event)
|
||||||
|
result.cancel()
|
||||||
|
break # Cancel after first event
|
||||||
|
# Try to get more events after cancel
|
||||||
|
more_events = [e async for e in result.stream_events()]
|
||||||
|
assert len(events) == 1
|
||||||
|
assert more_events == [], "No events should be yielded after cancel()"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancel_is_idempotent():
|
||||||
|
model = FakeModel()
|
||||||
|
agent = Agent(name="Joker", model=model)
|
||||||
|
result = Runner.run_streamed(agent, input="Please tell me 5 jokes.")
|
||||||
|
events = []
|
||||||
|
async for event in result.stream_events():
|
||||||
|
events.append(event)
|
||||||
|
result.cancel()
|
||||||
|
result.cancel() # Call cancel again
|
||||||
|
break
|
||||||
|
# Should not raise or misbehave
|
||||||
|
assert len(events) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancel_before_streaming():
|
||||||
|
model = FakeModel()
|
||||||
|
agent = Agent(name="Joker", model=model)
|
||||||
|
result = Runner.run_streamed(agent, input="Please tell me 5 jokes.")
|
||||||
|
result.cancel() # Cancel before streaming
|
||||||
|
events = [e async for e in result.stream_events()]
|
||||||
|
assert events == [], "No events should be yielded if cancel() is called before streaming."
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_cancel_cleans_up_resources():
|
||||||
|
model = FakeModel()
|
||||||
|
agent = Agent(name="Joker", model=model)
|
||||||
|
result = Runner.run_streamed(agent, input="Please tell me 5 jokes.")
|
||||||
|
# Start streaming, then cancel
|
||||||
|
async for _ in result.stream_events():
|
||||||
|
result.cancel()
|
||||||
|
break
|
||||||
|
# After cancel, queues should be empty and is_complete True
|
||||||
|
assert result.is_complete, "Result should be marked complete after cancel."
|
||||||
|
assert result._event_queue.empty(), "Event queue should be empty after cancel."
|
||||||
|
assert result._input_guardrail_queue.empty(), (
|
||||||
|
"Input guardrail queue should be empty after cancel."
|
||||||
|
)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue