Example for streaming guardrails (#505)
An example for the question in the issue attached - how to run guardrails during streaming. Towards #495.
This commit is contained in:
parent
5183f528f4
commit
5727a1c73a
3 changed files with 99 additions and 4 deletions
93
examples/agent_patterns/streaming_guardrails.py
Normal file
93
examples/agent_patterns/streaming_guardrails.py
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
from openai.types.responses import ResponseTextDeltaEvent
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from agents import Agent, Runner
|
||||
|
||||
"""
|
||||
This example shows how to use guardrails as the model is streaming. Output guardrails run after the
|
||||
final output has been generated; this example runs guardails every N tokens, allowing for early
|
||||
termination if bad output is detected.
|
||||
|
||||
The expected output is that you'll see a bunch of tokens stream in, then the guardrail will trigger
|
||||
and stop the streaming.
|
||||
"""
|
||||
|
||||
|
||||
agent = Agent(
|
||||
name="Assistant",
|
||||
instructions=(
|
||||
"You are a helpful assistant. You ALWAYS write long responses, making sure to be verbose "
|
||||
"and detailed."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class GuardrailOutput(BaseModel):
|
||||
reasoning: str = Field(
|
||||
description="Reasoning about whether the response could be understood by a ten year old."
|
||||
)
|
||||
is_readable_by_ten_year_old: bool = Field(
|
||||
description="Whether the response is understandable by a ten year old."
|
||||
)
|
||||
|
||||
|
||||
guardrail_agent = Agent(
|
||||
name="Checker",
|
||||
instructions=(
|
||||
"You will be given a question and a response. Your goal is to judge whether the response "
|
||||
"is simple enough to be understood by a ten year old."
|
||||
),
|
||||
output_type=GuardrailOutput,
|
||||
model="gpt-4o-mini",
|
||||
)
|
||||
|
||||
|
||||
async def check_guardrail(text: str) -> GuardrailOutput:
|
||||
result = await Runner.run(guardrail_agent, text)
|
||||
return result.final_output_as(GuardrailOutput)
|
||||
|
||||
|
||||
async def main():
|
||||
question = "What is a black hole, and how does it behave?"
|
||||
result = Runner.run_streamed(agent, question)
|
||||
current_text = ""
|
||||
|
||||
# We will check the guardrail every N characters
|
||||
next_guardrail_check_len = 300
|
||||
guardrail_task = None
|
||||
|
||||
async for event in result.stream_events():
|
||||
if event.type == "raw_response_event" and isinstance(event.data, ResponseTextDeltaEvent):
|
||||
print(event.data.delta, end="", flush=True)
|
||||
current_text += event.data.delta
|
||||
|
||||
# Check if it's time to run the guardrail check
|
||||
# Note that we don't run the guardrail check if there's already a task running. An
|
||||
# alternate implementation is to have N guardrails running, or cancel the previous
|
||||
# one.
|
||||
if len(current_text) >= next_guardrail_check_len and not guardrail_task:
|
||||
print("Running guardrail check")
|
||||
guardrail_task = asyncio.create_task(check_guardrail(current_text))
|
||||
next_guardrail_check_len += 300
|
||||
|
||||
# Every iteration of the loop, check if the guardrail has been triggered
|
||||
if guardrail_task and guardrail_task.done():
|
||||
guardrail_result = guardrail_task.result()
|
||||
if not guardrail_result.is_readable_by_ten_year_old:
|
||||
print("\n\n================\n\n")
|
||||
print(f"Guardrail triggered. Reasoning:\n{guardrail_result.reasoning}")
|
||||
break
|
||||
|
||||
# Do one final check on the final output
|
||||
guardrail_result = await check_guardrail(current_text)
|
||||
if not guardrail_result.is_readable_by_ten_year_old:
|
||||
print("\n\n================\n\n")
|
||||
print(f"Guardrail triggered. Reasoning:\n{guardrail_result.reasoning}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
@ -572,7 +572,6 @@ class OpenAIChatCompletionsModel(Model):
|
|||
|
||||
|
||||
class _Converter:
|
||||
|
||||
@classmethod
|
||||
def is_openai(cls, client: AsyncOpenAI):
|
||||
return str(client.base_url).startswith("https://api.openai.com")
|
||||
|
|
@ -585,11 +584,14 @@ class _Converter:
|
|||
|
||||
@classmethod
|
||||
def get_stream_options_param(
|
||||
cls, client: AsyncOpenAI, model_settings: ModelSettings
|
||||
cls, client: AsyncOpenAI, model_settings: ModelSettings
|
||||
) -> dict[str, bool] | None:
|
||||
default_include_usage = True if cls.is_openai(client) else None
|
||||
include_usage = model_settings.include_usage if model_settings.include_usage is not None \
|
||||
include_usage = (
|
||||
model_settings.include_usage
|
||||
if model_settings.include_usage is not None
|
||||
else default_include_usage
|
||||
)
|
||||
stream_options = {"include_usage": include_usage} if include_usage is not None else None
|
||||
return stream_options
|
||||
|
||||
|
|
|
|||
|
|
@ -250,7 +250,7 @@ class OpenAIResponsesModel(Model):
|
|||
text=response_format,
|
||||
store=self._non_null_or_not_given(model_settings.store),
|
||||
reasoning=self._non_null_or_not_given(model_settings.reasoning),
|
||||
metadata=self._non_null_or_not_given(model_settings.metadata)
|
||||
metadata=self._non_null_or_not_given(model_settings.metadata),
|
||||
)
|
||||
|
||||
def _get_client(self) -> AsyncOpenAI:
|
||||
|
|
|
|||
Loading…
Reference in a new issue