From 351b6074e5e69c29e791291e0b2a38e52981d2ac Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mart=C3=ADn=20Bravo?= <123977407+MartinEBravo@users.noreply.github.com> Date: Tue, 25 Mar 2025 19:12:40 +0100 Subject: [PATCH] Refactor visualization functions to improve formatting and streamline edge generation --- src/agents/extensions/visualization.py | 16 ++++++---------- tests/test_visualization.py | 1 - 2 files changed, 6 insertions(+), 11 deletions(-) diff --git a/src/agents/extensions/visualization.py b/src/agents/extensions/visualization.py index 08b185e..013a21e 100644 --- a/src/agents/extensions/visualization.py +++ b/src/agents/extensions/visualization.py @@ -45,10 +45,10 @@ def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str: # Start and end the graph parts.append( - f'"__start__" [label="__start__", shape=ellipse, style=filled, ' - f"fillcolor=lightblue, width=0.5, height=0.3];" - f'"__end__" [label="__end__", shape=ellipse, style=filled, ' - f"fillcolor=lightblue, width=0.5, height=0.3];" + '"__start__" [label="__start__", shape=ellipse, style=filled, ' + "fillcolor=lightblue, width=0.5, height=0.3];" + '"__end__" [label="__end__", shape=ellipse, style=filled, ' + "fillcolor=lightblue, width=0.5, height=0.3];" ) # Ensure parent agent node is colored if not parent: @@ -95,9 +95,7 @@ def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str: parts = [] if not parent: - parts.append( - f'"__start__" -> "{agent.name}";' - ) + parts.append(f'"__start__" -> "{agent.name}";') for tool in agent.tools: parts.append(f""" @@ -114,9 +112,7 @@ def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str: parts.append(get_all_edges(handoff, agent)) if not agent.handoffs and not isinstance(agent, Tool): - parts.append( - f'"{agent.name}" -> "__end__";' - ) + parts.append(f'"{agent.name}" -> "__end__";') return "".join(parts) diff --git a/tests/test_visualization.py b/tests/test_visualization.py index b530f50..6aa8677 100644 --- a/tests/test_visualization.py +++ b/tests/test_visualization.py @@ -103,7 +103,6 @@ def test_get_all_edges(mock_agent): assert '"Agent1" -> "Handoff1";' in result - def test_draw_graph(mock_agent): graph = draw_graph(mock_agent) assert isinstance(graph, graphviz.Source)