diff --git a/src/agents/extensions/visualization.py b/src/agents/extensions/visualization.py index e2de019..08b185e 100644 --- a/src/agents/extensions/visualization.py +++ b/src/agents/extensions/visualization.py @@ -4,6 +4,7 @@ import graphviz # type: ignore from agents import Agent from agents.handoffs import Handoff +from agents.tool import Tool def get_main_graph(agent: Agent) -> str: @@ -41,6 +42,14 @@ def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str: str: The DOT format string representing the nodes. """ parts = [] + + # Start and end the graph + parts.append( + f'"__start__" [label="__start__", shape=ellipse, style=filled, ' + f"fillcolor=lightblue, width=0.5, height=0.3];" + f'"__end__" [label="__end__", shape=ellipse, style=filled, ' + f"fillcolor=lightblue, width=0.5, height=0.3];" + ) # Ensure parent agent node is colored if not parent: parts.append( @@ -72,7 +81,7 @@ def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str: return "".join(parts) -def get_all_edges(agent: Agent) -> str: +def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str: """ Recursively generates the edges for the given agent and its handoffs in DOT format. @@ -85,6 +94,11 @@ def get_all_edges(agent: Agent) -> str: """ parts = [] + if not parent: + parts.append( + f'"__start__" -> "{agent.name}";' + ) + for tool in agent.tools: parts.append(f""" "{agent.name}" -> "{tool.name}" [style=dotted, penwidth=1.5]; @@ -97,7 +111,12 @@ def get_all_edges(agent: Agent) -> str: if isinstance(handoff, Agent): parts.append(f""" "{agent.name}" -> "{handoff.name}";""") - parts.append(get_all_edges(handoff)) + parts.append(get_all_edges(handoff, agent)) + + if not agent.handoffs and not isinstance(agent, Tool): + parts.append( + f'"{agent.name}" -> "__end__";' + ) return "".join(parts) diff --git a/tests/test_visualization.py b/tests/test_visualization.py index 046cdd6..b530f50 100644 --- a/tests/test_visualization.py +++ b/tests/test_visualization.py @@ -38,6 +38,14 @@ def test_get_main_graph(mock_agent): assert "graph [splines=true];" in result assert 'node [fontname="Arial"];' in result assert "edge [penwidth=1.5];" in result + assert ( + '"__start__" [label="__start__", shape=ellipse, style=filled, ' + "fillcolor=lightblue, width=0.5, height=0.3];" in result + ) + assert ( + '"__end__" [label="__end__", shape=ellipse, style=filled, ' + "fillcolor=lightblue, width=0.5, height=0.3];" in result + ) assert ( '"Agent1" [label="Agent1", shape=box, style=filled, ' "fillcolor=lightyellow, width=1.5, height=0.8];" in result @@ -58,6 +66,14 @@ def test_get_main_graph(mock_agent): def test_get_all_nodes(mock_agent): result = get_all_nodes(mock_agent) + assert ( + '"__start__" [label="__start__", shape=ellipse, style=filled, ' + "fillcolor=lightblue, width=0.5, height=0.3];" in result + ) + assert ( + '"__end__" [label="__end__", shape=ellipse, style=filled, ' + "fillcolor=lightblue, width=0.5, height=0.3];" in result + ) assert ( '"Agent1" [label="Agent1", shape=box, style=filled, ' "fillcolor=lightyellow, width=1.5, height=0.8];" in result @@ -78,6 +94,8 @@ def test_get_all_nodes(mock_agent): def test_get_all_edges(mock_agent): result = get_all_edges(mock_agent) + assert '"__start__" -> "Agent1";' in result + assert '"Agent1" -> "__end__";' assert '"Agent1" -> "Tool1" [style=dotted, penwidth=1.5];' in result assert '"Tool1" -> "Agent1" [style=dotted, penwidth=1.5];' in result assert '"Agent1" -> "Tool2" [style=dotted, penwidth=1.5];' in result @@ -85,6 +103,7 @@ def test_get_all_edges(mock_agent): assert '"Agent1" -> "Handoff1";' in result + def test_draw_graph(mock_agent): graph = draw_graph(mock_agent) assert isinstance(graph, graphviz.Source) @@ -92,6 +111,14 @@ def test_draw_graph(mock_agent): assert "graph [splines=true];" in graph.source assert 'node [fontname="Arial"];' in graph.source assert "edge [penwidth=1.5];" in graph.source + assert ( + '"__start__" [label="__start__", shape=ellipse, style=filled, ' + "fillcolor=lightblue, width=0.5, height=0.3];" in graph.source + ) + assert ( + '"__end__" [label="__end__", shape=ellipse, style=filled, ' + "fillcolor=lightblue, width=0.5, height=0.3];" in graph.source + ) assert ( '"Agent1" [label="Agent1", shape=box, style=filled, ' "fillcolor=lightyellow, width=1.5, height=0.8];" in graph.source