Add start and end nodes to graph visualization and update edge generation
This commit is contained in:
parent
29103caba9
commit
5ad53d8000
2 changed files with 48 additions and 2 deletions
|
|
@ -4,6 +4,7 @@ import graphviz # type: ignore
|
||||||
|
|
||||||
from agents import Agent
|
from agents import Agent
|
||||||
from agents.handoffs import Handoff
|
from agents.handoffs import Handoff
|
||||||
|
from agents.tool import Tool
|
||||||
|
|
||||||
|
|
||||||
def get_main_graph(agent: Agent) -> str:
|
def get_main_graph(agent: Agent) -> str:
|
||||||
|
|
@ -41,6 +42,14 @@ def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str:
|
||||||
str: The DOT format string representing the nodes.
|
str: The DOT format string representing the nodes.
|
||||||
"""
|
"""
|
||||||
parts = []
|
parts = []
|
||||||
|
|
||||||
|
# Start and end the graph
|
||||||
|
parts.append(
|
||||||
|
f'"__start__" [label="__start__", shape=ellipse, style=filled, '
|
||||||
|
f"fillcolor=lightblue, width=0.5, height=0.3];"
|
||||||
|
f'"__end__" [label="__end__", shape=ellipse, style=filled, '
|
||||||
|
f"fillcolor=lightblue, width=0.5, height=0.3];"
|
||||||
|
)
|
||||||
# Ensure parent agent node is colored
|
# Ensure parent agent node is colored
|
||||||
if not parent:
|
if not parent:
|
||||||
parts.append(
|
parts.append(
|
||||||
|
|
@ -72,7 +81,7 @@ def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str:
|
||||||
return "".join(parts)
|
return "".join(parts)
|
||||||
|
|
||||||
|
|
||||||
def get_all_edges(agent: Agent) -> str:
|
def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str:
|
||||||
"""
|
"""
|
||||||
Recursively generates the edges for the given agent and its handoffs in DOT format.
|
Recursively generates the edges for the given agent and its handoffs in DOT format.
|
||||||
|
|
||||||
|
|
@ -85,6 +94,11 @@ def get_all_edges(agent: Agent) -> str:
|
||||||
"""
|
"""
|
||||||
parts = []
|
parts = []
|
||||||
|
|
||||||
|
if not parent:
|
||||||
|
parts.append(
|
||||||
|
f'"__start__" -> "{agent.name}";'
|
||||||
|
)
|
||||||
|
|
||||||
for tool in agent.tools:
|
for tool in agent.tools:
|
||||||
parts.append(f"""
|
parts.append(f"""
|
||||||
"{agent.name}" -> "{tool.name}" [style=dotted, penwidth=1.5];
|
"{agent.name}" -> "{tool.name}" [style=dotted, penwidth=1.5];
|
||||||
|
|
@ -97,7 +111,12 @@ def get_all_edges(agent: Agent) -> str:
|
||||||
if isinstance(handoff, Agent):
|
if isinstance(handoff, Agent):
|
||||||
parts.append(f"""
|
parts.append(f"""
|
||||||
"{agent.name}" -> "{handoff.name}";""")
|
"{agent.name}" -> "{handoff.name}";""")
|
||||||
parts.append(get_all_edges(handoff))
|
parts.append(get_all_edges(handoff, agent))
|
||||||
|
|
||||||
|
if not agent.handoffs and not isinstance(agent, Tool):
|
||||||
|
parts.append(
|
||||||
|
f'"{agent.name}" -> "__end__";'
|
||||||
|
)
|
||||||
|
|
||||||
return "".join(parts)
|
return "".join(parts)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -38,6 +38,14 @@ def test_get_main_graph(mock_agent):
|
||||||
assert "graph [splines=true];" in result
|
assert "graph [splines=true];" in result
|
||||||
assert 'node [fontname="Arial"];' in result
|
assert 'node [fontname="Arial"];' in result
|
||||||
assert "edge [penwidth=1.5];" in result
|
assert "edge [penwidth=1.5];" in result
|
||||||
|
assert (
|
||||||
|
'"__start__" [label="__start__", shape=ellipse, style=filled, '
|
||||||
|
"fillcolor=lightblue, width=0.5, height=0.3];" in result
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
'"__end__" [label="__end__", shape=ellipse, style=filled, '
|
||||||
|
"fillcolor=lightblue, width=0.5, height=0.3];" in result
|
||||||
|
)
|
||||||
assert (
|
assert (
|
||||||
'"Agent1" [label="Agent1", shape=box, style=filled, '
|
'"Agent1" [label="Agent1", shape=box, style=filled, '
|
||||||
"fillcolor=lightyellow, width=1.5, height=0.8];" in result
|
"fillcolor=lightyellow, width=1.5, height=0.8];" in result
|
||||||
|
|
@ -58,6 +66,14 @@ def test_get_main_graph(mock_agent):
|
||||||
|
|
||||||
def test_get_all_nodes(mock_agent):
|
def test_get_all_nodes(mock_agent):
|
||||||
result = get_all_nodes(mock_agent)
|
result = get_all_nodes(mock_agent)
|
||||||
|
assert (
|
||||||
|
'"__start__" [label="__start__", shape=ellipse, style=filled, '
|
||||||
|
"fillcolor=lightblue, width=0.5, height=0.3];" in result
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
'"__end__" [label="__end__", shape=ellipse, style=filled, '
|
||||||
|
"fillcolor=lightblue, width=0.5, height=0.3];" in result
|
||||||
|
)
|
||||||
assert (
|
assert (
|
||||||
'"Agent1" [label="Agent1", shape=box, style=filled, '
|
'"Agent1" [label="Agent1", shape=box, style=filled, '
|
||||||
"fillcolor=lightyellow, width=1.5, height=0.8];" in result
|
"fillcolor=lightyellow, width=1.5, height=0.8];" in result
|
||||||
|
|
@ -78,6 +94,8 @@ def test_get_all_nodes(mock_agent):
|
||||||
|
|
||||||
def test_get_all_edges(mock_agent):
|
def test_get_all_edges(mock_agent):
|
||||||
result = get_all_edges(mock_agent)
|
result = get_all_edges(mock_agent)
|
||||||
|
assert '"__start__" -> "Agent1";' in result
|
||||||
|
assert '"Agent1" -> "__end__";'
|
||||||
assert '"Agent1" -> "Tool1" [style=dotted, penwidth=1.5];' in result
|
assert '"Agent1" -> "Tool1" [style=dotted, penwidth=1.5];' in result
|
||||||
assert '"Tool1" -> "Agent1" [style=dotted, penwidth=1.5];' in result
|
assert '"Tool1" -> "Agent1" [style=dotted, penwidth=1.5];' in result
|
||||||
assert '"Agent1" -> "Tool2" [style=dotted, penwidth=1.5];' in result
|
assert '"Agent1" -> "Tool2" [style=dotted, penwidth=1.5];' in result
|
||||||
|
|
@ -85,6 +103,7 @@ def test_get_all_edges(mock_agent):
|
||||||
assert '"Agent1" -> "Handoff1";' in result
|
assert '"Agent1" -> "Handoff1";' in result
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_draw_graph(mock_agent):
|
def test_draw_graph(mock_agent):
|
||||||
graph = draw_graph(mock_agent)
|
graph = draw_graph(mock_agent)
|
||||||
assert isinstance(graph, graphviz.Source)
|
assert isinstance(graph, graphviz.Source)
|
||||||
|
|
@ -92,6 +111,14 @@ def test_draw_graph(mock_agent):
|
||||||
assert "graph [splines=true];" in graph.source
|
assert "graph [splines=true];" in graph.source
|
||||||
assert 'node [fontname="Arial"];' in graph.source
|
assert 'node [fontname="Arial"];' in graph.source
|
||||||
assert "edge [penwidth=1.5];" in graph.source
|
assert "edge [penwidth=1.5];" in graph.source
|
||||||
|
assert (
|
||||||
|
'"__start__" [label="__start__", shape=ellipse, style=filled, '
|
||||||
|
"fillcolor=lightblue, width=0.5, height=0.3];" in graph.source
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
'"__end__" [label="__end__", shape=ellipse, style=filled, '
|
||||||
|
"fillcolor=lightblue, width=0.5, height=0.3];" in graph.source
|
||||||
|
)
|
||||||
assert (
|
assert (
|
||||||
'"Agent1" [label="Agent1", shape=box, style=filled, '
|
'"Agent1" [label="Agent1", shape=box, style=filled, '
|
||||||
"fillcolor=lightyellow, width=1.5, height=0.8];" in graph.source
|
"fillcolor=lightyellow, width=1.5, height=0.8];" in graph.source
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue