Add support for local shell, image generator, code interpreter tools (#732)
This commit is contained in:
parent
9fa5c39d69
commit
079764f0ab
7 changed files with 334 additions and 20 deletions
34
examples/tools/code_interpreter.py
Normal file
34
examples/tools/code_interpreter.py
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
import asyncio
|
||||
|
||||
from agents import Agent, CodeInterpreterTool, Runner, trace
|
||||
|
||||
|
||||
async def main():
|
||||
agent = Agent(
|
||||
name="Code interpreter",
|
||||
instructions="You love doing math.",
|
||||
tools=[
|
||||
CodeInterpreterTool(
|
||||
tool_config={"type": "code_interpreter", "container": {"type": "auto"}},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
with trace("Code interpreter example"):
|
||||
print("Solving math problem...")
|
||||
result = Runner.run_streamed(agent, "What is the square root of273 * 312821 plus 1782?")
|
||||
async for event in result.stream_events():
|
||||
if (
|
||||
event.type == "run_item_stream_event"
|
||||
and event.item.type == "tool_call_item"
|
||||
and event.item.raw_item.type == "code_interpreter_call"
|
||||
):
|
||||
print(f"Code interpreter code:\n```\n{event.item.raw_item.code}\n```\n")
|
||||
elif event.type == "run_item_stream_event":
|
||||
print(f"Other event: {event.item.type}")
|
||||
|
||||
print(f"Final output: {result.final_output}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
54
examples/tools/image_generator.py
Normal file
54
examples/tools/image_generator.py
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
import asyncio
|
||||
import base64
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
|
||||
from agents import Agent, ImageGenerationTool, Runner, trace
|
||||
|
||||
|
||||
def open_file(path: str) -> None:
|
||||
if sys.platform.startswith("darwin"):
|
||||
subprocess.run(["open", path], check=False) # macOS
|
||||
elif os.name == "nt": # Windows
|
||||
os.astartfile(path) # type: ignore
|
||||
elif os.name == "posix":
|
||||
subprocess.run(["xdg-open", path], check=False) # Linux/Unix
|
||||
else:
|
||||
print(f"Don't know how to open files on this platform: {sys.platform}")
|
||||
|
||||
|
||||
async def main():
|
||||
agent = Agent(
|
||||
name="Image generator",
|
||||
instructions="You are a helpful agent.",
|
||||
tools=[
|
||||
ImageGenerationTool(
|
||||
tool_config={"type": "image_generation", "quality": "low"},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
with trace("Image generation example"):
|
||||
print("Generating image, this may take a while...")
|
||||
result = await Runner.run(
|
||||
agent, "Create an image of a frog eating a pizza, comic book style."
|
||||
)
|
||||
print(result.final_output)
|
||||
for item in result.new_items:
|
||||
if (
|
||||
item.type == "tool_call_item"
|
||||
and item.raw_item.type == "image_generation_call"
|
||||
and (img_result := item.raw_item.result)
|
||||
):
|
||||
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp:
|
||||
tmp.write(base64.b64decode(img_result))
|
||||
temp_path = tmp.name
|
||||
|
||||
# Open the image
|
||||
open_file(temp_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
|
|
@ -54,11 +54,16 @@ from .stream_events import (
|
|||
StreamEvent,
|
||||
)
|
||||
from .tool import (
|
||||
CodeInterpreterTool,
|
||||
ComputerTool,
|
||||
FileSearchTool,
|
||||
FunctionTool,
|
||||
FunctionToolResult,
|
||||
HostedMCPTool,
|
||||
ImageGenerationTool,
|
||||
LocalShellCommandRequest,
|
||||
LocalShellExecutor,
|
||||
LocalShellTool,
|
||||
MCPToolApprovalFunction,
|
||||
MCPToolApprovalFunctionResult,
|
||||
MCPToolApprovalRequest,
|
||||
|
|
@ -210,6 +215,11 @@ __all__ = [
|
|||
"FunctionToolResult",
|
||||
"ComputerTool",
|
||||
"FileSearchTool",
|
||||
"CodeInterpreterTool",
|
||||
"ImageGenerationTool",
|
||||
"LocalShellCommandRequest",
|
||||
"LocalShellExecutor",
|
||||
"LocalShellTool",
|
||||
"Tool",
|
||||
"WebSearchTool",
|
||||
"HostedMCPTool",
|
||||
|
|
|
|||
|
|
@ -14,6 +14,9 @@ from openai.types.responses import (
|
|||
ResponseFunctionWebSearch,
|
||||
ResponseOutputMessage,
|
||||
)
|
||||
from openai.types.responses.response_code_interpreter_tool_call import (
|
||||
ResponseCodeInterpreterToolCall,
|
||||
)
|
||||
from openai.types.responses.response_computer_tool_call import (
|
||||
ActionClick,
|
||||
ActionDoubleClick,
|
||||
|
|
@ -26,7 +29,12 @@ from openai.types.responses.response_computer_tool_call import (
|
|||
ActionWait,
|
||||
)
|
||||
from openai.types.responses.response_input_param import ComputerCallOutput, McpApprovalResponse
|
||||
from openai.types.responses.response_output_item import McpApprovalRequest, McpCall, McpListTools
|
||||
from openai.types.responses.response_output_item import (
|
||||
ImageGenerationCall,
|
||||
LocalShellCall,
|
||||
McpApprovalRequest,
|
||||
McpListTools,
|
||||
)
|
||||
from openai.types.responses.response_reasoning_item import ResponseReasoningItem
|
||||
|
||||
from .agent import Agent, ToolsToFinalOutputResult
|
||||
|
|
@ -61,6 +69,8 @@ from .tool import (
|
|||
FunctionTool,
|
||||
FunctionToolResult,
|
||||
HostedMCPTool,
|
||||
LocalShellCommandRequest,
|
||||
LocalShellTool,
|
||||
MCPToolApprovalRequest,
|
||||
Tool,
|
||||
)
|
||||
|
|
@ -129,12 +139,19 @@ class ToolRunMCPApprovalRequest:
|
|||
mcp_tool: HostedMCPTool
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolRunLocalShellCall:
|
||||
tool_call: LocalShellCall
|
||||
local_shell_tool: LocalShellTool
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProcessedResponse:
|
||||
new_items: list[RunItem]
|
||||
handoffs: list[ToolRunHandoff]
|
||||
functions: list[ToolRunFunction]
|
||||
computer_actions: list[ToolRunComputerAction]
|
||||
local_shell_calls: list[ToolRunLocalShellCall]
|
||||
tools_used: list[str] # Names of all tools used, including hosted tools
|
||||
mcp_approval_requests: list[ToolRunMCPApprovalRequest] # Only requests with callbacks
|
||||
|
||||
|
|
@ -146,6 +163,7 @@ class ProcessedResponse:
|
|||
self.handoffs,
|
||||
self.functions,
|
||||
self.computer_actions,
|
||||
self.local_shell_calls,
|
||||
self.mcp_approval_requests,
|
||||
]
|
||||
)
|
||||
|
|
@ -371,11 +389,15 @@ class RunImpl:
|
|||
run_handoffs = []
|
||||
functions = []
|
||||
computer_actions = []
|
||||
local_shell_calls = []
|
||||
mcp_approval_requests = []
|
||||
tools_used: list[str] = []
|
||||
handoff_map = {handoff.tool_name: handoff for handoff in handoffs}
|
||||
function_map = {tool.name: tool for tool in all_tools if isinstance(tool, FunctionTool)}
|
||||
computer_tool = next((tool for tool in all_tools if isinstance(tool, ComputerTool)), None)
|
||||
local_shell_tool = next(
|
||||
(tool for tool in all_tools if isinstance(tool, LocalShellTool)), None
|
||||
)
|
||||
hosted_mcp_server_map = {
|
||||
tool.tool_config["server_label"]: tool
|
||||
for tool in all_tools
|
||||
|
|
@ -434,9 +456,29 @@ class RunImpl:
|
|||
)
|
||||
elif isinstance(output, McpListTools):
|
||||
items.append(MCPListToolsItem(raw_item=output, agent=agent))
|
||||
elif isinstance(output, McpCall):
|
||||
elif isinstance(output, ImageGenerationCall):
|
||||
items.append(ToolCallItem(raw_item=output, agent=agent))
|
||||
tools_used.append(output.name)
|
||||
tools_used.append("image_generation")
|
||||
elif isinstance(output, ResponseCodeInterpreterToolCall):
|
||||
items.append(ToolCallItem(raw_item=output, agent=agent))
|
||||
tools_used.append("code_interpreter")
|
||||
elif isinstance(output, LocalShellCall):
|
||||
items.append(ToolCallItem(raw_item=output, agent=agent))
|
||||
tools_used.append("local_shell")
|
||||
if not local_shell_tool:
|
||||
_error_tracing.attach_error_to_current_span(
|
||||
SpanError(
|
||||
message="Local shell tool not found",
|
||||
data={},
|
||||
)
|
||||
)
|
||||
raise ModelBehaviorError(
|
||||
"Model produced local shell call without a local shell tool."
|
||||
)
|
||||
local_shell_calls.append(
|
||||
ToolRunLocalShellCall(tool_call=output, local_shell_tool=local_shell_tool)
|
||||
)
|
||||
|
||||
elif not isinstance(output, ResponseFunctionToolCall):
|
||||
logger.warning(f"Unexpected output type, ignoring: {type(output)}")
|
||||
continue
|
||||
|
|
@ -478,6 +520,7 @@ class RunImpl:
|
|||
handoffs=run_handoffs,
|
||||
functions=functions,
|
||||
computer_actions=computer_actions,
|
||||
local_shell_calls=local_shell_calls,
|
||||
tools_used=tools_used,
|
||||
mcp_approval_requests=mcp_approval_requests,
|
||||
)
|
||||
|
|
@ -552,6 +595,30 @@ class RunImpl:
|
|||
for tool_run, result in zip(tool_runs, results)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
async def execute_local_shell_calls(
|
||||
cls,
|
||||
*,
|
||||
agent: Agent[TContext],
|
||||
calls: list[ToolRunLocalShellCall],
|
||||
context_wrapper: RunContextWrapper[TContext],
|
||||
hooks: RunHooks[TContext],
|
||||
config: RunConfig,
|
||||
) -> list[RunItem]:
|
||||
results: list[RunItem] = []
|
||||
# Need to run these serially, because each call can affect the local shell state
|
||||
for call in calls:
|
||||
results.append(
|
||||
await LocalShellAction.execute(
|
||||
agent=agent,
|
||||
call=call,
|
||||
hooks=hooks,
|
||||
context_wrapper=context_wrapper,
|
||||
config=config,
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
@classmethod
|
||||
async def execute_computer_actions(
|
||||
cls,
|
||||
|
|
@ -1021,3 +1088,54 @@ class ComputerAction:
|
|||
await computer.wait()
|
||||
|
||||
return await computer.screenshot()
|
||||
|
||||
|
||||
class LocalShellAction:
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
*,
|
||||
agent: Agent[TContext],
|
||||
call: ToolRunLocalShellCall,
|
||||
hooks: RunHooks[TContext],
|
||||
context_wrapper: RunContextWrapper[TContext],
|
||||
config: RunConfig,
|
||||
) -> RunItem:
|
||||
await asyncio.gather(
|
||||
hooks.on_tool_start(context_wrapper, agent, call.local_shell_tool),
|
||||
(
|
||||
agent.hooks.on_tool_start(context_wrapper, agent, call.local_shell_tool)
|
||||
if agent.hooks
|
||||
else _coro.noop_coroutine()
|
||||
),
|
||||
)
|
||||
|
||||
request = LocalShellCommandRequest(
|
||||
ctx_wrapper=context_wrapper,
|
||||
data=call.tool_call,
|
||||
)
|
||||
output = call.local_shell_tool.executor(request)
|
||||
if inspect.isawaitable(output):
|
||||
result = await output
|
||||
else:
|
||||
result = output
|
||||
|
||||
await asyncio.gather(
|
||||
hooks.on_tool_end(context_wrapper, agent, call.local_shell_tool, result),
|
||||
(
|
||||
agent.hooks.on_tool_end(context_wrapper, agent, call.local_shell_tool, result)
|
||||
if agent.hooks
|
||||
else _coro.noop_coroutine()
|
||||
),
|
||||
)
|
||||
|
||||
return ToolCallOutputItem(
|
||||
agent=agent,
|
||||
output=output,
|
||||
raw_item={
|
||||
"type": "local_shell_call_output",
|
||||
"id": call.tool_call.call_id,
|
||||
"output": result,
|
||||
# "id": "out" + call.tool_call.id, # TODO remove this, it should be optional
|
||||
},
|
||||
)
|
||||
|
|
|
|||
|
|
@ -18,12 +18,22 @@ from openai.types.responses import (
|
|||
ResponseOutputText,
|
||||
ResponseStreamEvent,
|
||||
)
|
||||
from openai.types.responses.response_code_interpreter_tool_call import (
|
||||
ResponseCodeInterpreterToolCall,
|
||||
)
|
||||
from openai.types.responses.response_input_item_param import (
|
||||
ComputerCallOutput,
|
||||
FunctionCallOutput,
|
||||
LocalShellCallOutput,
|
||||
McpApprovalResponse,
|
||||
)
|
||||
from openai.types.responses.response_output_item import McpApprovalRequest, McpCall, McpListTools
|
||||
from openai.types.responses.response_output_item import (
|
||||
ImageGenerationCall,
|
||||
LocalShellCall,
|
||||
McpApprovalRequest,
|
||||
McpCall,
|
||||
McpListTools,
|
||||
)
|
||||
from openai.types.responses.response_reasoning_item import ResponseReasoningItem
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import TypeAlias
|
||||
|
|
@ -114,6 +124,9 @@ ToolCallItemTypes: TypeAlias = Union[
|
|||
ResponseFileSearchToolCall,
|
||||
ResponseFunctionWebSearch,
|
||||
McpCall,
|
||||
ResponseCodeInterpreterToolCall,
|
||||
ImageGenerationCall,
|
||||
LocalShellCall,
|
||||
]
|
||||
"""A type that represents a tool call item."""
|
||||
|
||||
|
|
@ -129,10 +142,12 @@ class ToolCallItem(RunItemBase[ToolCallItemTypes]):
|
|||
|
||||
|
||||
@dataclass
|
||||
class ToolCallOutputItem(RunItemBase[Union[FunctionCallOutput, ComputerCallOutput]]):
|
||||
class ToolCallOutputItem(
|
||||
RunItemBase[Union[FunctionCallOutput, ComputerCallOutput, LocalShellCallOutput]]
|
||||
):
|
||||
"""Represents the output of a tool call."""
|
||||
|
||||
raw_item: FunctionCallOutput | ComputerCallOutput
|
||||
raw_item: FunctionCallOutput | ComputerCallOutput | LocalShellCallOutput
|
||||
"""The raw item from the model."""
|
||||
|
||||
output: Any
|
||||
|
|
|
|||
|
|
@ -24,7 +24,17 @@ from ..exceptions import UserError
|
|||
from ..handoffs import Handoff
|
||||
from ..items import ItemHelpers, ModelResponse, TResponseInputItem
|
||||
from ..logger import logger
|
||||
from ..tool import ComputerTool, FileSearchTool, FunctionTool, HostedMCPTool, Tool, WebSearchTool
|
||||
from ..tool import (
|
||||
CodeInterpreterTool,
|
||||
ComputerTool,
|
||||
FileSearchTool,
|
||||
FunctionTool,
|
||||
HostedMCPTool,
|
||||
ImageGenerationTool,
|
||||
LocalShellTool,
|
||||
Tool,
|
||||
WebSearchTool,
|
||||
)
|
||||
from ..tracing import SpanError, response_span
|
||||
from ..usage import Usage
|
||||
from ..version import __version__
|
||||
|
|
@ -295,6 +305,18 @@ class Converter:
|
|||
return {
|
||||
"type": "computer_use_preview",
|
||||
}
|
||||
elif tool_choice == "image_generation":
|
||||
return {
|
||||
"type": "image_generation",
|
||||
}
|
||||
elif tool_choice == "code_interpreter":
|
||||
return {
|
||||
"type": "code_interpreter",
|
||||
}
|
||||
elif tool_choice == "mcp":
|
||||
return {
|
||||
"type": "mcp",
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"type": "function",
|
||||
|
|
@ -386,6 +408,17 @@ class Converter:
|
|||
elif isinstance(tool, HostedMCPTool):
|
||||
converted_tool = tool.tool_config
|
||||
includes = None
|
||||
elif isinstance(tool, ImageGenerationTool):
|
||||
converted_tool = tool.tool_config
|
||||
includes = None
|
||||
elif isinstance(tool, CodeInterpreterTool):
|
||||
converted_tool = tool.tool_config
|
||||
includes = None
|
||||
elif isinstance(tool, LocalShellTool):
|
||||
converted_tool = {
|
||||
"type": "local_shell",
|
||||
}
|
||||
includes = None
|
||||
else:
|
||||
raise UserError(f"Unknown tool type: {type(tool)}, tool")
|
||||
|
||||
|
|
|
|||
|
|
@ -7,8 +7,8 @@ from dataclasses import dataclass
|
|||
from typing import Any, Callable, Literal, Union, overload
|
||||
|
||||
from openai.types.responses.file_search_tool_param import Filters, RankingOptions
|
||||
from openai.types.responses.response_output_item import McpApprovalRequest
|
||||
from openai.types.responses.tool_param import Mcp
|
||||
from openai.types.responses.response_output_item import LocalShellCall, McpApprovalRequest
|
||||
from openai.types.responses.tool_param import CodeInterpreter, ImageGeneration, Mcp
|
||||
from openai.types.responses.web_search_tool_param import UserLocation
|
||||
from pydantic import ValidationError
|
||||
from typing_extensions import Concatenate, NotRequired, ParamSpec, TypedDict
|
||||
|
|
@ -180,7 +180,67 @@ class HostedMCPTool:
|
|||
return "hosted_mcp"
|
||||
|
||||
|
||||
Tool = Union[FunctionTool, FileSearchTool, WebSearchTool, ComputerTool, HostedMCPTool]
|
||||
@dataclass
|
||||
class CodeInterpreterTool:
|
||||
"""A tool that allows the LLM to execute code in a sandboxed environment."""
|
||||
|
||||
tool_config: CodeInterpreter
|
||||
"""The tool config, which includes the container and other settings."""
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return "code_interpreter"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ImageGenerationTool:
|
||||
"""A tool that allows the LLM to generate images."""
|
||||
|
||||
tool_config: ImageGeneration
|
||||
"""The tool config, which image generation settings."""
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return "image_generation"
|
||||
|
||||
|
||||
@dataclass
|
||||
class LocalShellCommandRequest:
|
||||
"""A request to execute a command on a shell."""
|
||||
|
||||
ctx_wrapper: RunContextWrapper[Any]
|
||||
"""The run context."""
|
||||
|
||||
data: LocalShellCall
|
||||
"""The data from the local shell tool call."""
|
||||
|
||||
|
||||
LocalShellExecutor = Callable[[LocalShellCommandRequest], MaybeAwaitable[str]]
|
||||
"""A function that executes a command on a shell."""
|
||||
|
||||
|
||||
@dataclass
|
||||
class LocalShellTool:
|
||||
"""A tool that allows the LLM to execute commands on a shell."""
|
||||
|
||||
executor: LocalShellExecutor
|
||||
"""A function that executes a command on a shell."""
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return "local_shell"
|
||||
|
||||
|
||||
Tool = Union[
|
||||
FunctionTool,
|
||||
FileSearchTool,
|
||||
WebSearchTool,
|
||||
ComputerTool,
|
||||
HostedMCPTool,
|
||||
LocalShellTool,
|
||||
ImageGenerationTool,
|
||||
CodeInterpreterTool,
|
||||
]
|
||||
"""A tool that can be used in an agent."""
|
||||
|
||||
|
||||
|
|
@ -358,13 +418,3 @@ def function_tool(
|
|||
return _create_function_tool(real_func)
|
||||
|
||||
return decorator
|
||||
return decorator
|
||||
return decorator
|
||||
return decorator
|
||||
return decorator
|
||||
return decorator
|
||||
return decorator
|
||||
return decorator
|
||||
return decorator
|
||||
return decorator
|
||||
return decorator
|
||||
|
|
|
|||
Loading…
Reference in a new issue