Prompts support (#876)

Add support for the new openai prompts feature.
This commit is contained in:
Rohan Mehta 2025-06-16 15:47:48 -04:00 committed by GitHub
parent 281a7b2bb6
commit 2b9b8f7e73
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 342 additions and 9 deletions

View file

@ -0,0 +1,79 @@
import argparse
import asyncio
import random
from agents import Agent, GenerateDynamicPromptData, Runner
"""
NOTE: This example will not work out of the box, because the default prompt ID will not be available
in your project.
To use it, please:
1. Go to https://platform.openai.com/playground/prompts
2. Create a new prompt variable, `poem_style`.
3. Create a system prompt with the content:
```
Write a poem in {{poem_style}}
```
4. Run the example with the `--prompt-id` flag.
"""
DEFAULT_PROMPT_ID = "pmpt_6850729e8ba481939fd439e058c69ee004afaa19c520b78b"
class DynamicContext:
def __init__(self, prompt_id: str):
self.prompt_id = prompt_id
self.poem_style = random.choice(["limerick", "haiku", "ballad"])
print(f"[debug] DynamicContext initialized with poem_style: {self.poem_style}")
async def _get_dynamic_prompt(data: GenerateDynamicPromptData):
ctx: DynamicContext = data.context.context
return {
"id": ctx.prompt_id,
"version": "1",
"variables": {
"poem_style": ctx.poem_style,
},
}
async def dynamic_prompt(prompt_id: str):
context = DynamicContext(prompt_id)
agent = Agent(
name="Assistant",
prompt=_get_dynamic_prompt,
)
result = await Runner.run(agent, "Tell me about recursion in programming.", context=context)
print(result.final_output)
async def static_prompt(prompt_id: str):
agent = Agent(
name="Assistant",
prompt={
"id": prompt_id,
"version": "1",
"variables": {
"poem_style": "limerick",
},
},
)
result = await Runner.run(agent, "Tell me about recursion in programming.")
print(result.final_output)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--dynamic", action="store_true")
parser.add_argument("--prompt-id", type=str, default=DEFAULT_PROMPT_ID)
args = parser.parse_args()
if args.dynamic:
asyncio.run(dynamic_prompt(args.prompt_id))
else:
asyncio.run(static_prompt(args.prompt_id))

View file

@ -7,7 +7,7 @@ requires-python = ">=3.9"
license = "MIT" license = "MIT"
authors = [{ name = "OpenAI", email = "support@openai.com" }] authors = [{ name = "OpenAI", email = "support@openai.com" }]
dependencies = [ dependencies = [
"openai>=1.81.0", "openai>=1.87.0",
"pydantic>=2.10, <3", "pydantic>=2.10, <3",
"griffe>=1.5.6, <2", "griffe>=1.5.6, <2",
"typing-extensions>=4.12.2, <5", "typing-extensions>=4.12.2, <5",

View file

@ -45,6 +45,7 @@ from .models.interface import Model, ModelProvider, ModelTracing
from .models.openai_chatcompletions import OpenAIChatCompletionsModel from .models.openai_chatcompletions import OpenAIChatCompletionsModel
from .models.openai_provider import OpenAIProvider from .models.openai_provider import OpenAIProvider
from .models.openai_responses import OpenAIResponsesModel from .models.openai_responses import OpenAIResponsesModel
from .prompts import DynamicPromptFunction, GenerateDynamicPromptData, Prompt
from .repl import run_demo_loop from .repl import run_demo_loop
from .result import RunResult, RunResultStreaming from .result import RunResult, RunResultStreaming
from .run import RunConfig, Runner from .run import RunConfig, Runner
@ -178,6 +179,9 @@ __all__ = [
"AgentsException", "AgentsException",
"InputGuardrailTripwireTriggered", "InputGuardrailTripwireTriggered",
"OutputGuardrailTripwireTriggered", "OutputGuardrailTripwireTriggered",
"DynamicPromptFunction",
"GenerateDynamicPromptData",
"Prompt",
"MaxTurnsExceeded", "MaxTurnsExceeded",
"ModelBehaviorError", "ModelBehaviorError",
"UserError", "UserError",

View file

@ -7,6 +7,7 @@ from collections.abc import Awaitable
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, cast from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, cast
from openai.types.responses.response_prompt_param import ResponsePromptParam
from typing_extensions import NotRequired, TypeAlias, TypedDict from typing_extensions import NotRequired, TypeAlias, TypedDict
from .agent_output import AgentOutputSchemaBase from .agent_output import AgentOutputSchemaBase
@ -17,6 +18,7 @@ from .logger import logger
from .mcp import MCPUtil from .mcp import MCPUtil
from .model_settings import ModelSettings from .model_settings import ModelSettings
from .models.interface import Model from .models.interface import Model
from .prompts import DynamicPromptFunction, Prompt, PromptUtil
from .run_context import RunContextWrapper, TContext from .run_context import RunContextWrapper, TContext
from .tool import FunctionTool, FunctionToolResult, Tool, function_tool from .tool import FunctionTool, FunctionToolResult, Tool, function_tool
from .util import _transforms from .util import _transforms
@ -95,6 +97,12 @@ class Agent(Generic[TContext]):
return a string. return a string.
""" """
prompt: Prompt | DynamicPromptFunction | None = None
"""A prompt object (or a function that returns a Prompt). Prompts allow you to dynamically
configure the instructions, tools and other config for an agent outside of your code. Only
usable with OpenAI models, using the Responses API.
"""
handoff_description: str | None = None handoff_description: str | None = None
"""A description of the agent. This is used when the agent is used as a handoff, so that an """A description of the agent. This is used when the agent is used as a handoff, so that an
LLM knows what it does and when to invoke it. LLM knows what it does and when to invoke it.
@ -242,6 +250,12 @@ class Agent(Generic[TContext]):
return None return None
async def get_prompt(
self, run_context: RunContextWrapper[TContext]
) -> ResponsePromptParam | None:
"""Get the prompt for the agent."""
return await PromptUtil.to_model_input(self.prompt, run_context, self)
async def get_mcp_tools(self) -> list[Tool]: async def get_mcp_tools(self) -> list[Tool]:
"""Fetches the available tools from the MCP servers.""" """Fetches the available tools from the MCP servers."""
convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False) convert_schemas_to_strict = self.mcp_config.get("convert_schemas_to_strict", False)

View file

@ -71,6 +71,7 @@ class LitellmModel(Model):
handoffs: list[Handoff], handoffs: list[Handoff],
tracing: ModelTracing, tracing: ModelTracing,
previous_response_id: str | None, previous_response_id: str | None,
prompt: Any | None = None,
) -> ModelResponse: ) -> ModelResponse:
with generation_span( with generation_span(
model=str(self.model), model=str(self.model),
@ -88,6 +89,7 @@ class LitellmModel(Model):
span_generation, span_generation,
tracing, tracing,
stream=False, stream=False,
prompt=prompt,
) )
assert isinstance(response.choices[0], litellm.types.utils.Choices) assert isinstance(response.choices[0], litellm.types.utils.Choices)
@ -153,8 +155,8 @@ class LitellmModel(Model):
output_schema: AgentOutputSchemaBase | None, output_schema: AgentOutputSchemaBase | None,
handoffs: list[Handoff], handoffs: list[Handoff],
tracing: ModelTracing, tracing: ModelTracing,
*,
previous_response_id: str | None, previous_response_id: str | None,
prompt: Any | None = None,
) -> AsyncIterator[TResponseStreamEvent]: ) -> AsyncIterator[TResponseStreamEvent]:
with generation_span( with generation_span(
model=str(self.model), model=str(self.model),
@ -172,6 +174,7 @@ class LitellmModel(Model):
span_generation, span_generation,
tracing, tracing,
stream=True, stream=True,
prompt=prompt,
) )
final_response: Response | None = None final_response: Response | None = None
@ -202,6 +205,7 @@ class LitellmModel(Model):
span: Span[GenerationSpanData], span: Span[GenerationSpanData],
tracing: ModelTracing, tracing: ModelTracing,
stream: Literal[True], stream: Literal[True],
prompt: Any | None = None,
) -> tuple[Response, AsyncStream[ChatCompletionChunk]]: ... ) -> tuple[Response, AsyncStream[ChatCompletionChunk]]: ...
@overload @overload
@ -216,6 +220,7 @@ class LitellmModel(Model):
span: Span[GenerationSpanData], span: Span[GenerationSpanData],
tracing: ModelTracing, tracing: ModelTracing,
stream: Literal[False], stream: Literal[False],
prompt: Any | None = None,
) -> litellm.types.utils.ModelResponse: ... ) -> litellm.types.utils.ModelResponse: ...
async def _fetch_response( async def _fetch_response(
@ -229,6 +234,7 @@ class LitellmModel(Model):
span: Span[GenerationSpanData], span: Span[GenerationSpanData],
tracing: ModelTracing, tracing: ModelTracing,
stream: bool = False, stream: bool = False,
prompt: Any | None = None,
) -> litellm.types.utils.ModelResponse | tuple[Response, AsyncStream[ChatCompletionChunk]]: ) -> litellm.types.utils.ModelResponse | tuple[Response, AsyncStream[ChatCompletionChunk]]:
converted_messages = Converter.items_to_messages(input) converted_messages = Converter.items_to_messages(input)

