From a5b7abe8b4978815ad3694bb3555a4c61e8a844b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mart=C3=ADn=20Bravo?= <123977407+MartinEBravo@users.noreply.github.com> Date: Mon, 24 Mar 2025 09:30:13 +0100 Subject: [PATCH] feat: enhance visualization functions with optional type hints and improved handling of agents and handoffs --- src/agents/extensions/visualization.py | 39 +++++++++++++++++--------- 1 file changed, 26 insertions(+), 13 deletions(-) diff --git a/src/agents/extensions/visualization.py b/src/agents/extensions/visualization.py index 81b6338..a05dcd3 100644 --- a/src/agents/extensions/visualization.py +++ b/src/agents/extensions/visualization.py @@ -1,6 +1,7 @@ import graphviz from agents import Agent +from agents.handoffs import Handoff def get_main_graph(agent: Agent) -> str: @@ -29,7 +30,9 @@ def get_main_graph(agent: Agent) -> str: return "".join(parts) -def get_all_nodes(agent: Agent, parent: Agent = None) -> str: +from typing import Optional + +def get_all_nodes(agent: Agent, parent: Optional[Agent] = None) -> str: """ Recursively generates the nodes for the given agent and its handoffs in DOT format. @@ -47,25 +50,29 @@ def get_all_nodes(agent: Agent, parent: Agent = None) -> str: "fillcolor=lightyellow, width=1.5, height=0.8];" ) - # Smaller tools (ellipse, green) for tool in agent.tools: parts.append( f'"{tool.name}" [label="{tool.name}", shape=ellipse, style=filled, ' f"fillcolor=lightgreen, width=0.5, height=0.3];" ) - # Bigger handoffs (rounded box, yellow) for handoff in agent.handoffs: - parts.append( - f'"{handoff.name}" [label="{handoff.name}", shape=box, style=filled, style=rounded, ' - f"fillcolor=lightyellow, width=1.5, height=0.8];" - ) - parts.append(get_all_nodes(handoff)) + if isinstance(handoff, Handoff): + parts.append( + f'"{handoff.agent_name}" [label="{handoff.agent_name}", shape=box, style=filled, style=rounded, ' + f"fillcolor=lightyellow, width=1.5, height=0.8];" + ) + if isinstance(handoff, Agent): + parts.append( + f'"{handoff.name}" [label="{handoff.name}", shape=box, style=filled, style=rounded, ' + f"fillcolor=lightyellow, width=1.5, height=0.8];" + ) + parts.append(get_all_nodes(handoff)) return "".join(parts) -def get_all_edges(agent: Agent, parent: Agent = None) -> str: +def get_all_edges(agent: Agent, parent: Optional[Agent] = None) -> str: """ Recursively generates the edges for the given agent and its handoffs in DOT format. @@ -92,14 +99,20 @@ def get_all_edges(agent: Agent, parent: Agent = None) -> str: "{agent.name}" -> "__end__";""") for handoff in agent.handoffs: - parts.append(f""" - "{agent.name}" -> "{handoff.name}";""") - parts.append(get_all_edges(handoff, agent)) + if isinstance(handoff, Handoff): + parts.append(f""" + "{agent.name}" -> "{handoff.agent_name}";""") + if isinstance(handoff, Agent): + parts.append(f""" + "{agent.name}" -> "{handoff.name}";""") + parts.append(get_all_edges(handoff, agent)) return "".join(parts) -def draw_graph(agent: Agent, filename: str = None) -> graphviz.Source: +from typing import Optional + +def draw_graph(agent: Agent, filename: Optional[str] = None) -> graphviz.Source: """ Draws the graph for the given agent and optionally saves it as a PNG file.