Allow cancel out of the streaming result (#579)
Fix for #574 @rm-openai I'm not sure how to add a test within the repo but I have pasted a test script below that seems to work ```python import asyncio from openai.types.responses import ResponseTextDeltaEvent from agents import Agent, Runner async def main(): agent = Agent( name="Joker", instructions="You are a helpful assistant.", ) result = Runner.run_streamed(agent, input="Please tell me 5 jokes.") num_visible_event = 0 async for event in result.stream_events(): if event.type == "raw_response_event" and isinstance(event.data, ResponseTextDeltaEvent): print(event.data.delta, end="", flush=True) num_visible_event += 1 print(num_visible_event) if num_visible_event == 3: result.cancel() if __name__ == "__main__": asyncio.run(main()) ````
This commit is contained in:
parent
178020ea33
commit
a113fea0ee
2 changed files with 43 additions and 3 deletions
|
|
@ -75,7 +75,9 @@ class RunResultBase(abc.ABC):
|
|||
|
||||
def to_input_list(self) -> list[TResponseInputItem]:
|
||||
"""Creates a new input list, merging the original input with all the new items generated."""
|
||||
original_items: list[TResponseInputItem] = ItemHelpers.input_to_new_input_list(self.input)
|
||||
original_items: list[TResponseInputItem] = ItemHelpers.input_to_new_input_list(
|
||||
self.input
|
||||
)
|
||||
new_items = [item.to_input_item() for item in self.new_items]
|
||||
|
||||
return original_items + new_items
|
||||
|
|
@ -152,6 +154,18 @@ class RunResultStreaming(RunResultBase):
|
|||
"""
|
||||
return self.current_agent
|
||||
|
||||
def cancel(self) -> None:
|
||||
"""Cancels the streaming run, stopping all background tasks and marking the run as
|
||||
complete."""
|
||||
self._cleanup_tasks() # Cancel all running tasks
|
||||
self.is_complete = True # Mark the run as complete to stop event streaming
|
||||
|
||||
# Optionally, clear the event queue to prevent processing stale events
|
||||
while not self._event_queue.empty():
|
||||
self._event_queue.get_nowait()
|
||||
while not self._input_guardrail_queue.empty():
|
||||
self._input_guardrail_queue.get_nowait()
|
||||
|
||||
async def stream_events(self) -> AsyncIterator[StreamEvent]:
|
||||
"""Stream deltas for new items as they are generated. We're using the types from the
|
||||
OpenAI Responses API, so these are semantic events: each event has a `type` field that
|
||||
|
|
@ -192,13 +206,17 @@ class RunResultStreaming(RunResultBase):
|
|||
|
||||
def _check_errors(self):
|
||||
if self.current_turn > self.max_turns:
|
||||
self._stored_exception = MaxTurnsExceeded(f"Max turns ({self.max_turns}) exceeded")
|
||||
self._stored_exception = MaxTurnsExceeded(
|
||||
f"Max turns ({self.max_turns}) exceeded"
|
||||
)
|
||||
|
||||
# Fetch all the completed guardrail results from the queue and raise if needed
|
||||
while not self._input_guardrail_queue.empty():
|
||||
guardrail_result = self._input_guardrail_queue.get_nowait()
|
||||
if guardrail_result.output.tripwire_triggered:
|
||||
self._stored_exception = InputGuardrailTripwireTriggered(guardrail_result)
|
||||
self._stored_exception = InputGuardrailTripwireTriggered(
|
||||
guardrail_result
|
||||
)
|
||||
|
||||
# Check the tasks for any exceptions
|
||||
if self._run_impl_task and self._run_impl_task.done():
|
||||
|
|
|
|||
22
tests/test_cancel_streaming.py
Normal file
22
tests/test_cancel_streaming.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
import pytest
|
||||
|
||||
from agents import Agent, Runner
|
||||
|
||||
from .fake_model import FakeModel
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_joker_streamed_jokes_with_cancel():
|
||||
model = FakeModel()
|
||||
agent = Agent(name="Joker", model=model)
|
||||
|
||||
result = Runner.run_streamed(agent, input="Please tell me 5 jokes.")
|
||||
num_events = 0
|
||||
stop_after = 1 # There are two that the model gives back.
|
||||
|
||||
async for _event in result.stream_events():
|
||||
num_events += 1
|
||||
if num_events == 1:
|
||||
result.cancel()
|
||||
|
||||
assert num_events == 1, f"Expected {stop_after} visible events, but got {num_events}"
|
||||
Loading…
Reference in a new issue