diff --git a/src/agents/extensions/visualization.py b/src/agents/extensions/visualization.py index c020708..985731b 100644 --- a/src/agents/extensions/visualization.py +++ b/src/agents/extensions/visualization.py @@ -19,7 +19,6 @@ def get_main_graph(agent: Agent) -> str: graph [splines=true]; node [fontname="Arial"]; edge [penwidth=1.5]; - "__start__" [shape=ellipse, style=filled, fillcolor=lightblue]; "__end__" [shape=ellipse, style=filled, fillcolor=lightblue]; """ @@ -41,21 +40,23 @@ def get_all_nodes(agent: Agent, parent: Agent = None) -> str: str: The DOT format string representing the nodes. """ parts = [] - # Ensure parent agent node is colored if not parent: parts.append(f""" - "{agent.name}" [label="{agent.name}", shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8];""") + "{agent.name}" [label="{agent.name}", shape=box, style=filled, + fillcolor=lightyellow, width=1.5, height=0.8];""") # Smaller tools (ellipse, green) for tool in agent.tools: parts.append(f""" - "{tool.name}" [label="{tool.name}", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];""") + "{tool.name}" [label="{tool.name}", shape=ellipse, style=filled, + fillcolor=lightgreen, width=0.5, height=0.3];""") # Bigger handoffs (rounded box, yellow) for handoff in agent.handoffs: parts.append(f""" - "{handoff.name}" [label="{handoff.name}", shape=box, style=filled, style=rounded, fillcolor=lightyellow, width=1.5, height=0.8];""") + "{handoff.name}" [label="{handoff.name}", shape=box, style=filled, + style=rounded, fillcolor=lightyellow, width=1.5, height=0.8];""") parts.append(get_all_nodes(handoff)) return "".join(parts) diff --git a/tests/test_visualization.py b/tests/test_visualization.py index 8505b61..12d67fb 100644 --- a/tests/test_visualization.py +++ b/tests/test_visualization.py @@ -4,7 +4,12 @@ import graphviz import pytest from agents import Agent -from agents.extensions.visualization import draw_graph, get_all_edges, get_all_nodes, get_main_graph +from agents.extensions.visualization import ( + draw_graph, + get_all_edges, + get_all_nodes, + get_main_graph, +) @pytest.fixture @@ -36,40 +41,40 @@ def test_get_main_graph(mock_agent): assert '"__start__" [shape=ellipse, style=filled, fillcolor=lightblue];' in result assert '"__end__" [shape=ellipse, style=filled, fillcolor=lightblue];' in result assert ( - '"Agent1" [label="Agent1", shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8];' - in result + '"Agent1" [label="Agent1", shape=box, style=filled, ' + "fillcolor=lightyellow, width=1.5, height=0.8];" in result ) assert ( - '"Tool1" [label="Tool1", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' - in result + '"Tool1" [label="Tool1", shape=ellipse, style=filled, ' + "fillcolor=lightgreen, width=0.5, height=0.3];" in result ) assert ( - '"Tool2" [label="Tool2", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' - in result + '"Tool2" [label="Tool2", shape=ellipse, style=filled, ' + "fillcolor=lightgreen, width=0.5, height=0.3];" in result ) assert ( - '"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, fillcolor=lightyellow, width=1.5, height=0.8];' - in result + '"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, ' + "fillcolor=lightyellow, width=1.5, height=0.8];" in result ) def test_get_all_nodes(mock_agent): result = get_all_nodes(mock_agent) assert ( - '"Agent1" [label="Agent1", shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8];' - in result + '"Agent1" [label="Agent1", shape=box, style=filled, ' + "fillcolor=lightyellow, width=1.5, height=0.8];" in result ) assert ( - '"Tool1" [label="Tool1", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' - in result + '"Tool1" [label="Tool1", shape=ellipse, style=filled, ' + "fillcolor=lightgreen, width=0.5, height=0.3];" in result ) assert ( - '"Tool2" [label="Tool2", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' - in result + '"Tool2" [label="Tool2", shape=ellipse, style=filled, ' + "fillcolor=lightgreen, width=0.5, height=0.3];" in result ) assert ( - '"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, fillcolor=lightyellow, width=1.5, height=0.8];' - in result + '"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, ' + "fillcolor=lightyellow, width=1.5, height=0.8];" in result ) @@ -94,18 +99,18 @@ def test_draw_graph(mock_agent): assert '"__start__" [shape=ellipse, style=filled, fillcolor=lightblue];' in graph.source assert '"__end__" [shape=ellipse, style=filled, fillcolor=lightblue];' in graph.source assert ( - '"Agent1" [label="Agent1", shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8];' - in graph.source + '"Agent1" [label="Agent1", shape=box, style=filled, ' + "fillcolor=lightyellow, width=1.5, height=0.8];" in graph.source ) assert ( - '"Tool1" [label="Tool1", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' - in graph.source + '"Tool1" [label="Tool1", shape=ellipse, style=filled, ' + "fillcolor=lightgreen, width=0.5, height=0.3];" in graph.source ) assert ( - '"Tool2" [label="Tool2", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' - in graph.source + '"Tool2" [label="Tool2", shape=ellipse, style=filled, ' + "fillcolor=lightgreen, width=0.5, height=0.3];" in graph.source ) assert ( - '"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, fillcolor=lightyellow, width=1.5, height=0.8];' - in graph.source + '"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, ' + "fillcolor=lightyellow, width=1.5, height=0.8];" in graph.source )