View file

@ -5,6 +5,8 @@ import enum
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from openai.types.responses.response_prompt_param import ResponsePromptParam
from ..agent_output import AgentOutputSchemaBase from ..agent_output import AgentOutputSchemaBase
from ..handoffs import Handoff from ..handoffs import Handoff
from ..items import ModelResponse, TResponseInputItem, TResponseStreamEvent from ..items import ModelResponse, TResponseInputItem, TResponseStreamEvent
@ -46,6 +48,7 @@ class Model(abc.ABC):
tracing: ModelTracing, tracing: ModelTracing,
*, *,
previous_response_id: str | None, previous_response_id: str | None,
prompt: ResponsePromptParam | None,
) -> ModelResponse: ) -> ModelResponse:
"""Get a response from the model. """Get a response from the model.
@ -59,6 +62,7 @@ class Model(abc.ABC):
tracing: Tracing configuration. tracing: Tracing configuration.
previous_response_id: the ID of the previous response. Generally not used by the model, previous_response_id: the ID of the previous response. Generally not used by the model,
except for the OpenAI Responses API. except for the OpenAI Responses API.
prompt: The prompt config to use for the model.
Returns: Returns:
The full model response. The full model response.
@ -77,6 +81,7 @@ class Model(abc.ABC):
tracing: ModelTracing, tracing: ModelTracing,
*, *,
previous_response_id: str | None, previous_response_id: str | None,
prompt: ResponsePromptParam | None,
) -> AsyncIterator[TResponseStreamEvent]: ) -> AsyncIterator[TResponseStreamEvent]:
"""Stream a response from the model. """Stream a response from the model.
@ -90,6 +95,7 @@ class Model(abc.ABC):
tracing: Tracing configuration. tracing: Tracing configuration.
previous_response_id: the ID of the previous response. Generally not used by the model, previous_response_id: the ID of the previous response. Generally not used by the model,
except for the OpenAI Responses API. except for the OpenAI Responses API.
prompt: The prompt config to use for the model.
Returns: Returns:
An iterator of response stream events, in OpenAI Responses format. An iterator of response stream events, in OpenAI Responses format.

