feat: Add Graphviz-based agent visualization functionality (#147)

This pull request introduces functionality for visualizing agent
structures using Graphviz. The changes include adding a new dependency,
implementing functions to generate and draw graphs, and adding tests for
these functions.

New functionality for visualizing agent structures:

* Added `graphviz` as a new dependency in `pyproject.toml`.
* Implemented functions in `src/agents/visualizations.py` to generate
and draw graphs for agents using Graphviz. These functions include
`get_main_graph`, `get_all_nodes`, `get_all_edges`, and `draw_graph`.

Testing the new visualization functionality:

* Added tests in `tests/test_visualizations.py` to verify the
correctness of the graph generation and drawing functions. The tests
cover `get_main_graph`, `get_all_nodes`, `get_all_edges`, and
`draw_graph`.

For example, given the following code:

```python
from agents import Agent, function_tool
from agents.visualizations import draw_graph


@function_tool
def get_weather(city: str) -> str:
    return f"The weather in {city} is sunny."


spanish_agent = Agent(
    name="Spanish agent",
    instructions="You only speak Spanish.",
)

english_agent = Agent(
    name="English agent",
    instructions="You only speak English",
)

triage_agent = Agent(
    name="Triage agent",
    instructions="Handoff to the appropriate agent based on the language of the request.",
    handoffs=[spanish_agent, english_agent],
    tools=[get_weather],
)


draw_graph(triage_agent)
```

Generates the following image:

<img width="614" alt="Screenshot 2025-03-13 at 18 36 23"
src="https://github.com/user-attachments/assets/d01fe502-6886-4efb-aaf8-c92e4524b0fe"
/>
This commit is contained in:
Rohan Mehta 2025-03-25 19:22:58 -04:00 committed by GitHub
commit dd881eed9a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 381 additions and 1 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 93 KiB

86
docs/visualization.md Normal file
View file

@ -0,0 +1,86 @@
# Agent Visualization
Agent visualization allows you to generate a structured graphical representation of agents and their relationships using **Graphviz**. This is useful for understanding how agents, tools, and handoffs interact within an application.
## Installation
Install the optional `viz` dependency group:
```bash
pip install "openai-agents[viz]"
```
## Generating a Graph
You can generate an agent visualization using the `draw_graph` function. This function creates a directed graph where:
- **Agents** are represented as yellow boxes.
- **Tools** are represented as green ellipses.
- **Handoffs** are directed edges from one agent to another.
### Example Usage
```python
from agents import Agent, function_tool
from agents.extensions.visualization import draw_graph
@function_tool
def get_weather(city: str) -> str:
return f"The weather in {city} is sunny."
spanish_agent = Agent(
name="Spanish agent",
instructions="You only speak Spanish.",
)
english_agent = Agent(
name="English agent",
instructions="You only speak English",
)
triage_agent = Agent(
name="Triage agent",
instructions="Handoff to the appropriate agent based on the language of the request.",
handoffs=[spanish_agent, english_agent],
tools=[get_weather],
)
draw_graph(triage_agent)
```
![Agent Graph](./assets/images/graph.png)
This generates a graph that visually represents the structure of the **triage agent** and its connections to sub-agents and tools.
## Understanding the Visualization
The generated graph includes:
- A **start node** (`__start__`) indicating the entry point.
- Agents represented as **rectangles** with yellow fill.
- Tools represented as **ellipses** with green fill.
- Directed edges indicating interactions:
- **Solid arrows** for agent-to-agent handoffs.
- **Dotted arrows** for tool invocations.
- An **end node** (`__end__`) indicating where execution terminates.
## Customizing the Graph
### Showing the Graph
By default, `draw_graph` displays the graph inline. To show the graph in a separate window, write the following:
```python
draw_graph(triage_agent).view()
```
### Saving the Graph
By default, `draw_graph` displays the graph inline. To save it as a file, specify a filename:
```python
draw_graph(triage_agent, filename="agent_graph.png")
```
This will generate `agent_graph.png` in the working directory.

View file

@ -36,6 +36,7 @@ nav:
- multi_agent.md
- models.md
- config.md
- visualization.md
- Voice agents:
- voice/quickstart.md
- voice/pipeline.md

View file

@ -35,6 +35,7 @@ Repository = "https://github.com/openai/openai-agents-python"
[project.optional-dependencies]
voice = ["numpy>=2.2.0, <3; python_version>='3.10'", "websockets>=15.0, <16"]
viz = ["graphviz>=0.17"]
[dependency-groups]
dev = [
@ -56,7 +57,9 @@ dev = [
"pynput",
"textual",
"websockets",
"graphviz",
]
[tool.uv.workspace]
members = ["agents"]

View file

@ -0,0 +1,137 @@
from typing import Optional
import graphviz # type: ignore
from agents import Agent
from agents.handoffs import Handoff
from agents.tool import Tool
def get_main_graph(agent: Agent) -> str:
"""
Generates the main graph structure in DOT format for the given agent.
Args:
agent (Agent): The agent for which the graph is to be generated.
Returns:
str: The DOT format string representing the graph.
"""
parts = [
"""
digraph G {
graph [splines=true];
node [fontname="Arial"];
edge [penwidth=1.5];
"""
]
parts.append(get_all_nodes(agent))
parts.append(get_all_edges(agent))
parts.append("}")
return "".join(parts)
def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str:
"""
Recursively generates the nodes for the given agent and its handoffs in DOT format.
Args:
agent (Agent): The agent for which the nodes are to be generated.
Returns:
str: The DOT format string representing the nodes.
"""
parts = []
# Start and end the graph
parts.append(
'"__start__" [label="__start__", shape=ellipse, style=filled, '
"fillcolor=lightblue, width=0.5, height=0.3];"
'"__end__" [label="__end__", shape=ellipse, style=filled, '
"fillcolor=lightblue, width=0.5, height=0.3];"
)
# Ensure parent agent node is colored
if not parent:
parts.append(
f'"{agent.name}" [label="{agent.name}", shape=box, style=filled, '
"fillcolor=lightyellow, width=1.5, height=0.8];"
)
for tool in agent.tools:
parts.append(
f'"{tool.name}" [label="{tool.name}", shape=ellipse, style=filled, '
f"fillcolor=lightgreen, width=0.5, height=0.3];"
)
for handoff in agent.handoffs:
if isinstance(handoff, Handoff):
parts.append(
f'"{handoff.agent_name}" [label="{handoff.agent_name}", '
f"shape=box, style=filled, style=rounded, "
f"fillcolor=lightyellow, width=1.5, height=0.8];"
)
if isinstance(handoff, Agent):
parts.append(
f'"{handoff.name}" [label="{handoff.name}", '
f"shape=box, style=filled, style=rounded, "
f"fillcolor=lightyellow, width=1.5, height=0.8];"
)
parts.append(get_all_nodes(handoff))
return "".join(parts)
def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str:
"""
Recursively generates the edges for the given agent and its handoffs in DOT format.
Args:
agent (Agent): The agent for which the edges are to be generated.
parent (Agent, optional): The parent agent. Defaults to None.
Returns:
str: The DOT format string representing the edges.
"""
parts = []
if not parent:
parts.append(f'"__start__" -> "{agent.name}";')
for tool in agent.tools:
parts.append(f"""
"{agent.name}" -> "{tool.name}" [style=dotted, penwidth=1.5];
"{tool.name}" -> "{agent.name}" [style=dotted, penwidth=1.5];""")
for handoff in agent.handoffs:
if isinstance(handoff, Handoff):
parts.append(f"""
"{agent.name}" -> "{handoff.agent_name}";""")
if isinstance(handoff, Agent):
parts.append(f"""
"{agent.name}" -> "{handoff.name}";""")
parts.append(get_all_edges(handoff, agent))
if not agent.handoffs and not isinstance(agent, Tool): # type: ignore
parts.append(f'"{agent.name}" -> "__end__";')
return "".join(parts)
def draw_graph(agent: Agent, filename: Optional[str] = None) -> graphviz.Source:
"""
Draws the graph for the given agent and optionally saves it as a PNG file.
Args:
agent (Agent): The agent for which the graph is to be drawn.
filename (str): The name of the file to save the graph as a PNG.
Returns:
graphviz.Source: The graphviz Source object representing the graph.
"""
dot_code = get_main_graph(agent)
graph = graphviz.Source(dot_code)
if filename:
graph.render(filename, format="png")
return graph

136
tests/test_visualization.py Normal file
View file

@ -0,0 +1,136 @@
from unittest.mock import Mock
import graphviz # type: ignore
import pytest
from agents import Agent
from agents.extensions.visualization import (
draw_graph,
get_all_edges,
get_all_nodes,
get_main_graph,
)
from agents.handoffs import Handoff
@pytest.fixture
def mock_agent():
tool1 = Mock()
tool1.name = "Tool1"
tool2 = Mock()
tool2.name = "Tool2"
handoff1 = Mock(spec=Handoff)
handoff1.agent_name = "Handoff1"
agent = Mock(spec=Agent)
agent.name = "Agent1"
agent.tools = [tool1, tool2]
agent.handoffs = [handoff1]
return agent
def test_get_main_graph(mock_agent):
result = get_main_graph(mock_agent)
print(result)
assert "digraph G" in result
assert "graph [splines=true];" in result
assert 'node [fontname="Arial"];' in result
assert "edge [penwidth=1.5];" in result
assert (
'"__start__" [label="__start__", shape=ellipse, style=filled, '
"fillcolor=lightblue, width=0.5, height=0.3];" in result
)
assert (
'"__end__" [label="__end__", shape=ellipse, style=filled, '
"fillcolor=lightblue, width=0.5, height=0.3];" in result
)
assert (
'"Agent1" [label="Agent1", shape=box, style=filled, '
"fillcolor=lightyellow, width=1.5, height=0.8];" in result
)
assert (
'"Tool1" [label="Tool1", shape=ellipse, style=filled, '
"fillcolor=lightgreen, width=0.5, height=0.3];" in result
)
assert (
'"Tool2" [label="Tool2", shape=ellipse, style=filled, '
"fillcolor=lightgreen, width=0.5, height=0.3];" in result
)
assert (
'"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, '
"fillcolor=lightyellow, width=1.5, height=0.8];" in result
)
def test_get_all_nodes(mock_agent):
result = get_all_nodes(mock_agent)
assert (
'"__start__" [label="__start__", shape=ellipse, style=filled, '
"fillcolor=lightblue, width=0.5, height=0.3];" in result
)
assert (
'"__end__" [label="__end__", shape=ellipse, style=filled, '
"fillcolor=lightblue, width=0.5, height=0.3];" in result
)
assert (
'"Agent1" [label="Agent1", shape=box, style=filled, '
"fillcolor=lightyellow, width=1.5, height=0.8];" in result
)
assert (
'"Tool1" [label="Tool1", shape=ellipse, style=filled, '
"fillcolor=lightgreen, width=0.5, height=0.3];" in result
)
assert (
'"Tool2" [label="Tool2", shape=ellipse, style=filled, '
"fillcolor=lightgreen, width=0.5, height=0.3];" in result
)
assert (
'"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, '
"fillcolor=lightyellow, width=1.5, height=0.8];" in result
)
def test_get_all_edges(mock_agent):
result = get_all_edges(mock_agent)
assert '"__start__" -> "Agent1";' in result
assert '"Agent1" -> "__end__";'
assert '"Agent1" -> "Tool1" [style=dotted, penwidth=1.5];' in result
assert '"Tool1" -> "Agent1" [style=dotted, penwidth=1.5];' in result
assert '"Agent1" -> "Tool2" [style=dotted, penwidth=1.5];' in result
assert '"Tool2" -> "Agent1" [style=dotted, penwidth=1.5];' in result
assert '"Agent1" -> "Handoff1";' in result
def test_draw_graph(mock_agent):
graph = draw_graph(mock_agent)
assert isinstance(graph, graphviz.Source)
assert "digraph G" in graph.source
assert "graph [splines=true];" in graph.source
assert 'node [fontname="Arial"];' in graph.source
assert "edge [penwidth=1.5];" in graph.source
assert (
'"__start__" [label="__start__", shape=ellipse, style=filled, '
"fillcolor=lightblue, width=0.5, height=0.3];" in graph.source
)
assert (
'"__end__" [label="__end__", shape=ellipse, style=filled, '
"fillcolor=lightblue, width=0.5, height=0.3];" in graph.source
)
assert (
'"Agent1" [label="Agent1", shape=box, style=filled, '
"fillcolor=lightyellow, width=1.5, height=0.8];" in graph.source
)
assert (
'"Tool1" [label="Tool1", shape=ellipse, style=filled, '
"fillcolor=lightgreen, width=0.5, height=0.3];" in graph.source
)
assert (
'"Tool2" [label="Tool2", shape=ellipse, style=filled, '
"fillcolor=lightgreen, width=0.5, height=0.3];" in graph.source
)
assert (
'"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, '
"fillcolor=lightyellow, width=1.5, height=0.8];" in graph.source
)

17
uv.lock
View file

@ -1,4 +1,5 @@
version = 1
revision = 1
requires-python = ">=3.9"
resolution-markers = [
"python_full_version >= '3.10'",
@ -348,6 +349,15 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/f7/ec/67fbef5d497f86283db54c22eec6f6140243aae73265799baaaa19cd17fb/ghp_import-2.1.0-py3-none-any.whl", hash = "sha256:8337dd7b50877f163d4c0289bc1f1c7f127550241988d568c1db512c4324a619", size = 11034 },
]
[[package]]
name = "graphviz"
version = "0.20.3"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/fa/83/5a40d19b8347f017e417710907f824915fba411a9befd092e52746b63e9f/graphviz-0.20.3.zip", hash = "sha256:09d6bc81e6a9fa392e7ba52135a9d49f1ed62526f96499325930e87ca1b5925d", size = 256455 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/00/be/d59db2d1d52697c6adc9eacaf50e8965b6345cc143f671e1ed068818d5cf/graphviz-0.20.3-py3-none-any.whl", hash = "sha256:81f848f2904515d8cd359cc611faba817598d2feaac4027b266aa3eda7b3dde5", size = 47126 },
]
[[package]]
name = "greenlet"
version = "3.1.1"
@ -1090,6 +1100,9 @@ dependencies = [
]
[package.optional-dependencies]
viz = [
{ name = "graphviz" },
]
voice = [
{ name = "numpy", version = "2.2.4", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.10'" },
{ name = "websockets" },
@ -1098,6 +1111,7 @@ voice = [
[package.dev-dependencies]
dev = [
{ name = "coverage" },
{ name = "graphviz" },
{ name = "inline-snapshot" },
{ name = "mkdocs" },
{ name = "mkdocs-material" },
@ -1118,6 +1132,7 @@ dev = [
[package.metadata]
requires-dist = [
{ name = "graphviz", marker = "extra == 'viz'", specifier = ">=0.17" },
{ name = "griffe", specifier = ">=1.5.6,<2" },
{ name = "mcp", marker = "python_full_version >= '3.10'" },
{ name = "numpy", marker = "python_full_version >= '3.10' and extra == 'voice'", specifier = ">=2.2.0,<3" },
@ -1128,10 +1143,12 @@ requires-dist = [
{ name = "typing-extensions", specifier = ">=4.12.2,<5" },
{ name = "websockets", marker = "extra == 'voice'", specifier = ">=15.0,<16" },
]
provides-extras = ["voice", "viz"]
[package.metadata.requires-dev]
dev = [
{ name = "coverage", specifier = ">=7.6.12" },
{ name = "graphviz" },
{ name = "inline-snapshot", specifier = ">=0.20.7" },
{ name = "mkdocs", specifier = ">=1.6.0" },
{ name = "mkdocs-material", specifier = ">=9.6.0" },