diff --git a/src/agents/visualizations.py b/src/agents/visualizations.py index 934647f..42019d4 100644 --- a/src/agents/visualizations.py +++ b/src/agents/visualizations.py @@ -13,7 +13,8 @@ def get_main_graph(agent: Agent) -> str: Returns: str: The DOT format string representing the graph. """ - parts = [""" + parts = [ + """ digraph G { graph [splines=true]; node [fontname="Arial"]; @@ -21,7 +22,8 @@ def get_main_graph(agent: Agent) -> str: "__start__" [shape=ellipse, style=filled, fillcolor=lightblue]; "__end__" [shape=ellipse, style=filled, fillcolor=lightblue]; - """] + """ + ] parts.append(get_all_nodes(agent)) parts.append(get_all_edges(agent)) parts.append("}") @@ -39,23 +41,23 @@ def get_all_nodes(agent: Agent, parent: Agent = None) -> str: str: The DOT format string representing the nodes. """ parts = [] - + # Ensure parent agent node is colored if not parent: parts.append(f""" "{agent.name}" [label="{agent.name}", shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8];""") - + # Smaller tools (ellipse, green) for tool in agent.tools: parts.append(f""" "{tool.name}" [label="{tool.name}", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];""") - + # Bigger handoffs (rounded box, yellow) for handoff in agent.handoffs: parts.append(f""" "{handoff.name}" [label="{handoff.name}", shape=box, style=filled, style=rounded, fillcolor=lightyellow, width=1.5, height=0.8];""") parts.append(get_all_nodes(handoff)) - + return "".join(parts) @@ -71,25 +73,25 @@ def get_all_edges(agent: Agent, parent: Agent = None) -> str: str: The DOT format string representing the edges. """ parts = [] - + if not parent: parts.append(f""" "__start__" -> "{agent.name}";""") - + for tool in agent.tools: parts.append(f""" "{agent.name}" -> "{tool.name}" [style=dotted, penwidth=1.5]; "{tool.name}" -> "{agent.name}" [style=dotted, penwidth=1.5];""") - + if not agent.handoffs: parts.append(f""" "{agent.name}" -> "__end__";""") - + for handoff in agent.handoffs: parts.append(f""" "{agent.name}" -> "{handoff.name}";""") parts.append(get_all_edges(handoff, agent)) - + return "".join(parts) @@ -106,8 +108,8 @@ def draw_graph(agent: Agent, filename: str = None) -> graphviz.Source: """ dot_code = get_main_graph(agent) graph = graphviz.Source(dot_code) - + if filename: - graph.render(filename, format='png') - - return graph \ No newline at end of file + graph.render(filename, format="png") + + return graph diff --git a/tests/test_visualizations.py b/tests/test_visualizations.py index 0062f85..d813978 100644 --- a/tests/test_visualizations.py +++ b/tests/test_visualizations.py @@ -1,8 +1,11 @@ -import pytest from unittest.mock import Mock -from src.agents.visualizations import get_main_graph, get_all_nodes, get_all_edges, draw_graph -from src.agents.agent import Agent + import graphviz +import pytest + +from src.agents.agent import Agent +from src.agents.visualizations import draw_graph, get_all_edges, get_all_nodes, get_main_graph + @pytest.fixture def mock_agent(): @@ -10,7 +13,7 @@ def mock_agent(): tool1.name = "Tool1" tool2 = Mock() tool2.name = "Tool2" - + handoff1 = Mock() handoff1.name = "Handoff1" handoff1.tools = [] @@ -20,28 +23,55 @@ def mock_agent(): agent.name = "Agent1" agent.tools = [tool1, tool2] agent.handoffs = [handoff1] - + return agent + def test_get_main_graph(mock_agent): result = get_main_graph(mock_agent) assert "digraph G" in result - assert 'graph [splines=true];' in result + assert "graph [splines=true];" in result assert 'node [fontname="Arial"];' in result - assert 'edge [penwidth=1.5];' in result + assert "edge [penwidth=1.5];" in result assert '"__start__" [shape=ellipse, style=filled, fillcolor=lightblue];' in result assert '"__end__" [shape=ellipse, style=filled, fillcolor=lightblue];' in result - assert '"Agent1" [label="Agent1", shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8];' in result - assert '"Tool1" [label="Tool1", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' in result - assert '"Tool2" [label="Tool2", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' in result - assert '"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, fillcolor=lightyellow, width=1.5, height=0.8];' in result + assert ( + '"Agent1" [label="Agent1", shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8];' + in result + ) + assert ( + '"Tool1" [label="Tool1", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' + in result + ) + assert ( + '"Tool2" [label="Tool2", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' + in result + ) + assert ( + '"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, fillcolor=lightyellow, width=1.5, height=0.8];' + in result + ) + def test_get_all_nodes(mock_agent): result = get_all_nodes(mock_agent) - assert '"Agent1" [label="Agent1", shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8];' in result - assert '"Tool1" [label="Tool1", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' in result - assert '"Tool2" [label="Tool2", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' in result - assert '"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, fillcolor=lightyellow, width=1.5, height=0.8];' in result + assert ( + '"Agent1" [label="Agent1", shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8];' + in result + ) + assert ( + '"Tool1" [label="Tool1", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' + in result + ) + assert ( + '"Tool2" [label="Tool2", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' + in result + ) + assert ( + '"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, fillcolor=lightyellow, width=1.5, height=0.8];' + in result + ) + def test_get_all_edges(mock_agent): result = get_all_edges(mock_agent) @@ -53,16 +83,29 @@ def test_get_all_edges(mock_agent): assert '"Agent1" -> "Handoff1";' in result assert '"Handoff1" -> "__end__";' in result + def test_draw_graph(mock_agent): graph = draw_graph(mock_agent) assert isinstance(graph, graphviz.Source) assert "digraph G" in graph.source - assert 'graph [splines=true];' in graph.source + assert "graph [splines=true];" in graph.source assert 'node [fontname="Arial"];' in graph.source - assert 'edge [penwidth=1.5];' in graph.source + assert "edge [penwidth=1.5];" in graph.source assert '"__start__" [shape=ellipse, style=filled, fillcolor=lightblue];' in graph.source assert '"__end__" [shape=ellipse, style=filled, fillcolor=lightblue];' in graph.source - assert '"Agent1" [label="Agent1", shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8];' in graph.source - assert '"Tool1" [label="Tool1", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' in graph.source - assert '"Tool2" [label="Tool2", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' in graph.source - assert '"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, fillcolor=lightyellow, width=1.5, height=0.8];' in graph.source + assert ( + '"Agent1" [label="Agent1", shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8];' + in graph.source + ) + assert ( + '"Tool1" [label="Tool1", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' + in graph.source + ) + assert ( + '"Tool2" [label="Tool2", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' + in graph.source + ) + assert ( + '"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, fillcolor=lightyellow, width=1.5, height=0.8];' + in graph.source + )