feat: enhance visualization functions with optional type hints and improved handling of agents and handoffs

This commit is contained in:
Martín Bravo 2025-03-24 09:30:13 +01:00
parent 9f7d596d14
commit a5b7abe8b4

View file

@ -1,6 +1,7 @@
import graphviz
from agents import Agent
from agents.handoffs import Handoff
def get_main_graph(agent: Agent) -> str:
@ -29,7 +30,9 @@ def get_main_graph(agent: Agent) -> str:
return "".join(parts)
def get_all_nodes(agent: Agent, parent: Agent = None) -> str:
from typing import Optional
def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str:
"""
Recursively generates the nodes for the given agent and its handoffs in DOT format.
@ -47,25 +50,29 @@ def get_all_nodes(agent: Agent, parent: Agent = None) -> str:
"fillcolor=lightyellow, width=1.5, height=0.8];"
)
# Smaller tools (ellipse, green)
for tool in agent.tools:
parts.append(
f'"{tool.name}" [label="{tool.name}", shape=ellipse, style=filled, '
f"fillcolor=lightgreen, width=0.5, height=0.3];"
)
# Bigger handoffs (rounded box, yellow)
for handoff in agent.handoffs:
parts.append(
f'"{handoff.name}" [label="{handoff.name}", shape=box, style=filled, style=rounded, '
f"fillcolor=lightyellow, width=1.5, height=0.8];"
)
parts.append(get_all_nodes(handoff))
if isinstance(handoff, Handoff):
parts.append(
f'"{handoff.agent_name}" [label="{handoff.agent_name}", shape=box, style=filled, style=rounded, '
f"fillcolor=lightyellow, width=1.5, height=0.8];"
)
if isinstance(handoff, Agent):
parts.append(
f'"{handoff.name}" [label="{handoff.name}", shape=box, style=filled, style=rounded, '
f"fillcolor=lightyellow, width=1.5, height=0.8];"
)
parts.append(get_all_nodes(handoff))
return "".join(parts)
def get_all_edges(agent: Agent, parent: Agent = None) -> str:
def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str:
"""
Recursively generates the edges for the given agent and its handoffs in DOT format.
@ -92,14 +99,20 @@ def get_all_edges(agent: Agent, parent: Agent = None) -> str:
"{agent.name}" -> "__end__";""")
for handoff in agent.handoffs:
parts.append(f"""
"{agent.name}" -> "{handoff.name}";""")
parts.append(get_all_edges(handoff, agent))
if isinstance(handoff, Handoff):
parts.append(f"""
"{agent.name}" -> "{handoff.agent_name}";""")
if isinstance(handoff, Agent):
parts.append(f"""
"{agent.name}" -> "{handoff.name}";""")
parts.append(get_all_edges(handoff, agent))
return "".join(parts)
def draw_graph(agent: Agent, filename: str = None) -> graphviz.Source:
from typing import Optional
def draw_graph(agent: Agent, filename: Optional[str] = None) -> graphviz.Source:
"""
Draws the graph for the given agent and optionally saves it as a PNG file.