View file

@ -9,6 +9,7 @@ from openai import NOT_GIVEN, AsyncOpenAI, AsyncStream
from openai.types import ChatModel from openai.types import ChatModel
from openai.types.chat import ChatCompletion, ChatCompletionChunk from openai.types.chat import ChatCompletion, ChatCompletionChunk
from openai.types.responses import Response from openai.types.responses import Response
from openai.types.responses.response_prompt_param import ResponsePromptParam
from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails from openai.types.responses.response_usage import InputTokensDetails, OutputTokensDetails
from .. import _debug from .. import _debug
@ -53,6 +54,7 @@ class OpenAIChatCompletionsModel(Model):
handoffs: list[Handoff], handoffs: list[Handoff],
tracing: ModelTracing, tracing: ModelTracing,
previous_response_id: str | None, previous_response_id: str | None,
prompt: ResponsePromptParam | None = None,
) -> ModelResponse: ) -> ModelResponse:
with generation_span( with generation_span(
model=str(self.model), model=str(self.model),
@ -69,6 +71,7 @@ class OpenAIChatCompletionsModel(Model):
span_generation, span_generation,
tracing, tracing,
stream=False, stream=False,
prompt=prompt,
) )
first_choice = response.choices[0] first_choice = response.choices[0]
@ -136,8 +139,8 @@ class OpenAIChatCompletionsModel(Model):
output_schema: AgentOutputSchemaBase | None, output_schema: AgentOutputSchemaBase | None,
handoffs: list[Handoff], handoffs: list[Handoff],
tracing: ModelTracing, tracing: ModelTracing,
*,
previous_response_id: str | None, previous_response_id: str | None,
prompt: ResponsePromptParam | None = None,
) -> AsyncIterator[TResponseStreamEvent]: ) -> AsyncIterator[TResponseStreamEvent]:
""" """
Yields a partial message as it is generated, as well as the usage information. Yields a partial message as it is generated, as well as the usage information.
@ -157,6 +160,7 @@ class OpenAIChatCompletionsModel(Model):
span_generation, span_generation,
tracing, tracing,
stream=True, stream=True,
prompt=prompt,
) )
final_response: Response | None = None final_response: Response | None = None
@ -187,6 +191,7 @@ class OpenAIChatCompletionsModel(Model):
span: Span[GenerationSpanData], span: Span[GenerationSpanData],
tracing: ModelTracing, tracing: ModelTracing,
stream: Literal[True], stream: Literal[True],
prompt: ResponsePromptParam | None = None,
) -> tuple[Response, AsyncStream[ChatCompletionChunk]]: ... ) -> tuple[Response, AsyncStream[ChatCompletionChunk]]: ...
@overload @overload
@ -201,6 +206,7 @@ class OpenAIChatCompletionsModel(Model):
span: Span[GenerationSpanData], span: Span[GenerationSpanData],
tracing: ModelTracing, tracing: ModelTracing,
stream: Literal[False], stream: Literal[False],
prompt: ResponsePromptParam | None = None,
) -> ChatCompletion: ... ) -> ChatCompletion: ...
async def _fetch_response( async def _fetch_response(
@ -214,6 +220,7 @@ class OpenAIChatCompletionsModel(Model):
span: Span[GenerationSpanData], span: Span[GenerationSpanData],
tracing: ModelTracing, tracing: ModelTracing,
stream: bool = False, stream: bool = False,
prompt: ResponsePromptParam | None = None,
) -> ChatCompletion | tuple[Response, AsyncStream[ChatCompletionChunk]]: ) -> ChatCompletion | tuple[Response, AsyncStream[ChatCompletionChunk]]:
converted_messages = Converter.items_to_messages(input) converted_messages = Converter.items_to_messages(input)

