From 09d70c074daf210fbb1a3acd31bc2ac048f9ba26 Mon Sep 17 00:00:00 2001 From: Rohan Mehta Date: Sun, 16 Mar 2025 18:48:45 -0400 Subject: [PATCH] utils directory --- examples/basic/hello_world_jupyter.py | 2 +- src/agents/_run_impl.py | 24 +++++------ src/agents/_utils.py | 61 --------------------------- src/agents/agent.py | 6 +-- src/agents/agent_output.py | 8 ++-- src/agents/guardrail.py | 2 +- src/agents/handoffs.py | 8 ++-- src/agents/run.py | 20 ++++----- src/agents/tool.py | 7 +-- src/agents/util/__init__.py | 0 src/agents/util/_coro.py | 2 + src/agents/util/_error_tracing.py | 16 +++++++ src/agents/util/_json.py | 31 ++++++++++++++ src/agents/util/_transforms.py | 11 +++++ src/agents/util/_types.py | 7 +++ tests/test_function_tool_decorator.py | 3 +- tests/test_output_tool.py | 6 ++- 17 files changed, 111 insertions(+), 103 deletions(-) delete mode 100644 src/agents/_utils.py create mode 100644 src/agents/util/__init__.py create mode 100644 src/agents/util/_coro.py create mode 100644 src/agents/util/_error_tracing.py create mode 100644 src/agents/util/_json.py create mode 100644 src/agents/util/_transforms.py create mode 100644 src/agents/util/_types.py diff --git a/examples/basic/hello_world_jupyter.py b/examples/basic/hello_world_jupyter.py index bb8f14c..c929a7c 100644 --- a/examples/basic/hello_world_jupyter.py +++ b/examples/basic/hello_world_jupyter.py @@ -3,7 +3,7 @@ from agents import Agent, Runner agent = Agent(name="Assistant", instructions="You are a helpful assistant") # Intended for Jupyter notebooks where there's an existing event loop -result = await Runner.run(agent, "Write a haiku about recursion in programming.") # type: ignore[top-level-await] # noqa: F704 +result = await Runner.run(agent, "Write a haiku about recursion in programming.") # type: ignore[top-level-await] # noqa: F704 print(result.final_output) # Code within code loops, diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index 2c84950..c0c0ebd 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -25,7 +25,6 @@ from openai.types.responses.response_computer_tool_call import ( from openai.types.responses.response_input_param import ComputerCallOutput from openai.types.responses.response_reasoning_item import ResponseReasoningItem -from . import _utils from .agent import Agent from .agent_output import AgentOutputSchema from .computer import AsyncComputer, Computer @@ -59,6 +58,7 @@ from .tracing import ( handoff_span, trace, ) +from .util import _coro, _error_tracing if TYPE_CHECKING: from .run import RunConfig @@ -293,7 +293,7 @@ class RunImpl: elif isinstance(output, ResponseComputerToolCall): items.append(ToolCallItem(raw_item=output, agent=agent)) if not computer_tool: - _utils.attach_error_to_current_span( + _error_tracing.attach_error_to_current_span( SpanError( message="Computer tool not found", data={}, @@ -324,7 +324,7 @@ class RunImpl: # Regular function tool call else: if output.name not in function_map: - _utils.attach_error_to_current_span( + _error_tracing.attach_error_to_current_span( SpanError( message="Tool not found", data={"tool_name": output.name}, @@ -368,7 +368,7 @@ class RunImpl: ( agent.hooks.on_tool_start(context_wrapper, agent, func_tool) if agent.hooks - else _utils.noop_coroutine() + else _coro.noop_coroutine() ), func_tool.on_invoke_tool(context_wrapper, tool_call.arguments), ) @@ -378,11 +378,11 @@ class RunImpl: ( agent.hooks.on_tool_end(context_wrapper, agent, func_tool, result) if agent.hooks - else _utils.noop_coroutine() + else _coro.noop_coroutine() ), ) except Exception as e: - _utils.attach_error_to_current_span( + _error_tracing.attach_error_to_current_span( SpanError( message="Error running tool", data={"tool_name": func_tool.name, "error": str(e)}, @@ -502,7 +502,7 @@ class RunImpl: source=agent, ) if agent.hooks - else _utils.noop_coroutine() + else _coro.noop_coroutine() ), ) @@ -520,7 +520,7 @@ class RunImpl: new_items=tuple(new_step_items), ) if not callable(input_filter): - _utils.attach_error_to_span( + _error_tracing.attach_error_to_span( span_handoff, SpanError( message="Invalid input filter", @@ -530,7 +530,7 @@ class RunImpl: raise UserError(f"Invalid input filter: {input_filter}") filtered = input_filter(handoff_input_data) if not isinstance(filtered, HandoffInputData): - _utils.attach_error_to_span( + _error_tracing.attach_error_to_span( span_handoff, SpanError( message="Invalid input filter result", @@ -591,7 +591,7 @@ class RunImpl: hooks.on_agent_end(context_wrapper, agent, final_output), agent.hooks.on_end(context_wrapper, agent, final_output) if agent.hooks - else _utils.noop_coroutine(), + else _coro.noop_coroutine(), ) @classmethod @@ -706,7 +706,7 @@ class ComputerAction: ( agent.hooks.on_tool_start(context_wrapper, agent, action.computer_tool) if agent.hooks - else _utils.noop_coroutine() + else _coro.noop_coroutine() ), output_func, ) @@ -716,7 +716,7 @@ class ComputerAction: ( agent.hooks.on_tool_end(context_wrapper, agent, action.computer_tool, output) if agent.hooks - else _utils.noop_coroutine() + else _coro.noop_coroutine() ), ) diff --git a/src/agents/_utils.py b/src/agents/_utils.py deleted file mode 100644 index 2a0293a..0000000 --- a/src/agents/_utils.py +++ /dev/null @@ -1,61 +0,0 @@ -from __future__ import annotations - -import re -from collections.abc import Awaitable -from typing import Any, Literal, Union - -from pydantic import TypeAdapter, ValidationError -from typing_extensions import TypeVar - -from .exceptions import ModelBehaviorError -from .logger import logger -from .tracing import Span, SpanError, get_current_span - -T = TypeVar("T") - -MaybeAwaitable = Union[Awaitable[T], T] - - -def transform_string_function_style(name: str) -> str: - # Replace spaces with underscores - name = name.replace(" ", "_") - - # Replace non-alphanumeric characters with underscores - name = re.sub(r"[^a-zA-Z0-9]", "_", name) - - return name.lower() - - -def validate_json(json_str: str, type_adapter: TypeAdapter[T], partial: bool) -> T: - partial_setting: bool | Literal["off", "on", "trailing-strings"] = ( - "trailing-strings" if partial else False - ) - try: - validated = type_adapter.validate_json(json_str, experimental_allow_partial=partial_setting) - return validated - except ValidationError as e: - attach_error_to_current_span( - SpanError( - message="Invalid JSON provided", - data={}, - ) - ) - raise ModelBehaviorError( - f"Invalid JSON when parsing {json_str} for {type_adapter}; {e}" - ) from e - - -def attach_error_to_span(span: Span[Any], error: SpanError) -> None: - span.set_error(error) - - -def attach_error_to_current_span(error: SpanError) -> None: - span = get_current_span() - if span: - attach_error_to_span(span, error) - else: - logger.warning(f"No span to add error {error} to") - - -async def noop_coroutine() -> None: - pass diff --git a/src/agents/agent.py b/src/agents/agent.py index 61c0a89..84d0ae9 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -6,8 +6,6 @@ from collections.abc import Awaitable from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Callable, Generic, cast -from . import _utils -from ._utils import MaybeAwaitable from .guardrail import InputGuardrail, OutputGuardrail from .handoffs import Handoff from .items import ItemHelpers @@ -16,6 +14,8 @@ from .model_settings import ModelSettings from .models.interface import Model from .run_context import RunContextWrapper, TContext from .tool import Tool, function_tool +from .util import _transforms +from .util._types import MaybeAwaitable if TYPE_CHECKING: from .lifecycle import AgentHooks @@ -126,7 +126,7 @@ class Agent(Generic[TContext]): """ @function_tool( - name_override=tool_name or _utils.transform_string_function_style(self.name), + name_override=tool_name or _transforms.transform_string_function_style(self.name), description_override=tool_description or "", ) async def run_agent(context: RunContextWrapper, input: str) -> str: diff --git a/src/agents/agent_output.py b/src/agents/agent_output.py index 0c28800..3262c57 100644 --- a/src/agents/agent_output.py +++ b/src/agents/agent_output.py @@ -4,10 +4,10 @@ from typing import Any from pydantic import BaseModel, TypeAdapter from typing_extensions import TypedDict, get_args, get_origin -from . import _utils from .exceptions import ModelBehaviorError, UserError from .strict_schema import ensure_strict_json_schema from .tracing import SpanError +from .util import _error_tracing, _json _WRAPPER_DICT_KEY = "response" @@ -87,10 +87,10 @@ class AgentOutputSchema: """Validate a JSON string against the output type. Returns the validated object, or raises a `ModelBehaviorError` if the JSON is invalid. """ - validated = _utils.validate_json(json_str, self._type_adapter, partial) + validated = _json.validate_json(json_str, self._type_adapter, partial) if self._is_wrapped: if not isinstance(validated, dict): - _utils.attach_error_to_current_span( + _error_tracing.attach_error_to_current_span( SpanError( message="Invalid JSON", data={"details": f"Expected a dict, got {type(validated)}"}, @@ -101,7 +101,7 @@ class AgentOutputSchema: ) if _WRAPPER_DICT_KEY not in validated: - _utils.attach_error_to_current_span( + _error_tracing.attach_error_to_current_span( SpanError( message="Invalid JSON", data={"details": f"Could not find key {_WRAPPER_DICT_KEY} in JSON"}, diff --git a/src/agents/guardrail.py b/src/agents/guardrail.py index 5bebcd6..a96f0f7 100644 --- a/src/agents/guardrail.py +++ b/src/agents/guardrail.py @@ -7,10 +7,10 @@ from typing import TYPE_CHECKING, Any, Callable, Generic, Union, overload from typing_extensions import TypeVar -from ._utils import MaybeAwaitable from .exceptions import UserError from .items import TResponseInputItem from .run_context import RunContextWrapper, TContext +from .util._types import MaybeAwaitable if TYPE_CHECKING: from .agent import Agent diff --git a/src/agents/handoffs.py b/src/agents/handoffs.py index ac15740..686191f 100644 --- a/src/agents/handoffs.py +++ b/src/agents/handoffs.py @@ -8,12 +8,12 @@ from typing import TYPE_CHECKING, Any, Callable, Generic, cast, overload from pydantic import TypeAdapter from typing_extensions import TypeAlias, TypeVar -from . import _utils from .exceptions import ModelBehaviorError, UserError from .items import RunItem, TResponseInputItem from .run_context import RunContextWrapper, TContext from .strict_schema import ensure_strict_json_schema from .tracing.spans import SpanError +from .util import _error_tracing, _json, _transforms if TYPE_CHECKING: from .agent import Agent @@ -104,7 +104,7 @@ class Handoff(Generic[TContext]): @classmethod def default_tool_name(cls, agent: Agent[Any]) -> str: - return _utils.transform_string_function_style(f"transfer_to_{agent.name}") + return _transforms.transform_string_function_style(f"transfer_to_{agent.name}") @classmethod def default_tool_description(cls, agent: Agent[Any]) -> str: @@ -192,7 +192,7 @@ def handoff( ) -> Agent[Any]: if input_type is not None and type_adapter is not None: if input_json is None: - _utils.attach_error_to_current_span( + _error_tracing.attach_error_to_current_span( SpanError( message="Handoff function expected non-null input, but got None", data={"details": "input_json is None"}, @@ -200,7 +200,7 @@ def handoff( ) raise ModelBehaviorError("Handoff function expected non-null input, but got None") - validated_input = _utils.validate_json( + validated_input = _json.validate_json( json_str=input_json, type_adapter=type_adapter, partial=False, diff --git a/src/agents/run.py b/src/agents/run.py index dfff7e3..934400f 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -7,7 +7,6 @@ from typing import Any, cast from openai.types.responses import ResponseCompletedEvent -from . import Model, _utils from ._run_impl import ( NextStepFinalOutput, NextStepHandoff, @@ -33,7 +32,7 @@ from .items import ItemHelpers, ModelResponse, RunItem, TResponseInputItem from .lifecycle import RunHooks from .logger import logger from .model_settings import ModelSettings -from .models.interface import ModelProvider +from .models.interface import Model, ModelProvider from .models.openai_provider import OpenAIProvider from .result import RunResult, RunResultStreaming from .run_context import RunContextWrapper, TContext @@ -41,6 +40,7 @@ from .stream_events import AgentUpdatedStreamEvent, RawResponsesStreamEvent from .tracing import Span, SpanError, agent_span, get_current_trace, trace from .tracing.span_data import AgentSpanData from .usage import Usage +from .util import _coro, _error_tracing DEFAULT_MAX_TURNS = 10 @@ -193,7 +193,7 @@ class Runner: current_turn += 1 if current_turn > max_turns: - _utils.attach_error_to_span( + _error_tracing.attach_error_to_span( current_span, SpanError( message="Max turns exceeded", @@ -447,7 +447,7 @@ class Runner: for done in asyncio.as_completed(guardrail_tasks): result = await done if result.output.tripwire_triggered: - _utils.attach_error_to_span( + _error_tracing.attach_error_to_span( parent_span, SpanError( message="Guardrail tripwire triggered", @@ -511,7 +511,7 @@ class Runner: streamed_result.current_turn = current_turn if current_turn > max_turns: - _utils.attach_error_to_span( + _error_tracing.attach_error_to_span( current_span, SpanError( message="Max turns exceeded", @@ -583,7 +583,7 @@ class Runner: pass except Exception as e: if current_span: - _utils.attach_error_to_span( + _error_tracing.attach_error_to_span( current_span, SpanError( message="Error in agent run", @@ -615,7 +615,7 @@ class Runner: ( agent.hooks.on_start(context_wrapper, agent) if agent.hooks - else _utils.noop_coroutine() + else _coro.noop_coroutine() ), ) @@ -705,7 +705,7 @@ class Runner: ( agent.hooks.on_start(context_wrapper, agent) if agent.hooks - else _utils.noop_coroutine() + else _coro.noop_coroutine() ), ) @@ -796,7 +796,7 @@ class Runner: # Cancel all guardrail tasks if a tripwire is triggered. for t in guardrail_tasks: t.cancel() - _utils.attach_error_to_current_span( + _error_tracing.attach_error_to_current_span( SpanError( message="Guardrail tripwire triggered", data={"guardrail": result.guardrail.get_name()}, @@ -834,7 +834,7 @@ class Runner: # Cancel all guardrail tasks if a tripwire is triggered. for t in guardrail_tasks: t.cancel() - _utils.attach_error_to_current_span( + _error_tracing.attach_error_to_current_span( SpanError( message="Guardrail tripwire triggered", data={"guardrail": result.guardrail.get_name()}, diff --git a/src/agents/tool.py b/src/agents/tool.py index cbe8794..0baf2c0 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -11,14 +11,15 @@ from openai.types.responses.web_search_tool_param import UserLocation from pydantic import ValidationError from typing_extensions import Concatenate, ParamSpec -from . import _debug, _utils -from ._utils import MaybeAwaitable +from . import _debug from .computer import AsyncComputer, Computer from .exceptions import ModelBehaviorError from .function_schema import DocstringStyle, function_schema from .logger import logger from .run_context import RunContextWrapper from .tracing import SpanError +from .util import _error_tracing +from .util._types import MaybeAwaitable ToolParams = ParamSpec("ToolParams") @@ -263,7 +264,7 @@ def function_tool( if inspect.isawaitable(result): return await result - _utils.attach_error_to_current_span( + _error_tracing.attach_error_to_current_span( SpanError( message="Error running tool (non-fatal)", data={ diff --git a/src/agents/util/__init__.py b/src/agents/util/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/agents/util/_coro.py b/src/agents/util/_coro.py new file mode 100644 index 0000000..647ab86 --- /dev/null +++ b/src/agents/util/_coro.py @@ -0,0 +1,2 @@ +async def noop_coroutine() -> None: + pass diff --git a/src/agents/util/_error_tracing.py b/src/agents/util/_error_tracing.py new file mode 100644 index 0000000..09dbb1d --- /dev/null +++ b/src/agents/util/_error_tracing.py @@ -0,0 +1,16 @@ +from typing import Any + +from ..logger import logger +from ..tracing import Span, SpanError, get_current_span + + +def attach_error_to_span(span: Span[Any], error: SpanError) -> None: + span.set_error(error) + + +def attach_error_to_current_span(error: SpanError) -> None: + span = get_current_span() + if span: + attach_error_to_span(span, error) + else: + logger.warning(f"No span to add error {error} to") diff --git a/src/agents/util/_json.py b/src/agents/util/_json.py new file mode 100644 index 0000000..1e081f6 --- /dev/null +++ b/src/agents/util/_json.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +from typing import Literal + +from pydantic import TypeAdapter, ValidationError +from typing_extensions import TypeVar + +from ..exceptions import ModelBehaviorError +from ..tracing import SpanError +from ._error_tracing import attach_error_to_current_span + +T = TypeVar("T") + + +def validate_json(json_str: str, type_adapter: TypeAdapter[T], partial: bool) -> T: + partial_setting: bool | Literal["off", "on", "trailing-strings"] = ( + "trailing-strings" if partial else False + ) + try: + validated = type_adapter.validate_json(json_str, experimental_allow_partial=partial_setting) + return validated + except ValidationError as e: + attach_error_to_current_span( + SpanError( + message="Invalid JSON provided", + data={}, + ) + ) + raise ModelBehaviorError( + f"Invalid JSON when parsing {json_str} for {type_adapter}; {e}" + ) from e diff --git a/src/agents/util/_transforms.py b/src/agents/util/_transforms.py new file mode 100644 index 0000000..b303074 --- /dev/null +++ b/src/agents/util/_transforms.py @@ -0,0 +1,11 @@ +import re + + +def transform_string_function_style(name: str) -> str: + # Replace spaces with underscores + name = name.replace(" ", "_") + + # Replace non-alphanumeric characters with underscores + name = re.sub(r"[^a-zA-Z0-9]", "_", name) + + return name.lower() diff --git a/src/agents/util/_types.py b/src/agents/util/_types.py new file mode 100644 index 0000000..8571a69 --- /dev/null +++ b/src/agents/util/_types.py @@ -0,0 +1,7 @@ +from collections.abc import Awaitable +from typing import Union + +from typing_extensions import TypeVar + +T = TypeVar("T") +MaybeAwaitable = Union[Awaitable[T], T] diff --git a/tests/test_function_tool_decorator.py b/tests/test_function_tool_decorator.py index b581660..f146ec7 100644 --- a/tests/test_function_tool_decorator.py +++ b/tests/test_function_tool_decorator.py @@ -175,12 +175,11 @@ def multiple_optional_params_function( return f"{x}_{y}_{z}" - @pytest.mark.asyncio async def test_multiple_optional_params_function(): tool = multiple_optional_params_function - input_data: dict[str,Any] = {} + input_data: dict[str, Any] = {} output = await tool.on_invoke_tool(ctx_wrapper(), json.dumps(input_data)) assert output == "42_hello_no_z" diff --git a/tests/test_output_tool.py b/tests/test_output_tool.py index 31ac984..86c4b3b 100644 --- a/tests/test_output_tool.py +++ b/tests/test_output_tool.py @@ -4,8 +4,9 @@ import pytest from pydantic import BaseModel from typing_extensions import TypedDict -from agents import Agent, AgentOutputSchema, ModelBehaviorError, Runner, UserError, _utils +from agents import Agent, AgentOutputSchema, ModelBehaviorError, Runner, UserError from agents.agent_output import _WRAPPER_DICT_KEY +from agents.util import _json def test_plain_text_output(): @@ -77,7 +78,7 @@ def test_bad_json_raises_error(mocker): output_schema = Runner._get_output_schema(agent) assert output_schema, "Should have an output tool config with a structured output type" - mock_validate_json = mocker.patch.object(_utils, "validate_json") + mock_validate_json = mocker.patch.object(_json, "validate_json") mock_validate_json.return_value = ["foo"] with pytest.raises(ModelBehaviorError): @@ -111,3 +112,4 @@ def test_setting_strict_false_works(): output_wrapper = AgentOutputSchema(output_type=Foo, strict_json_schema=False) assert not output_wrapper.strict_json_schema assert output_wrapper.json_schema() == Foo.model_json_schema() + assert output_wrapper.json_schema() == Foo.model_json_schema()