parent
281a7b2bb6
commit
2b9b8f7e73
19 changed files with 342 additions and 9 deletions
79
examples/basic/prompt_template.py
Normal file
79
examples/basic/prompt_template.py
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
import argparse
|
||||
import asyncio
|
||||
import random
|
||||
|
||||
from agents import Agent, GenerateDynamicPromptData, Runner
|
||||
|
||||
"""
|
||||
NOTE: This example will not work out of the box, because the default prompt ID will not be available
|
||||
in your project.
|
||||
|
||||
To use it, please:
|
||||
1. Go to https://platform.openai.com/playground/prompts
|
||||
2. Create a new prompt variable, `poem_style`.
|
||||
3. Create a system prompt with the content:
|
||||
```
|
||||
Write a poem in {{poem_style}}
|
||||
```
|
||||
4. Run the example with the `--prompt-id` flag.
|
||||
"""
|
||||
|
||||
DEFAULT_PROMPT_ID = "pmpt_6850729e8ba481939fd439e058c69ee004afaa19c520b78b"
|
||||
|
||||
|
||||
class DynamicContext:
|
||||
def __init__(self, prompt_id: str):
|
||||
self.prompt_id = prompt_id
|
||||
self.poem_style = random.choice(["limerick", "haiku", "ballad"])
|
||||
print(f"[debug] DynamicContext initialized with poem_style: {self.poem_style}")
|
||||
|
||||
|
||||
async def _get_dynamic_prompt(data: GenerateDynamicPromptData):
|
||||
ctx: DynamicContext = data.context.context
|
||||
return {
|
||||
"id": ctx.prompt_id,
|
||||
"version": "1",
|
||||
"variables": {
|
||||
"poem_style": ctx.poem_style,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
async def dynamic_prompt(prompt_id: str):
|
||||
context = DynamicContext(prompt_id)
|
||||
|
||||
agent = Agent(
|
||||
name="Assistant",
|
||||
prompt=_get_dynamic_prompt,
|
||||
)
|
||||
|
||||
result = await Runner.run(agent, "Tell me about recursion in programming.", context=context)
|
||||
print(result.final_output)
|
||||
|
||||
|
||||
async def static_prompt(prompt_id: str):
|
||||
agent = Agent(
|
||||
name="Assistant",
|
||||
prompt={
|
||||
"id": prompt_id,
|
||||
"version": "1",
|
||||
"variables": {
|
||||
"poem_style": "limerick",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
result = await Runner.run(agent, "Tell me about recursion in programming.")
|
||||
print(result.final_output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--dynamic", action="store_true")
|
||||
parser.add_argument("--prompt-id", type=str, default=DEFAULT_PROMPT_ID)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.dynamic:
|
||||
asyncio.run(dynamic_prompt(args.prompt_id))
|
||||
else:
|
||||
asyncio.run(static_prompt(args.prompt_id))
|
||||
|
|
@ -7,7 +7,7 @@ requires-python = ">=3.9"
|
|||
license = "MIT"
|
||||
authors = [{ name = "OpenAI", email = "support@openai.com" }]
|
||||
dependencies = [
|
||||
"openai>=1.81.0",
|
||||
"openai>=1.87.0",
|
||||
"pydantic>=2.10, <3",
|
||||
"griffe>=1.5.6, <2",
|
||||
"typing-extensions>=4.12.2, <5",
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ from .models.interface import Model, ModelProvider, ModelTracing
|
|||
from .models.openai_chatcompletions import OpenAIChatCompletionsModel
|
||||
from .models.openai_provider import OpenAIProvider
|
||||
from .models.openai_responses import OpenAIResponsesModel
|
||||
from .prompts import DynamicPromptFunction, GenerateDynamicPromptData, Prompt
|
||||
from .repl import run_demo_loop
|
||||
from .result import RunResult, RunResultStreaming
|
||||
from .run import RunConfig, Runner
|
||||
|
|
@ -178,6 +179,9 @@ __all__ = [
|
|||
"AgentsException",
|
||||
"InputGuardrailTripwireTriggered",
|
||||
"OutputGuardrailTripwireTriggered",
|
||||
"DynamicPromptFunction",
|
||||
"GenerateDynamicPromptData",
|
||||
"Prompt",
|
||||
"MaxTurnsExceeded",
|
||||
"ModelBehaviorError",
|
||||
"UserError",
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from collections.abc import Awaitable
|
|||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, cast
|
||||
|
||||
from openai.types.responses.response_prompt_param import ResponsePromptParam
|
||||
from typing_extensions import NotRequired, TypeAlias, TypedDict
|
||||
|
||||
from .agent_output import AgentOutputSchemaBase
|
||||
|
|
@ -17,6 +18,7 @@ from .logger import logger
|
|||
from .mcp import MCPUtil
|
||||
from .model_settings import ModelSettings
|
||||
from .models.interface import Model
|
||||
from .prompts import DynamicPromptFunction, Prompt, PromptUtil
|
||||
from .run_context import RunContextWrapper, TContext
|
||||
from .tool import FunctionTool, FunctionToolResult, Tool, function_tool
|
||||
from .util import _transforms
|
||||
|
|
@ -95,6 +97,12 @@ class Agent(Generic[TContext]):
|
|||
return a string.
|
||||
"""
|
||||
|
||||
prompt: Prompt | DynamicPromptFunction | None = None
|
||||
"""A prompt object (or a function that returns a Prompt). Prompts allow you to dynamically
|
||||
configure the instructions, tools and other config for an agent outside of your code. Only
|
||||
usable with OpenAI models, using the Responses API.
|
||||
"""
|
||||
|
||||
handoff_description: str | None = None
|
||||
"""A description of the agent. This is used when the agent is used as a handoff, so that an
|
||||
LLM knows what it does and when to invoke it.
|
||||
|
|
@ -242,6 +250,12 @@ class Agent(Generic[TContext]):
|
|||
|
||||
return None
|
||||
|
||||
async def get_prompt(
|
||||
self, run_context: RunContextWrapper[TContext]
|
||||
) -> ResponsePromptParam | None:
|
||||
"""Get the prompt for the agent."""
|
||||
return await PromptUtil.to_model_input(self.prompt, run_context, self)
|
||||
|
||||
async def get_mcp_tools(self) -> list[Tool]:
|
||||
"""Fetches the available tools from the MCP servers."""
|
||||
convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)
|
||||
|
|
|
|||
|
|
@ -71,6 +71,7 @@ class LitellmModel(Model):
|
|||
handoffs: list[Handoff],
|
||||
tracing: ModelTracing,
|
||||
previous_response_id: str | None,
|
||||
prompt: Any | None = None,
|
||||
) -> ModelResponse:
|
||||
with generation_span(
|
||||
model=str(self.model),
|
||||
|
|
@ -88,6 +89,7 @@ class LitellmModel(Model):
|
|||
span_generation,
|
||||
tracing,
|
||||
stream=False,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
assert isinstance(response.choices[0], litellm.types.utils.Choices)
|
||||
|
|
@ -153,8 +155,8 @@ class LitellmModel(Model):
|
|||
output_schema: AgentOutputSchemaBase | None,
|
||||
handoffs: list[Handoff],
|
||||
tracing: ModelTracing,
|
||||
*,
|
||||
previous_response_id: str | None,
|
||||
prompt: Any | None = None,
|
||||
) -> AsyncIterator[TResponseStreamEvent]:
|
||||
with generation_span(
|
||||
model=str(self.model),
|
||||
|
|
@ -172,6 +174,7 @@ class LitellmModel(Model):
|
|||
span_generation,
|
||||
tracing,
|
||||
stream=True,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
final_response: Response | None = None
|
||||
|
|
@ -202,6 +205,7 @@ class LitellmModel(Model):
|
|||
span: Span[GenerationSpanData],
|
||||
tracing: ModelTracing,
|
||||
stream: Literal[True],
|
||||
prompt: Any | None = None,
|
||||
) -> tuple[Response, AsyncStream[ChatCompletionChunk]]: ...
|
||||
|
||||
@overload
|
||||
|
|
@ -216,6 +220,7 @@ class LitellmModel(Model):
|
|||
span: Span[GenerationSpanData],
|
||||
tracing: ModelTracing,
|
||||
stream: Literal[False],
|
||||
prompt: Any | None = None,
|
||||
) -> litellm.types.utils.ModelResponse: ...
|
||||
|
||||
async def _fetch_response(
|
||||
|
|
@ -229,6 +234,7 @@ class LitellmModel(Model):
|
|||
span: Span[GenerationSpanData],
|
||||
tracing: ModelTracing,
|
||||
stream: bool = False,
|
||||
prompt: Any | None = None,
|
||||
) -> litellm.types.utils.ModelResponse | tuple[Response, AsyncStream[ChatCompletionChunk]]:
|
||||
converted_messages = Converter.items_to_messages(input)
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,8 @@ import enum
|
|||
from collections.abc import AsyncIterator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from openai.types.responses.response_prompt_param import ResponsePromptParam
|
||||
|
||||
from ..agent_output import AgentOutputSchemaBase
|
||||
from ..handoffs import Handoff
|
||||
from ..items import ModelResponse, TResponseInputItem, TResponseStreamEvent
|
||||
|
|
@ -46,6 +48,7 @@ class Model(abc.ABC):
|
|||
tracing: ModelTracing,
|
||||
*,
|
||||
previous_response_id: str | None,
|
||||
prompt: ResponsePromptParam | None,
|
||||
) -> ModelResponse:
|
||||
"""Get a response from the model.
|
||||
|
||||
|
|
@ -59,6 +62,7 @@ class Model(abc.ABC):
|
|||
tracing: Tracing configuration.
|
||||
previous_response_id: the ID of the previous response. Generally not used by the model,
|
||||
except for the OpenAI Responses API.
|
||||
prompt: The prompt config to use for the model.
|
||||
|
||||
Returns:
|
||||
The full model response.
|
||||
|
|
@ -77,6 +81,7 @@ class Model(abc.ABC):
|
|||
tracing: ModelTracing,
|
||||
*,
|
||||
previous_response_id: str | None,
|
||||
prompt: ResponsePromptParam | None,
|
||||
) -> AsyncIterator[TResponseStreamEvent]:
|
||||
"""Stream a response from the model.
|
||||
|
||||
|
|
@ -90,6 +95,7 @@ class Model(abc.ABC):
|
|||
tracing: Tracing configuration.
|
||||
previous_response_id: the ID of the previous response. Generally not used by the model,
|
||||
except for the OpenAI Responses API.
|
||||
prompt: The prompt config to use for the model.
|
||||
|
||||
Returns:
|
||||
An iterator of response stream events, in OpenAI Responses format.
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from openai import NOT_GIVEN, AsyncOpenAI, AsyncStream
|
|||
from openai.types import ChatModel
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
from openai.types.responses import Response
|
||||
from openai.types.responses.response_prompt_param import ResponsePromptParam
|
||||
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
|
||||
|
||||
from .. import _debug
|
||||
|
|
@ -53,6 +54,7 @@ class OpenAIChatCompletionsModel(Model):
|
|||
handoffs: list[Handoff],
|
||||
tracing: ModelTracing,
|
||||
previous_response_id: str | None,
|
||||
prompt: ResponsePromptParam | None = None,
|
||||
) -> ModelResponse:
|
||||
with generation_span(
|
||||
model=str(self.model),
|
||||
|
|
@ -69,6 +71,7 @@ class OpenAIChatCompletionsModel(Model):
|
|||
span_generation,
|
||||
tracing,
|
||||
stream=False,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
first_choice = response.choices[0]
|
||||
|
|
@ -136,8 +139,8 @@ class OpenAIChatCompletionsModel(Model):
|
|||
output_schema: AgentOutputSchemaBase | None,
|
||||
handoffs: list[Handoff],
|
||||
tracing: ModelTracing,
|
||||
*,
|
||||
previous_response_id: str | None,
|
||||
prompt: ResponsePromptParam | None = None,
|
||||
) -> AsyncIterator[TResponseStreamEvent]:
|
||||
"""
|
||||
Yields a partial message as it is generated, as well as the usage information.
|
||||
|
|
@ -157,6 +160,7 @@ class OpenAIChatCompletionsModel(Model):
|
|||
span_generation,
|
||||
tracing,
|
||||
stream=True,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
final_response: Response | None = None
|
||||
|
|
@ -187,6 +191,7 @@ class OpenAIChatCompletionsModel(Model):
|
|||
span: Span[GenerationSpanData],
|
||||
tracing: ModelTracing,
|
||||
stream: Literal[True],
|
||||
prompt: ResponsePromptParam | None = None,
|
||||
) -> tuple[Response, AsyncStream[ChatCompletionChunk]]: ...
|
||||
|
||||
@overload
|
||||
|
|
@ -201,6 +206,7 @@ class OpenAIChatCompletionsModel(Model):
|
|||
span: Span[GenerationSpanData],
|
||||
tracing: ModelTracing,
|
||||
stream: Literal[False],
|
||||
prompt: ResponsePromptParam | None = None,
|
||||
) -> ChatCompletion: ...
|
||||
|
||||
async def _fetch_response(
|
||||
|
|
@ -214,6 +220,7 @@ class OpenAIChatCompletionsModel(Model):
|
|||
span: Span[GenerationSpanData],
|
||||
tracing: ModelTracing,
|
||||
stream: bool = False,
|
||||
prompt: ResponsePromptParam | None = None,
|
||||
) -> ChatCompletion | tuple[Response, AsyncStream[ChatCompletionChunk]]:
|
||||
converted_messages = Converter.items_to_messages(input)
|
||||
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from openai.types.responses import (
|
|||
WebSearchToolParam,
|
||||
response_create_params,
|
||||
)
|
||||
from openai.types.responses.response_prompt_param import ResponsePromptParam
|
||||
|
||||
from .. import _debug
|
||||
from ..agent_output import AgentOutputSchemaBase
|
||||
|
|
@ -74,6 +75,7 @@ class OpenAIResponsesModel(Model):
|
|||
handoffs: list[Handoff],
|
||||
tracing: ModelTracing,
|
||||
previous_response_id: str | None,
|
||||
prompt: ResponsePromptParam | None = None,
|
||||
) -> ModelResponse:
|
||||
with response_span(disabled=tracing.is_disabled()) as span_response:
|
||||
try:
|
||||
|
|
@ -86,6 +88,7 @@ class OpenAIResponsesModel(Model):
|
|||
handoffs,
|
||||
previous_response_id,
|
||||
stream=False,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
if _debug.DONT_LOG_MODEL_DATA:
|
||||
|
|
@ -141,6 +144,7 @@ class OpenAIResponsesModel(Model):
|
|||
handoffs: list[Handoff],
|
||||
tracing: ModelTracing,
|
||||
previous_response_id: str | None,
|
||||
prompt: ResponsePromptParam | None = None,
|
||||
) -> AsyncIterator[ResponseStreamEvent]:
|
||||
"""
|
||||
Yields a partial message as it is generated, as well as the usage information.
|
||||
|
|
@ -156,6 +160,7 @@ class OpenAIResponsesModel(Model):
|
|||
handoffs,
|
||||
previous_response_id,
|
||||
stream=True,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
final_response: Response | None = None
|
||||
|
|
@ -192,6 +197,7 @@ class OpenAIResponsesModel(Model):
|
|||
handoffs: list[Handoff],
|
||||
previous_response_id: str | None,
|
||||
stream: Literal[True],
|
||||
prompt: ResponsePromptParam | None = None,
|
||||
) -> AsyncStream[ResponseStreamEvent]: ...
|
||||
|
||||
@overload
|
||||
|
|
@ -205,6 +211,7 @@ class OpenAIResponsesModel(Model):
|
|||
handoffs: list[Handoff],
|
||||
previous_response_id: str | None,
|
||||
stream: Literal[False],
|
||||
prompt: ResponsePromptParam | None = None,
|
||||
) -> Response: ...
|
||||
|
||||
async def _fetch_response(
|
||||
|
|
@ -217,6 +224,7 @@ class OpenAIResponsesModel(Model):
|
|||
handoffs: list[Handoff],
|
||||
previous_response_id: str | None,
|
||||
stream: Literal[True] | Literal[False] = False,
|
||||
prompt: ResponsePromptParam | None = None,
|
||||
) -> Response | AsyncStream[ResponseStreamEvent]:
|
||||
list_input = ItemHelpers.input_to_new_input_list(input)
|
||||
|
||||
|
|
@ -252,6 +260,7 @@ class OpenAIResponsesModel(Model):
|
|||
input=list_input,
|
||||
include=converted_tools.includes,
|
||||
tools=converted_tools.tools,
|
||||
prompt=self._non_null_or_not_given(prompt),
|
||||
temperature=self._non_null_or_not_given(model_settings.temperature),
|
||||
top_p=self._non_null_or_not_given(model_settings.top_p),
|
||||
truncation=self._non_null_or_not_given(model_settings.truncation),
|
||||
|
|
|
|||
76
src/agents/prompts.py
Normal file
76
src/agents/prompts.py
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
from openai.types.responses.response_prompt_param import (
|
||||
ResponsePromptParam,
|
||||
Variables as ResponsesPromptVariables,
|
||||
)
|
||||
from typing_extensions import NotRequired, TypedDict
|
||||
|
||||
from agents.util._types import MaybeAwaitable
|
||||
|
||||
from .exceptions import UserError
|
||||
from .run_context import RunContextWrapper
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .agent import Agent
|
||||
|
||||
|
||||
class Prompt(TypedDict):
|
||||
"""Prompt configuration to use for interacting with an OpenAI model."""
|
||||
|
||||
id: str
|
||||
"""The unique ID of the prompt."""
|
||||
|
||||
version: NotRequired[str]
|
||||
"""Optional version of the prompt."""
|
||||
|
||||
variables: NotRequired[dict[str, ResponsesPromptVariables]]
|
||||
"""Optional variables to substitute into the prompt."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class GenerateDynamicPromptData:
|
||||
"""Inputs to a function that allows you to dynamically generate a prompt."""
|
||||
|
||||
context: RunContextWrapper[Any]
|
||||
"""The run context."""
|
||||
|
||||
agent: Agent[Any]
|
||||
"""The agent for which the prompt is being generated."""
|
||||
|
||||
|
||||
DynamicPromptFunction = Callable[[GenerateDynamicPromptData], MaybeAwaitable[Prompt]]
|
||||
"""A function that dynamically generates a prompt."""
|
||||
|
||||
|
||||
class PromptUtil:
|
||||
@staticmethod
|
||||
async def to_model_input(
|
||||
prompt: Prompt | DynamicPromptFunction | None,
|
||||
context: RunContextWrapper[Any],
|
||||
agent: Agent[Any],
|
||||
) -> ResponsePromptParam | None:
|
||||
if prompt is None:
|
||||
return None
|
||||
|
||||
resolved_prompt: Prompt
|
||||
if isinstance(prompt, dict):
|
||||
resolved_prompt = prompt
|
||||
else:
|
||||
func_result = prompt(GenerateDynamicPromptData(context=context, agent=agent))
|
||||
if inspect.isawaitable(func_result):
|
||||
resolved_prompt = await func_result
|
||||
else:
|
||||
resolved_prompt = func_result
|
||||
if not isinstance(resolved_prompt, dict):
|
||||
raise UserError("Dynamic prompt function must return a Prompt")
|
||||
|
||||
return {
|
||||
"id": resolved_prompt["id"],
|
||||
"version": resolved_prompt.get("version"),
|
||||
"variables": resolved_prompt.get("variables"),
|
||||
}
|
||||
|
|
@ -6,6 +6,9 @@ from dataclasses import dataclass, field
|
|||
from typing import Any, cast
|
||||
|
||||
from openai.types.responses import ResponseCompletedEvent
|
||||
from openai.types.responses.response_prompt_param import (
|
||||
ResponsePromptParam,
|
||||
)
|
||||
|
||||
from ._run_impl import (
|
||||
AgentToolUseTracker,
|
||||
|
|
@ -682,7 +685,10 @@ class Runner:
|
|||
streamed_result.current_agent = agent
|
||||
streamed_result._current_agent_output_schema = output_schema
|
||||
|
||||
system_prompt = await agent.get_system_prompt(context_wrapper)
|
||||
system_prompt, prompt_config = await asyncio.gather(
|
||||
agent.get_system_prompt(context_wrapper),
|
||||
agent.get_prompt(context_wrapper),
|
||||
)
|
||||
|
||||
handoffs = cls._get_handoffs(agent)
|
||||
model = cls._get_model(agent, run_config)
|
||||
|
|
@ -706,6 +712,7 @@ class Runner:
|
|||
run_config.tracing_disabled, run_config.trace_include_sensitive_data
|
||||
),
|
||||
previous_response_id=previous_response_id,
|
||||
prompt=prompt_config,
|
||||
):
|
||||
if isinstance(event, ResponseCompletedEvent):
|
||||
usage = (
|
||||
|
|
@ -777,7 +784,10 @@ class Runner:
|
|||
),
|
||||
)
|
||||
|
||||
system_prompt = await agent.get_system_prompt(context_wrapper)
|
||||
system_prompt, prompt_config = await asyncio.gather(
|
||||
agent.get_system_prompt(context_wrapper),
|
||||
agent.get_prompt(context_wrapper),
|
||||
)
|
||||
|
||||
output_schema = cls._get_output_schema(agent)
|
||||
handoffs = cls._get_handoffs(agent)
|
||||
|
|
@ -795,6 +805,7 @@ class Runner:
|
|||
run_config,
|
||||
tool_use_tracker,
|
||||
previous_response_id,
|
||||
prompt_config,
|
||||
)
|
||||
|
||||
return await cls._get_single_step_result_from_response(
|
||||
|
|
@ -938,6 +949,7 @@ class Runner:
|
|||
run_config: RunConfig,
|
||||
tool_use_tracker: AgentToolUseTracker,
|
||||
previous_response_id: str | None,
|
||||
prompt_config: ResponsePromptParam | None,
|
||||
) -> ModelResponse:
|
||||
model = cls._get_model(agent, run_config)
|
||||
model_settings = agent.model_settings.resolve(run_config.model_settings)
|
||||
|
|
@ -954,6 +966,7 @@ class Runner:
|
|||
run_config.tracing_disabled, run_config.trace_include_sensitive_data
|
||||
),
|
||||
previous_response_id=previous_response_id,
|
||||
prompt=prompt_config,
|
||||
)
|
||||
|
||||
context_wrapper.usage.add(new_response.usage)
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from .run_context import RunContextWrapper, TContext
|
|||
def _assert_must_pass_tool_call_id() -> str:
|
||||
raise ValueError("tool_call_id must be passed to ToolContext")
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolContext(RunContextWrapper[TContext]):
|
||||
"""The context of a tool call."""
|
||||
|
|
|
|||
|
|
@ -61,6 +61,7 @@ class FakeModel(Model):
|
|||
tracing: ModelTracing,
|
||||
*,
|
||||
previous_response_id: str | None,
|
||||
prompt: Any | None,
|
||||
) -> ModelResponse:
|
||||
self.last_turn_args = {
|
||||
"system_instructions": system_instructions,
|
||||
|
|
@ -103,6 +104,7 @@ class FakeModel(Model):
|
|||
tracing: ModelTracing,
|
||||
*,
|
||||
previous_response_id: str | None,
|
||||
prompt: Any | None,
|
||||
) -> AsyncIterator[TResponseStreamEvent]:
|
||||
self.last_turn_args = {
|
||||
"system_instructions": system_instructions,
|
||||
|
|
|
|||
|
|
@ -90,6 +90,7 @@ async def test_stream_response_yields_events_for_text_content(monkeypatch) -> No
|
|||
handoffs=[],
|
||||
tracing=ModelTracing.DISABLED,
|
||||
previous_response_id=None,
|
||||
prompt=None,
|
||||
):
|
||||
output_events.append(event)
|
||||
# We expect a response.created, then a response.output_item.added, content part added,
|
||||
|
|
@ -182,6 +183,7 @@ async def test_stream_response_yields_events_for_refusal_content(monkeypatch) ->
|
|||
handoffs=[],
|
||||
tracing=ModelTracing.DISABLED,
|
||||
previous_response_id=None,
|
||||
prompt=None,
|
||||
):
|
||||
output_events.append(event)
|
||||
# Expect sequence similar to text: created, output_item.added, content part added,
|
||||
|
|
@ -270,6 +272,7 @@ async def test_stream_response_yields_events_for_tool_call(monkeypatch) -> None:
|
|||
handoffs=[],
|
||||
tracing=ModelTracing.DISABLED,
|
||||
previous_response_id=None,
|
||||
prompt=None,
|
||||
):
|
||||
output_events.append(event)
|
||||
# Sequence should be: response.created, then after loop we expect function call-related events:
|
||||
|
|
|
|||
97
tests/test_agent_prompt.py
Normal file
97
tests/test_agent_prompt.py
Normal file
|
|
@ -0,0 +1,97 @@
|
|||
import pytest
|
||||
|
||||
from agents import Agent, Prompt, RunContextWrapper, Runner
|
||||
|
||||
from .fake_model import FakeModel
|
||||
from .test_responses import get_text_message
|
||||
|
||||
|
||||
class PromptCaptureFakeModel(FakeModel):
|
||||
"""Subclass of FakeModel that records the prompt passed to the model."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.last_prompt = None
|
||||
|
||||
async def get_response(
|
||||
self,
|
||||
system_instructions,
|
||||
input,
|
||||
model_settings,
|
||||
tools,
|
||||
output_schema,
|
||||
handoffs,
|
||||
tracing,
|
||||
*,
|
||||
previous_response_id,
|
||||
prompt,
|
||||
):
|
||||
# Record the prompt that the agent resolved and passed in.
|
||||
self.last_prompt = prompt
|
||||
return await super().get_response(
|
||||
system_instructions,
|
||||
input,
|
||||
model_settings,
|
||||
tools,
|
||||
output_schema,
|
||||
handoffs,
|
||||
tracing,
|
||||
previous_response_id=previous_response_id,
|
||||
prompt=prompt,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_static_prompt_is_resolved_correctly():
|
||||
static_prompt: Prompt = {
|
||||
"id": "my_prompt",
|
||||
"version": "1",
|
||||
"variables": {"some_var": "some_value"},
|
||||
}
|
||||
|
||||
agent = Agent(name="test", prompt=static_prompt)
|
||||
context_wrapper = RunContextWrapper(context=None)
|
||||
|
||||
resolved = await agent.get_prompt(context_wrapper)
|
||||
|
||||
assert resolved == {
|
||||
"id": "my_prompt",
|
||||
"version": "1",
|
||||
"variables": {"some_var": "some_value"},
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dynamic_prompt_is_resolved_correctly():
|
||||
dynamic_prompt_value: Prompt = {"id": "dyn_prompt", "version": "2"}
|
||||
|
||||
def dynamic_prompt_fn(_data):
|
||||
return dynamic_prompt_value
|
||||
|
||||
agent = Agent(name="test", prompt=dynamic_prompt_fn)
|
||||
context_wrapper = RunContextWrapper(context=None)
|
||||
|
||||
resolved = await agent.get_prompt(context_wrapper)
|
||||
|
||||
assert resolved == {"id": "dyn_prompt", "version": "2", "variables": None}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_prompt_is_passed_to_model():
|
||||
static_prompt: Prompt = {"id": "model_prompt"}
|
||||
|
||||
model = PromptCaptureFakeModel()
|
||||
agent = Agent(name="test", model=model, prompt=static_prompt)
|
||||
|
||||
# Ensure the model returns a simple message so the run completes in one turn.
|
||||
model.set_next_output([get_text_message("done")])
|
||||
|
||||
await Runner.run(agent, input="hello")
|
||||
|
||||
# The model should have received the prompt resolved by the agent.
|
||||
expected_prompt = {
|
||||
"id": "model_prompt",
|
||||
"version": None,
|
||||
"variables": None,
|
||||
}
|
||||
assert model.last_prompt == expected_prompt
|
||||
|
|
@ -77,6 +77,7 @@ async def test_get_response_with_text_message(monkeypatch) -> None:
|
|||
handoffs=[],
|
||||
tracing=ModelTracing.DISABLED,
|
||||
previous_response_id=None,
|
||||
prompt=None,
|
||||
)
|
||||
# Should have produced exactly one output message with one text part
|
||||
assert isinstance(resp, ModelResponse)
|
||||
|
|
@ -128,6 +129,7 @@ async def test_get_response_with_refusal(monkeypatch) -> None:
|
|||
handoffs=[],
|
||||
tracing=ModelTracing.DISABLED,
|
||||
previous_response_id=None,
|
||||
prompt=None,
|
||||
)
|
||||
assert len(resp.output) == 1
|
||||
assert isinstance(resp.output[0], ResponseOutputMessage)
|
||||
|
|
@ -180,6 +182,7 @@ async def test_get_response_with_tool_call(monkeypatch) -> None:
|
|||
handoffs=[],
|
||||
tracing=ModelTracing.DISABLED,
|
||||
previous_response_id=None,
|
||||
prompt=None,
|
||||
)
|
||||
# Expect a message item followed by a function tool call item.
|
||||
assert len(resp.output) == 2
|
||||
|
|
@ -221,6 +224,7 @@ async def test_get_response_with_no_message(monkeypatch) -> None:
|
|||
handoffs=[],
|
||||
tracing=ModelTracing.DISABLED,
|
||||
previous_response_id=None,
|
||||
prompt=None,
|
||||
)
|
||||
assert resp.output == []
|
||||
|
||||
|
|
|
|||
|
|
@ -90,6 +90,7 @@ async def test_stream_response_yields_events_for_text_content(monkeypatch) -> No
|
|||
handoffs=[],
|
||||
tracing=ModelTracing.DISABLED,
|
||||
previous_response_id=None,
|
||||
prompt=None,
|
||||
):
|
||||
output_events.append(event)
|
||||
# We expect a response.created, then a response.output_item.added, content part added,
|
||||
|
|
@ -182,6 +183,7 @@ async def test_stream_response_yields_events_for_refusal_content(monkeypatch) ->
|
|||
handoffs=[],
|
||||
tracing=ModelTracing.DISABLED,
|
||||
previous_response_id=None,
|
||||
prompt=None,
|
||||
):
|
||||
output_events.append(event)
|
||||
# Expect sequence similar to text: created, output_item.added, content part added,
|
||||
|
|
@ -270,6 +272,7 @@ async def test_stream_response_yields_events_for_tool_call(monkeypatch) -> None:
|
|||
handoffs=[],
|
||||
tracing=ModelTracing.DISABLED,
|
||||
previous_response_id=None,
|
||||
prompt=None,
|
||||
):
|
||||
output_events.append(event)
|
||||
# Sequence should be: response.created, then after loop we expect function call-related events:
|
||||
|
|
|
|||
|
|
@ -71,6 +71,7 @@ async def test_get_response_creates_trace(monkeypatch):
|
|||
handoffs,
|
||||
prev_response_id,
|
||||
stream,
|
||||
prompt,
|
||||
):
|
||||
return DummyResponse()
|
||||
|
||||
|
|
@ -115,6 +116,7 @@ async def test_non_data_tracing_doesnt_set_response_id(monkeypatch):
|
|||
handoffs,
|
||||
prev_response_id,
|
||||
stream,
|
||||
prompt,
|
||||
):
|
||||
return DummyResponse()
|
||||
|
||||
|
|
@ -157,6 +159,7 @@ async def test_disable_tracing_does_not_create_span(monkeypatch):
|
|||
handoffs,
|
||||
prev_response_id,
|
||||
stream,
|
||||
prompt,
|
||||
):
|
||||
return DummyResponse()
|
||||
|
||||
|
|
@ -196,6 +199,7 @@ async def test_stream_response_creates_trace(monkeypatch):
|
|||
handoffs,
|
||||
prev_response_id,
|
||||
stream,
|
||||
prompt,
|
||||
):
|
||||
class DummyStream:
|
||||
async def __aiter__(self):
|
||||
|
|
@ -249,6 +253,7 @@ async def test_stream_non_data_tracing_doesnt_set_response_id(monkeypatch):
|
|||
handoffs,
|
||||
prev_response_id,
|
||||
stream,
|
||||
prompt,
|
||||
):
|
||||
class DummyStream:
|
||||
async def __aiter__(self):
|
||||
|
|
@ -301,6 +306,7 @@ async def test_stream_disabled_tracing_doesnt_create_span(monkeypatch):
|
|||
handoffs,
|
||||
prev_response_id,
|
||||
stream,
|
||||
prompt,
|
||||
):
|
||||
class DummyStream:
|
||||
async def __aiter__(self):
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ from __future__ import annotations
|
|||
|
||||
import json
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from inline_snapshot import snapshot
|
||||
|
|
@ -53,6 +54,7 @@ class FakeStreamingModel(Model):
|
|||
tracing: ModelTracing,
|
||||
*,
|
||||
previous_response_id: str | None,
|
||||
prompt: Any | None,
|
||||
) -> ModelResponse:
|
||||
raise NotImplementedError("Not implemented")
|
||||
|
||||
|
|
@ -67,6 +69,7 @@ class FakeStreamingModel(Model):
|
|||
tracing: ModelTracing,
|
||||
*,
|
||||
previous_response_id: str | None,
|
||||
prompt: Any | None,
|
||||
) -> AsyncIterator[TResponseStreamEvent]:
|
||||
output = self.get_next_output()
|
||||
for item in output:
|
||||
|
|
|
|||
8
uv.lock
8
uv.lock
|
|
@ -1461,7 +1461,7 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "openai"
|
||||
version = "1.81.0"
|
||||
version = "1.87.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "anyio" },
|
||||
|
|
@ -1473,9 +1473,9 @@ dependencies = [
|
|||
{ name = "tqdm" },
|
||||
{ name = "typing-extensions" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/1c/89/a1e4f3fa7ca4f7fec90dbf47d93b7cd5ff65924926733af15044e302a192/openai-1.81.0.tar.gz", hash = "sha256:349567a8607e0bcffd28e02f96b5c2397d0d25d06732d90ab3ecbf97abf030f9", size = 456861 }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/47/ed/2b3f6c7e950784e9442115ab8ebeff514d543fb33da10607b39364645a75/openai-1.87.0.tar.gz", hash = "sha256:5c69764171e0db9ef993e7a4d8a01fd8ff1026b66f8bdd005b9461782b6e7dfc", size = 470880 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/02/66/bcc7f9bf48e8610a33e3b5c96a5a644dad032d92404ea2a5e8b43ba067e8/openai-1.81.0-py3-none-any.whl", hash = "sha256:1c71572e22b43876c5d7d65ade0b7b516bb527c3d44ae94111267a09125f7bae", size = 717529 },
|
||||
{ url = "https://files.pythonhosted.org/packages/36/ac/313ded47ce1d5bc2ec02ed5dd5506bf5718678a4655ac20f337231d9aae3/openai-1.87.0-py3-none-any.whl", hash = "sha256:f9bcae02ac4fff6522276eee85d33047335cfb692b863bd8261353ce4ada5692", size = 734368 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -1536,7 +1536,7 @@ requires-dist = [
|
|||
{ name = "litellm", marker = "extra == 'litellm'", specifier = ">=1.67.4.post1,<2" },
|
||||
{ name = "mcp", marker = "python_full_version >= '3.10'", specifier = ">=1.9.4,<2" },
|
||||
{ name = "numpy", marker = "python_full_version >= '3.10' and extra == 'voice'", specifier = ">=2.2.0,<3" },
|
||||
{ name = "openai", specifier = ">=1.81.0" },
|
||||
{ name = "openai", specifier = ">=1.87.0" },
|
||||
{ name = "pydantic", specifier = ">=2.10,<3" },
|
||||
{ name = "requests", specifier = ">=2.0,<3" },
|
||||
{ name = "types-requests", specifier = ">=2.0,<3" },
|
||||
|
|
|
|||
Loading…
Reference in a new issue