View file

@ -17,6 +17,7 @@ from openai.types.responses import (
WebSearchToolParam, WebSearchToolParam,
response_create_params, response_create_params,
) )
from openai.types.responses.response_prompt_param import ResponsePromptParam
from .. import _debug from .. import _debug
from ..agent_output import AgentOutputSchemaBase from ..agent_output import AgentOutputSchemaBase
@ -74,6 +75,7 @@ class OpenAIResponsesModel(Model):
handoffs: list[Handoff], handoffs: list[Handoff],
tracing: ModelTracing, tracing: ModelTracing,
previous_response_id: str | None, previous_response_id: str | None,
prompt: ResponsePromptParam | None = None,
) -> ModelResponse: ) -> ModelResponse:
with response_span(disabled=tracing.is_disabled()) as span_response: with response_span(disabled=tracing.is_disabled()) as span_response:
try: try:
@ -86,6 +88,7 @@ class OpenAIResponsesModel(Model):
handoffs, handoffs,
previous_response_id, previous_response_id,
stream=False, stream=False,
prompt=prompt,
) )
if _debug.DONT_LOG_MODEL_DATA: if _debug.DONT_LOG_MODEL_DATA:
@ -141,6 +144,7 @@ class OpenAIResponsesModel(Model):
handoffs: list[Handoff], handoffs: list[Handoff],
tracing: ModelTracing, tracing: ModelTracing,
previous_response_id: str | None, previous_response_id: str | None,
prompt: ResponsePromptParam | None = None,
) -> AsyncIterator[ResponseStreamEvent]: ) -> AsyncIterator[ResponseStreamEvent]:
""" """
Yields a partial message as it is generated, as well as the usage information. Yields a partial message as it is generated, as well as the usage information.
@ -156,6 +160,7 @@ class OpenAIResponsesModel(Model):
handoffs, handoffs,
previous_response_id, previous_response_id,
stream=True, stream=True,
prompt=prompt,
) )
final_response: Response | None = None final_response: Response | None = None
@ -192,6 +197,7 @@ class OpenAIResponsesModel(Model):
handoffs: list[Handoff], handoffs: list[Handoff],
previous_response_id: str | None, previous_response_id: str | None,
stream: Literal[True], stream: Literal[True],
prompt: ResponsePromptParam | None = None,
) -> AsyncStream[ResponseStreamEvent]: ... ) -> AsyncStream[ResponseStreamEvent]: ...
@overload @overload
@ -205,6 +211,7 @@ class OpenAIResponsesModel(Model):
handoffs: list[Handoff], handoffs: list[Handoff],
previous_response_id: str | None, previous_response_id: str | None,
stream: Literal[False], stream: Literal[False],
prompt: ResponsePromptParam | None = None,
) -> Response: ... ) -> Response: ...
async def _fetch_response( async def _fetch_response(
@ -217,6 +224,7 @@ class OpenAIResponsesModel(Model):
handoffs: list[Handoff], handoffs: list[Handoff],
previous_response_id: str | None, previous_response_id: str | None,
stream: Literal[True] | Literal[False] = False, stream: Literal[True] | Literal[False] = False,
prompt: ResponsePromptParam | None = None,
) -> Response | AsyncStream[ResponseStreamEvent]: ) -> Response | AsyncStream[ResponseStreamEvent]:
list_input = ItemHelpers.input_to_new_input_list(input) list_input = ItemHelpers.input_to_new_input_list(input)
@ -252,6 +260,7 @@ class OpenAIResponsesModel(Model):
input=list_input, input=list_input,
include=converted_tools.includes, include=converted_tools.includes,
tools=converted_tools.tools, tools=converted_tools.tools,
prompt=self._non_null_or_not_given(prompt),
temperature=self._non_null_or_not_given(model_settings.temperature), temperature=self._non_null_or_not_given(model_settings.temperature),
top_p=self._non_null_or_not_given(model_settings.top_p), top_p=self._non_null_or_not_given(model_settings.top_p),
truncation=self._non_null_or_not_given(model_settings.truncation), truncation=self._non_null_or_not_given(model_settings.truncation),

76
src/agents/prompts.py Normal file
View file

