utils directory
This commit is contained in:
parent
951193bd21
commit
09d70c074d
17 changed files with 111 additions and 103 deletions
|
|
@ -3,7 +3,7 @@ from agents import Agent, Runner
|
|||
agent = Agent(name="Assistant", instructions="You are a helpful assistant")
|
||||
|
||||
# Intended for Jupyter notebooks where there's an existing event loop
|
||||
result = await Runner.run(agent, "Write a haiku about recursion in programming.") # type: ignore[top-level-await] # noqa: F704
|
||||
result = await Runner.run(agent, "Write a haiku about recursion in programming.") # type: ignore[top-level-await] # noqa: F704
|
||||
print(result.final_output)
|
||||
|
||||
# Code within code loops,
|
||||
|
|
|
|||
|
|
@ -25,7 +25,6 @@ from openai.types.responses.response_computer_tool_call import (
|
|||
from openai.types.responses.response_input_param import ComputerCallOutput
|
||||
from openai.types.responses.response_reasoning_item import ResponseReasoningItem
|
||||
|
||||
from . import _utils
|
||||
from .agent import Agent
|
||||
from .agent_output import AgentOutputSchema
|
||||
from .computer import AsyncComputer, Computer
|
||||
|
|
@ -59,6 +58,7 @@ from .tracing import (
|
|||
handoff_span,
|
||||
trace,
|
||||
)
|
||||
from .util import _coro, _error_tracing
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .run import RunConfig
|
||||
|
|
@ -293,7 +293,7 @@ class RunImpl:
|
|||
elif isinstance(output, ResponseComputerToolCall):
|
||||
items.append(ToolCallItem(raw_item=output, agent=agent))
|
||||
if not computer_tool:
|
||||
_utils.attach_error_to_current_span(
|
||||
_error_tracing.attach_error_to_current_span(
|
||||
SpanError(
|
||||
message="Computer tool not found",
|
||||
data={},
|
||||
|
|
@ -324,7 +324,7 @@ class RunImpl:
|
|||
# Regular function tool call
|
||||
else:
|
||||
if output.name not in function_map:
|
||||
_utils.attach_error_to_current_span(
|
||||
_error_tracing.attach_error_to_current_span(
|
||||
SpanError(
|
||||
message="Tool not found",
|
||||
data={"tool_name": output.name},
|
||||
|
|
@ -368,7 +368,7 @@ class RunImpl:
|
|||
(
|
||||
agent.hooks.on_tool_start(context_wrapper, agent, func_tool)
|
||||
if agent.hooks
|
||||
else _utils.noop_coroutine()
|
||||
else _coro.noop_coroutine()
|
||||
),
|
||||
func_tool.on_invoke_tool(context_wrapper, tool_call.arguments),
|
||||
)
|
||||
|
|
@ -378,11 +378,11 @@ class RunImpl:
|
|||
(
|
||||
agent.hooks.on_tool_end(context_wrapper, agent, func_tool, result)
|
||||
if agent.hooks
|
||||
else _utils.noop_coroutine()
|
||||
else _coro.noop_coroutine()
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
_utils.attach_error_to_current_span(
|
||||
_error_tracing.attach_error_to_current_span(
|
||||
SpanError(
|
||||
message="Error running tool",
|
||||
data={"tool_name": func_tool.name, "error": str(e)},
|
||||
|
|
@ -502,7 +502,7 @@ class RunImpl:
|
|||
source=agent,
|
||||
)
|
||||
if agent.hooks
|
||||
else _utils.noop_coroutine()
|
||||
else _coro.noop_coroutine()
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -520,7 +520,7 @@ class RunImpl:
|
|||
new_items=tuple(new_step_items),
|
||||
)
|
||||
if not callable(input_filter):
|
||||
_utils.attach_error_to_span(
|
||||
_error_tracing.attach_error_to_span(
|
||||
span_handoff,
|
||||
SpanError(
|
||||
message="Invalid input filter",
|
||||
|
|
@ -530,7 +530,7 @@ class RunImpl:
|
|||
raise UserError(f"Invalid input filter: {input_filter}")
|
||||
filtered = input_filter(handoff_input_data)
|
||||
if not isinstance(filtered, HandoffInputData):
|
||||
_utils.attach_error_to_span(
|
||||
_error_tracing.attach_error_to_span(
|
||||
span_handoff,
|
||||
SpanError(
|
||||
message="Invalid input filter result",
|
||||
|
|
@ -591,7 +591,7 @@ class RunImpl:
|
|||
hooks.on_agent_end(context_wrapper, agent, final_output),
|
||||
agent.hooks.on_end(context_wrapper, agent, final_output)
|
||||
if agent.hooks
|
||||
else _utils.noop_coroutine(),
|
||||
else _coro.noop_coroutine(),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
|
@ -706,7 +706,7 @@ class ComputerAction:
|
|||
(
|
||||
agent.hooks.on_tool_start(context_wrapper, agent, action.computer_tool)
|
||||
if agent.hooks
|
||||
else _utils.noop_coroutine()
|
||||
else _coro.noop_coroutine()
|
||||
),
|
||||
output_func,
|
||||
)
|
||||
|
|
@ -716,7 +716,7 @@ class ComputerAction:
|
|||
(
|
||||
agent.hooks.on_tool_end(context_wrapper, agent, action.computer_tool, output)
|
||||
if agent.hooks
|
||||
else _utils.noop_coroutine()
|
||||
else _coro.noop_coroutine()
|
||||
),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,61 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from collections.abc import Awaitable
|
||||
from typing import Any, Literal, Union
|
||||
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
from .exceptions import ModelBehaviorError
|
||||
from .logger import logger
|
||||
from .tracing import Span, SpanError, get_current_span
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
MaybeAwaitable = Union[Awaitable[T], T]
|
||||
|
||||
|
||||
def transform_string_function_style(name: str) -> str:
|
||||
# Replace spaces with underscores
|
||||
name = name.replace(" ", "_")
|
||||
|
||||
# Replace non-alphanumeric characters with underscores
|
||||
name = re.sub(r"[^a-zA-Z0-9]", "_", name)
|
||||
|
||||
return name.lower()
|
||||
|
||||
|
||||
def validate_json(json_str: str, type_adapter: TypeAdapter[T], partial: bool) -> T:
|
||||
partial_setting: bool | Literal["off", "on", "trailing-strings"] = (
|
||||
"trailing-strings" if partial else False
|
||||
)
|
||||
try:
|
||||
validated = type_adapter.validate_json(json_str, experimental_allow_partial=partial_setting)
|
||||
return validated
|
||||
except ValidationError as e:
|
||||
attach_error_to_current_span(
|
||||
SpanError(
|
||||
message="Invalid JSON provided",
|
||||
data={},
|
||||
)
|
||||
)
|
||||
raise ModelBehaviorError(
|
||||
f"Invalid JSON when parsing {json_str} for {type_adapter}; {e}"
|
||||
) from e
|
||||
|
||||
|
||||
def attach_error_to_span(span: Span[Any], error: SpanError) -> None:
|
||||
span.set_error(error)
|
||||
|
||||
|
||||
def attach_error_to_current_span(error: SpanError) -> None:
|
||||
span = get_current_span()
|
||||
if span:
|
||||
attach_error_to_span(span, error)
|
||||
else:
|
||||
logger.warning(f"No span to add error {error} to")
|
||||
|
||||
|
||||
async def noop_coroutine() -> None:
|
||||
pass
|
||||
|
|
@ -6,8 +6,6 @@ from collections.abc import Awaitable
|
|||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Callable, Generic, cast
|
||||
|
||||
from . import _utils
|
||||
from ._utils import MaybeAwaitable
|
||||
from .guardrail import InputGuardrail, OutputGuardrail
|
||||
from .handoffs import Handoff
|
||||
from .items import ItemHelpers
|
||||
|
|
@ -16,6 +14,8 @@ from .model_settings import ModelSettings
|
|||
from .models.interface import Model
|
||||
from .run_context import RunContextWrapper, TContext
|
||||
from .tool import Tool, function_tool
|
||||
from .util import _transforms
|
||||
from .util._types import MaybeAwaitable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .lifecycle import AgentHooks
|
||||
|
|
@ -126,7 +126,7 @@ class Agent(Generic[TContext]):
|
|||
"""
|
||||
|
||||
@function_tool(
|
||||
name_override=tool_name or _utils.transform_string_function_style(self.name),
|
||||
name_override=tool_name or _transforms.transform_string_function_style(self.name),
|
||||
description_override=tool_description or "",
|
||||
)
|
||||
async def run_agent(context: RunContextWrapper, input: str) -> str:
|
||||
|
|
|
|||
|
|
@ -4,10 +4,10 @@ from typing import Any
|
|||
from pydantic import BaseModel, TypeAdapter
|
||||
from typing_extensions import TypedDict, get_args, get_origin
|
||||
|
||||
from . import _utils
|
||||
from .exceptions import ModelBehaviorError, UserError
|
||||
from .strict_schema import ensure_strict_json_schema
|
||||
from .tracing import SpanError
|
||||
from .util import _error_tracing, _json
|
||||
|
||||
_WRAPPER_DICT_KEY = "response"
|
||||
|
||||
|
|
@ -87,10 +87,10 @@ class AgentOutputSchema:
|
|||
"""Validate a JSON string against the output type. Returns the validated object, or raises
|
||||
a `ModelBehaviorError` if the JSON is invalid.
|
||||
"""
|
||||
validated = _utils.validate_json(json_str, self._type_adapter, partial)
|
||||
validated = _json.validate_json(json_str, self._type_adapter, partial)
|
||||
if self._is_wrapped:
|
||||
if not isinstance(validated, dict):
|
||||
_utils.attach_error_to_current_span(
|
||||
_error_tracing.attach_error_to_current_span(
|
||||
SpanError(
|
||||
message="Invalid JSON",
|
||||
data={"details": f"Expected a dict, got {type(validated)}"},
|
||||
|
|
@ -101,7 +101,7 @@ class AgentOutputSchema:
|
|||
)
|
||||
|
||||
if _WRAPPER_DICT_KEY not in validated:
|
||||
_utils.attach_error_to_current_span(
|
||||
_error_tracing.attach_error_to_current_span(
|
||||
SpanError(
|
||||
message="Invalid JSON",
|
||||
data={"details": f"Could not find key {_WRAPPER_DICT_KEY} in JSON"},
|
||||
|
|
|
|||
|
|
@ -7,10 +7,10 @@ from typing import TYPE_CHECKING, Any, Callable, Generic, Union, overload
|
|||
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
from ._utils import MaybeAwaitable
|
||||
from .exceptions import UserError
|
||||
from .items import TResponseInputItem
|
||||
from .run_context import RunContextWrapper, TContext
|
||||
from .util._types import MaybeAwaitable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .agent import Agent
|
||||
|
|
|
|||
|
|
@ -8,12 +8,12 @@ from typing import TYPE_CHECKING, Any, Callable, Generic, cast, overload
|
|||
from pydantic import TypeAdapter
|
||||
from typing_extensions import TypeAlias, TypeVar
|
||||
|
||||
from . import _utils
|
||||
from .exceptions import ModelBehaviorError, UserError
|
||||
from .items import RunItem, TResponseInputItem
|
||||
from .run_context import RunContextWrapper, TContext
|
||||
from .strict_schema import ensure_strict_json_schema
|
||||
from .tracing.spans import SpanError
|
||||
from .util import _error_tracing, _json, _transforms
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .agent import Agent
|
||||
|
|
@ -104,7 +104,7 @@ class Handoff(Generic[TContext]):
|
|||
|
||||
@classmethod
|
||||
def default_tool_name(cls, agent: Agent[Any]) -> str:
|
||||
return _utils.transform_string_function_style(f"transfer_to_{agent.name}")
|
||||
return _transforms.transform_string_function_style(f"transfer_to_{agent.name}")
|
||||
|
||||
@classmethod
|
||||
def default_tool_description(cls, agent: Agent[Any]) -> str:
|
||||
|
|
@ -192,7 +192,7 @@ def handoff(
|
|||
) -> Agent[Any]:
|
||||
if input_type is not None and type_adapter is not None:
|
||||
if input_json is None:
|
||||
_utils.attach_error_to_current_span(
|
||||
_error_tracing.attach_error_to_current_span(
|
||||
SpanError(
|
||||
message="Handoff function expected non-null input, but got None",
|
||||
data={"details": "input_json is None"},
|
||||
|
|
@ -200,7 +200,7 @@ def handoff(
|
|||
)
|
||||
raise ModelBehaviorError("Handoff function expected non-null input, but got None")
|
||||
|
||||
validated_input = _utils.validate_json(
|
||||
validated_input = _json.validate_json(
|
||||
json_str=input_json,
|
||||
type_adapter=type_adapter,
|
||||
partial=False,
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ from typing import Any, cast
|
|||
|
||||
from openai.types.responses import ResponseCompletedEvent
|
||||
|
||||
from . import Model, _utils
|
||||
from ._run_impl import (
|
||||
NextStepFinalOutput,
|
||||
NextStepHandoff,
|
||||
|
|
@ -33,7 +32,7 @@ from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem
|
|||
from .lifecycle import RunHooks
|
||||
from .logger import logger
|
||||
from .model_settings import ModelSettings
|
||||
from .models.interface import ModelProvider
|
||||
from .models.interface import Model, ModelProvider
|
||||
from .models.openai_provider import OpenAIProvider
|
||||
from .result import RunResult, RunResultStreaming
|
||||
from .run_context import RunContextWrapper, TContext
|
||||
|
|
@ -41,6 +40,7 @@ from .stream_events import AgentUpdatedStreamEvent, RawResponsesStreamEvent
|
|||
from .tracing import Span, SpanError, agent_span, get_current_trace, trace
|
||||
from .tracing.span_data import AgentSpanData
|
||||
from .usage import Usage
|
||||
from .util import _coro, _error_tracing
|
||||
|
||||
DEFAULT_MAX_TURNS = 10
|
||||
|
||||
|
|
@ -193,7 +193,7 @@ class Runner:
|
|||
|
||||
current_turn += 1
|
||||
if current_turn > max_turns:
|
||||
_utils.attach_error_to_span(
|
||||
_error_tracing.attach_error_to_span(
|
||||
current_span,
|
||||
SpanError(
|
||||
message="Max turns exceeded",
|
||||
|
|
@ -447,7 +447,7 @@ class Runner:
|
|||
for done in asyncio.as_completed(guardrail_tasks):
|
||||
result = await done
|
||||
if result.output.tripwire_triggered:
|
||||
_utils.attach_error_to_span(
|
||||
_error_tracing.attach_error_to_span(
|
||||
parent_span,
|
||||
SpanError(
|
||||
message="Guardrail tripwire triggered",
|
||||
|
|
@ -511,7 +511,7 @@ class Runner:
|
|||
streamed_result.current_turn = current_turn
|
||||
|
||||
if current_turn > max_turns:
|
||||
_utils.attach_error_to_span(
|
||||
_error_tracing.attach_error_to_span(
|
||||
current_span,
|
||||
SpanError(
|
||||
message="Max turns exceeded",
|
||||
|
|
@ -583,7 +583,7 @@ class Runner:
|
|||
pass
|
||||
except Exception as e:
|
||||
if current_span:
|
||||
_utils.attach_error_to_span(
|
||||
_error_tracing.attach_error_to_span(
|
||||
current_span,
|
||||
SpanError(
|
||||
message="Error in agent run",
|
||||
|
|
@ -615,7 +615,7 @@ class Runner:
|
|||
(
|
||||
agent.hooks.on_start(context_wrapper, agent)
|
||||
if agent.hooks
|
||||
else _utils.noop_coroutine()
|
||||
else _coro.noop_coroutine()
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -705,7 +705,7 @@ class Runner:
|
|||
(
|
||||
agent.hooks.on_start(context_wrapper, agent)
|
||||
if agent.hooks
|
||||
else _utils.noop_coroutine()
|
||||
else _coro.noop_coroutine()
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -796,7 +796,7 @@ class Runner:
|
|||
# Cancel all guardrail tasks if a tripwire is triggered.
|
||||
for t in guardrail_tasks:
|
||||
t.cancel()
|
||||
_utils.attach_error_to_current_span(
|
||||
_error_tracing.attach_error_to_current_span(
|
||||
SpanError(
|
||||
message="Guardrail tripwire triggered",
|
||||
data={"guardrail": result.guardrail.get_name()},
|
||||
|
|
@ -834,7 +834,7 @@ class Runner:
|
|||
# Cancel all guardrail tasks if a tripwire is triggered.
|
||||
for t in guardrail_tasks:
|
||||
t.cancel()
|
||||
_utils.attach_error_to_current_span(
|
||||
_error_tracing.attach_error_to_current_span(
|
||||
SpanError(
|
||||
message="Guardrail tripwire triggered",
|
||||
data={"guardrail": result.guardrail.get_name()},
|
||||
|
|
|
|||
|
|
@ -11,14 +11,15 @@ from openai.types.responses.web_search_tool_param import UserLocation
|
|||
from pydantic import ValidationError
|
||||
from typing_extensions import Concatenate, ParamSpec
|
||||
|
||||
from . import _debug, _utils
|
||||
from ._utils import MaybeAwaitable
|
||||
from . import _debug
|
||||
from .computer import AsyncComputer, Computer
|
||||
from .exceptions import ModelBehaviorError
|
||||
from .function_schema import DocstringStyle, function_schema
|
||||
from .logger import logger
|
||||
from .run_context import RunContextWrapper
|
||||
from .tracing import SpanError
|
||||
from .util import _error_tracing
|
||||
from .util._types import MaybeAwaitable
|
||||
|
||||
ToolParams = ParamSpec("ToolParams")
|
||||
|
||||
|
|
@ -263,7 +264,7 @@ def function_tool(
|
|||
if inspect.isawaitable(result):
|
||||
return await result
|
||||
|
||||
_utils.attach_error_to_current_span(
|
||||
_error_tracing.attach_error_to_current_span(
|
||||
SpanError(
|
||||
message="Error running tool (non-fatal)",
|
||||
data={
|
||||
|
|
|
|||
0
src/agents/util/__init__.py
Normal file
0
src/agents/util/__init__.py
Normal file
2
src/agents/util/_coro.py
Normal file
2
src/agents/util/_coro.py
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
async def noop_coroutine() -> None:
|
||||
pass
|
||||
16
src/agents/util/_error_tracing.py
Normal file
16
src/agents/util/_error_tracing.py
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
from typing import Any
|
||||
|
||||
from ..logger import logger
|
||||
from ..tracing import Span, SpanError, get_current_span
|
||||
|
||||
|
||||
def attach_error_to_span(span: Span[Any], error: SpanError) -> None:
|
||||
span.set_error(error)
|
||||
|
||||
|
||||
def attach_error_to_current_span(error: SpanError) -> None:
|
||||
span = get_current_span()
|
||||
if span:
|
||||
attach_error_to_span(span, error)
|
||||
else:
|
||||
logger.warning(f"No span to add error {error} to")
|
||||
31
src/agents/util/_json.py
Normal file
31
src/agents/util/_json.py
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
from ..exceptions import ModelBehaviorError
|
||||
from ..tracing import SpanError
|
||||
from ._error_tracing import attach_error_to_current_span
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def validate_json(json_str: str, type_adapter: TypeAdapter[T], partial: bool) -> T:
|
||||
partial_setting: bool | Literal["off", "on", "trailing-strings"] = (
|
||||
"trailing-strings" if partial else False
|
||||
)
|
||||
try:
|
||||
validated = type_adapter.validate_json(json_str, experimental_allow_partial=partial_setting)
|
||||
return validated
|
||||
except ValidationError as e:
|
||||
attach_error_to_current_span(
|
||||
SpanError(
|
||||
message="Invalid JSON provided",
|
||||
data={},
|
||||
)
|
||||
)
|
||||
raise ModelBehaviorError(
|
||||
f"Invalid JSON when parsing {json_str} for {type_adapter}; {e}"
|
||||
) from e
|
||||
11
src/agents/util/_transforms.py
Normal file
11
src/agents/util/_transforms.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
import re
|
||||
|
||||
|
||||
def transform_string_function_style(name: str) -> str:
|
||||
# Replace spaces with underscores
|
||||
name = name.replace(" ", "_")
|
||||
|
||||
# Replace non-alphanumeric characters with underscores
|
||||
name = re.sub(r"[^a-zA-Z0-9]", "_", name)
|
||||
|
||||
return name.lower()
|
||||
7
src/agents/util/_types.py
Normal file
7
src/agents/util/_types.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
from collections.abc import Awaitable
|
||||
from typing import Union
|
||||
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
MaybeAwaitable = Union[Awaitable[T], T]
|
||||
|
|
@ -175,12 +175,11 @@ def multiple_optional_params_function(
|
|||
return f"{x}_{y}_{z}"
|
||||
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_optional_params_function():
|
||||
tool = multiple_optional_params_function
|
||||
|
||||
input_data: dict[str,Any] = {}
|
||||
input_data: dict[str, Any] = {}
|
||||
output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data))
|
||||
assert output == "42_hello_no_z"
|
||||
|
||||
|
|
|
|||
|
|
@ -4,8 +4,9 @@ import pytest
|
|||
from pydantic import BaseModel
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from agents import Agent, AgentOutputSchema, ModelBehaviorError, Runner, UserError, _utils
|
||||
from agents import Agent, AgentOutputSchema, ModelBehaviorError, Runner, UserError
|
||||
from agents.agent_output import _WRAPPER_DICT_KEY
|
||||
from agents.util import _json
|
||||
|
||||
|
||||
def test_plain_text_output():
|
||||
|
|
@ -77,7 +78,7 @@ def test_bad_json_raises_error(mocker):
|
|||
output_schema = Runner._get_output_schema(agent)
|
||||
assert output_schema, "Should have an output tool config with a structured output type"
|
||||
|
||||
mock_validate_json = mocker.patch.object(_utils, "validate_json")
|
||||
mock_validate_json = mocker.patch.object(_json, "validate_json")
|
||||
mock_validate_json.return_value = ["foo"]
|
||||
|
||||
with pytest.raises(ModelBehaviorError):
|
||||
|
|
@ -111,3 +112,4 @@ def test_setting_strict_false_works():
|
|||
output_wrapper = AgentOutputSchema(output_type=Foo, strict_json_schema=False)
|
||||
assert not output_wrapper.strict_json_schema
|
||||
assert output_wrapper.json_schema() == Foo.model_json_schema()
|
||||
assert output_wrapper.json_schema() == Foo.model_json_schema()
|
||||
|
|
|
|||
Loading…
Reference in a new issue