feat: Add Graphviz-based agent visualization functionality (#147)
This pull request introduces functionality for visualizing agent
structures using Graphviz. The changes include adding a new dependency,
implementing functions to generate and draw graphs, and adding tests for
these functions.
New functionality for visualizing agent structures:
* Added `graphviz` as a new dependency in `pyproject.toml`.
* Implemented functions in `src/agents/visualizations.py` to generate
and draw graphs for agents using Graphviz. These functions include
`get_main_graph`, `get_all_nodes`, `get_all_edges`, and `draw_graph`.
Testing the new visualization functionality:
* Added tests in `tests/test_visualizations.py` to verify the
correctness of the graph generation and drawing functions. The tests
cover `get_main_graph`, `get_all_nodes`, `get_all_edges`, and
`draw_graph`.
For example, given the following code:
```python
from agents import Agent, function_tool
from agents.visualizations import draw_graph
@function_tool
def get_weather(city: str) -> str:
return f"The weather in {city} is sunny."
spanish_agent = Agent(
name="Spanish agent",
instructions="You only speak Spanish.",
)
english_agent = Agent(
name="English agent",
instructions="You only speak English",
)
triage_agent = Agent(
name="Triage agent",
instructions="Handoff to the appropriate agent based on the language of the request.",
handoffs=[spanish_agent, english_agent],
tools=[get_weather],
)
draw_graph(triage_agent)
```
Generates the following image:
<img width="614" alt="Screenshot 2025-03-13 at 18 36 23"
src="https://github.com/user-attachments/assets/d01fe502-6886-4efb-aaf8-c92e4524b0fe"
/>
This commit is contained in:
commit
dd881eed9a
8 changed files with 381 additions and 1 deletions
BIN
docs/assets/images/graph.png
Normal file
BIN
docs/assets/images/graph.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 93 KiB |
86
docs/visualization.md
Normal file
86
docs/visualization.md
Normal file
|
|
@ -0,0 +1,86 @@
|
|||
# Agent Visualization
|
||||
|
||||
Agent visualization allows you to generate a structured graphical representation of agents and their relationships using **Graphviz**. This is useful for understanding how agents, tools, and handoffs interact within an application.
|
||||
|
||||
## Installation
|
||||
|
||||
Install the optional `viz` dependency group:
|
||||
|
||||
```bash
|
||||
pip install "openai-agents[viz]"
|
||||
```
|
||||
|
||||
## Generating a Graph
|
||||
|
||||
You can generate an agent visualization using the `draw_graph` function. This function creates a directed graph where:
|
||||
|
||||
- **Agents** are represented as yellow boxes.
|
||||
- **Tools** are represented as green ellipses.
|
||||
- **Handoffs** are directed edges from one agent to another.
|
||||
|
||||
### Example Usage
|
||||
|
||||
```python
|
||||
from agents import Agent, function_tool
|
||||
from agents.extensions.visualization import draw_graph
|
||||
|
||||
@function_tool
|
||||
def get_weather(city: str) -> str:
|
||||
return f"The weather in {city} is sunny."
|
||||
|
||||
spanish_agent = Agent(
|
||||
name="Spanish agent",
|
||||
instructions="You only speak Spanish.",
|
||||
)
|
||||
|
||||
english_agent = Agent(
|
||||
name="English agent",
|
||||
instructions="You only speak English",
|
||||
)
|
||||
|
||||
triage_agent = Agent(
|
||||
name="Triage agent",
|
||||
instructions="Handoff to the appropriate agent based on the language of the request.",
|
||||
handoffs=[spanish_agent, english_agent],
|
||||
tools=[get_weather],
|
||||
)
|
||||
|
||||
draw_graph(triage_agent)
|
||||
```
|
||||
|
||||

|
||||
|
||||
This generates a graph that visually represents the structure of the **triage agent** and its connections to sub-agents and tools.
|
||||
|
||||
|
||||
## Understanding the Visualization
|
||||
|
||||
The generated graph includes:
|
||||
|
||||
- A **start node** (`__start__`) indicating the entry point.
|
||||
- Agents represented as **rectangles** with yellow fill.
|
||||
- Tools represented as **ellipses** with green fill.
|
||||
- Directed edges indicating interactions:
|
||||
- **Solid arrows** for agent-to-agent handoffs.
|
||||
- **Dotted arrows** for tool invocations.
|
||||
- An **end node** (`__end__`) indicating where execution terminates.
|
||||
|
||||
## Customizing the Graph
|
||||
|
||||
### Showing the Graph
|
||||
By default, `draw_graph` displays the graph inline. To show the graph in a separate window, write the following:
|
||||
|
||||
```python
|
||||
draw_graph(triage_agent).view()
|
||||
```
|
||||
|
||||
### Saving the Graph
|
||||
By default, `draw_graph` displays the graph inline. To save it as a file, specify a filename:
|
||||
|
||||
```python
|
||||
draw_graph(triage_agent, filename="agent_graph.png")
|
||||
```
|
||||
|
||||
This will generate `agent_graph.png` in the working directory.
|
||||
|
||||
|
||||
|
|
@ -36,6 +36,7 @@ nav:
|
|||
- multi_agent.md
|
||||
- models.md
|
||||
- config.md
|
||||
- visualization.md
|
||||
- Voice agents:
|
||||
- voice/quickstart.md
|
||||
- voice/pipeline.md
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ Repository = "https://github.com/openai/openai-agents-python"
|
|||
|
||||
[project.optional-dependencies]
|
||||
voice = ["numpy>=2.2.0, <3; python_version>='3.10'", "websockets>=15.0, <16"]
|
||||
viz = ["graphviz>=0.17"]
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
|
|
@ -56,7 +57,9 @@ dev = [
|
|||
"pynput",
|
||||
"textual",
|
||||
"websockets",
|
||||
"graphviz",
|
||||
]
|
||||
|
||||
[tool.uv.workspace]
|
||||
members = ["agents"]
|
||||
|
||||
|
|
|
|||
137
src/agents/extensions/visualization.py
Normal file
137
src/agents/extensions/visualization.py
Normal file
|
|
@ -0,0 +1,137 @@
|
|||
from typing import Optional
|
||||
|
||||
import graphviz # type: ignore
|
||||
|
||||
from agents import Agent
|
||||
from agents.handoffs import Handoff
|
||||
from agents.tool import Tool
|
||||
|
||||
|
||||
def get_main_graph(agent: Agent) -> str:
|
||||
"""
|
||||
Generates the main graph structure in DOT format for the given agent.
|
||||
|
||||
Args:
|
||||
agent (Agent): The agent for which the graph is to be generated.
|
||||
|
||||
Returns:
|
||||
str: The DOT format string representing the graph.
|
||||
"""
|
||||
parts = [
|
||||
"""
|
||||
digraph G {
|
||||
graph [splines=true];
|
||||
node [fontname="Arial"];
|
||||
edge [penwidth=1.5];
|
||||
"""
|
||||
]
|
||||
parts.append(get_all_nodes(agent))
|
||||
parts.append(get_all_edges(agent))
|
||||
parts.append("}")
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str:
|
||||
"""
|
||||
Recursively generates the nodes for the given agent and its handoffs in DOT format.
|
||||
|
||||
Args:
|
||||
agent (Agent): The agent for which the nodes are to be generated.
|
||||
|
||||
Returns:
|
||||
str: The DOT format string representing the nodes.
|
||||
"""
|
||||
parts = []
|
||||
|
||||
# Start and end the graph
|
||||
parts.append(
|
||||
'"__start__" [label="__start__", shape=ellipse, style=filled, '
|
||||
"fillcolor=lightblue, width=0.5, height=0.3];"
|
||||
'"__end__" [label="__end__", shape=ellipse, style=filled, '
|
||||
"fillcolor=lightblue, width=0.5, height=0.3];"
|
||||
)
|
||||
# Ensure parent agent node is colored
|
||||
if not parent:
|
||||
parts.append(
|
||||
f'"{agent.name}" [label="{agent.name}", shape=box, style=filled, '
|
||||
"fillcolor=lightyellow, width=1.5, height=0.8];"
|
||||
)
|
||||
|
||||
for tool in agent.tools:
|
||||
parts.append(
|
||||
f'"{tool.name}" [label="{tool.name}", shape=ellipse, style=filled, '
|
||||
f"fillcolor=lightgreen, width=0.5, height=0.3];"
|
||||
)
|
||||
|
||||
for handoff in agent.handoffs:
|
||||
if isinstance(handoff, Handoff):
|
||||
parts.append(
|
||||
f'"{handoff.agent_name}" [label="{handoff.agent_name}", '
|
||||
f"shape=box, style=filled, style=rounded, "
|
||||
f"fillcolor=lightyellow, width=1.5, height=0.8];"
|
||||
)
|
||||
if isinstance(handoff, Agent):
|
||||
parts.append(
|
||||
f'"{handoff.name}" [label="{handoff.name}", '
|
||||
f"shape=box, style=filled, style=rounded, "
|
||||
f"fillcolor=lightyellow, width=1.5, height=0.8];"
|
||||
)
|
||||
parts.append(get_all_nodes(handoff))
|
||||
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str:
|
||||
"""
|
||||
Recursively generates the edges for the given agent and its handoffs in DOT format.
|
||||
|
||||
Args:
|
||||
agent (Agent): The agent for which the edges are to be generated.
|
||||
parent (Agent, optional): The parent agent. Defaults to None.
|
||||
|
||||
Returns:
|
||||
str: The DOT format string representing the edges.
|
||||
"""
|
||||
parts = []
|
||||
|
||||
if not parent:
|
||||
parts.append(f'"__start__" -> "{agent.name}";')
|
||||
|
||||
for tool in agent.tools:
|
||||
parts.append(f"""
|
||||
"{agent.name}" -> "{tool.name}" [style=dotted, penwidth=1.5];
|
||||
"{tool.name}" -> "{agent.name}" [style=dotted, penwidth=1.5];""")
|
||||
|
||||
for handoff in agent.handoffs:
|
||||
if isinstance(handoff, Handoff):
|
||||
parts.append(f"""
|
||||
"{agent.name}" -> "{handoff.agent_name}";""")
|
||||
if isinstance(handoff, Agent):
|
||||
parts.append(f"""
|
||||
"{agent.name}" -> "{handoff.name}";""")
|
||||
parts.append(get_all_edges(handoff, agent))
|
||||
|
||||
if not agent.handoffs and not isinstance(agent, Tool): # type: ignore
|
||||
parts.append(f'"{agent.name}" -> "__end__";')
|
||||
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
def draw_graph(agent: Agent, filename: Optional[str] = None) -> graphviz.Source:
|
||||
"""
|
||||
Draws the graph for the given agent and optionally saves it as a PNG file.
|
||||
|
||||
Args:
|
||||
agent (Agent): The agent for which the graph is to be drawn.
|
||||
filename (str): The name of the file to save the graph as a PNG.
|
||||
|
||||
Returns:
|
||||
graphviz.Source: The graphviz Source object representing the graph.
|
||||
"""
|
||||
dot_code = get_main_graph(agent)
|
||||
graph = graphviz.Source(dot_code)
|
||||
|
||||
if filename:
|
||||
graph.render(filename, format="png")
|
||||
|
||||
return graph
|
||||
136
tests/test_visualization.py
Normal file
136
tests/test_visualization.py
Normal file
|
|
@ -0,0 +1,136 @@
|
|||
from unittest.mock import Mock
|
||||
|
||||
import graphviz # type: ignore
|
||||
import pytest
|
||||
|
||||
from agents import Agent
|
||||
from agents.extensions.visualization import (
|
||||
draw_graph,
|
||||
get_all_edges,
|
||||
get_all_nodes,
|
||||
get_main_graph,
|
||||
)
|
||||
from agents.handoffs import Handoff
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent():
|
||||
tool1 = Mock()
|
||||
tool1.name = "Tool1"
|
||||
tool2 = Mock()
|
||||
tool2.name = "Tool2"
|
||||
|
||||
handoff1 = Mock(spec=Handoff)
|
||||
handoff1.agent_name = "Handoff1"
|
||||
|
||||
agent = Mock(spec=Agent)
|
||||
agent.name = "Agent1"
|
||||
agent.tools = [tool1, tool2]
|
||||
agent.handoffs = [handoff1]
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def test_get_main_graph(mock_agent):
|
||||
result = get_main_graph(mock_agent)
|
||||
print(result)
|
||||
assert "digraph G" in result
|
||||
assert "graph [splines=true];" in result
|
||||
assert 'node [fontname="Arial"];' in result
|
||||
assert "edge [penwidth=1.5];" in result
|
||||
assert (
|
||||
'"__start__" [label="__start__", shape=ellipse, style=filled, '
|
||||
"fillcolor=lightblue, width=0.5, height=0.3];" in result
|
||||
)
|
||||
assert (
|
||||
'"__end__" [label="__end__", shape=ellipse, style=filled, '
|
||||
"fillcolor=lightblue, width=0.5, height=0.3];" in result
|
||||
)
|
||||
assert (
|
||||
'"Agent1" [label="Agent1", shape=box, style=filled, '
|
||||
"fillcolor=lightyellow, width=1.5, height=0.8];" in result
|
||||
)
|
||||
assert (
|
||||
'"Tool1" [label="Tool1", shape=ellipse, style=filled, '
|
||||
"fillcolor=lightgreen, width=0.5, height=0.3];" in result
|
||||
)
|
||||
assert (
|
||||
'"Tool2" [label="Tool2", shape=ellipse, style=filled, '
|
||||
"fillcolor=lightgreen, width=0.5, height=0.3];" in result
|
||||
)
|
||||
assert (
|
||||
'"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, '
|
||||
"fillcolor=lightyellow, width=1.5, height=0.8];" in result
|
||||
)
|
||||
|
||||
|
||||
def test_get_all_nodes(mock_agent):
|
||||
result = get_all_nodes(mock_agent)
|
||||
assert (
|
||||
'"__start__" [label="__start__", shape=ellipse, style=filled, '
|
||||
"fillcolor=lightblue, width=0.5, height=0.3];" in result
|
||||
)
|
||||
assert (
|
||||
'"__end__" [label="__end__", shape=ellipse, style=filled, '
|
||||
"fillcolor=lightblue, width=0.5, height=0.3];" in result
|
||||
)
|
||||
assert (
|
||||
'"Agent1" [label="Agent1", shape=box, style=filled, '
|
||||
"fillcolor=lightyellow, width=1.5, height=0.8];" in result
|
||||
)
|
||||
assert (
|
||||
'"Tool1" [label="Tool1", shape=ellipse, style=filled, '
|
||||
"fillcolor=lightgreen, width=0.5, height=0.3];" in result
|
||||
)
|
||||
assert (
|
||||
'"Tool2" [label="Tool2", shape=ellipse, style=filled, '
|
||||
"fillcolor=lightgreen, width=0.5, height=0.3];" in result
|
||||
)
|
||||
assert (
|
||||
'"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, '
|
||||
"fillcolor=lightyellow, width=1.5, height=0.8];" in result
|
||||
)
|
||||
|
||||
|
||||
def test_get_all_edges(mock_agent):
|
||||
result = get_all_edges(mock_agent)
|
||||
assert '"__start__" -> "Agent1";' in result
|
||||
assert '"Agent1" -> "__end__";'
|
||||
assert '"Agent1" -> "Tool1" [style=dotted, penwidth=1.5];' in result
|
||||
assert '"Tool1" -> "Agent1" [style=dotted, penwidth=1.5];' in result
|
||||
assert '"Agent1" -> "Tool2" [style=dotted, penwidth=1.5];' in result
|
||||
assert '"Tool2" -> "Agent1" [style=dotted, penwidth=1.5];' in result
|
||||
assert '"Agent1" -> "Handoff1";' in result
|
||||
|
||||
|
||||
def test_draw_graph(mock_agent):
|
||||
graph = draw_graph(mock_agent)
|
||||
assert isinstance(graph, graphviz.Source)
|
||||
assert "digraph G" in graph.source
|
||||
assert "graph [splines=true];" in graph.source
|
||||
assert 'node [fontname="Arial"];' in graph.source
|
||||
assert "edge [penwidth=1.5];" in graph.source
|
||||
assert (
|
||||
'"__start__" [label="__start__", shape=ellipse, style=filled, '
|
||||
"fillcolor=lightblue, width=0.5, height=0.3];" in graph.source
|
||||
)
|
||||
assert (
|
||||
'"__end__" [label="__end__", shape=ellipse, style=filled, '
|
||||
"fillcolor=lightblue, width=0.5, height=0.3];" in graph.source
|
||||
)
|
||||
assert (
|
||||
'"Agent1" [label="Agent1", shape=box, style=filled, '
|
||||
"fillcolor=lightyellow, width=1.5, height=0.8];" in graph.source
|
||||
)
|
||||
assert (
|
||||
'"Tool1" [label="Tool1", shape=ellipse, style=filled, '
|
||||
"fillcolor=lightgreen, width=0.5, height=0.3];" in graph.source
|
||||
)
|
||||
assert (
|
||||
'"Tool2" [label="Tool2", shape=ellipse, style=filled, '
|
||||
"fillcolor=lightgreen, width=0.5, height=0.3];" in graph.source
|
||||
)
|
||||
assert (
|
||||
'"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, '
|
||||
"fillcolor=lightyellow, width=1.5, height=0.8];" in graph.source
|
||||
)
|
||||
17
uv.lock
17
uv.lock
|
|
@ -1,4 +1,5 @@
|
|||
version = 1
|
||||
revision = 1
|
||||
requires-python = ">=3.9"
|
||||
resolution-markers = [
|
||||
"python_full_version >= '3.10'",
|
||||
|
|
@ -348,6 +349,15 @@ wheels = [
|
|||
{ url = "https://files.pythonhosted.org/packages/f7/ec/67fbef5d497f86283db54c22eec6f6140243aae73265799baaaa19cd17fb/ghp_import-2.1.0-py3-none-any.whl", hash = "sha256:8337dd7b50877f163d4c0289bc1f1c7f127550241988d568c1db512c4324a619", size = 11034 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "graphviz"
|
||||
version = "0.20.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/fa/83/5a40d19b8347f017e417710907f824915fba411a9befd092e52746b63e9f/graphviz-0.20.3.zip", hash = "sha256:09d6bc81e6a9fa392e7ba52135a9d49f1ed62526f96499325930e87ca1b5925d", size = 256455 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/00/be/d59db2d1d52697c6adc9eacaf50e8965b6345cc143f671e1ed068818d5cf/graphviz-0.20.3-py3-none-any.whl", hash = "sha256:81f848f2904515d8cd359cc611faba817598d2feaac4027b266aa3eda7b3dde5", size = 47126 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "greenlet"
|
||||
version = "3.1.1"
|
||||
|
|
@ -1090,6 +1100,9 @@ dependencies = [
|
|||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
viz = [
|
||||
{ name = "graphviz" },
|
||||
]
|
||||
voice = [
|
||||
{ name = "numpy", version = "2.2.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" },
|
||||
{ name = "websockets" },
|
||||
|
|
@ -1098,6 +1111,7 @@ voice = [
|
|||
[package.dev-dependencies]
|
||||
dev = [
|
||||
{ name = "coverage" },
|
||||
{ name = "graphviz" },
|
||||
{ name = "inline-snapshot" },
|
||||
{ name = "mkdocs" },
|
||||
{ name = "mkdocs-material" },
|
||||
|
|
@ -1118,6 +1132,7 @@ dev = [
|
|||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "graphviz", marker = "extra == 'viz'", specifier = ">=0.17" },
|
||||
{ name = "griffe", specifier = ">=1.5.6,<2" },
|
||||
{ name = "mcp", marker = "python_full_version >= '3.10'" },
|
||||
{ name = "numpy", marker = "python_full_version >= '3.10' and extra == 'voice'", specifier = ">=2.2.0,<3" },
|
||||
|
|
@ -1128,10 +1143,12 @@ requires-dist = [
|
|||
{ name = "typing-extensions", specifier = ">=4.12.2,<5" },
|
||||
{ name = "websockets", marker = "extra == 'voice'", specifier = ">=15.0,<16" },
|
||||
]
|
||||
provides-extras = ["voice", "viz"]
|
||||
|
||||
[package.metadata.requires-dev]
|
||||
dev = [
|
||||
{ name = "coverage", specifier = ">=7.6.12" },
|
||||
{ name = "graphviz" },
|
||||
{ name = "inline-snapshot", specifier = ">=0.20.7" },
|
||||
{ name = "mkdocs", specifier = ">=1.6.0" },
|
||||
{ name = "mkdocs-material", specifier = ">=9.6.0" },
|
||||
|
|
|
|||
Loading…
Reference in a new issue