refactor: clean up visualization functions by removing unused nodes and improving type hints
This commit is contained in:
parent
a5b7abe8b4
commit
623063b633
2 changed files with 10 additions and 26 deletions
|
|
@ -1,3 +1,5 @@
|
|||
from typing import Optional
|
||||
|
||||
import graphviz
|
||||
|
||||
from agents import Agent
|
||||
|
|
@ -20,8 +22,6 @@ def get_main_graph(agent: Agent) -> str:
|
|||
graph [splines=true];
|
||||
node [fontname="Arial"];
|
||||
edge [penwidth=1.5];
|
||||
"__start__" [shape=ellipse, style=filled, fillcolor=lightblue];
|
||||
"__end__" [shape=ellipse, style=filled, fillcolor=lightblue];
|
||||
"""
|
||||
]
|
||||
parts.append(get_all_nodes(agent))
|
||||
|
|
@ -30,8 +30,6 @@ def get_main_graph(agent: Agent) -> str:
|
|||
return "".join(parts)
|
||||
|
||||
|
||||
from typing import Optional
|
||||
|
||||
def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str:
|
||||
"""
|
||||
Recursively generates the nodes for the given agent and its handoffs in DOT format.
|
||||
|
|
@ -59,12 +57,14 @@ def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str:
|
|||
for handoff in agent.handoffs:
|
||||
if isinstance(handoff, Handoff):
|
||||
parts.append(
|
||||
f'"{handoff.agent_name}" [label="{handoff.agent_name}", shape=box, style=filled, style=rounded, '
|
||||
f'"{handoff.agent_name}" [label="{handoff.agent_name}", shape=box, '
|
||||
f"shape=box, style=filled, style=rounded, "
|
||||
f"fillcolor=lightyellow, width=1.5, height=0.8];"
|
||||
)
|
||||
if isinstance(handoff, Agent):
|
||||
parts.append(
|
||||
f'"{handoff.name}" [label="{handoff.name}", shape=box, style=filled, style=rounded, '
|
||||
f'"{handoff.name}" [label="{handoff.name}", '
|
||||
f"shape=box, style=filled, style=rounded, "
|
||||
f"fillcolor=lightyellow, width=1.5, height=0.8];"
|
||||
)
|
||||
parts.append(get_all_nodes(handoff))
|
||||
|
|
@ -85,19 +85,11 @@ def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str:
|
|||
"""
|
||||
parts = []
|
||||
|
||||
if not parent:
|
||||
parts.append(f"""
|
||||
"__start__" -> "{agent.name}";""")
|
||||
|
||||
for tool in agent.tools:
|
||||
parts.append(f"""
|
||||
"{agent.name}" -> "{tool.name}" [style=dotted, penwidth=1.5];
|
||||
"{tool.name}" -> "{agent.name}" [style=dotted, penwidth=1.5];""")
|
||||
|
||||
if not agent.handoffs:
|
||||
parts.append(f"""
|
||||
"{agent.name}" -> "__end__";""")
|
||||
|
||||
for handoff in agent.handoffs:
|
||||
if isinstance(handoff, Handoff):
|
||||
parts.append(f"""
|
||||
|
|
@ -110,8 +102,6 @@ def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str:
|
|||
return "".join(parts)
|
||||
|
||||
|
||||
from typing import Optional
|
||||
|
||||
def draw_graph(agent: Agent, filename: Optional[str] = None) -> graphviz.Source:
|
||||
"""
|
||||
Draws the graph for the given agent and optionally saves it as a PNG file.
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from agents.extensions.visualization import (
|
|||
get_all_nodes,
|
||||
get_main_graph,
|
||||
)
|
||||
from agents.handoffs import Handoff
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -19,10 +20,8 @@ def mock_agent():
|
|||
tool2 = Mock()
|
||||
tool2.name = "Tool2"
|
||||
|
||||
handoff1 = Mock()
|
||||
handoff1.name = "Handoff1"
|
||||
handoff1.tools = []
|
||||
handoff1.handoffs = []
|
||||
handoff1 = Mock(spec=Handoff)
|
||||
handoff1.agent_name = "Handoff1"
|
||||
|
||||
agent = Mock(spec=Agent)
|
||||
agent.name = "Agent1"
|
||||
|
|
@ -34,12 +33,11 @@ def mock_agent():
|
|||
|
||||
def test_get_main_graph(mock_agent):
|
||||
result = get_main_graph(mock_agent)
|
||||
print(result)
|
||||
assert "digraph G" in result
|
||||
assert "graph [splines=true];" in result
|
||||
assert 'node [fontname="Arial"];' in result
|
||||
assert "edge [penwidth=1.5];" in result
|
||||
assert '"__start__" [shape=ellipse, style=filled, fillcolor=lightblue];' in result
|
||||
assert '"__end__" [shape=ellipse, style=filled, fillcolor=lightblue];' in result
|
||||
assert (
|
||||
'"Agent1" [label="Agent1", shape=box, style=filled, '
|
||||
"fillcolor=lightyellow, width=1.5, height=0.8];" in result
|
||||
|
|
@ -80,13 +78,11 @@ def test_get_all_nodes(mock_agent):
|
|||
|
||||
def test_get_all_edges(mock_agent):
|
||||
result = get_all_edges(mock_agent)
|
||||
assert '"__start__" -> "Agent1";' in result
|
||||
assert '"Agent1" -> "Tool1" [style=dotted, penwidth=1.5];' in result
|
||||
assert '"Tool1" -> "Agent1" [style=dotted, penwidth=1.5];' in result
|
||||
assert '"Agent1" -> "Tool2" [style=dotted, penwidth=1.5];' in result
|
||||
assert '"Tool2" -> "Agent1" [style=dotted, penwidth=1.5];' in result
|
||||
assert '"Agent1" -> "Handoff1";' in result
|
||||
assert '"Handoff1" -> "__end__";' in result
|
||||
|
||||
|
||||
def test_draw_graph(mock_agent):
|
||||
|
|
@ -96,8 +92,6 @@ def test_draw_graph(mock_agent):
|
|||
assert "graph [splines=true];" in graph.source
|
||||
assert 'node [fontname="Arial"];' in graph.source
|
||||
assert "edge [penwidth=1.5];" in graph.source
|
||||
assert '"__start__" [shape=ellipse, style=filled, fillcolor=lightblue];' in graph.source
|
||||
assert '"__end__" [shape=ellipse, style=filled, fillcolor=lightblue];' in graph.source
|
||||
assert (
|
||||
'"Agent1" [label="Agent1", shape=box, style=filled, '
|
||||
"fillcolor=lightyellow, width=1.5, height=0.8];" in graph.source
|
||||
|
|
|
|||
Loading…
Reference in a new issue