diff --git a/src/agents/extensions/visualization.py b/src/agents/extensions/visualization.py index a05dcd3..dd75acf 100644 --- a/src/agents/extensions/visualization.py +++ b/src/agents/extensions/visualization.py @@ -1,3 +1,5 @@ +from typing import Optional + import graphviz from agents import Agent @@ -20,8 +22,6 @@ def get_main_graph(agent: Agent) -> str: graph [splines=true]; node [fontname="Arial"]; edge [penwidth=1.5]; - "__start__" [shape=ellipse, style=filled, fillcolor=lightblue]; - "__end__" [shape=ellipse, style=filled, fillcolor=lightblue]; """ ] parts.append(get_all_nodes(agent)) @@ -30,8 +30,6 @@ def get_main_graph(agent: Agent) -> str: return "".join(parts) -from typing import Optional - def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str: """ Recursively generates the nodes for the given agent and its handoffs in DOT format. @@ -59,12 +57,14 @@ def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str: for handoff in agent.handoffs: if isinstance(handoff, Handoff): parts.append( - f'"{handoff.agent_name}" [label="{handoff.agent_name}", shape=box, style=filled, style=rounded, ' + f'"{handoff.agent_name}" [label="{handoff.agent_name}", shape=box, ' + f"shape=box, style=filled, style=rounded, " f"fillcolor=lightyellow, width=1.5, height=0.8];" ) if isinstance(handoff, Agent): parts.append( - f'"{handoff.name}" [label="{handoff.name}", shape=box, style=filled, style=rounded, ' + f'"{handoff.name}" [label="{handoff.name}", ' + f"shape=box, style=filled, style=rounded, " f"fillcolor=lightyellow, width=1.5, height=0.8];" ) parts.append(get_all_nodes(handoff)) @@ -85,19 +85,11 @@ def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str: """ parts = [] - if not parent: - parts.append(f""" - "__start__" -> "{agent.name}";""") - for tool in agent.tools: parts.append(f""" "{agent.name}" -> "{tool.name}" [style=dotted, penwidth=1.5]; "{tool.name}" -> "{agent.name}" [style=dotted, penwidth=1.5];""") - if not agent.handoffs: - parts.append(f""" - "{agent.name}" -> "__end__";""") - for handoff in agent.handoffs: if isinstance(handoff, Handoff): parts.append(f""" @@ -110,8 +102,6 @@ def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str: return "".join(parts) -from typing import Optional - def draw_graph(agent: Agent, filename: Optional[str] = None) -> graphviz.Source: """ Draws the graph for the given agent and optionally saves it as a PNG file. diff --git a/tests/test_visualization.py b/tests/test_visualization.py index 12d67fb..1f83d85 100644 --- a/tests/test_visualization.py +++ b/tests/test_visualization.py @@ -10,6 +10,7 @@ from agents.extensions.visualization import ( get_all_nodes, get_main_graph, ) +from agents.handoffs import Handoff @pytest.fixture @@ -19,10 +20,8 @@ def mock_agent(): tool2 = Mock() tool2.name = "Tool2" - handoff1 = Mock() - handoff1.name = "Handoff1" - handoff1.tools = [] - handoff1.handoffs = [] + handoff1 = Mock(spec=Handoff) + handoff1.agent_name = "Handoff1" agent = Mock(spec=Agent) agent.name = "Agent1" @@ -34,12 +33,11 @@ def mock_agent(): def test_get_main_graph(mock_agent): result = get_main_graph(mock_agent) + print(result) assert "digraph G" in result assert "graph [splines=true];" in result assert 'node [fontname="Arial"];' in result assert "edge [penwidth=1.5];" in result - assert '"__start__" [shape=ellipse, style=filled, fillcolor=lightblue];' in result - assert '"__end__" [shape=ellipse, style=filled, fillcolor=lightblue];' in result assert ( '"Agent1" [label="Agent1", shape=box, style=filled, ' "fillcolor=lightyellow, width=1.5, height=0.8];" in result @@ -80,13 +78,11 @@ def test_get_all_nodes(mock_agent): def test_get_all_edges(mock_agent): result = get_all_edges(mock_agent) - assert '"__start__" -> "Agent1";' in result assert '"Agent1" -> "Tool1" [style=dotted, penwidth=1.5];' in result assert '"Tool1" -> "Agent1" [style=dotted, penwidth=1.5];' in result assert '"Agent1" -> "Tool2" [style=dotted, penwidth=1.5];' in result assert '"Tool2" -> "Agent1" [style=dotted, penwidth=1.5];' in result assert '"Agent1" -> "Handoff1";' in result - assert '"Handoff1" -> "__end__";' in result def test_draw_graph(mock_agent): @@ -96,8 +92,6 @@ def test_draw_graph(mock_agent): assert "graph [splines=true];" in graph.source assert 'node [fontname="Arial"];' in graph.source assert "edge [penwidth=1.5];" in graph.source - assert '"__start__" [shape=ellipse, style=filled, fillcolor=lightblue];' in graph.source - assert '"__end__" [shape=ellipse, style=filled, fillcolor=lightblue];' in graph.source assert ( '"Agent1" [label="Agent1", shape=box, style=filled, ' "fillcolor=lightyellow, width=1.5, height=0.8];" in graph.source