Fix visualization recursion with cycle detection (#737)
## Summary - avoid infinite recursion in visualization by tracking visited agents - test cycle detection in graph utility ## Testing - `make mypy` - `make tests` Resolves #668
This commit is contained in:
parent
1364f4408e
commit
db462e32a3
2 changed files with 50 additions and 18 deletions
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Optional
|
from __future__ import annotations
|
||||||
|
|
||||||
import graphviz # type: ignore
|
import graphviz # type: ignore
|
||||||
|
|
||||||
|
|
@ -31,7 +31,9 @@ def get_main_graph(agent: Agent) -> str:
|
||||||
return "".join(parts)
|
return "".join(parts)
|
||||||
|
|
||||||
|
|
||||||
def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str:
|
def get_all_nodes(
|
||||||
|
agent: Agent, parent: Agent | None = None, visited: set[str] | None = None
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Recursively generates the nodes for the given agent and its handoffs in DOT format.
|
Recursively generates the nodes for the given agent and its handoffs in DOT format.
|
||||||
|
|
||||||
|
|
@ -41,17 +43,23 @@ def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str:
|
||||||
Returns:
|
Returns:
|
||||||
str: The DOT format string representing the nodes.
|
str: The DOT format string representing the nodes.
|
||||||
"""
|
"""
|
||||||
|
if visited is None:
|
||||||
|
visited = set()
|
||||||
|
if agent.name in visited:
|
||||||
|
return ""
|
||||||
|
visited.add(agent.name)
|
||||||
|
|
||||||
parts = []
|
parts = []
|
||||||
|
|
||||||
# Start and end the graph
|
# Start and end the graph
|
||||||
parts.append(
|
|
||||||
'"__start__" [label="__start__", shape=ellipse, style=filled, '
|
|
||||||
"fillcolor=lightblue, width=0.5, height=0.3];"
|
|
||||||
'"__end__" [label="__end__", shape=ellipse, style=filled, '
|
|
||||||
"fillcolor=lightblue, width=0.5, height=0.3];"
|
|
||||||
)
|
|
||||||
# Ensure parent agent node is colored
|
|
||||||
if not parent:
|
if not parent:
|
||||||
|
parts.append(
|
||||||
|
'"__start__" [label="__start__", shape=ellipse, style=filled, '
|
||||||
|
"fillcolor=lightblue, width=0.5, height=0.3];"
|
||||||
|
'"__end__" [label="__end__", shape=ellipse, style=filled, '
|
||||||
|
"fillcolor=lightblue, width=0.5, height=0.3];"
|
||||||
|
)
|
||||||
|
# Ensure parent agent node is colored
|
||||||
parts.append(
|
parts.append(
|
||||||
f'"{agent.name}" [label="{agent.name}", shape=box, style=filled, '
|
f'"{agent.name}" [label="{agent.name}", shape=box, style=filled, '
|
||||||
"fillcolor=lightyellow, width=1.5, height=0.8];"
|
"fillcolor=lightyellow, width=1.5, height=0.8];"
|
||||||
|
|
@ -71,17 +79,20 @@ def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str:
|
||||||
f"fillcolor=lightyellow, width=1.5, height=0.8];"
|
f"fillcolor=lightyellow, width=1.5, height=0.8];"
|
||||||
)
|
)
|
||||||
if isinstance(handoff, Agent):
|
if isinstance(handoff, Agent):
|
||||||
parts.append(
|
if handoff.name not in visited:
|
||||||
f'"{handoff.name}" [label="{handoff.name}", '
|
parts.append(
|
||||||
f"shape=box, style=filled, style=rounded, "
|
f'"{handoff.name}" [label="{handoff.name}", '
|
||||||
f"fillcolor=lightyellow, width=1.5, height=0.8];"
|
f"shape=box, style=filled, style=rounded, "
|
||||||
)
|
f"fillcolor=lightyellow, width=1.5, height=0.8];"
|
||||||
parts.append(get_all_nodes(handoff))
|
)
|
||||||
|
parts.append(get_all_nodes(handoff, agent, visited))
|
||||||
|
|
||||||
return "".join(parts)
|
return "".join(parts)
|
||||||
|
|
||||||
|
|
||||||
def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str:
|
def get_all_edges(
|
||||||
|
agent: Agent, parent: Agent | None = None, visited: set[str] | None = None
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Recursively generates the edges for the given agent and its handoffs in DOT format.
|
Recursively generates the edges for the given agent and its handoffs in DOT format.
|
||||||
|
|
||||||
|
|
@ -92,6 +103,12 @@ def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str:
|
||||||
Returns:
|
Returns:
|
||||||
str: The DOT format string representing the edges.
|
str: The DOT format string representing the edges.
|
||||||
"""
|
"""
|
||||||
|
if visited is None:
|
||||||
|
visited = set()
|
||||||
|
if agent.name in visited:
|
||||||
|
return ""
|
||||||
|
visited.add(agent.name)
|
||||||
|
|
||||||
parts = []
|
parts = []
|
||||||
|
|
||||||
if not parent:
|
if not parent:
|
||||||
|
|
@ -109,7 +126,7 @@ def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str:
|
||||||
if isinstance(handoff, Agent):
|
if isinstance(handoff, Agent):
|
||||||
parts.append(f"""
|
parts.append(f"""
|
||||||
"{agent.name}" -> "{handoff.name}";""")
|
"{agent.name}" -> "{handoff.name}";""")
|
||||||
parts.append(get_all_edges(handoff, agent))
|
parts.append(get_all_edges(handoff, agent, visited))
|
||||||
|
|
||||||
if not agent.handoffs and not isinstance(agent, Tool): # type: ignore
|
if not agent.handoffs and not isinstance(agent, Tool): # type: ignore
|
||||||
parts.append(f'"{agent.name}" -> "__end__";')
|
parts.append(f'"{agent.name}" -> "__end__";')
|
||||||
|
|
@ -117,7 +134,7 @@ def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str:
|
||||||
return "".join(parts)
|
return "".join(parts)
|
||||||
|
|
||||||
|
|
||||||
def draw_graph(agent: Agent, filename: Optional[str] = None) -> graphviz.Source:
|
def draw_graph(agent: Agent, filename: str | None = None) -> graphviz.Source:
|
||||||
"""
|
"""
|
||||||
Draws the graph for the given agent and optionally saves it as a PNG file.
|
Draws the graph for the given agent and optionally saves it as a PNG file.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -134,3 +134,18 @@ def test_draw_graph(mock_agent):
|
||||||
'"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, '
|
'"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, '
|
||||||
"fillcolor=lightyellow, width=1.5, height=0.8];" in graph.source
|
"fillcolor=lightyellow, width=1.5, height=0.8];" in graph.source
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_cycle_detection():
|
||||||
|
agent_a = Agent(name="A")
|
||||||
|
agent_b = Agent(name="B")
|
||||||
|
agent_a.handoffs.append(agent_b)
|
||||||
|
agent_b.handoffs.append(agent_a)
|
||||||
|
|
||||||
|
nodes = get_all_nodes(agent_a)
|
||||||
|
edges = get_all_edges(agent_a)
|
||||||
|
|
||||||
|
assert nodes.count('"A" [label="A"') == 1
|
||||||
|
assert nodes.count('"B" [label="B"') == 1
|
||||||
|
assert '"A" -> "B"' in edges
|
||||||
|
assert '"B" -> "A"' in edges
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue