Add unit tests for visualization functions in test_visualizations.py
This commit is contained in:
parent
cecdcd0af4
commit
9b972b33fa
1 changed files with 68 additions and 0 deletions
68
tests/test_visualizations.py
Normal file
68
tests/test_visualizations.py
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
import pytest
|
||||
from unittest.mock import Mock
|
||||
from src.agents.visualizations import get_main_graph, get_all_nodes, get_all_edges, draw_graph
|
||||
from src.agents.agent import Agent
|
||||
import graphviz
|
||||
|
||||
@pytest.fixture
|
||||
def mock_agent():
|
||||
tool1 = Mock()
|
||||
tool1.name = "Tool1"
|
||||
tool2 = Mock()
|
||||
tool2.name = "Tool2"
|
||||
|
||||
handoff1 = Mock()
|
||||
handoff1.name = "Handoff1"
|
||||
handoff1.tools = []
|
||||
handoff1.handoffs = []
|
||||
|
||||
agent = Mock(spec=Agent)
|
||||
agent.name = "Agent1"
|
||||
agent.tools = [tool1, tool2]
|
||||
agent.handoffs = [handoff1]
|
||||
|
||||
return agent
|
||||
|
||||
def test_get_main_graph(mock_agent):
|
||||
result = get_main_graph(mock_agent)
|
||||
assert "digraph G" in result
|
||||
assert 'graph [splines=true];' in result
|
||||
assert 'node [fontname="Arial"];' in result
|
||||
assert 'edge [penwidth=1.5];' in result
|
||||
assert '"__start__" [shape=ellipse, style=filled, fillcolor=lightblue];' in result
|
||||
assert '"__end__" [shape=ellipse, style=filled, fillcolor=lightblue];' in result
|
||||
assert '"Agent1" [label="Agent1", shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8];' in result
|
||||
assert '"Tool1" [label="Tool1", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' in result
|
||||
assert '"Tool2" [label="Tool2", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' in result
|
||||
assert '"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, fillcolor=lightyellow, width=1.5, height=0.8];' in result
|
||||
|
||||
def test_get_all_nodes(mock_agent):
|
||||
result = get_all_nodes(mock_agent)
|
||||
assert '"Agent1" [label="Agent1", shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8];' in result
|
||||
assert '"Tool1" [label="Tool1", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' in result
|
||||
assert '"Tool2" [label="Tool2", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' in result
|
||||
assert '"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, fillcolor=lightyellow, width=1.5, height=0.8];' in result
|
||||
|
||||
def test_get_all_edges(mock_agent):
|
||||
result = get_all_edges(mock_agent)
|
||||
assert '"__start__" -> "Agent1";' in result
|
||||
assert '"Agent1" -> "Tool1" [style=dotted, penwidth=1.5];' in result
|
||||
assert '"Tool1" -> "Agent1" [style=dotted, penwidth=1.5];' in result
|
||||
assert '"Agent1" -> "Tool2" [style=dotted, penwidth=1.5];' in result
|
||||
assert '"Tool2" -> "Agent1" [style=dotted, penwidth=1.5];' in result
|
||||
assert '"Agent1" -> "Handoff1";' in result
|
||||
assert '"Handoff1" -> "__end__";' in result
|
||||
|
||||
def test_draw_graph(mock_agent):
|
||||
graph = draw_graph(mock_agent)
|
||||
assert isinstance(graph, graphviz.Source)
|
||||
assert "digraph G" in graph.source
|
||||
assert 'graph [splines=true];' in graph.source
|
||||
assert 'node [fontname="Arial"];' in graph.source
|
||||
assert 'edge [penwidth=1.5];' in graph.source
|
||||
assert '"__start__" [shape=ellipse, style=filled, fillcolor=lightblue];' in graph.source
|
||||
assert '"__end__" [shape=ellipse, style=filled, fillcolor=lightblue];' in graph.source
|
||||
assert '"Agent1" [label="Agent1", shape=box, style=filled, fillcolor=lightyellow, width=1.5, height=0.8];' in graph.source
|
||||
assert '"Tool1" [label="Tool1", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' in graph.source
|
||||
assert '"Tool2" [label="Tool2", shape=ellipse, style=filled, fillcolor=lightgreen, width=0.5, height=0.3];' in graph.source
|
||||
assert '"Handoff1" [label="Handoff1", shape=box, style=filled, style=rounded, fillcolor=lightyellow, width=1.5, height=0.8];' in graph.source
|
||||
Loading…
Reference in a new issue