@ -0,0 +1,76 @@
from __future__ import annotations
import inspect
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Callable
from openai.types.responses.response_prompt_param import (
ResponsePromptParam,
Variables as ResponsesPromptVariables,
)
from typing_extensions import NotRequired, TypedDict
from agents.util._types import MaybeAwaitable
from .exceptions import UserError
from .run_context import RunContextWrapper
if TYPE_CHECKING:
from .agent import Agent
class Prompt(TypedDict):
"""Prompt configuration to use for interacting with an OpenAI model."""
id: str
"""The unique ID of the prompt."""
version: NotRequired[str]
"""Optional version of the prompt."""
variables: NotRequired[dict[str, ResponsesPromptVariables]]
"""Optional variables to substitute into the prompt."""
@dataclass
class GenerateDynamicPromptData:
"""Inputs to a function that allows you to dynamically generate a prompt."""
context: RunContextWrapper[Any]
"""The run context."""
agent: Agent[Any]
"""The agent for which the prompt is being generated."""
DynamicPromptFunction = Callable[[GenerateDynamicPromptData], MaybeAwaitable[Prompt]]
"""A function that dynamically generates a prompt."""
class PromptUtil:
@staticmethod
async def to_model_input(
prompt: Prompt | DynamicPromptFunction | None,
context: RunContextWrapper[Any],
agent: Agent[Any],
) -> ResponsePromptParam | None:
if prompt is None:
return None
resolved_prompt: Prompt
if isinstance(prompt, dict):
resolved_prompt = prompt
else:
func_result = prompt(GenerateDynamicPromptData(context=context, agent=agent))
if inspect.isawaitable(func_result):
resolved_prompt = await func_result
else:
resolved_prompt = func_result
if not isinstance(resolved_prompt, dict):
raise UserError("Dynamic prompt function must return a Prompt")
return {
"id": resolved_prompt["id"],
"version": resolved_prompt.get("version"),
"variables": resolved_prompt.get("variables"),
}

View file

@ -6,6 +6,9 @@ from dataclasses import dataclass, field
from typing import Any, cast from typing import Any, cast
from openai.types.responses import ResponseCompletedEvent from openai.types.responses import ResponseCompletedEvent
from openai.types.responses.response_prompt_param import (
ResponsePromptParam,
)
from ._run_impl import ( from ._run_impl import (
AgentToolUseTracker, AgentToolUseTracker,
@ -682,7 +685,10 @@ class Runner:
streamed_result.current_agent = agent streamed_result.current_agent = agent
streamed_result._current_agent_output_schema = output_schema streamed_result._current_agent_output_schema = output_schema
system_prompt = await agent.get_system_prompt(context_wrapper) system_prompt, prompt_config = await asyncio.gather(
agent.get_system_prompt(context_wrapper),
agent.get_prompt(context_wrapper),
)
handoffs = cls._get_handoffs(agent) handoffs = cls._get_handoffs(agent)
model = cls._get_model(agent, run_config) model = cls._get_model(agent, run_config)
@ -706,6 +712,7 @@ class Runner:
run_config.tracing_disabled, run_config.trace_include_sensitive_data run_config.tracing_disabled, run_config.trace_include_sensitive_data
), ),
previous_response_id=previous_response_id, previous_response_id=previous_response_id,
prompt=prompt_config,
): ):
if isinstance(event, ResponseCompletedEvent): if isinstance(event, ResponseCompletedEvent):
usage = ( usage = (
@ -777,7 +784,10 @@ class Runner:
), ),
) )
system_prompt = await agent.get_system_prompt(context_wrapper) system_prompt, prompt_config = await asyncio.gather(
agent.get_system_prompt(context_wrapper),
agent.get_prompt(context_wrapper),
)
output_schema = cls._get_output_schema(agent) output_schema = cls._get_output_schema(agent)
handoffs = cls._get_handoffs(agent) handoffs = cls._get_handoffs(agent)
@ -795,6 +805,7 @@ class Runner:
run_config, run_config,
tool_use_tracker, tool_use_tracker,
previous_response_id, previous_response_id,
prompt_config,
) )
return await cls._get_single_step_result_from_response( return await cls._get_single_step_result_from_response(
@ -938,6 +949,7 @@ class Runner:
run_config: RunConfig, run_config: RunConfig,
tool_use_tracker: AgentToolUseTracker, tool_use_tracker: AgentToolUseTracker,
previous_response_id: str | None, previous_response_id: str | None,
prompt_config: ResponsePromptParam | None,
) -> ModelResponse: ) -> ModelResponse:
model = cls._get_model(agent, run_config) model = cls._get_model(agent, run_config)
model_settings = agent.model_settings.resolve(run_config.model_settings) model_settings = agent.model_settings.resolve(run_config.model_settings)
@ -954,6 +966,7 @@ class Runner:
run_config.tracing_disabled, run_config.trace_include_sensitive_data run_config.tracing_disabled, run_config.trace_include_sensitive_data
), ),
previous_response_id=previous_response_id, previous_response_id=previous_response_id,
prompt=prompt_config,
) )
context_wrapper.usage.add(new_response.usage) context_wrapper.usage.add(new_response.usage)

