262 lines
8.9 KiB
Python
262 lines
8.9 KiB
Python
from __future__ import annotations
|
|
|
|
from typing import Any
|
|
|
|
import pytest
|
|
|
|
from agents import (
|
|
Agent,
|
|
GuardrailFunctionOutput,
|
|
InputGuardrail,
|
|
OutputGuardrail,
|
|
RunContextWrapper,
|
|
TResponseInputItem,
|
|
UserError,
|
|
)
|
|
from agents.guardrail import input_guardrail, output_guardrail
|
|
|
|
|
|
def get_sync_guardrail(triggers: bool, output_info: Any | None = None):
|
|
def sync_guardrail(
|
|
context: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem]
|
|
):
|
|
return GuardrailFunctionOutput(
|
|
output_info=output_info,
|
|
tripwire_triggered=triggers,
|
|
)
|
|
|
|
return sync_guardrail
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_sync_input_guardrail():
|
|
guardrail = InputGuardrail(guardrail_function=get_sync_guardrail(triggers=False))
|
|
result = await guardrail.run(
|
|
agent=Agent(name="test"), input="test", context=RunContextWrapper(context=None)
|
|
)
|
|
assert not result.output.tripwire_triggered
|
|
assert result.output.output_info is None
|
|
|
|
guardrail = InputGuardrail(guardrail_function=get_sync_guardrail(triggers=True))
|
|
result = await guardrail.run(
|
|
agent=Agent(name="test"), input="test", context=RunContextWrapper(context=None)
|
|
)
|
|
assert result.output.tripwire_triggered
|
|
assert result.output.output_info is None
|
|
|
|
guardrail = InputGuardrail(
|
|
guardrail_function=get_sync_guardrail(triggers=True, output_info="test")
|
|
)
|
|
result = await guardrail.run(
|
|
agent=Agent(name="test"), input="test", context=RunContextWrapper(context=None)
|
|
)
|
|
assert result.output.tripwire_triggered
|
|
assert result.output.output_info == "test"
|
|
|
|
|
|
def get_async_input_guardrail(triggers: bool, output_info: Any | None = None):
|
|
async def async_guardrail(
|
|
context: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem]
|
|
):
|
|
return GuardrailFunctionOutput(
|
|
output_info=output_info,
|
|
tripwire_triggered=triggers,
|
|
)
|
|
|
|
return async_guardrail
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_input_guardrail():
|
|
guardrail = InputGuardrail(guardrail_function=get_async_input_guardrail(triggers=False))
|
|
result = await guardrail.run(
|
|
agent=Agent(name="test"), input="test", context=RunContextWrapper(context=None)
|
|
)
|
|
assert not result.output.tripwire_triggered
|
|
assert result.output.output_info is None
|
|
|
|
guardrail = InputGuardrail(guardrail_function=get_async_input_guardrail(triggers=True))
|
|
result = await guardrail.run(
|
|
agent=Agent(name="test"), input="test", context=RunContextWrapper(context=None)
|
|
)
|
|
assert result.output.tripwire_triggered
|
|
assert result.output.output_info is None
|
|
|
|
guardrail = InputGuardrail(
|
|
guardrail_function=get_async_input_guardrail(triggers=True, output_info="test")
|
|
)
|
|
result = await guardrail.run(
|
|
agent=Agent(name="test"), input="test", context=RunContextWrapper(context=None)
|
|
)
|
|
assert result.output.tripwire_triggered
|
|
assert result.output.output_info == "test"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_invalid_input_guardrail_raises_user_error():
|
|
with pytest.raises(UserError):
|
|
# Purposely ignoring type error
|
|
guardrail = InputGuardrail(guardrail_function="foo") # type: ignore
|
|
await guardrail.run(
|
|
agent=Agent(name="test"), input="test", context=RunContextWrapper(context=None)
|
|
)
|
|
|
|
|
|
def get_sync_output_guardrail(triggers: bool, output_info: Any | None = None):
|
|
def sync_guardrail(context: RunContextWrapper[Any], agent: Agent[Any], agent_output: Any):
|
|
return GuardrailFunctionOutput(
|
|
output_info=output_info,
|
|
tripwire_triggered=triggers,
|
|
)
|
|
|
|
return sync_guardrail
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_sync_output_guardrail():
|
|
guardrail = OutputGuardrail(guardrail_function=get_sync_output_guardrail(triggers=False))
|
|
result = await guardrail.run(
|
|
agent=Agent(name="test"), agent_output="test", context=RunContextWrapper(context=None)
|
|
)
|
|
assert not result.output.tripwire_triggered
|
|
assert result.output.output_info is None
|
|
|
|
guardrail = OutputGuardrail(guardrail_function=get_sync_output_guardrail(triggers=True))
|
|
result = await guardrail.run(
|
|
agent=Agent(name="test"), agent_output="test", context=RunContextWrapper(context=None)
|
|
)
|
|
assert result.output.tripwire_triggered
|
|
assert result.output.output_info is None
|
|
|
|
guardrail = OutputGuardrail(
|
|
guardrail_function=get_sync_output_guardrail(triggers=True, output_info="test")
|
|
)
|
|
result = await guardrail.run(
|
|
agent=Agent(name="test"), agent_output="test", context=RunContextWrapper(context=None)
|
|
)
|
|
assert result.output.tripwire_triggered
|
|
assert result.output.output_info == "test"
|
|
|
|
|
|
def get_async_output_guardrail(triggers: bool, output_info: Any | None = None):
|
|
async def async_guardrail(
|
|
context: RunContextWrapper[Any], agent: Agent[Any], agent_output: Any
|
|
):
|
|
return GuardrailFunctionOutput(
|
|
output_info=output_info,
|
|
tripwire_triggered=triggers,
|
|
)
|
|
|
|
return async_guardrail
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_async_output_guardrail():
|
|
guardrail = OutputGuardrail(guardrail_function=get_async_output_guardrail(triggers=False))
|
|
result = await guardrail.run(
|
|
agent=Agent(name="test"), agent_output="test", context=RunContextWrapper(context=None)
|
|
)
|
|
assert not result.output.tripwire_triggered
|
|
assert result.output.output_info is None
|
|
|
|
guardrail = OutputGuardrail(guardrail_function=get_async_output_guardrail(triggers=True))
|
|
result = await guardrail.run(
|
|
agent=Agent(name="test"), agent_output="test", context=RunContextWrapper(context=None)
|
|
)
|
|
assert result.output.tripwire_triggered
|
|
assert result.output.output_info is None
|
|
|
|
guardrail = OutputGuardrail(
|
|
guardrail_function=get_async_output_guardrail(triggers=True, output_info="test")
|
|
)
|
|
result = await guardrail.run(
|
|
agent=Agent(name="test"), agent_output="test", context=RunContextWrapper(context=None)
|
|
)
|
|
assert result.output.tripwire_triggered
|
|
assert result.output.output_info == "test"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_invalid_output_guardrail_raises_user_error():
|
|
with pytest.raises(UserError):
|
|
# Purposely ignoring type error
|
|
guardrail = OutputGuardrail(guardrail_function="foo") # type: ignore
|
|
await guardrail.run(
|
|
agent=Agent(name="test"), agent_output="test", context=RunContextWrapper(context=None)
|
|
)
|
|
|
|
|
|
@input_guardrail
|
|
def decorated_input_guardrail(
|
|
context: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem]
|
|
) -> GuardrailFunctionOutput:
|
|
return GuardrailFunctionOutput(
|
|
output_info="test_1",
|
|
tripwire_triggered=False,
|
|
)
|
|
|
|
|
|
@input_guardrail(name="Custom name")
|
|
def decorated_named_input_guardrail(
|
|
context: RunContextWrapper[Any], agent: Agent[Any], input: str | list[TResponseInputItem]
|
|
) -> GuardrailFunctionOutput:
|
|
return GuardrailFunctionOutput(
|
|
output_info="test_2",
|
|
tripwire_triggered=False,
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_input_guardrail_decorators():
|
|
guardrail = decorated_input_guardrail
|
|
result = await guardrail.run(
|
|
agent=Agent(name="test"), input="test", context=RunContextWrapper(context=None)
|
|
)
|
|
assert not result.output.tripwire_triggered
|
|
assert result.output.output_info == "test_1"
|
|
|
|
guardrail = decorated_named_input_guardrail
|
|
result = await guardrail.run(
|
|
agent=Agent(name="test"), input="test", context=RunContextWrapper(context=None)
|
|
)
|
|
assert not result.output.tripwire_triggered
|
|
assert result.output.output_info == "test_2"
|
|
assert guardrail.get_name() == "Custom name"
|
|
|
|
|
|
@output_guardrail
|
|
def decorated_output_guardrail(
|
|
context: RunContextWrapper[Any], agent: Agent[Any], agent_output: Any
|
|
) -> GuardrailFunctionOutput:
|
|
return GuardrailFunctionOutput(
|
|
output_info="test_3",
|
|
tripwire_triggered=False,
|
|
)
|
|
|
|
|
|
@output_guardrail(name="Custom name")
|
|
def decorated_named_output_guardrail(
|
|
context: RunContextWrapper[Any], agent: Agent[Any], agent_output: Any
|
|
) -> GuardrailFunctionOutput:
|
|
return GuardrailFunctionOutput(
|
|
output_info="test_4",
|
|
tripwire_triggered=False,
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_output_guardrail_decorators():
|
|
guardrail = decorated_output_guardrail
|
|
result = await guardrail.run(
|
|
agent=Agent(name="test"), agent_output="test", context=RunContextWrapper(context=None)
|
|
)
|
|
assert not result.output.tripwire_triggered
|
|
assert result.output.output_info == "test_3"
|
|
|
|
guardrail = decorated_named_output_guardrail
|
|
result = await guardrail.run(
|
|
agent=Agent(name="test"), agent_output="test", context=RunContextWrapper(context=None)
|
|
)
|
|
assert not result.output.tripwire_triggered
|
|
assert result.output.output_info == "test_4"
|
|
assert guardrail.get_name() == "Custom name"
|