open-notebook/open_notebook/domain/base.py
Luis Novo 45a99831a9
Hide sources notes (#273)
* fix: add missing overflow wrapper to notebooks list page

Adds flex-1 overflow-y-auto wrapper to enable proper scrolling
when notebook list exceeds viewport height. Matches the layout
pattern used by all other dashboard pages.

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: reorder transformation routes to prevent dynamic route interception

Moved static routes (/transformations/execute and /transformations/default-prompt)
before dynamic routes (/transformations/{transformation_id}) to ensure FastAPI
matches them correctly. Previously, requests to static routes were incorrectly
captured by the dynamic route handler.

Fixes #250

Co-Authored-By: Claude <noreply@anthropic.com>

* chore: bump to 1.2.1

* hide source and notes panel - fixes #193

* feat: improve layout for mobile views

* bump version to 1.2.2

* fix: address PR review feedback for collapsible columns

- Remove unused CollapseButton component from CollapsibleColumn.tsx
- Rename useCollapseButton to createCollapseButton (not a React hook)
- Move dialogs outside Card in SourcesColumn.tsx for consistency
- Add useMemo for collapseButton in both columns to prevent re-renders

* feat: support multiple sources

* fix: prevent ChatColumn double mounting on desktop

Add useIsDesktop hook to conditionally render mobile view only on
mobile screens. Previously, the mobile ChatColumn was hidden via CSS
on desktop but still mounted, causing duplicate hooks initialization
and redundant network requests.

---------

Co-authored-by: Claude <noreply@anthropic.com>
2025-11-25 16:59:26 -03:00

329 lines
12 KiB
Python

from datetime import datetime
from typing import Any, ClassVar, Dict, List, Optional, Type, TypeVar, Union, cast
from loguru import logger
from pydantic import BaseModel, ValidationError, field_validator, model_validator
from open_notebook.database.repository import (
ensure_record_id,
repo_create,
repo_delete,
repo_query,
repo_relate,
repo_update,
repo_upsert,
)
from open_notebook.exceptions import (
DatabaseOperationError,
InvalidInputError,
NotFoundError,
)
T = TypeVar("T", bound="ObjectModel")
class ObjectModel(BaseModel):
id: Optional[str] = None
table_name: ClassVar[str] = ""
created: Optional[datetime] = None
updated: Optional[datetime] = None
@classmethod
async def get_all(cls: Type[T], order_by=None) -> List[T]:
try:
# If called from a specific subclass, use its table_name
if cls.table_name:
target_class = cls
table_name = cls.table_name
else:
# This path is taken if called directly from ObjectModel
raise InvalidInputError(
"get_all() must be called from a specific model class"
)
if order_by:
query = f"SELECT * FROM {table_name} ORDER BY {order_by}"
else:
query = f"SELECT * FROM {table_name}"
result = await repo_query(query)
objects = []
for obj in result:
try:
objects.append(target_class(**obj))
except Exception as e:
logger.critical(f"Error creating object: {str(e)}")
return objects
except Exception as e:
logger.error(f"Error fetching all {cls.table_name}: {str(e)}")
logger.exception(e)
raise DatabaseOperationError(e)
@classmethod
async def get(cls: Type[T], id: str) -> T:
if not id:
raise InvalidInputError("ID cannot be empty")
try:
# Get the table name from the ID (everything before the first colon)
table_name = id.split(":")[0] if ":" in id else id
# If we're calling from a specific subclass and IDs match, use that class
if cls.table_name and cls.table_name == table_name:
target_class: Type[T] = cls
else:
# Otherwise, find the appropriate subclass based on table_name
found_class = cls._get_class_by_table_name(table_name)
if not found_class:
raise InvalidInputError(f"No class found for table {table_name}")
target_class = cast(Type[T], found_class)
result = await repo_query("SELECT * FROM $id", {"id": ensure_record_id(id)})
if result:
return target_class(**result[0])
else:
raise NotFoundError(f"{table_name} with id {id} not found")
except Exception as e:
logger.error(f"Error fetching object with id {id}: {str(e)}")
logger.exception(e)
raise NotFoundError(f"Object with id {id} not found - {str(e)}")
@classmethod
def _get_class_by_table_name(cls, table_name: str) -> Optional[Type["ObjectModel"]]:
"""Find the appropriate subclass based on table_name."""
def get_all_subclasses(c: Type["ObjectModel"]) -> List[Type["ObjectModel"]]:
all_subclasses: List[Type["ObjectModel"]] = []
for subclass in c.__subclasses__():
all_subclasses.append(subclass)
all_subclasses.extend(get_all_subclasses(subclass))
return all_subclasses
for subclass in get_all_subclasses(ObjectModel):
if hasattr(subclass, "table_name") and subclass.table_name == table_name:
return subclass
return None
def needs_embedding(self) -> bool:
return False
def get_embedding_content(self) -> Optional[str]:
return None
async def save(self) -> None:
from open_notebook.domain.models import model_manager
try:
self.model_validate(self.model_dump(), strict=True)
data = self._prepare_save_data()
data["updated"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
if self.needs_embedding():
embedding_content = self.get_embedding_content()
if embedding_content:
EMBEDDING_MODEL = await model_manager.get_embedding_model()
if not EMBEDDING_MODEL:
logger.warning(
"No embedding model found. Content will not be searchable."
)
data["embedding"] = (
(await EMBEDDING_MODEL.aembed([embedding_content]))[0]
if EMBEDDING_MODEL
else []
)
repo_result: Union[List[Dict[str, Any]], Dict[str, Any]]
if self.id is None:
data["created"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
repo_result = await repo_create(self.__class__.table_name, data)
else:
data["created"] = (
self.created.strftime("%Y-%m-%d %H:%M:%S")
if isinstance(self.created, datetime)
else self.created
)
logger.debug(f"Updating record with id {self.id}")
repo_result = await repo_update(
self.__class__.table_name, self.id, data
)
# Update the current instance with the result
# repo_result is a list of dictionaries
result_list: List[Dict[str, Any]] = repo_result if isinstance(repo_result, list) else [repo_result]
for key, value in result_list[0].items():
if hasattr(self, key):
if isinstance(getattr(self, key), BaseModel):
setattr(self, key, type(getattr(self, key))(**value))
else:
setattr(self, key, value)
except ValidationError as e:
logger.error(f"Validation failed: {e}")
raise
except RuntimeError:
# Transaction conflicts should propagate for retry
raise
except Exception as e:
logger.error(f"Error saving record: {e}")
raise DatabaseOperationError(e)
def _prepare_save_data(self) -> Dict[str, Any]:
data = self.model_dump()
return {key: value for key, value in data.items() if value is not None}
async def delete(self) -> bool:
if self.id is None:
raise InvalidInputError("Cannot delete object without an ID")
try:
logger.debug(f"Deleting record with id {self.id}")
return await repo_delete(self.id)
except Exception as e:
logger.error(
f"Error deleting {self.__class__.table_name} with id {self.id}: {str(e)}"
)
raise DatabaseOperationError(
f"Failed to delete {self.__class__.table_name}"
)
async def relate(
self, relationship: str, target_id: str, data: Optional[Dict] = {}
) -> Any:
if not relationship or not target_id or not self.id:
raise InvalidInputError("Relationship and target ID must be provided")
try:
return await repo_relate(
source=self.id, relationship=relationship, target=target_id, data=data
)
except Exception as e:
logger.error(f"Error creating relationship: {str(e)}")
logger.exception(e)
raise DatabaseOperationError(e)
@field_validator("created", "updated", mode="before")
@classmethod
def parse_datetime(cls, value):
if isinstance(value, str):
return datetime.fromisoformat(value.replace("Z", "+00:00"))
return value
class RecordModel(BaseModel):
record_id: ClassVar[str]
auto_save: ClassVar[bool] = (
False # Default to False, can be overridden in subclasses
)
_instances: ClassVar[Dict[str, "RecordModel"]] = {} # Store instances by record_id
class Config:
validate_assignment = True
arbitrary_types_allowed = True
extra = "allow"
from_attributes = True
defer_build = True
def __new__(cls, **kwargs):
# If an instance already exists for this record_id, return it
if cls.record_id in cls._instances:
instance = cls._instances[cls.record_id]
# Update instance with any new kwargs if provided
if kwargs:
for key, value in kwargs.items():
setattr(instance, key, value)
return instance
# If no instance exists, create a new one
instance = super().__new__(cls)
cls._instances[cls.record_id] = instance
return instance
def __init__(self, **kwargs):
# Only initialize if this is a new instance
if not hasattr(self, "_initialized"):
object.__setattr__(self, "__dict__", {})
# For RecordModel, we need to handle async initialization differently
# Initialize with provided kwargs only for now
super().__init__(**kwargs)
# Mark as initialized but not loaded from DB yet
object.__setattr__(self, "_initialized", True)
object.__setattr__(self, "_db_loaded", False)
async def _load_from_db(self):
"""Load data from database if not already loaded"""
if not getattr(self, "_db_loaded", False):
result = await repo_query(
"SELECT * FROM ONLY $record_id",
{"record_id": ensure_record_id(self.record_id)},
)
# Handle case where record doesn't exist yet
if result:
if isinstance(result, list) and len(result) > 0:
# Standard list response
row = result[0]
if isinstance(row, dict):
for key, value in row.items():
if hasattr(self, key):
object.__setattr__(self, key, value)
elif isinstance(result, dict):
# Direct dict response
for key, value in result.items():
if hasattr(self, key):
object.__setattr__(self, key, value)
object.__setattr__(self, "_db_loaded", True)
@classmethod
async def get_instance(cls) -> "RecordModel":
"""Get or create the singleton instance and load from DB"""
instance = cls()
await instance._load_from_db()
return instance
@model_validator(mode="after")
def auto_save_validator(self):
if self.__class__.auto_save:
# Auto-save can't work with async - log warning
logger.warning(
f"Auto-save is enabled for {self.__class__.__name__} but update() is now async. Call await instance.update() manually."
)
return self
async def update(self):
# Get all non-ClassVar fields and their values
data = {
field_name: getattr(self, field_name)
for field_name, field_info in self.model_fields.items()
if not str(field_info.annotation).startswith("typing.ClassVar")
}
await repo_upsert(
self.__class__.table_name
if hasattr(self.__class__, "table_name")
else "record",
self.record_id,
data,
)
result = await repo_query(
"SELECT * FROM $record_id", {"record_id": ensure_record_id(self.record_id)}
)
if result:
for key, value in result[0].items():
if hasattr(self, key):
object.__setattr__(
self, key, value
) # Use object.__setattr__ to avoid triggering validation again
return self
@classmethod
def clear_instance(cls):
"""Clear the singleton instance (useful for testing)"""
if cls.record_id in cls._instances:
del cls._instances[cls.record_id]
async def patch(self, model_dict: dict):
"""Update model attributes from dictionary and save"""
for key, value in model_dict.items():
setattr(self, key, value)
await self.update()