View file

@ -7,6 +7,7 @@ from .run_context import RunContextWrapper, TContext
def _assert_must_pass_tool_call_id() -> str: def _assert_must_pass_tool_call_id() -> str:
raise ValueError("tool_call_id must be passed to ToolContext") raise ValueError("tool_call_id must be passed to ToolContext")
@dataclass @dataclass
class ToolContext(RunContextWrapper[TContext]): class ToolContext(RunContextWrapper[TContext]):
"""The context of a tool call.""" """The context of a tool call."""

View file

@ -61,6 +61,7 @@ class FakeModel(Model):
tracing: ModelTracing, tracing: ModelTracing,
*, *,
previous_response_id: str | None, previous_response_id: str | None,
prompt: Any | None,
) -> ModelResponse: ) -> ModelResponse:
self.last_turn_args = { self.last_turn_args = {
"system_instructions": system_instructions, "system_instructions": system_instructions,
@ -103,6 +104,7 @@ class FakeModel(Model):
tracing: ModelTracing, tracing: ModelTracing,
*, *,
previous_response_id: str | None, previous_response_id: str | None,
prompt: Any | None,
) -> AsyncIterator[TResponseStreamEvent]: ) -> AsyncIterator[TResponseStreamEvent]:
self.last_turn_args = { self.last_turn_args = {
"system_instructions": system_instructions, "system_instructions": system_instructions,

View file

@ -90,6 +90,7 @@ async def test_stream_response_yields_events_for_text_content(monkeypatch) -> No
handoffs=[], handoffs=[],
tracing=ModelTracing.DISABLED, tracing=ModelTracing.DISABLED,
previous_response_id=None, previous_response_id=None,
prompt=None,
): ):
output_events.append(event) output_events.append(event)
# We expect a response.created, then a response.output_item.added, content part added, # We expect a response.created, then a response.output_item.added, content part added,
@ -182,6 +183,7 @@ async def test_stream_response_yields_events_for_refusal_content(monkeypatch) ->
handoffs=[], handoffs=[],
tracing=ModelTracing.DISABLED, tracing=ModelTracing.DISABLED,
previous_response_id=None, previous_response_id=None,
prompt=None,
): ):
output_events.append(event) output_events.append(event)
# Expect sequence similar to text: created, output_item.added, content part added, # Expect sequence similar to text: created, output_item.added, content part added,
@ -270,6 +272,7 @@ async def test_stream_response_yields_events_for_tool_call(monkeypatch) -> None:
handoffs=[], handoffs=[],
tracing=ModelTracing.DISABLED, tracing=ModelTracing.DISABLED,
previous_response_id=None, previous_response_id=None,
prompt=None,
): ):
output_events.append(event) output_events.append(event)
# Sequence should be: response.created, then after loop we expect function call-related events: # Sequence should be: response.created, then after loop we expect function call-related events:

View file

@ -0,0 +1,97 @@
import pytest
from agents import Agent, Prompt, RunContextWrapper, Runner
from .fake_model import FakeModel
from .test_responses import get_text_message
class PromptCaptureFakeModel(FakeModel):
"""Subclass of FakeModel that records the prompt passed to the model."""
def __init__(self):
super().__init__()
self.last_prompt = None
async def get_response(
self,
system_instructions,
input,
model_settings,
tools,
output_schema,
handoffs,
tracing,
*,
previous_response_id,
prompt,
):
# Record the prompt that the agent resolved and passed in.
self.last_prompt = prompt
return await super().get_response(
system_instructions,
input,
model_settings,
tools,
output_schema,
handoffs,
tracing,
previous_response_id=previous_response_id,
prompt=prompt,
)
@pytest.mark.asyncio
async def test_static_prompt_is_resolved_correctly():
static_prompt: Prompt = {
"id": "my_prompt",
"version": "1",
"variables": {"some_var": "some_value"},
}
agent = Agent(name="test", prompt=static_prompt)
context_wrapper = RunContextWrapper(context=None)
resolved = await agent.get_prompt(context_wrapper)
assert resolved == {
"id": "my_prompt",
"version": "1",
"variables": {"some_var": "some_value"},
}
@pytest.mark.asyncio
async def test_dynamic_prompt_is_resolved_correctly():
dynamic_prompt_value: Prompt = {"id": "dyn_prompt", "version": "2"}
def dynamic_prompt_fn(_data):
return dynamic_prompt_value
agent = Agent(name="test", prompt=dynamic_prompt_fn)
context_wrapper = RunContextWrapper(context=None)
resolved = await agent.get_prompt(context_wrapper)
assert resolved == {"id": "dyn_prompt", "version": "2", "variables": None}
@pytest.mark.asyncio
async def test_prompt_is_passed_to_model():
static_prompt: Prompt = {"id": "model_prompt"}
model = PromptCaptureFakeModel()
agent = Agent(name="test", model=model, prompt=static_prompt)
# Ensure the model returns a simple message so the run completes in one turn.
model.set_next_output([get_text_message("done")])
await Runner.run(agent, input="hello")
# The model should have received the prompt resolved by the agent.
expected_prompt = {
"id": "model_prompt",
"version": None,
"variables": None,
}
assert model.last_prompt == expected_prompt

