Merge pull request #722 from lfnovo/fix/source-and-credential-bugs

fix: source asset persistence, title preservation, credential cascade delete
This commit is contained in:
Luis Novo 2026-04-06 08:01:56 -03:00 committed by GitHub
commit 33920285ca
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 389 additions and 21 deletions

5
.gitignore vendored
View file

@ -138,4 +138,7 @@ specs/
*.local.yml
**/*.local.md
**/*.local.md
.harness/
.mcp.json

View file

@ -248,7 +248,6 @@ async def update_credential(credential_id: str, request: UpdateCredentialRequest
@router.delete("/{credential_id}", response_model=CredentialDeleteResponse)
async def delete_credential(
credential_id: str,
delete_models: bool = Query(False, description="Also delete linked models"),
migrate_to: Optional[str] = Query(
None, description="Migrate linked models to this credential ID"
),
@ -257,24 +256,13 @@ async def delete_credential(
Delete a credential.
If the credential has linked models:
- Pass delete_models=true to delete them
- Pass migrate_to=<credential_id> to reassign them
- Without either, returns 409 with linked model info
- Pass migrate_to=<credential_id> to reassign them to another credential
- Otherwise, linked models are cascade-deleted automatically
"""
try:
cred = await Credential.get(credential_id)
linked_models = await cred.get_linked_models()
if linked_models and not delete_models and not migrate_to:
raise HTTPException(
status_code=409,
detail={
"message": f"Credential has {len(linked_models)} linked model(s)",
"model_ids": [m.id for m in linked_models],
"model_names": [f"{m.provider}/{m.name}" for m in linked_models],
},
)
deleted_models = 0
if linked_models and migrate_to:
@ -284,8 +272,8 @@ async def delete_credential(
model.credential = target_cred.id
await model.save()
elif linked_models and delete_models:
# Delete linked models
elif linked_models:
# Cascade-delete linked models (default behavior when no migrate_to)
for model in linked_models:
await model.delete()
deleted_models += 1

View file

@ -31,7 +31,7 @@ from api.models import (
from commands.source_commands import SourceProcessingInput
from open_notebook.config import UPLOADS_FOLDER
from open_notebook.database.repository import ensure_record_id, repo_query
from open_notebook.domain.notebook import Notebook, Source
from open_notebook.domain.notebook import Asset, Notebook, Source
from open_notebook.domain.transformation import Transformation
from open_notebook.exceptions import InvalidInputError
@ -353,10 +353,19 @@ async def create_source(
# ASYNC PATH: Create source record first, then queue command
logger.info("Using async processing path")
# Create minimal source record - let SurrealDB generate the ID
# Create source record with asset - let SurrealDB generate the ID
# Persist asset before save so it's available for retry if processing fails
if source_data.type == "link":
source_asset = Asset(url=source_data.url)
elif source_data.type == "upload":
source_asset = Asset(file_path=file_path or source_data.file_path)
else:
source_asset = None
source = Source(
title=source_data.title or "Processing...",
topics=[],
asset=source_asset,
)
await source.save()

View file

@ -106,8 +106,8 @@ async def save_source(state: SourceState) -> dict:
source.asset = Asset(url=content_state.url, file_path=content_state.file_path)
source.full_text = content_state.content
# Preserve existing title if none provided in processed content
if content_state.title:
# Preserve user-set title; only overwrite placeholder or empty titles
if content_state.title and (not source.title or source.title == "Processing..."):
source.title = content_state.title
await source.save()

View file

@ -0,0 +1,100 @@
"""Tests for the credentials API endpoint."""
from unittest.mock import AsyncMock, patch
import pytest
from fastapi.testclient import TestClient
@pytest.fixture
def client():
"""Create test client after environment variables have been cleared by conftest."""
from api.main import app
return TestClient(app)
class TestCredentialCascadeDelete:
"""Tests for #651 - deleting credential cascade-deletes linked models."""
@pytest.mark.asyncio
@patch("api.routers.credentials.Credential.get")
async def test_cascade_delete_linked_models(self, mock_get, client):
"""Deleting credential without options cascade-deletes linked models."""
mock_model1 = AsyncMock()
mock_model1.id = "model:1"
mock_model1.provider = "openai"
mock_model1.name = "gpt-4"
mock_model2 = AsyncMock()
mock_model2.id = "model:2"
mock_model2.provider = "openai"
mock_model2.name = "gpt-3.5-turbo"
mock_cred = AsyncMock()
mock_cred.get_linked_models = AsyncMock(
return_value=[mock_model1, mock_model2]
)
mock_cred.delete = AsyncMock()
mock_get.return_value = mock_cred
response = client.delete("/api/credentials/cred:123")
assert response.status_code == 200
data = response.json()
assert data["deleted_models"] == 2
assert data["message"] == "Credential deleted successfully"
mock_model1.delete.assert_awaited_once()
mock_model2.delete.assert_awaited_once()
mock_cred.delete.assert_awaited_once()
@pytest.mark.asyncio
@patch("api.routers.credentials.Credential.get")
async def test_delete_credential_no_linked_models(self, mock_get, client):
"""Deleting credential with no linked models works cleanly."""
mock_cred = AsyncMock()
mock_cred.get_linked_models = AsyncMock(return_value=[])
mock_cred.delete = AsyncMock()
mock_get.return_value = mock_cred
response = client.delete("/api/credentials/cred:123")
assert response.status_code == 200
data = response.json()
assert data["deleted_models"] == 0
mock_cred.delete.assert_awaited_once()
@pytest.mark.asyncio
@patch("api.routers.credentials.Credential.get")
async def test_migrate_models_instead_of_delete(self, mock_get, client):
"""Passing migrate_to reassigns models instead of deleting them."""
mock_model = AsyncMock()
mock_model.id = "model:1"
mock_model.credential = "cred:123"
mock_model.save = AsyncMock()
mock_cred = AsyncMock()
mock_cred.get_linked_models = AsyncMock(return_value=[mock_model])
mock_cred.delete = AsyncMock()
mock_target_cred = AsyncMock()
mock_target_cred.id = "cred:456"
# First call returns cred to delete, second returns target
mock_get.side_effect = [mock_cred, mock_target_cred]
response = client.delete(
"/api/credentials/cred:123?migrate_to=cred:456"
)
assert response.status_code == 200
data = response.json()
assert data["deleted_models"] == 0 # Models were migrated, not deleted
mock_model.save.assert_awaited_once()
assert mock_model.credential == "cred:456"
mock_cred.delete.assert_awaited_once()
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View file

@ -6,9 +6,12 @@ without heavy mocking of the actual processing logic.
"""
from datetime import datetime
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from open_notebook.domain.notebook import Source
from open_notebook.graphs.prompt import PatternChainState, graph
from open_notebook.graphs.tools import get_current_timestamp
from open_notebook.graphs.transformation import (
@ -151,5 +154,130 @@ class TestTransformationGraph:
assert hasattr(transformation_graph, "ainvoke")
# ============================================================================
# TEST SUITE 4: Source Graph - Title Preservation
# ============================================================================
class TestSaveSourceTitlePreservation:
"""Test save_source node preserves user-set titles (#670)."""
@pytest.mark.asyncio
@patch("open_notebook.graphs.source.Source.get")
async def test_custom_title_preserved(self, mock_get):
"""User-set title is NOT overwritten by content_state.title."""
from open_notebook.graphs.source import save_source
mock_source = MagicMock(spec=Source)
mock_source.title = "My Custom Research Title"
mock_source.save = AsyncMock()
mock_get.return_value = mock_source
content_state = MagicMock()
content_state.title = "video.mp4"
content_state.url = "https://example.com"
content_state.file_path = None
content_state.content = "Some content"
state = {
"source_id": "source:123",
"content_state": content_state,
"embed": False,
"apply_transformations": [],
}
await save_source(state)
assert mock_source.title == "My Custom Research Title"
mock_source.save.assert_awaited_once()
@pytest.mark.asyncio
@patch("open_notebook.graphs.source.Source.get")
async def test_placeholder_title_replaced(self, mock_get):
"""Placeholder 'Processing...' title IS replaced by extracted title."""
from open_notebook.graphs.source import save_source
mock_source = MagicMock(spec=Source)
mock_source.title = "Processing..."
mock_source.save = AsyncMock()
mock_get.return_value = mock_source
content_state = MagicMock()
content_state.title = "Extracted Article Title"
content_state.url = "https://example.com"
content_state.file_path = None
content_state.content = "Some content"
state = {
"source_id": "source:123",
"content_state": content_state,
"embed": False,
"apply_transformations": [],
}
await save_source(state)
assert mock_source.title == "Extracted Article Title"
mock_source.save.assert_awaited_once()
@pytest.mark.asyncio
@patch("open_notebook.graphs.source.Source.get")
async def test_none_title_replaced(self, mock_get):
"""None title IS replaced by extracted title."""
from open_notebook.graphs.source import save_source
mock_source = MagicMock(spec=Source)
mock_source.title = None
mock_source.save = AsyncMock()
mock_get.return_value = mock_source
content_state = MagicMock()
content_state.title = "Extracted Title"
content_state.url = None
content_state.file_path = "/tmp/file.pdf"
content_state.content = "Content"
state = {
"source_id": "source:123",
"content_state": content_state,
"embed": False,
"apply_transformations": [],
}
await save_source(state)
assert mock_source.title == "Extracted Title"
mock_source.save.assert_awaited_once()
@pytest.mark.asyncio
@patch("open_notebook.graphs.source.Source.get")
async def test_empty_title_replaced(self, mock_get):
"""Empty string title IS replaced by extracted title."""
from open_notebook.graphs.source import save_source
mock_source = MagicMock(spec=Source)
mock_source.title = ""
mock_source.save = AsyncMock()
mock_get.return_value = mock_source
content_state = MagicMock()
content_state.title = "Extracted Title"
content_state.url = None
content_state.file_path = None
content_state.content = "Content"
state = {
"source_id": "source:123",
"content_state": content_state,
"embed": False,
"apply_transformations": [],
}
await save_source(state)
assert mock_source.title == "Extracted Title"
mock_source.save.assert_awaited_once()
if __name__ == "__main__":
pytest.main([__file__, "-v"])

140
tests/test_sources_api.py Normal file
View file

@ -0,0 +1,140 @@
"""Tests for the sources API endpoint."""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from open_notebook.domain.notebook import Source
@pytest.fixture
def client():
"""Create test client after environment variables have been cleared by conftest."""
from api.main import app
return TestClient(app)
class TestAsyncSourceAssetPersistence:
"""Tests for #627 - asset is persisted before async processing.
These tests hit the real create_source endpoint with mocked DB/command
calls, verifying that the Source saved to the database has the correct
asset set *before* async processing begins.
"""
@pytest.mark.asyncio
@patch("api.routers.sources.CommandService.submit_command_job", new_callable=AsyncMock)
@patch("api.routers.sources.Source.add_to_notebook", new_callable=AsyncMock)
@patch("api.routers.sources.Notebook.get", new_callable=AsyncMock)
async def test_async_link_source_persists_url_asset(
self, mock_nb_get, mock_add_nb, mock_submit, client
):
"""POST /sources with type=link and async_processing=true persists Asset(url=...)."""
mock_nb_get.return_value = MagicMock()
mock_submit.return_value = "command:123"
saved_sources = []
async def capture_save(self_source):
saved_sources.append(self_source)
self_source.id = "source:fake"
self_source.command = None
with patch.object(Source, "save", autospec=True, side_effect=capture_save):
response = client.post(
"/api/sources",
data={
"type": "link",
"url": "https://example.com/article",
"notebooks": '["notebook:1"]',
"async_processing": "true",
},
)
assert response.status_code == 200
assert len(saved_sources) >= 1
source = saved_sources[0]
assert source.asset is not None
assert source.asset.url == "https://example.com/article"
assert source.asset.file_path is None
@pytest.mark.asyncio
@patch("api.routers.sources.CommandService.submit_command_job", new_callable=AsyncMock)
@patch("api.routers.sources.Source.add_to_notebook", new_callable=AsyncMock)
@patch("api.routers.sources.Notebook.get", new_callable=AsyncMock)
@patch("api.routers.sources.save_uploaded_file", new_callable=AsyncMock)
async def test_async_upload_source_persists_file_asset(
self, mock_upload, mock_nb_get, mock_add_nb, mock_submit, client
):
"""POST /sources with type=upload and async_processing=true persists Asset(file_path=...)."""
mock_nb_get.return_value = MagicMock()
mock_upload.return_value = "/tmp/uploads/video.mp4"
mock_submit.return_value = "command:123"
saved_sources = []
async def capture_save(self_source):
saved_sources.append(self_source)
self_source.id = "source:fake"
self_source.command = None
with patch.object(Source, "save", autospec=True, side_effect=capture_save):
response = client.post(
"/api/sources",
data={
"type": "upload",
"notebooks": '["notebook:1"]',
"async_processing": "true",
},
files={"file": ("video.mp4", b"fake content", "video/mp4")},
)
assert response.status_code == 200
assert len(saved_sources) >= 1
source = saved_sources[0]
assert source.asset is not None
assert source.asset.file_path == "/tmp/uploads/video.mp4"
assert source.asset.url is None
@pytest.mark.asyncio
@patch("api.routers.sources.CommandService.submit_command_job", new_callable=AsyncMock)
@patch("api.routers.sources.Source.add_to_notebook", new_callable=AsyncMock)
@patch("api.routers.sources.Notebook.get", new_callable=AsyncMock)
async def test_async_text_source_has_no_asset(
self, mock_nb_get, mock_add_nb, mock_submit, client
):
"""POST /sources with type=text and async_processing=true has asset=None."""
mock_nb_get.return_value = MagicMock()
mock_submit.return_value = "command:123"
saved_sources = []
async def capture_save(self_source):
saved_sources.append(self_source)
self_source.id = "source:fake"
self_source.command = None
with patch.object(Source, "save", autospec=True, side_effect=capture_save):
response = client.post(
"/api/sources",
data={
"type": "text",
"content": "Some text content",
"notebooks": '["notebook:1"]',
"async_processing": "true",
},
)
assert response.status_code == 200
assert len(saved_sources) >= 1
source = saved_sources[0]
assert source.asset is None
if __name__ == "__main__":
pytest.main([__file__, "-v"])