from unittest.mock import Mock import graphviz import pytest from agents import Agent from agents.extensions.visualization import draw_graph, get_all_edges, get_all_nodes, get_main_graph @pytest.fixture def mock_agent(): tool1 = Mock() tool1.name = "Tool1" tool2 = Mock() tool2.name = "Tool2" handoff1 = Mock() handoff1.name = "Handoff1" handoff1.tools = [] handoff1.handoffs = [] agent = Mock(spec=Agent) agent.name = "Agent1" agent.tools = [tool1, tool2] agent.handoffs = [handoff1] return agent def test_get_main_graph(mock_agent): result = get_main_graph(mock_agent) assert "digraph G" in result assert "graph [splines=true];" in result assert 'node [fontname="Arial"];' in result assert "edge [penwidth=1.5];" in result assert '"__start__" [shape=ellipse, style=filled, fillcolor=lightblue];' in result assert '"__end__" [shape=ellipse, style=filled, fillcolor=lightblue];' in result assert ( '"Agent1" [label="Agent1", shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8];' in result ) assert ( '"Tool1" [label="Tool1", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' in result ) assert ( '"Tool2" [label="Tool2", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' in result ) assert ( '"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, fillcolor=lightyellow, width=1.5, height=0.8];' in result ) def test_get_all_nodes(mock_agent): result = get_all_nodes(mock_agent) assert ( '"Agent1" [label="Agent1", shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8];' in result ) assert ( '"Tool1" [label="Tool1", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' in result ) assert ( '"Tool2" [label="Tool2", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' in result ) assert ( '"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, fillcolor=lightyellow, width=1.5, height=0.8];' in result ) def test_get_all_edges(mock_agent): result = get_all_edges(mock_agent) assert '"__start__" -> "Agent1";' in result assert '"Agent1" -> "Tool1" [style=dotted, penwidth=1.5];' in result assert '"Tool1" -> "Agent1" [style=dotted, penwidth=1.5];' in result assert '"Agent1" -> "Tool2" [style=dotted, penwidth=1.5];' in result assert '"Tool2" -> "Agent1" [style=dotted, penwidth=1.5];' in result assert '"Agent1" -> "Handoff1";' in result assert '"Handoff1" -> "__end__";' in result def test_draw_graph(mock_agent): graph = draw_graph(mock_agent) assert isinstance(graph, graphviz.Source) assert "digraph G" in graph.source assert "graph [splines=true];" in graph.source assert 'node [fontname="Arial"];' in graph.source assert "edge [penwidth=1.5];" in graph.source assert '"__start__" [shape=ellipse, style=filled, fillcolor=lightblue];' in graph.source assert '"__end__" [shape=ellipse, style=filled, fillcolor=lightblue];' in graph.source assert ( '"Agent1" [label="Agent1", shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8];' in graph.source ) assert ( '"Tool1" [label="Tool1", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' in graph.source ) assert ( '"Tool2" [label="Tool2", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' in graph.source ) assert ( '"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, fillcolor=lightyellow, width=1.5, height=0.8];' in graph.source )