View file

@ -77,6 +77,7 @@ async def test_get_response_with_text_message(monkeypatch) -> None:
handoffs=[], handoffs=[],
tracing=ModelTracing.DISABLED, tracing=ModelTracing.DISABLED,
previous_response_id=None, previous_response_id=None,
prompt=None,
) )
# Should have produced exactly one output message with one text part # Should have produced exactly one output message with one text part
assert isinstance(resp, ModelResponse) assert isinstance(resp, ModelResponse)
@ -128,6 +129,7 @@ async def test_get_response_with_refusal(monkeypatch) -> None:
handoffs=[], handoffs=[],
tracing=ModelTracing.DISABLED, tracing=ModelTracing.DISABLED,
previous_response_id=None, previous_response_id=None,
prompt=None,
) )
assert len(resp.output) == 1 assert len(resp.output) == 1
assert isinstance(resp.output[0], ResponseOutputMessage) assert isinstance(resp.output[0], ResponseOutputMessage)
@ -180,6 +182,7 @@ async def test_get_response_with_tool_call(monkeypatch) -> None:
handoffs=[], handoffs=[],
tracing=ModelTracing.DISABLED, tracing=ModelTracing.DISABLED,
previous_response_id=None, previous_response_id=None,
prompt=None,
) )
# Expect a message item followed by a function tool call item. # Expect a message item followed by a function tool call item.
assert len(resp.output) == 2 assert len(resp.output) == 2
@ -221,6 +224,7 @@ async def test_get_response_with_no_message(monkeypatch) -> None:
handoffs=[], handoffs=[],
tracing=ModelTracing.DISABLED, tracing=ModelTracing.DISABLED,
previous_response_id=None, previous_response_id=None,
prompt=None,
) )
assert resp.output == [] assert resp.output == []

View file

@ -90,6 +90,7 @@ async def test_stream_response_yields_events_for_text_content(monkeypatch) -> No
handoffs=[], handoffs=[],
tracing=ModelTracing.DISABLED, tracing=ModelTracing.DISABLED,
previous_response_id=None, previous_response_id=None,
prompt=None,
): ):
output_events.append(event) output_events.append(event)
# We expect a response.created, then a response.output_item.added, content part added, # We expect a response.created, then a response.output_item.added, content part added,
@ -182,6 +183,7 @@ async def test_stream_response_yields_events_for_refusal_content(monkeypatch) ->
handoffs=[], handoffs=[],
tracing=ModelTracing.DISABLED, tracing=ModelTracing.DISABLED,
previous_response_id=None, previous_response_id=None,
prompt=None,
): ):
output_events.append(event) output_events.append(event)
# Expect sequence similar to text: created, output_item.added, content part added, # Expect sequence similar to text: created, output_item.added, content part added,
@ -270,6 +272,7 @@ async def test_stream_response_yields_events_for_tool_call(monkeypatch) -> None:
handoffs=[], handoffs=[],
tracing=ModelTracing.DISABLED, tracing=ModelTracing.DISABLED,
previous_response_id=None, previous_response_id=None,
prompt=None,
): ):
output_events.append(event) output_events.append(event)
# Sequence should be: response.created, then after loop we expect function call-related events: # Sequence should be: response.created, then after loop we expect function call-related events:

View file

@ -71,6 +71,7 @@ async def test_get_response_creates_trace(monkeypatch):
handoffs, handoffs,
prev_response_id, prev_response_id,
stream, stream,
prompt,
): ):
return DummyResponse() return DummyResponse()
@ -115,6 +116,7 @@ async def test_non_data_tracing_doesnt_set_response_id(monkeypatch):
handoffs, handoffs,
prev_response_id, prev_response_id,
stream, stream,
prompt,
): ):
return DummyResponse() return DummyResponse()
@ -157,6 +159,7 @@ async def test_disable_tracing_does_not_create_span(monkeypatch):
handoffs, handoffs,
prev_response_id, prev_response_id,
stream, stream,
prompt,
): ):
return DummyResponse() return DummyResponse()
@ -196,6 +199,7 @@ async def test_stream_response_creates_trace(monkeypatch):
handoffs, handoffs,
prev_response_id, prev_response_id,
stream, stream,
prompt,
): ):
class DummyStream: class DummyStream:
async def __aiter__(self): async def __aiter__(self):
@ -249,6 +253,7 @@ async def test_stream_non_data_tracing_doesnt_set_response_id(monkeypatch):
handoffs, handoffs,
prev_response_id, prev_response_id,
stream, stream,
prompt,
): ):
class DummyStream: class DummyStream:
async def __aiter__(self): async def __aiter__(self):
@ -301,6 +306,7 @@ async def test_stream_disabled_tracing_doesnt_create_span(monkeypatch):
handoffs, handoffs,
prev_response_id, prev_response_id,
stream, stream,
prompt,
): ):
class DummyStream: class DummyStream:
async def __aiter__(self): async def __aiter__(self):

