feat: enhance visualization functions with optional type hints and improved handling of agents and handoffs
This commit is contained in:
parent
9f7d596d14
commit
a5b7abe8b4
1 changed files with 26 additions and 13 deletions
|
|
@ -1,6 +1,7 @@
|
|||
import graphviz
|
||||
|
||||
from agents import Agent
|
||||
from agents.handoffs import Handoff
|
||||
|
||||
|
||||
def get_main_graph(agent: Agent) -> str:
|
||||
|
|
@ -29,7 +30,9 @@ def get_main_graph(agent: Agent) -> str:
|
|||
return "".join(parts)
|
||||
|
||||
|
||||
def get_all_nodes(agent: Agent, parent: Agent = None) -> str:
|
||||
from typing import Optional
|
||||
|
||||
def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str:
|
||||
"""
|
||||
Recursively generates the nodes for the given agent and its handoffs in DOT format.
|
||||
|
||||
|
|
@ -47,25 +50,29 @@ def get_all_nodes(agent: Agent, parent: Agent = None) -> str:
|
|||
"fillcolor=lightyellow, width=1.5, height=0.8];"
|
||||
)
|
||||
|
||||
# Smaller tools (ellipse, green)
|
||||
for tool in agent.tools:
|
||||
parts.append(
|
||||
f'"{tool.name}" [label="{tool.name}", shape=ellipse, style=filled, '
|
||||
f"fillcolor=lightgreen, width=0.5, height=0.3];"
|
||||
)
|
||||
|
||||
# Bigger handoffs (rounded box, yellow)
|
||||
for handoff in agent.handoffs:
|
||||
parts.append(
|
||||
f'"{handoff.name}" [label="{handoff.name}", shape=box, style=filled, style=rounded, '
|
||||
f"fillcolor=lightyellow, width=1.5, height=0.8];"
|
||||
)
|
||||
parts.append(get_all_nodes(handoff))
|
||||
if isinstance(handoff, Handoff):
|
||||
parts.append(
|
||||
f'"{handoff.agent_name}" [label="{handoff.agent_name}", shape=box, style=filled, style=rounded, '
|
||||
f"fillcolor=lightyellow, width=1.5, height=0.8];"
|
||||
)
|
||||
if isinstance(handoff, Agent):
|
||||
parts.append(
|
||||
f'"{handoff.name}" [label="{handoff.name}", shape=box, style=filled, style=rounded, '
|
||||
f"fillcolor=lightyellow, width=1.5, height=0.8];"
|
||||
)
|
||||
parts.append(get_all_nodes(handoff))
|
||||
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
def get_all_edges(agent: Agent, parent: Agent = None) -> str:
|
||||
def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str:
|
||||
"""
|
||||
Recursively generates the edges for the given agent and its handoffs in DOT format.
|
||||
|
||||
|
|
@ -92,14 +99,20 @@ def get_all_edges(agent: Agent, parent: Agent = None) -> str:
|
|||
"{agent.name}" -> "__end__";""")
|
||||
|
||||
for handoff in agent.handoffs:
|
||||
parts.append(f"""
|
||||
"{agent.name}" -> "{handoff.name}";""")
|
||||
parts.append(get_all_edges(handoff, agent))
|
||||
if isinstance(handoff, Handoff):
|
||||
parts.append(f"""
|
||||
"{agent.name}" -> "{handoff.agent_name}";""")
|
||||
if isinstance(handoff, Agent):
|
||||
parts.append(f"""
|
||||
"{agent.name}" -> "{handoff.name}";""")
|
||||
parts.append(get_all_edges(handoff, agent))
|
||||
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
def draw_graph(agent: Agent, filename: str = None) -> graphviz.Source:
|
||||
from typing import Optional
|
||||
|
||||
def draw_graph(agent: Agent, filename: Optional[str] = None) -> graphviz.Source:
|
||||
"""
|
||||
Draws the graph for the given agent and optionally saves it as a PNG file.
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue