style: improve code formatting and readability in visualization functions

This commit is contained in:
Martín Bravo 2025-03-18 10:09:44 +01:00
parent 39ff00dd9d
commit 0079bca717
2 changed files with 36 additions and 30 deletions

View file

@ -19,7 +19,6 @@ def get_main_graph(agent: Agent) -> str:
graph [splines=true];
node [fontname="Arial"];
edge [penwidth=1.5];
"__start__" [shape=ellipse, style=filled, fillcolor=lightblue];
"__end__" [shape=ellipse, style=filled, fillcolor=lightblue];
"""
@ -41,21 +40,23 @@ def get_all_nodes(agent: Agent, parent: Agent = None) -> str:
str: The DOT format string representing the nodes.
"""
parts = []
# Ensure parent agent node is colored
if not parent:
parts.append(f"""
"{agent.name}" [label="{agent.name}", shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8];""")
"{agent.name}" [label="{agent.name}", shape=box, style=filled,
fillcolor=lightyellow, width=1.5, height=0.8];""")
# Smaller tools (ellipse, green)
for tool in agent.tools:
parts.append(f"""
"{tool.name}" [label="{tool.name}", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];""")
"{tool.name}" [label="{tool.name}", shape=ellipse, style=filled,
fillcolor=lightgreen, width=0.5, height=0.3];""")
# Bigger handoffs (rounded box, yellow)
for handoff in agent.handoffs:
parts.append(f"""
"{handoff.name}" [label="{handoff.name}", shape=box, style=filled, style=rounded, fillcolor=lightyellow, width=1.5, height=0.8];""")
"{handoff.name}" [label="{handoff.name}", shape=box, style=filled,
style=rounded, fillcolor=lightyellow, width=1.5, height=0.8];""")
parts.append(get_all_nodes(handoff))
return "".join(parts)

View file

@ -4,7 +4,12 @@ import graphviz
import pytest
from agents import Agent
from agents.extensions.visualization import draw_graph, get_all_edges, get_all_nodes, get_main_graph
from agents.extensions.visualization import (
draw_graph,
get_all_edges,
get_all_nodes,
get_main_graph,
)
@pytest.fixture
@ -36,40 +41,40 @@ def test_get_main_graph(mock_agent):
assert '"__start__" [shape=ellipse, style=filled, fillcolor=lightblue];' in result
assert '"__end__" [shape=ellipse, style=filled, fillcolor=lightblue];' in result
assert (
'"Agent1" [label="Agent1", shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8];'
in result
'"Agent1" [label="Agent1", shape=box, style=filled, '
"fillcolor=lightyellow, width=1.5, height=0.8];" in result
)
assert (
'"Tool1" [label="Tool1", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];'
in result
'"Tool1" [label="Tool1", shape=ellipse, style=filled, '
"fillcolor=lightgreen, width=0.5, height=0.3];" in result
)
assert (
'"Tool2" [label="Tool2", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];'
in result
'"Tool2" [label="Tool2", shape=ellipse, style=filled, '
"fillcolor=lightgreen, width=0.5, height=0.3];" in result
)
assert (
'"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, fillcolor=lightyellow, width=1.5, height=0.8];'
in result
'"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, '
"fillcolor=lightyellow, width=1.5, height=0.8];" in result
)
def test_get_all_nodes(mock_agent):
result = get_all_nodes(mock_agent)
assert (
'"Agent1" [label="Agent1", shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8];'
in result
'"Agent1" [label="Agent1", shape=box, style=filled, '
"fillcolor=lightyellow, width=1.5, height=0.8];" in result
)
assert (
'"Tool1" [label="Tool1", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];'
in result
'"Tool1" [label="Tool1", shape=ellipse, style=filled, '
"fillcolor=lightgreen, width=0.5, height=0.3];" in result
)
assert (
'"Tool2" [label="Tool2", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];'
in result
'"Tool2" [label="Tool2", shape=ellipse, style=filled, '
"fillcolor=lightgreen, width=0.5, height=0.3];" in result
)
assert (
'"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, fillcolor=lightyellow, width=1.5, height=0.8];'
in result
'"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, '
"fillcolor=lightyellow, width=1.5, height=0.8];" in result
)
@ -94,18 +99,18 @@ def test_draw_graph(mock_agent):
assert '"__start__" [shape=ellipse, style=filled, fillcolor=lightblue];' in graph.source
assert '"__end__" [shape=ellipse, style=filled, fillcolor=lightblue];' in graph.source
assert (
'"Agent1" [label="Agent1", shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8];'
in graph.source
'"Agent1" [label="Agent1", shape=box, style=filled, '
"fillcolor=lightyellow, width=1.5, height=0.8];" in graph.source
)
assert (
'"Tool1" [label="Tool1", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];'
in graph.source
'"Tool1" [label="Tool1", shape=ellipse, style=filled, '
"fillcolor=lightgreen, width=0.5, height=0.3];" in graph.source
)
assert (
'"Tool2" [label="Tool2", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];'
in graph.source
'"Tool2" [label="Tool2", shape=ellipse, style=filled, '
"fillcolor=lightgreen, width=0.5, height=0.3];" in graph.source
)
assert (
'"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, fillcolor=lightyellow, width=1.5, height=0.8];'
in graph.source
'"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, '
"fillcolor=lightyellow, width=1.5, height=0.8];" in graph.source
)