Linting
This commit is contained in:
parent
2993d26d58
commit
29e9983ae8
2 changed files with 81 additions and 36 deletions
|
|
@ -13,7 +13,8 @@ def get_main_graph(agent: Agent) -> str:
|
|||
Returns:
|
||||
str: The DOT format string representing the graph.
|
||||
"""
|
||||
parts = ["""
|
||||
parts = [
|
||||
"""
|
||||
digraph G {
|
||||
graph [splines=true];
|
||||
node [fontname="Arial"];
|
||||
|
|
@ -21,7 +22,8 @@ def get_main_graph(agent: Agent) -> str:
|
|||
|
||||
"__start__" [shape=ellipse, style=filled, fillcolor=lightblue];
|
||||
"__end__" [shape=ellipse, style=filled, fillcolor=lightblue];
|
||||
"""]
|
||||
"""
|
||||
]
|
||||
parts.append(get_all_nodes(agent))
|
||||
parts.append(get_all_edges(agent))
|
||||
parts.append("}")
|
||||
|
|
@ -39,23 +41,23 @@ def get_all_nodes(agent: Agent, parent: Agent = None) -> str:
|
|||
str: The DOT format string representing the nodes.
|
||||
"""
|
||||
parts = []
|
||||
|
||||
|
||||
# Ensure parent agent node is colored
|
||||
if not parent:
|
||||
parts.append(f"""
|
||||
"{agent.name}" [label="{agent.name}", shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8];""")
|
||||
|
||||
|
||||
# Smaller tools (ellipse, green)
|
||||
for tool in agent.tools:
|
||||
parts.append(f"""
|
||||
"{tool.name}" [label="{tool.name}", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];""")
|
||||
|
||||
|
||||
# Bigger handoffs (rounded box, yellow)
|
||||
for handoff in agent.handoffs:
|
||||
parts.append(f"""
|
||||
"{handoff.name}" [label="{handoff.name}", shape=box, style=filled, style=rounded, fillcolor=lightyellow, width=1.5, height=0.8];""")
|
||||
parts.append(get_all_nodes(handoff))
|
||||
|
||||
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
|
|
@ -71,25 +73,25 @@ def get_all_edges(agent: Agent, parent: Agent = None) -> str:
|
|||
str: The DOT format string representing the edges.
|
||||
"""
|
||||
parts = []
|
||||
|
||||
|
||||
if not parent:
|
||||
parts.append(f"""
|
||||
"__start__" -> "{agent.name}";""")
|
||||
|
||||
|
||||
for tool in agent.tools:
|
||||
parts.append(f"""
|
||||
"{agent.name}" -> "{tool.name}" [style=dotted, penwidth=1.5];
|
||||
"{tool.name}" -> "{agent.name}" [style=dotted, penwidth=1.5];""")
|
||||
|
||||
|
||||
if not agent.handoffs:
|
||||
parts.append(f"""
|
||||
"{agent.name}" -> "__end__";""")
|
||||
|
||||
|
||||
for handoff in agent.handoffs:
|
||||
parts.append(f"""
|
||||
"{agent.name}" -> "{handoff.name}";""")
|
||||
parts.append(get_all_edges(handoff, agent))
|
||||
|
||||
|
||||
return "".join(parts)
|
||||
|
||||
|
||||
|
|
@ -106,8 +108,8 @@ def draw_graph(agent: Agent, filename: str = None) -> graphviz.Source:
|
|||
"""
|
||||
dot_code = get_main_graph(agent)
|
||||
graph = graphviz.Source(dot_code)
|
||||
|
||||
|
||||
if filename:
|
||||
graph.render(filename, format='png')
|
||||
|
||||
return graph
|
||||
graph.render(filename, format="png")
|
||||
|
||||
return graph
|
||||
|
|
|
|||
|
|
@ -1,8 +1,11 @@
|
|||
import pytest
|
||||
from unittest.mock import Mock
|
||||
from src.agents.visualizations import get_main_graph, get_all_nodes, get_all_edges, draw_graph
|
||||
from src.agents.agent import Agent
|
||||
|
||||
import graphviz
|
||||
import pytest
|
||||
|
||||
from src.agents.agent import Agent
|
||||
from src.agents.visualizations import draw_graph, get_all_edges, get_all_nodes, get_main_graph
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent():
|
||||
|
|
@ -10,7 +13,7 @@ def mock_agent():
|
|||
tool1.name = "Tool1"
|
||||
tool2 = Mock()
|
||||
tool2.name = "Tool2"
|
||||
|
||||
|
||||
handoff1 = Mock()
|
||||
handoff1.name = "Handoff1"
|
||||
handoff1.tools = []
|
||||
|
|
@ -20,28 +23,55 @@ def mock_agent():
|
|||
agent.name = "Agent1"
|
||||
agent.tools = [tool1, tool2]
|
||||
agent.handoffs = [handoff1]
|
||||
|
||||
|
||||
return agent
|
||||
|
||||
|
||||
def test_get_main_graph(mock_agent):
|
||||
result = get_main_graph(mock_agent)
|
||||
assert "digraph G" in result
|
||||
assert 'graph [splines=true];' in result
|
||||
assert "graph [splines=true];" in result
|
||||
assert 'node [fontname="Arial"];' in result
|
||||
assert 'edge [penwidth=1.5];' in result
|
||||
assert "edge [penwidth=1.5];" in result
|
||||
assert '"__start__" [shape=ellipse, style=filled, fillcolor=lightblue];' in result
|
||||
assert '"__end__" [shape=ellipse, style=filled, fillcolor=lightblue];' in result
|
||||
assert '"Agent1" [label="Agent1", shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8];' in result
|
||||
assert '"Tool1" [label="Tool1", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' in result
|
||||
assert '"Tool2" [label="Tool2", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' in result
|
||||
assert '"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, fillcolor=lightyellow, width=1.5, height=0.8];' in result
|
||||
assert (
|
||||
'"Agent1" [label="Agent1", shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8];'
|
||||
in result
|
||||
)
|
||||
assert (
|
||||
'"Tool1" [label="Tool1", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];'
|
||||
in result
|
||||
)
|
||||
assert (
|
||||
'"Tool2" [label="Tool2", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];'
|
||||
in result
|
||||
)
|
||||
assert (
|
||||
'"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, fillcolor=lightyellow, width=1.5, height=0.8];'
|
||||
in result
|
||||
)
|
||||
|
||||
|
||||
def test_get_all_nodes(mock_agent):
|
||||
result = get_all_nodes(mock_agent)
|
||||
assert '"Agent1" [label="Agent1", shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8];' in result
|
||||
assert '"Tool1" [label="Tool1", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' in result
|
||||
assert '"Tool2" [label="Tool2", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' in result
|
||||
assert '"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, fillcolor=lightyellow, width=1.5, height=0.8];' in result
|
||||
assert (
|
||||
'"Agent1" [label="Agent1", shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8];'
|
||||
in result
|
||||
)
|
||||
assert (
|
||||
'"Tool1" [label="Tool1", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];'
|
||||
in result
|
||||
)
|
||||
assert (
|
||||
'"Tool2" [label="Tool2", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];'
|
||||
in result
|
||||
)
|
||||
assert (
|
||||
'"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, fillcolor=lightyellow, width=1.5, height=0.8];'
|
||||
in result
|
||||
)
|
||||
|
||||
|
||||
def test_get_all_edges(mock_agent):
|
||||
result = get_all_edges(mock_agent)
|
||||
|
|
@ -53,16 +83,29 @@ def test_get_all_edges(mock_agent):
|
|||
assert '"Agent1" -> "Handoff1";' in result
|
||||
assert '"Handoff1" -> "__end__";' in result
|
||||
|
||||
|
||||
def test_draw_graph(mock_agent):
|
||||
graph = draw_graph(mock_agent)
|
||||
assert isinstance(graph, graphviz.Source)
|
||||
assert "digraph G" in graph.source
|
||||
assert 'graph [splines=true];' in graph.source
|
||||
assert "graph [splines=true];" in graph.source
|
||||
assert 'node [fontname="Arial"];' in graph.source
|
||||
assert 'edge [penwidth=1.5];' in graph.source
|
||||
assert "edge [penwidth=1.5];" in graph.source
|
||||
assert '"__start__" [shape=ellipse, style=filled, fillcolor=lightblue];' in graph.source
|
||||
assert '"__end__" [shape=ellipse, style=filled, fillcolor=lightblue];' in graph.source
|
||||
assert '"Agent1" [label="Agent1", shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8];' in graph.source
|
||||
assert '"Tool1" [label="Tool1", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' in graph.source
|
||||
assert '"Tool2" [label="Tool2", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' in graph.source
|
||||
assert '"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, fillcolor=lightyellow, width=1.5, height=0.8];' in graph.source
|
||||
assert (
|
||||
'"Agent1" [label="Agent1", shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8];'
|
||||
in graph.source
|
||||
)
|
||||
assert (
|
||||
'"Tool1" [label="Tool1", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];'
|
||||
in graph.source
|
||||
)
|
||||
assert (
|
||||
'"Tool2" [label="Tool2", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];'
|
||||
in graph.source
|
||||
)
|
||||
assert (
|
||||
'"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, fillcolor=lightyellow, width=1.5, height=0.8];'
|
||||
in graph.source
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue