Allow cancel out of the streaming result (#579)

Fix for #574 

@rm-openai I'm not sure how to add a test within the repo but I have
pasted a test script below that seems to work

```python
import asyncio
from openai.types.responses import ResponseTextDeltaEvent
from agents import Agent, Runner

async def main():
    agent = Agent(
        name="Joker",
        instructions="You are a helpful assistant.",
    )

    result = Runner.run_streamed(agent, input="Please tell me 5 jokes.")
    num_visible_event = 0
    async for event in result.stream_events():
        if event.type == "raw_response_event" and isinstance(event.data, ResponseTextDeltaEvent):
            print(event.data.delta, end="", flush=True)
            num_visible_event += 1
            print(num_visible_event)
            if num_visible_event == 3:
                result.cancel()


if __name__ == "__main__":
    asyncio.run(main())
````
This commit is contained in:
Andrew Han 2025-04-23 16:51:10 -07:00 committed by GitHub
parent 178020ea33
commit a113fea0ee
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 43 additions and 3 deletions

View file

@ -75,7 +75,9 @@ class RunResultBase(abc.ABC):
def to_input_list(self) -> list[TResponseInputItem]:
"""Creates a new input list, merging the original input with all the new items generated."""
original_items: list[TResponseInputItem] = ItemHelpers.input_to_new_input_list(self.input)
original_items: list[TResponseInputItem] = ItemHelpers.input_to_new_input_list(
self.input
)
new_items = [item.to_input_item() for item in self.new_items]
return original_items + new_items
@ -152,6 +154,18 @@ class RunResultStreaming(RunResultBase):
"""
return self.current_agent
def cancel(self) -> None:
"""Cancels the streaming run, stopping all background tasks and marking the run as
complete."""
self._cleanup_tasks() # Cancel all running tasks
self.is_complete = True # Mark the run as complete to stop event streaming
# Optionally, clear the event queue to prevent processing stale events
while not self._event_queue.empty():
self._event_queue.get_nowait()
while not self._input_guardrail_queue.empty():
self._input_guardrail_queue.get_nowait()
async def stream_events(self) -> AsyncIterator[StreamEvent]:
"""Stream deltas for new items as they are generated. We're using the types from the
OpenAI Responses API, so these are semantic events: each event has a `type` field that
@ -192,13 +206,17 @@ class RunResultStreaming(RunResultBase):
def _check_errors(self):
if self.current_turn > self.max_turns:
self._stored_exception = MaxTurnsExceeded(f"Max turns ({self.max_turns}) exceeded")
self._stored_exception = MaxTurnsExceeded(
f"Max turns ({self.max_turns}) exceeded"
)
# Fetch all the completed guardrail results from the queue and raise if needed
while not self._input_guardrail_queue.empty():
guardrail_result = self._input_guardrail_queue.get_nowait()
if guardrail_result.output.tripwire_triggered:
self._stored_exception = InputGuardrailTripwireTriggered(guardrail_result)
self._stored_exception = InputGuardrailTripwireTriggered(
guardrail_result
)
# Check the tasks for any exceptions
if self._run_impl_task and self._run_impl_task.done():

View file

@ -0,0 +1,22 @@
import pytest
from agents import Agent, Runner
from .fake_model import FakeModel
@pytest.mark.asyncio
async def test_joker_streamed_jokes_with_cancel():
model = FakeModel()
agent = Agent(name="Joker", model=model)
result = Runner.run_streamed(agent, input="Please tell me 5 jokes.")
num_events = 0
stop_after = 1 # There are two that the model gives back.
async for _event in result.stream_events():
num_events += 1
if num_events == 1:
result.cancel()
assert num_events == 1, f"Expected {stop_after} visible events, but got {num_events}"