View file

@ -2,6 +2,7 @@ from __future__ import annotations
import json import json
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from typing import Any
import pytest import pytest
from inline_snapshot import snapshot from inline_snapshot import snapshot
@ -53,6 +54,7 @@ class FakeStreamingModel(Model):
tracing: ModelTracing, tracing: ModelTracing,
*, *,
previous_response_id: str | None, previous_response_id: str | None,
prompt: Any | None,
) -> ModelResponse: ) -> ModelResponse:
raise NotImplementedError("Not implemented") raise NotImplementedError("Not implemented")
@ -67,6 +69,7 @@ class FakeStreamingModel(Model):
tracing: ModelTracing, tracing: ModelTracing,
*, *,
previous_response_id: str | None, previous_response_id: str | None,
prompt: Any | None,
) -> AsyncIterator[TResponseStreamEvent]: ) -> AsyncIterator[TResponseStreamEvent]:
output = self.get_next_output() output = self.get_next_output()
for item in output: for item in output:

View file

@ -1461,7 +1461,7 @@ wheels = [
[[package]] [[package]]
name = "openai" name = "openai"
version = "1.81.0" version = "1.87.0"
source = { registry = "https://pypi.org/simple" } source = { registry = "https://pypi.org/simple" }
dependencies = [ dependencies = [
{ name = "anyio" }, { name = "anyio" },
@ -1473,9 +1473,9 @@ dependencies = [
{ name = "tqdm" }, { name = "tqdm" },
{ name = "typing-extensions" }, { name = "typing-extensions" },
] ]
sdist = { url = "https://files.pythonhosted.org/packages/1c/89/a1e4f3fa7ca4f7fec90dbf47d93b7cd5ff65924926733af15044e302a192/openai-1.81.0.tar.gz", hash = "sha256:349567a8607e0bcffd28e02f96b5c2397d0d25d06732d90ab3ecbf97abf030f9", size = 456861 } sdist = { url = "https://files.pythonhosted.org/packages/47/ed/2b3f6c7e950784e9442115ab8ebeff514d543fb33da10607b39364645a75/openai-1.87.0.tar.gz", hash = "sha256:5c69764171e0db9ef993e7a4d8a01fd8ff1026b66f8bdd005b9461782b6e7dfc", size = 470880 }
wheels = [ wheels = [
{ url = "https://files.pythonhosted.org/packages/02/66/bcc7f9bf48e8610a33e3b5c96a5a644dad032d92404ea2a5e8b43ba067e8/openai-1.81.0-py3-none-any.whl", hash = "sha256:1c71572e22b43876c5d7d65ade0b7b516bb527c3d44ae94111267a09125f7bae", size = 717529 }, { url = "https://files.pythonhosted.org/packages/36/ac/313ded47ce1d5bc2ec02ed5dd5506bf5718678a4655ac20f337231d9aae3/openai-1.87.0-py3-none-any.whl", hash = "sha256:f9bcae02ac4fff6522276eee85d33047335cfb692b863bd8261353ce4ada5692", size = 734368 },
] ]
[[package]] [[package]]
@ -1536,7 +1536,7 @@ requires-dist = [
{ name = "litellm", marker = "extra == 'litellm'", specifier = ">=1.67.4.post1,<2" }, { name = "litellm", marker = "extra == 'litellm'", specifier = ">=1.67.4.post1,<2" },
{ name = "mcp", marker = "python_full_version >= '3.10'", specifier = ">=1.9.4,<2" }, { name = "mcp", marker = "python_full_version >= '3.10'", specifier = ">=1.9.4,<2" },
{ name = "numpy", marker = "python_full_version >= '3.10' and extra == 'voice'", specifier = ">=2.2.0,<3" }, { name = "numpy", marker = "python_full_version >= '3.10' and extra == 'voice'", specifier = ">=2.2.0,<3" },
{ name = "openai", specifier = ">=1.81.0" }, { name = "openai", specifier = ">=1.87.0" },
{ name = "pydantic", specifier = ">=2.10,<3" }, { name = "pydantic", specifier = ">=2.10,<3" },
{ name = "requests", specifier = ">=2.0,<3" }, { name = "requests", specifier = ">=2.0,<3" },
{ name = "types-requests", specifier = ">=2.0,<3" }, { name = "types-requests", specifier = ">=2.0,<3" },