Add support for local shell, image generator, code interpreter tools (#732)

This commit is contained in:
Rohan Mehta 2025-05-21 15:26:22 -04:00 committed by GitHub
parent 9fa5c39d69
commit 079764f0ab
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 334 additions and 20 deletions

View file

@ -0,0 +1,34 @@
import asyncio
from agents import Agent, CodeInterpreterTool, Runner, trace
async def main():
agent = Agent(
name="Code interpreter",
instructions="You love doing math.",
tools=[
CodeInterpreterTool(
tool_config={"type": "code_interpreter", "container": {"type": "auto"}},
)
],
)
with trace("Code interpreter example"):
print("Solving math problem...")
result = Runner.run_streamed(agent, "What is the square root of273 * 312821 plus 1782?")
async for event in result.stream_events():
if (
event.type == "run_item_stream_event"
and event.item.type == "tool_call_item"
and event.item.raw_item.type == "code_interpreter_call"
):
print(f"Code interpreter code:\n```\n{event.item.raw_item.code}\n```\n")
elif event.type == "run_item_stream_event":
print(f"Other event: {event.item.type}")
print(f"Final output: {result.final_output}")
if __name__ == "__main__":
asyncio.run(main())

View file

@ -0,0 +1,54 @@
import asyncio
import base64
import os
import subprocess
import sys
import tempfile
from agents import Agent, ImageGenerationTool, Runner, trace
def open_file(path: str) -> None:
if sys.platform.startswith("darwin"):
subprocess.run(["open", path], check=False) # macOS
elif os.name == "nt": # Windows
os.astartfile(path) # type: ignore
elif os.name == "posix":
subprocess.run(["xdg-open", path], check=False) # Linux/Unix
else:
print(f"Don't know how to open files on this platform: {sys.platform}")
async def main():
agent = Agent(
name="Image generator",
instructions="You are a helpful agent.",
tools=[
ImageGenerationTool(
tool_config={"type": "image_generation", "quality": "low"},
)
],
)
with trace("Image generation example"):
print("Generating image, this may take a while...")
result = await Runner.run(
agent, "Create an image of a frog eating a pizza, comic book style."
)
print(result.final_output)
for item in result.new_items:
if (
item.type == "tool_call_item"
and item.raw_item.type == "image_generation_call"
and (img_result := item.raw_item.result)
):
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
tmp.write(base64.b64decode(img_result))
temp_path = tmp.name
# Open the image
open_file(temp_path)
if __name__ == "__main__":
asyncio.run(main())

View file

@ -54,11 +54,16 @@ from .stream_events import (
StreamEvent,
)
from .tool import (
CodeInterpreterTool,
ComputerTool,
FileSearchTool,
FunctionTool,
FunctionToolResult,
HostedMCPTool,
ImageGenerationTool,
LocalShellCommandRequest,
LocalShellExecutor,
LocalShellTool,
MCPToolApprovalFunction,
MCPToolApprovalFunctionResult,
MCPToolApprovalRequest,
@ -210,6 +215,11 @@ __all__ = [
"FunctionToolResult",
"ComputerTool",
"FileSearchTool",
"CodeInterpreterTool",
"ImageGenerationTool",
"LocalShellCommandRequest",
"LocalShellExecutor",
"LocalShellTool",
"Tool",
"WebSearchTool",
"HostedMCPTool",

View file

@ -14,6 +14,9 @@ from openai.types.responses import (
ResponseFunctionWebSearch,
ResponseOutputMessage,
)
from openai.types.responses.response_code_interpreter_tool_call import (
ResponseCodeInterpreterToolCall,
)
from openai.types.responses.response_computer_tool_call import (
ActionClick,
ActionDoubleClick,
@ -26,7 +29,12 @@ from openai.types.responses.response_computer_tool_call import (
ActionWait,
)
from openai.types.responses.response_input_param import ComputerCallOutput, McpApprovalResponse
from openai.types.responses.response_output_item import McpApprovalRequest, McpCall, McpListTools
from openai.types.responses.response_output_item import (
ImageGenerationCall,
LocalShellCall,
McpApprovalRequest,
McpListTools,
)
from openai.types.responses.response_reasoning_item import ResponseReasoningItem
from .agent import Agent, ToolsToFinalOutputResult
@ -61,6 +69,8 @@ from .tool import (
FunctionTool,
FunctionToolResult,
HostedMCPTool,
LocalShellCommandRequest,
LocalShellTool,
MCPToolApprovalRequest,
Tool,
)
@ -129,12 +139,19 @@ class ToolRunMCPApprovalRequest:
mcp_tool: HostedMCPTool
@dataclass
class ToolRunLocalShellCall:
tool_call: LocalShellCall
local_shell_tool: LocalShellTool
@dataclass
class ProcessedResponse:
new_items: list[RunItem]
handoffs: list[ToolRunHandoff]
functions: list[ToolRunFunction]
computer_actions: list[ToolRunComputerAction]
local_shell_calls: list[ToolRunLocalShellCall]
tools_used: list[str] # Names of all tools used, including hosted tools
mcp_approval_requests: list[ToolRunMCPApprovalRequest] # Only requests with callbacks
@ -146,6 +163,7 @@ class ProcessedResponse:
self.handoffs,
self.functions,
self.computer_actions,
self.local_shell_calls,
self.mcp_approval_requests,
]
)
@ -371,11 +389,15 @@ class RunImpl:
run_handoffs = []
functions = []
computer_actions = []
local_shell_calls = []
mcp_approval_requests = []
tools_used: list[str] = []
handoff_map = {handoff.tool_name: handoff for handoff in handoffs}
function_map = {tool.name: tool for tool in all_tools if isinstance(tool, FunctionTool)}
computer_tool = next((tool for tool in all_tools if isinstance(tool, ComputerTool)), None)
local_shell_tool = next(
(tool for tool in all_tools if isinstance(tool, LocalShellTool)), None
)
hosted_mcp_server_map = {
tool.tool_config["server_label"]: tool
for tool in all_tools
@ -434,9 +456,29 @@ class RunImpl:
)
elif isinstance(output, McpListTools):
items.append(MCPListToolsItem(raw_item=output, agent=agent))
elif isinstance(output, McpCall):
elif isinstance(output, ImageGenerationCall):
items.append(ToolCallItem(raw_item=output, agent=agent))
tools_used.append(output.name)
tools_used.append("image_generation")
elif isinstance(output, ResponseCodeInterpreterToolCall):
items.append(ToolCallItem(raw_item=output, agent=agent))
tools_used.append("code_interpreter")
elif isinstance(output, LocalShellCall):
items.append(ToolCallItem(raw_item=output, agent=agent))
tools_used.append("local_shell")
if not local_shell_tool:
_error_tracing.attach_error_to_current_span(
SpanError(
message="Local shell tool not found",
data={},
)
)
raise ModelBehaviorError(
"Model produced local shell call without a local shell tool."
)
local_shell_calls.append(
ToolRunLocalShellCall(tool_call=output, local_shell_tool=local_shell_tool)
)
elif not isinstance(output, ResponseFunctionToolCall):
logger.warning(f"Unexpected output type, ignoring: {type(output)}")
continue
@ -478,6 +520,7 @@ class RunImpl:
handoffs=run_handoffs,
functions=functions,
computer_actions=computer_actions,
local_shell_calls=local_shell_calls,
tools_used=tools_used,
mcp_approval_requests=mcp_approval_requests,
)
@ -552,6 +595,30 @@ class RunImpl:
for tool_run, result in zip(tool_runs, results)
]
@classmethod
async def execute_local_shell_calls(
cls,
*,
agent: Agent[TContext],
calls: list[ToolRunLocalShellCall],
context_wrapper: RunContextWrapper[TContext],
hooks: RunHooks[TContext],
config: RunConfig,
) -> list[RunItem]:
results: list[RunItem] = []
# Need to run these serially, because each call can affect the local shell state
for call in calls:
results.append(
await LocalShellAction.execute(
agent=agent,
call=call,
hooks=hooks,
context_wrapper=context_wrapper,
config=config,
)
)
return results
@classmethod
async def execute_computer_actions(
cls,
@ -1021,3 +1088,54 @@ class ComputerAction:
await computer.wait()
return await computer.screenshot()
class LocalShellAction:
@classmethod
async def execute(
cls,
*,
agent: Agent[TContext],
call: ToolRunLocalShellCall,
hooks: RunHooks[TContext],
context_wrapper: RunContextWrapper[TContext],
config: RunConfig,
) -> RunItem:
await asyncio.gather(
hooks.on_tool_start(context_wrapper, agent, call.local_shell_tool),
(
agent.hooks.on_tool_start(context_wrapper, agent, call.local_shell_tool)
if agent.hooks
else _coro.noop_coroutine()
),
)
request = LocalShellCommandRequest(
ctx_wrapper=context_wrapper,
data=call.tool_call,
)
output = call.local_shell_tool.executor(request)
if inspect.isawaitable(output):
result = await output
else:
result = output
await asyncio.gather(
hooks.on_tool_end(context_wrapper, agent, call.local_shell_tool, result),
(
agent.hooks.on_tool_end(context_wrapper, agent, call.local_shell_tool, result)
if agent.hooks
else _coro.noop_coroutine()
),
)
return ToolCallOutputItem(
agent=agent,
output=output,
raw_item={
"type": "local_shell_call_output",
"id": call.tool_call.call_id,
"output": result,
# "id": "out" + call.tool_call.id, # TODO remove this, it should be optional
},
)

View file

@ -18,12 +18,22 @@ from openai.types.responses import (
ResponseOutputText,
ResponseStreamEvent,
)
from openai.types.responses.response_code_interpreter_tool_call import (
ResponseCodeInterpreterToolCall,
)
from openai.types.responses.response_input_item_param import (
ComputerCallOutput,
FunctionCallOutput,
LocalShellCallOutput,
McpApprovalResponse,
)
from openai.types.responses.response_output_item import McpApprovalRequest, McpCall, McpListTools
from openai.types.responses.response_output_item import (
ImageGenerationCall,
LocalShellCall,
McpApprovalRequest,
McpCall,
McpListTools,
)
from openai.types.responses.response_reasoning_item import ResponseReasoningItem
from pydantic import BaseModel
from typing_extensions import TypeAlias
@ -114,6 +124,9 @@ ToolCallItemTypes: TypeAlias = Union[
ResponseFileSearchToolCall,
ResponseFunctionWebSearch,
McpCall,
ResponseCodeInterpreterToolCall,
ImageGenerationCall,
LocalShellCall,
]
"""A type that represents a tool call item."""
@ -129,10 +142,12 @@ class ToolCallItem(RunItemBase[ToolCallItemTypes]):
@dataclass
class ToolCallOutputItem(RunItemBase[Union[FunctionCallOutput, ComputerCallOutput]]):
class ToolCallOutputItem(
RunItemBase[Union[FunctionCallOutput, ComputerCallOutput, LocalShellCallOutput]]
):
"""Represents the output of a tool call."""
raw_item: FunctionCallOutput | ComputerCallOutput
raw_item: FunctionCallOutput | ComputerCallOutput | LocalShellCallOutput
"""The raw item from the model."""
output: Any

View file

@ -24,7 +24,17 @@ from ..exceptions import UserError
from ..handoffs import Handoff
from ..items import ItemHelpers, ModelResponse, TResponseInputItem
from ..logger import logger
from ..tool import ComputerTool, FileSearchTool, FunctionTool, HostedMCPTool, Tool, WebSearchTool
from ..tool import (
CodeInterpreterTool,
ComputerTool,
FileSearchTool,
FunctionTool,
HostedMCPTool,
ImageGenerationTool,
LocalShellTool,
Tool,
WebSearchTool,
)
from ..tracing import SpanError, response_span
from ..usage import Usage
from ..version import __version__
@ -295,6 +305,18 @@ class Converter:
return {
"type": "computer_use_preview",
}
elif tool_choice == "image_generation":
return {
"type": "image_generation",
}
elif tool_choice == "code_interpreter":
return {
"type": "code_interpreter",
}
elif tool_choice == "mcp":
return {
"type": "mcp",
}
else:
return {
"type": "function",
@ -386,6 +408,17 @@ class Converter:
elif isinstance(tool, HostedMCPTool):
converted_tool = tool.tool_config
includes = None
elif isinstance(tool, ImageGenerationTool):
converted_tool = tool.tool_config
includes = None
elif isinstance(tool, CodeInterpreterTool):
converted_tool = tool.tool_config
includes = None
elif isinstance(tool, LocalShellTool):
converted_tool = {
"type": "local_shell",
}
includes = None
else:
raise UserError(f"Unknown tool type: {type(tool)}, tool")

View file

@ -7,8 +7,8 @@ from dataclasses import dataclass
from typing import Any, Callable, Literal, Union, overload
from openai.types.responses.file_search_tool_param import Filters, RankingOptions
from openai.types.responses.response_output_item import McpApprovalRequest
from openai.types.responses.tool_param import Mcp
from openai.types.responses.response_output_item import LocalShellCall, McpApprovalRequest
from openai.types.responses.tool_param import CodeInterpreter, ImageGeneration, Mcp
from openai.types.responses.web_search_tool_param import UserLocation
from pydantic import ValidationError
from typing_extensions import Concatenate, NotRequired, ParamSpec, TypedDict
@ -180,7 +180,67 @@ class HostedMCPTool:
return "hosted_mcp"
Tool = Union[FunctionTool, FileSearchTool, WebSearchTool, ComputerTool, HostedMCPTool]
@dataclass
class CodeInterpreterTool:
"""A tool that allows the LLM to execute code in a sandboxed environment."""
tool_config: CodeInterpreter
"""The tool config, which includes the container and other settings."""
@property
def name(self):
return "code_interpreter"
@dataclass
class ImageGenerationTool:
"""A tool that allows the LLM to generate images."""
tool_config: ImageGeneration
"""The tool config, which image generation settings."""
@property
def name(self):
return "image_generation"
@dataclass
class LocalShellCommandRequest:
"""A request to execute a command on a shell."""
ctx_wrapper: RunContextWrapper[Any]
"""The run context."""
data: LocalShellCall
"""The data from the local shell tool call."""
LocalShellExecutor = Callable[[LocalShellCommandRequest], MaybeAwaitable[str]]
"""A function that executes a command on a shell."""
@dataclass
class LocalShellTool:
"""A tool that allows the LLM to execute commands on a shell."""
executor: LocalShellExecutor
"""A function that executes a command on a shell."""
@property
def name(self):
return "local_shell"
Tool = Union[
FunctionTool,
FileSearchTool,
WebSearchTool,
ComputerTool,
HostedMCPTool,
LocalShellTool,
ImageGenerationTool,
CodeInterpreterTool,
]
"""A tool that can be used in an agent."""
@ -358,13 +418,3 @@ def function_tool(
return _create_function_tool(real_func)
return decorator
return decorator
return decorator
return decorator
return decorator
return decorator
return decorator
return decorator
return decorator
return decorator
return decorator