Claude/add initial tests 011 cukte9g4 qwj hjw7 g3ny rf (#190)

* test: add comprehensive unit tests for domain module

Add 24 comprehensive unit tests covering the open_notebook.domain module:

**ObjectModel Base (5 tests)**
- Create and update operations with timestamps
- Get by ID with class resolution
- Delete validation
- Relationship creation

**RecordModel Singleton (3 tests)**
- Singleton pattern behavior
- Async database loading
- Update persistence

**ModelManager (3 tests)**
- Singleton pattern
- Model instance caching
- Default model retrieval

**Notebook Domain (3 tests)**
- Name validation (empty/whitespace)
- Source relationship queries
- Archived flag defaults

**Source Domain (3 tests)**
- Text vectorization and chunking
- Insight validation and creation
- RecordID command field parsing

**Note Domain (2 tests)**
- Content validation
- Embedding configuration

**Podcast Domain (2 tests)**
- Speaker profile validation
- Episode profile segment validation

**Additional Tests (3 tests)**
- ChatSession relationships
- Transformation creation
- ContentSettings defaults

All tests use proper mocking to avoid database dependencies and validate
both business logic and error handling. Tests follow pytest best practices
with async support, fixtures, and comprehensive assertions.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* test: add comprehensive tests for utils and graphs modules

Add 56 new unit tests covering utils and graphs modules:

**Utils Module Tests (36 tests)**

Text Utilities (13 tests):
- Text splitting with various chunk sizes
- ASCII and non-printable character removal
- Thinking tag parsing and cleaning (single/multiple tags)
- Edge cases (empty strings, invalid input, large content)

Token Utilities (4 tests):
- Token counting with tiktoken
- Cost calculation
- Fallback behavior when tiktoken unavailable

Version Utilities (7 tests):
- Semantic version comparison (equal, less, greater, prerelease)
- Installed package version retrieval
- GitHub version fetching with URL validation

Context Builder (12 tests):
- ContextItem and ContextConfig creation
- Builder initialization with various parameters
- Priority sorting and deduplication
- Token-based truncation
- Response formatting
- Source and notebook context building
- Convenience functions

**Graphs Module Tests (20 tests)**

Model Provisioning (4 tests):
- Default model selection
- Large context model triggering (>105k tokens)
- Specific model ID selection
- Kwargs pass-through

Tools (3 tests):
- Current timestamp format validation
- Timestamp validity checking
- Tool decoration verification

Prompt Graph (5 tests):
- PatternChainState structure
- Model calling with/without parser
- Graph compilation and execution

Transformation Graph (8 tests):
- TransformationState structure
- Transformation with source objects
- Transformation with direct input text
- Thinking content cleaning
- Content validation
- Graph compilation and execution
- Default prompt integration

All tests use proper mocking to avoid external dependencies (network,
database) and validate both success paths and error handling.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* improve tests

---------

Co-authored-by: Claude <noreply@anthropic.com>
This commit is contained in:
Luis Novo 2025-10-21 16:54:59 -03:00 committed by GitHub
parent fc8a4a0c64
commit 18b4dfdb77
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 758 additions and 520 deletions

View file

@ -65,7 +65,8 @@ build-backend = "setuptools.build_meta"
[dependency-groups]
dev = [
"pre-commit>=4.1.0",
"types-requests>=2.32.4.20250913"
"pytest-asyncio>=1.2.0",
"types-requests>=2.32.4.20250913",
]
[tool.isort]

19
tests/conftest.py Normal file
View file

@ -0,0 +1,19 @@
"""
Pytest configuration file.
This file ensures that the project root is in the Python path,
allowing tests to import from the api and open_notebook modules.
"""
import os
import sys
from pathlib import Path
# Add the project root to the Python path
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
# Ensure password auth is disabled for tests
# The PasswordAuthMiddleware skips auth when this env var is not set
if "OPEN_NOTEBOOK_PASSWORD" in os.environ:
del os.environ["OPEN_NOTEBOOK_PASSWORD"]

308
tests/test_domain.py Normal file
View file

@ -0,0 +1,308 @@
"""
Unit tests for the open_notebook.domain module.
This test suite focuses on validation logic, business rules, and data structures
that can be tested without database mocking.
"""
import pytest
from pydantic import ValidationError
from open_notebook.domain.base import RecordModel
from open_notebook.domain.content_settings import ContentSettings
from open_notebook.domain.models import ModelManager
from open_notebook.domain.notebook import Note, Notebook, Source
from open_notebook.domain.podcast import EpisodeProfile, SpeakerProfile
from open_notebook.domain.transformation import Transformation
from open_notebook.exceptions import InvalidInputError
# ============================================================================
# TEST SUITE 1: RecordModel Singleton Pattern
# ============================================================================
class TestRecordModelSingleton:
"""Test suite for RecordModel singleton behavior."""
def test_recordmodel_singleton_behavior(self):
"""Test that same instance is returned for same record_id."""
class TestRecord(RecordModel):
record_id = "test:singleton"
value: int = 0
# Clear any existing instance
TestRecord.clear_instance()
# Create first instance
instance1 = TestRecord(value=42)
assert instance1.value == 42
# Create second instance - should return same object
instance2 = TestRecord(value=99)
assert instance1 is instance2
assert instance2.value == 99 # Value was updated
# Cleanup
TestRecord.clear_instance()
# ============================================================================
# TEST SUITE 2: ModelManager Singleton
# ============================================================================
class TestModelManager:
"""Test suite for ModelManager singleton pattern."""
def test_model_manager_singleton(self):
"""Test ModelManager implements singleton pattern correctly."""
manager1 = ModelManager()
manager2 = ModelManager()
assert manager1 is manager2
assert id(manager1) == id(manager2)
# ============================================================================
# TEST SUITE 3: Notebook Domain Logic
# ============================================================================
class TestNotebookDomain:
"""Test suite for Notebook validation and business rules."""
def test_notebook_name_validation(self):
"""Test empty/whitespace names are rejected."""
# Empty name should raise error
with pytest.raises(InvalidInputError, match="Notebook name cannot be empty"):
Notebook(name="", description="Test")
# Whitespace-only name should raise error
with pytest.raises(InvalidInputError, match="Notebook name cannot be empty"):
Notebook(name=" ", description="Test")
# Valid name should work
notebook = Notebook(name="Valid Name", description="Test")
assert notebook.name == "Valid Name"
def test_notebook_archived_flag(self):
"""Test archived flag defaults to False."""
notebook = Notebook(name="Test", description="Test")
assert notebook.archived is False
notebook_archived = Notebook(name="Test", description="Test", archived=True)
assert notebook_archived.archived is True
# ============================================================================
# TEST SUITE 4: Source Domain
# ============================================================================
class TestSourceDomain:
"""Test suite for Source domain model."""
def test_source_command_field_parsing(self):
"""Test RecordID parsing for command field."""
# Test with string command
source = Source(title="Test", command="command:123")
assert source.command is not None
# Test with None command
source2 = Source(title="Test", command=None)
assert source2.command is None
# Test command is included in save data prep
source3 = Source(id="source:123", title="Test", command="command:456")
save_data = source3._prepare_save_data()
assert "command" in save_data
# ============================================================================
# TEST SUITE 5: Note Domain
# ============================================================================
class TestNoteDomain:
"""Test suite for Note validation."""
def test_note_content_validation(self):
"""Test empty content is rejected."""
# None content is allowed
note = Note(title="Test", content=None)
assert note.content is None
# Non-empty content is valid
note2 = Note(title="Test", content="Valid content")
assert note2.content == "Valid content"
# Empty string should raise error
with pytest.raises(InvalidInputError, match="Note content cannot be empty"):
Note(title="Test", content="")
# Whitespace-only should raise error
with pytest.raises(InvalidInputError, match="Note content cannot be empty"):
Note(title="Test", content=" ")
def test_note_embedding_enabled(self):
"""Test notes have embedding enabled by default."""
note = Note(title="Test", content="Test content")
assert note.needs_embedding() is True
assert note.get_embedding_content() == "Test content"
# Test with None content
note2 = Note(title="Test", content=None)
assert note2.get_embedding_content() is None
# ============================================================================
# TEST SUITE 6: Podcast Domain Validation
# ============================================================================
class TestPodcastDomain:
"""Test suite for Podcast domain validation."""
def test_speaker_profile_validation(self):
"""Test speaker profile validates count and required fields."""
# Test invalid - no speakers
with pytest.raises(ValidationError):
SpeakerProfile(
name="Test",
tts_provider="openai",
tts_model="tts-1",
speakers=[],
)
# Test invalid - too many speakers (> 4)
with pytest.raises(ValidationError):
SpeakerProfile(
name="Test",
tts_provider="openai",
tts_model="tts-1",
speakers=[{"name": f"Speaker{i}"} for i in range(5)],
)
# Test invalid - missing required fields
with pytest.raises(ValidationError):
SpeakerProfile(
name="Test",
tts_provider="openai",
tts_model="tts-1",
speakers=[{"name": "Speaker 1"}], # Missing voice_id, backstory, personality
)
# Test valid - single speaker with all fields
profile = SpeakerProfile(
name="Test",
tts_provider="openai",
tts_model="tts-1",
speakers=[
{
"name": "Host",
"voice_id": "voice123",
"backstory": "A friendly host",
"personality": "Enthusiastic and welcoming",
}
],
)
assert len(profile.speakers) == 1
assert profile.speakers[0]["name"] == "Host"
# ============================================================================
# TEST SUITE 7: Transformation Domain
# ============================================================================
class TestTransformationDomain:
"""Test suite for Transformation domain model."""
def test_transformation_creation(self):
"""Test transformation model creation."""
transform = Transformation(
name="summarize",
title="Summarize Content",
description="Creates a summary",
prompt="Summarize the following text: {content}",
apply_default=True,
)
assert transform.name == "summarize"
assert transform.apply_default is True
# ============================================================================
# TEST SUITE 8: Content Settings
# ============================================================================
class TestContentSettings:
"""Test suite for ContentSettings defaults."""
def test_content_settings_defaults(self):
"""Test ContentSettings has proper defaults."""
settings = ContentSettings()
assert settings.record_id == "open_notebook:content_settings"
assert settings.default_content_processing_engine_doc == "auto"
assert settings.default_embedding_option == "ask"
assert settings.auto_delete_files == "yes"
assert len(settings.youtube_preferred_languages) > 0
# ============================================================================
# TEST SUITE 9: Episode Profile Validation
# ============================================================================
class TestEpisodeProfile:
"""Test suite for EpisodeProfile validation."""
def test_episode_profile_segment_validation(self):
"""Test segment count validation (3-20)."""
# Test invalid - too few segments
with pytest.raises(ValidationError, match="Number of segments must be between 3 and 20"):
EpisodeProfile(
name="Test",
speaker_config="default",
outline_provider="openai",
outline_model="gpt-4",
transcript_provider="openai",
transcript_model="gpt-4",
default_briefing="Test briefing",
num_segments=2,
)
# Test invalid - too many segments
with pytest.raises(ValidationError, match="Number of segments must be between 3 and 20"):
EpisodeProfile(
name="Test",
speaker_config="default",
outline_provider="openai",
outline_model="gpt-4",
transcript_provider="openai",
transcript_model="gpt-4",
default_briefing="Test briefing",
num_segments=21,
)
# Test valid segment count
profile = EpisodeProfile(
name="Test",
speaker_config="default",
outline_provider="openai",
outline_model="gpt-4",
transcript_provider="openai",
transcript_model="gpt-4",
default_briefing="Test briefing",
num_segments=5,
)
assert profile.num_segments == 5
if __name__ == "__main__":
pytest.main([__file__, "-v"])

155
tests/test_graphs.py Normal file
View file

@ -0,0 +1,155 @@
"""
Unit tests for the open_notebook.graphs module.
This test suite focuses on testing graph structures, tools, and validation
without heavy mocking of the actual processing logic.
"""
from datetime import datetime
import pytest
from open_notebook.graphs.prompt import PatternChainState, graph
from open_notebook.graphs.tools import get_current_timestamp
from open_notebook.graphs.transformation import (
TransformationState,
run_transformation,
graph as transformation_graph,
)
# ============================================================================
# TEST SUITE 1: Graph Tools
# ============================================================================
class TestGraphTools:
"""Test suite for graph tool definitions."""
def test_get_current_timestamp_format(self):
"""Test timestamp tool returns correct format."""
timestamp = get_current_timestamp.func()
assert isinstance(timestamp, str)
assert len(timestamp) == 14 # YYYYMMDDHHmmss format
assert timestamp.isdigit()
def test_get_current_timestamp_validity(self):
"""Test timestamp represents valid datetime."""
timestamp = get_current_timestamp.func()
# Parse it back to datetime to verify validity
year = int(timestamp[0:4])
month = int(timestamp[4:6])
day = int(timestamp[6:8])
hour = int(timestamp[8:10])
minute = int(timestamp[10:12])
second = int(timestamp[12:14])
# Should be valid date components
assert 2020 <= year <= 2100
assert 1 <= month <= 12
assert 1 <= day <= 31
assert 0 <= hour <= 23
assert 0 <= minute <= 59
assert 0 <= second <= 59
# Should parse as datetime
dt = datetime.strptime(timestamp, "%Y%m%d%H%M%S")
assert isinstance(dt, datetime)
def test_get_current_timestamp_is_tool(self):
"""Test that function is properly decorated as a tool."""
# Check it has tool attributes
assert hasattr(get_current_timestamp, "name")
assert hasattr(get_current_timestamp, "description")
# ============================================================================
# TEST SUITE 2: Prompt Graph State
# ============================================================================
class TestPromptGraph:
"""Test suite for prompt pattern chain graph."""
def test_pattern_chain_state_structure(self):
"""Test PatternChainState structure and fields."""
state = PatternChainState(
prompt="Test prompt",
parser=None,
input_text="Test input",
output=""
)
assert state["prompt"] == "Test prompt"
assert state["parser"] is None
assert state["input_text"] == "Test input"
assert state["output"] == ""
def test_prompt_graph_compilation(self):
"""Test that prompt graph compiles correctly."""
assert graph is not None
# Graph should have the expected structure
assert hasattr(graph, "invoke")
assert hasattr(graph, "ainvoke")
# ============================================================================
# TEST SUITE 3: Transformation Graph
# ============================================================================
class TestTransformationGraph:
"""Test suite for transformation graph workflows."""
def test_transformation_state_structure(self):
"""Test TransformationState structure and fields."""
from unittest.mock import MagicMock
from open_notebook.domain.notebook import Source
from open_notebook.domain.transformation import Transformation
mock_source = MagicMock(spec=Source)
mock_transformation = MagicMock(spec=Transformation)
state = TransformationState(
input_text="Test text",
source=mock_source,
transformation=mock_transformation,
output=""
)
assert state["input_text"] == "Test text"
assert state["source"] == mock_source
assert state["transformation"] == mock_transformation
assert state["output"] == ""
@pytest.mark.asyncio
async def test_run_transformation_assertion_no_content(self):
"""Test transformation raises assertion with no content."""
from unittest.mock import MagicMock
from open_notebook.domain.transformation import Transformation
mock_transformation = MagicMock(spec=Transformation)
state = {
"input_text": None,
"transformation": mock_transformation,
"source": None
}
config = {"configurable": {"model_id": None}}
with pytest.raises(AssertionError, match="No content to transform"):
await run_transformation(state, config)
def test_transformation_graph_compilation(self):
"""Test that transformation graph compiles correctly."""
assert transformation_graph is not None
assert hasattr(transformation_graph, "invoke")
assert hasattr(transformation_graph, "ainvoke")
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View file

@ -1,296 +0,0 @@
"""
Integration tests for Source Chat Langgraph.
These tests verify that the Source Chat Langgraph integrates correctly
with the existing Open Notebook infrastructure.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from langchain_core.messages import AIMessage, HumanMessage
from open_notebook.domain.notebook import Source, SourceInsight
from open_notebook.graphs.source_chat import (
SourceChatState,
_format_source_context,
call_model_with_source_context,
source_chat_graph,
)
@pytest.fixture
def mock_source():
"""Create a mock Source object for testing."""
source = MagicMock(spec=Source)
source.id = "source:test123"
source.title = "Test Source"
source.topics = ["AI", "Machine Learning"]
source.full_text = "This is test content for the source."
source.model_dump.return_value = {
"id": "source:test123",
"title": "Test Source",
"topics": ["AI", "Machine Learning"],
"full_text": "This is test content for the source."
}
return source
@pytest.fixture
def mock_insight():
"""Create a mock SourceInsight object for testing."""
insight = MagicMock(spec=SourceInsight)
insight.id = "insight:test456"
insight.insight_type = "summary"
insight.content = "This is a test insight about the source."
insight.model_dump.return_value = {
"id": "insight:test456",
"insight_type": "summary",
"content": "This is a test insight about the source."
}
return insight
@pytest.fixture
def sample_state():
"""Create a sample SourceChatState for testing."""
return SourceChatState(
messages=[HumanMessage(content="What are the main topics in this source?")],
source_id="source:test123",
source=None,
insights=None,
context=None,
model_override=None,
context_indicators=None
)
@pytest.fixture
def sample_config():
"""Create a sample configuration for testing."""
return {
"configurable": {
"thread_id": "test_thread",
"model_id": "test_model"
}
}
class TestSourceChatState:
"""Test the SourceChatState TypedDict structure."""
def test_source_chat_state_creation(self, sample_state):
"""Test that SourceChatState can be created with required fields."""
assert sample_state["source_id"] == "source:test123"
assert len(sample_state["messages"]) == 1
assert sample_state["source"] is None
assert sample_state["insights"] is None
class TestContextFormatting:
"""Test the context formatting functionality."""
def test_format_source_context_with_sources(self):
"""Test formatting context data containing sources."""
context_data = {
"sources": [{
"id": "source:test123",
"title": "Test Source",
"full_text": "This is test content."
}],
"insights": [],
"metadata": {
"source_count": 1,
"insight_count": 0
},
"total_tokens": 100
}
result = _format_source_context(context_data)
assert "## SOURCE CONTENT" in result
assert "source:test123" in result
assert "Test Source" in result
assert "This is test content." in result
assert "## CONTEXT METADATA" in result
def test_format_source_context_with_insights(self):
"""Test formatting context data containing insights."""
context_data = {
"sources": [],
"insights": [{
"id": "insight:test456",
"insight_type": "summary",
"content": "Test insight content."
}],
"metadata": {
"source_count": 0,
"insight_count": 1
},
"total_tokens": 50
}
result = _format_source_context(context_data)
assert "## SOURCE INSIGHTS" in result
assert "insight:test456" in result
assert "summary" in result
assert "Test insight content." in result
def test_format_source_context_empty(self):
"""Test formatting empty context data."""
context_data = {
"sources": [],
"insights": [],
"metadata": {
"source_count": 0,
"insight_count": 0
},
"total_tokens": 0
}
result = _format_source_context(context_data)
assert "## CONTEXT METADATA" in result
assert "Source count: 0" in result
assert "Insight count: 0" in result
class TestSourceChatIntegration:
"""Test the integration of source chat components."""
@patch('open_notebook.graphs.source_chat.ContextBuilder')
@patch('open_notebook.graphs.source_chat.provision_langchain_model')
@patch('open_notebook.graphs.source_chat.Prompter')
async def test_call_model_with_source_context(
self,
mock_prompter,
mock_provision_model,
mock_context_builder,
sample_state,
sample_config,
mock_source,
mock_insight
):
"""Test the main model calling function with mocked dependencies."""
# Mock the ContextBuilder
mock_builder_instance = AsyncMock()
mock_builder_instance.build.return_value = {
"sources": [mock_source.model_dump()],
"insights": [mock_insight.model_dump()],
"metadata": {"source_count": 1, "insight_count": 1},
"total_tokens": 150
}
mock_context_builder.return_value = mock_builder_instance
# Mock the Prompter
mock_prompter_instance = MagicMock()
mock_prompter_instance.render.return_value = "Rendered prompt"
mock_prompter.return_value = mock_prompter_instance
# Mock the model
mock_model = AsyncMock()
mock_ai_message = AIMessage(content="Test response from AI")
mock_model.invoke.return_value = mock_ai_message
mock_provision_model.return_value = mock_model
# Call the function
result = await call_model_with_source_context(sample_state, sample_config) # type: ignore[misc]
# Verify the result
assert "messages" in result
assert result["messages"] == mock_ai_message
assert "source" in result
assert "insights" in result
assert "context" in result
assert "context_indicators" in result
# Verify mocks were called correctly
mock_context_builder.assert_called_once()
mock_builder_instance.build.assert_called_once()
mock_prompter.assert_called_once_with(prompt_template="source_chat")
mock_provision_model.assert_called_once()
def test_source_chat_graph_structure(self):
"""Test that the source chat graph is properly structured."""
# Verify the graph has the expected structure
assert source_chat_graph is not None
# Check that the graph has nodes
nodes = source_chat_graph.get_graph().nodes
assert "source_chat_agent" in [node for node in nodes]
# Check that the graph has the checkpointer
assert source_chat_graph.checkpointer is not None
@pytest.mark.asyncio
async def test_source_chat_state_validation(self):
"""Test that the source chat validates required state fields."""
# Test with missing source_id
invalid_state = SourceChatState(
messages=[HumanMessage(content="Test")],
source_id="", # Empty source_id should cause error
source=None,
insights=None,
context=None,
model_override=None,
context_indicators=None
)
config = {"configurable": {"thread_id": "test"}}
# This should raise an error due to missing source_id
with pytest.raises(ValueError, match="source_id is required"):
await call_model_with_source_context(invalid_state, config) # type: ignore[misc, arg-type]
class TestSourceChatGraphExecution:
"""Test the execution of the source chat graph."""
@patch('open_notebook.graphs.source_chat.Source')
@patch('open_notebook.graphs.source_chat.ContextBuilder')
@patch('open_notebook.graphs.source_chat.provision_langchain_model')
@patch('open_notebook.graphs.source_chat.Prompter')
@pytest.mark.asyncio
async def test_graph_execution_flow(
self,
mock_prompter,
mock_provision_model,
mock_context_builder,
mock_source_class,
sample_state,
sample_config
):
"""Test the complete graph execution flow with mocked dependencies."""
# Setup mocks (similar to previous test but for full graph execution)
mock_builder_instance = AsyncMock()
mock_builder_instance.build.return_value = {
"sources": [{"id": "source:test123", "title": "Test"}],
"insights": [{"id": "insight:test456", "content": "Test insight"}],
"metadata": {"source_count": 1, "insight_count": 1},
"total_tokens": 100
}
mock_context_builder.return_value = mock_builder_instance
mock_prompter_instance = MagicMock()
mock_prompter_instance.render.return_value = "Test prompt"
mock_prompter.return_value = mock_prompter_instance
mock_model = AsyncMock()
mock_model.invoke.return_value = AIMessage(content="AI response")
mock_provision_model.return_value = mock_model
# Execute the graph
result = await source_chat_graph.ainvoke(sample_state, sample_config)
# Verify the result structure
assert "messages" in result
assert "source_id" in result
assert result["source_id"] == "source:test123"
if __name__ == "__main__":
# Run the tests
pytest.main([__file__, "-v"])

View file

@ -1,223 +0,0 @@
from unittest.mock import AsyncMock, patch
import pytest
from fastapi.testclient import TestClient
from api.main import app
client = TestClient(app)
class TestSourceChatAPI:
"""Test suite for Source Chat API endpoints."""
@pytest.fixture
def sample_source_id(self):
return "test_source_123"
@pytest.fixture
def sample_session_id(self):
return "test_session_456"
@patch('api.routers.source_chat.Source.get')
@patch('api.routers.source_chat.ChatSession.save')
@patch('api.routers.source_chat.ChatSession.relate')
def test_create_source_chat_session(self, mock_relate, mock_save, mock_source_get, sample_source_id):
"""Test creating a new source chat session."""
# Mock source exists
mock_source = AsyncMock()
mock_source.id = f"source:{sample_source_id}"
mock_source_get.return_value = mock_source
# Mock session save and relate
mock_save.return_value = None
mock_relate.return_value = None
# Create session request
request_data = {
"source_id": sample_source_id,
"title": "Test Chat Session",
"model_override": "gpt-4"
}
response = client.post(
f"/api/sources/{sample_source_id}/chat/sessions",
json=request_data
)
assert response.status_code == 200
data = response.json()
assert data["title"] == "Test Chat Session"
assert data["source_id"] == sample_source_id
assert data["model_override"] == "gpt-4"
assert "id" in data
assert "created" in data
@patch('api.routers.source_chat.Source.get')
def test_create_session_source_not_found(self, mock_source_get, sample_source_id):
"""Test creating session for non-existent source."""
mock_source_get.return_value = None
request_data = {
"source_id": sample_source_id,
"title": "Test Chat Session"
}
response = client.post(
f"/api/sources/{sample_source_id}/chat/sessions",
json=request_data
)
assert response.status_code == 404
assert "Source not found" in response.json()["detail"]
@patch('api.routers.source_chat.Source.get')
@patch('api.routers.source_chat.repo_query')
def test_get_source_chat_sessions(self, mock_repo_query, mock_source_get, sample_source_id):
"""Test getting all chat sessions for a source."""
# Mock source exists
mock_source = AsyncMock()
mock_source.id = f"source:{sample_source_id}"
mock_source_get.return_value = mock_source
# Mock query returns sessions
mock_repo_query.return_value = [
{"in": "chat_session:session1"},
{"in": "chat_session:session2"}
]
# Mock ChatSession.get for each session
with patch('api.routers.source_chat.ChatSession.get') as mock_session_get:
mock_session1 = AsyncMock()
mock_session1.id = "chat_session:session1"
mock_session1.title = "Session 1"
mock_session1.created = "2024-01-01T00:00:00Z"
mock_session1.updated = "2024-01-01T00:00:00Z"
mock_session1.model_override = None
mock_session2 = AsyncMock()
mock_session2.id = "chat_session:session2"
mock_session2.title = "Session 2"
mock_session2.created = "2024-01-01T00:00:00Z"
mock_session2.updated = "2024-01-01T00:00:00Z"
mock_session2.model_override = "gpt-4"
mock_session_get.side_effect = [mock_session1, mock_session2]
response = client.get(f"/api/sources/{sample_source_id}/chat/sessions")
assert response.status_code == 200
data = response.json()
assert len(data) == 2
assert data[0]["title"] == "Session 1"
assert data[1]["title"] == "Session 2"
assert data[1]["model_override"] == "gpt-4"
@patch('api.routers.source_chat.Source.get')
@patch('api.routers.source_chat.ChatSession.get')
@patch('api.routers.source_chat.repo_query')
@patch('api.routers.source_chat.source_chat_graph.get_state')
def test_get_source_chat_session_with_messages(
self, mock_get_state, mock_repo_query, mock_session_get, mock_source_get,
sample_source_id, sample_session_id
):
"""Test getting a specific chat session with messages."""
# Mock source exists
mock_source = AsyncMock()
mock_source.id = f"source:{sample_source_id}"
mock_source_get.return_value = mock_source
# Mock session exists
mock_session = AsyncMock()
mock_session.id = f"chat_session:{sample_session_id}"
mock_session.title = "Test Session"
mock_session.created = "2024-01-01T00:00:00Z"
mock_session.updated = "2024-01-01T00:00:00Z"
mock_session.model_override = "gpt-4"
mock_session_get.return_value = mock_session
# Mock relation exists
mock_repo_query.return_value = [{"relation": "exists"}]
# Mock graph state with messages
mock_message = AsyncMock()
mock_message.type = "human"
mock_message.content = "Hello"
mock_message.id = "msg_1"
mock_state = AsyncMock()
mock_state.values = {
"messages": [mock_message],
"context_indicators": {"sources": ["source:123"], "insights": ["insight:456"], "notes": []}
}
mock_get_state.return_value = mock_state
response = client.get(f"/api/sources/{sample_source_id}/chat/sessions/{sample_session_id}")
assert response.status_code == 200
data = response.json()
assert data["title"] == "Test Session"
assert data["model_override"] == "gpt-4"
assert len(data["messages"]) == 1
assert data["messages"][0]["content"] == "Hello"
assert data["context_indicators"]["sources"] == ["source:123"]
@patch('api.routers.source_chat.Source.get')
@patch('api.routers.source_chat.ChatSession.get')
@patch('api.routers.source_chat.repo_query')
@patch('api.routers.source_chat.ChatSession.save')
def test_update_source_chat_session(
self, mock_save, mock_repo_query, mock_session_get, mock_source_get,
sample_source_id, sample_session_id
):
"""Test updating a source chat session."""
# Mock source exists
mock_source = AsyncMock()
mock_source.id = f"source:{sample_source_id}"
mock_source_get.return_value = mock_source
# Mock session exists
mock_session = AsyncMock()
mock_session.id = f"chat_session:{sample_session_id}"
mock_session.title = "Old Title"
mock_session.created = "2024-01-01T00:00:00Z"
mock_session.updated = "2024-01-01T00:00:00Z"
mock_session.model_override = None
mock_session_get.return_value = mock_session
# Mock relation exists
mock_repo_query.return_value = [{"relation": "exists"}]
# Mock save
mock_save.return_value = None
request_data = {
"title": "New Title",
"model_override": "gpt-4"
}
response = client.put(
f"/api/sources/{sample_source_id}/chat/sessions/{sample_session_id}",
json=request_data
)
assert response.status_code == 200
data = response.json()
assert data["title"] == "New Title"
# Note: The mock will still return the original values unless we update them
# In a real test, we'd want to verify the session was updated properly
def test_api_endpoints_structure(self):
"""Test that all expected endpoints are properly structured."""
# Test endpoint paths are correctly formed
from api.routers.source_chat import router
routes = [route.path for route in router.routes] # type: ignore[attr-defined]
expected_routes = [
"/sources/{source_id}/chat/sessions",
"/sources/{source_id}/chat/sessions/{session_id}",
"/sources/{source_id}/chat/sessions/{session_id}/messages"
]
for expected_route in expected_routes:
assert any(expected_route in route for route in routes), f"Route {expected_route} not found"

259
tests/test_utils.py Normal file
View file

@ -0,0 +1,259 @@
"""
Unit tests for the open_notebook.utils module.
This test suite focuses on testing utility functions that perform actual logic
without heavy mocking - string processing, validation, and algorithms.
"""
import pytest
from open_notebook.utils import (
clean_thinking_content,
compare_versions,
get_installed_version,
parse_thinking_content,
remove_non_ascii,
remove_non_printable,
split_text,
token_count,
)
from open_notebook.utils.context_builder import ContextBuilder, ContextConfig
# ============================================================================
# TEST SUITE 1: Text Utilities
# ============================================================================
class TestTextUtilities:
"""Test suite for text utility functions."""
def test_split_text_empty_string(self):
"""Test splitting empty or very short strings."""
assert split_text("") == []
assert split_text("short") == ["short"]
def test_remove_non_ascii(self):
"""Test removal of non-ASCII characters."""
# Text with various non-ASCII characters
text_with_unicode = "Hello 世界 café naïve émoji 🎉"
result = remove_non_ascii(text_with_unicode)
# Should only contain ASCII characters
assert result == "Hello caf nave moji "
# All characters should be in ASCII range
assert all(ord(char) < 128 for char in result)
def test_remove_non_ascii_pure_ascii(self):
"""Test that pure ASCII text is unchanged."""
text = "Hello World 123 !@#"
result = remove_non_ascii(text)
assert result == text
def test_remove_non_printable(self):
"""Test removal of non-printable characters."""
# Text with various Unicode whitespace and control chars
text = "Hello\u2000World\u200B\u202FTest"
result = remove_non_printable(text)
# Should have regular spaces and printable chars only
assert "Hello" in result
assert "World" in result
assert "Test" in result
def test_remove_non_printable_preserves_newlines(self):
"""Test that newlines and tabs are preserved."""
text = "Line1\nLine2\tTabbed"
result = remove_non_printable(text)
assert "\n" in result
assert "\t" in result
def test_parse_thinking_content_basic(self):
"""Test parsing single thinking block."""
content = "<think>This is my thinking</think>Here is my answer"
thinking, cleaned = parse_thinking_content(content)
assert thinking == "This is my thinking"
assert cleaned == "Here is my answer"
def test_parse_thinking_content_multiple_tags(self):
"""Test parsing multiple thinking blocks."""
content = "<think>First thought</think>Answer<think>Second thought</think>More"
thinking, cleaned = parse_thinking_content(content)
assert "First thought" in thinking
assert "Second thought" in thinking
assert "<think>" not in cleaned
assert "Answer" in cleaned
assert "More" in cleaned
def test_parse_thinking_content_no_tags(self):
"""Test parsing content without thinking tags."""
content = "Just regular content"
thinking, cleaned = parse_thinking_content(content)
assert thinking == ""
assert cleaned == "Just regular content"
def test_parse_thinking_content_invalid_input(self):
"""Test parsing with invalid input types."""
# Non-string input
thinking, cleaned = parse_thinking_content(None)
assert thinking == ""
assert cleaned == ""
# Integer input
thinking, cleaned = parse_thinking_content(123)
assert thinking == ""
assert cleaned == "123"
def test_parse_thinking_content_large_content(self):
"""Test that very large content is not processed."""
large_content = "x" * 200000 # > 100KB limit
thinking, cleaned = parse_thinking_content(large_content)
# Should return unchanged due to size limit
assert thinking == ""
assert cleaned == large_content
def test_clean_thinking_content(self):
"""Test convenience function for cleaning thinking content."""
content = "<think>Internal thoughts</think>Public response"
result = clean_thinking_content(content)
assert "<think>" not in result
assert "Public response" in result
assert "Internal thoughts" not in result
# ============================================================================
# TEST SUITE 2: Token Utilities
# ============================================================================
class TestTokenUtilities:
"""Test suite for token counting fallback behavior."""
def test_token_count_fallback(self):
"""Test fallback when tiktoken raises an error."""
from unittest.mock import patch
# Make tiktoken raise an ImportError to trigger fallback
with patch("tiktoken.get_encoding", side_effect=ImportError("tiktoken not available")):
text = "one two three four five"
count = token_count(text)
# Fallback uses word count * 1.3
# 5 words * 1.3 = 6.5 -> 6
assert isinstance(count, int)
assert count > 0
# ============================================================================
# TEST SUITE 3: Version Utilities
# ============================================================================
class TestVersionUtilities:
"""Test suite for version management functions."""
def test_compare_versions_equal(self):
"""Test comparing equal versions."""
result = compare_versions("1.0.0", "1.0.0")
assert result == 0
def test_compare_versions_less_than(self):
"""Test comparing when first version is less."""
result = compare_versions("1.0.0", "2.0.0")
assert result == -1
result = compare_versions("1.0.0", "1.1.0")
assert result == -1
result = compare_versions("1.0.0", "1.0.1")
assert result == -1
def test_compare_versions_greater_than(self):
"""Test comparing when first version is greater."""
result = compare_versions("2.0.0", "1.0.0")
assert result == 1
result = compare_versions("1.1.0", "1.0.0")
assert result == 1
result = compare_versions("1.0.1", "1.0.0")
assert result == 1
def test_compare_versions_prerelease(self):
"""Test comparing versions with pre-release tags."""
result = compare_versions("1.0.0", "1.0.0-alpha")
assert result == 1 # Release > pre-release
result = compare_versions("1.0.0-beta", "1.0.0-alpha")
assert result == 1 # beta > alpha
def test_get_installed_version_success(self):
"""Test getting installed package version."""
# Test with a known installed package
version = get_installed_version("pytest")
assert isinstance(version, str)
assert len(version) > 0
# Should look like a version (has dots)
assert "." in version
def test_get_installed_version_not_found(self):
"""Test getting version of non-existent package."""
from importlib.metadata import PackageNotFoundError
with pytest.raises(PackageNotFoundError):
get_installed_version("this-package-does-not-exist-12345")
def test_get_version_from_github_invalid_url(self):
"""Test GitHub version fetch with invalid URL."""
from open_notebook.utils.version_utils import get_version_from_github
with pytest.raises(ValueError, match="Not a GitHub URL"):
get_version_from_github("https://example.com/repo")
with pytest.raises(ValueError, match="Invalid GitHub repository URL"):
get_version_from_github("https://github.com/")
# ============================================================================
# TEST SUITE 4: Context Builder Configuration
# ============================================================================
class TestContextBuilder:
"""Test suite for ContextBuilder initialization and configuration."""
def test_context_config_defaults(self):
"""Test ContextConfig default values."""
config = ContextConfig()
assert config.sources == {}
assert config.notes == {}
assert config.include_insights is True
assert config.include_notes is True
assert config.priority_weights is not None
assert "source" in config.priority_weights
assert "note" in config.priority_weights
assert "insight" in config.priority_weights
def test_context_builder_initialization(self):
"""Test ContextBuilder initialization with various params."""
builder = ContextBuilder(
source_id="source:123",
notebook_id="notebook:456",
max_tokens=1000,
include_insights=False
)
assert builder.source_id == "source:123"
assert builder.notebook_id == "notebook:456"
assert builder.max_tokens == 1000
assert builder.include_insights is False
if __name__ == "__main__":
pytest.main([__file__, "-v"])

15
uv.lock
View file

@ -2256,6 +2256,7 @@ dev = [
[package.dev-dependencies]
dev = [
{ name = "pre-commit" },
{ name = "pytest-asyncio" },
{ name = "types-requests" },
]
@ -2302,6 +2303,7 @@ provides-extras = ["dev"]
[package.metadata.requires-dev]
dev = [
{ name = "pre-commit", specifier = ">=4.1.0" },
{ name = "pytest-asyncio", specifier = ">=1.2.0" },
{ name = "types-requests", specifier = ">=2.32.4.20250913" },
]
@ -2973,6 +2975,19 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/a8/a4/20da314d277121d6534b3a980b29035dcd51e6744bd79075a6ce8fa4eb8d/pytest-8.4.2-py3-none-any.whl", hash = "sha256:872f880de3fc3a5bdc88a11b39c9710c3497a547cfa9320bc3c5e62fbf272e79", size = 365750, upload-time = "2025-09-04T14:34:20.226Z" },
]
[[package]]
name = "pytest-asyncio"
version = "1.2.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pytest" },
{ name = "typing-extensions" },
]
sdist = { url = "https://files.pythonhosted.org/packages/42/86/9e3c5f48f7b7b638b216e4b9e645f54d199d7abbbab7a64a13b4e12ba10f/pytest_asyncio-1.2.0.tar.gz", hash = "sha256:c609a64a2a8768462d0c99811ddb8bd2583c33fd33cf7f21af1c142e824ffb57", size = 50119, upload-time = "2025-09-12T07:33:53.816Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/04/93/2fa34714b7a4ae72f2f8dad66ba17dd9a2c793220719e736dda28b7aec27/pytest_asyncio-1.2.0-py3-none-any.whl", hash = "sha256:8e17ae5e46d8e7efe51ab6494dd2010f4ca8dae51652aa3c8d55acf50bfb2e99", size = 15095, upload-time = "2025-09-12T07:33:52.639Z" },
]
[[package]]
name = "python-dateutil"
version = "2.9.0.post0"