From e5b253b11dd1490b1554ea32d9f1534398f74a22 Mon Sep 17 00:00:00 2001 From: Luis Novo Date: Tue, 7 Apr 2026 07:51:25 -0300 Subject: [PATCH] fix: prevent SurrealDB injection via order_by and unparameterized queries - Add allowlist validation for order_by param in notebooks endpoint - Parameterize session_id query in source_chat router - Add regex validation in base.py get_all() order_by parameter - Convert async_migrate bump/lower_version to parameterized queries --- api/routers/notebooks.py | 29 +++++++++++++++++++- api/routers/source_chat.py | 4 ++- open_notebook/database/async_migrate.py | 8 ++++-- open_notebook/domain/base.py | 35 ++++++++++++++++++++++++- 4 files changed, 71 insertions(+), 5 deletions(-) diff --git a/api/routers/notebooks.py b/api/routers/notebooks.py index 8a891c5..fd08e44 100644 --- a/api/routers/notebooks.py +++ b/api/routers/notebooks.py @@ -24,13 +24,38 @@ async def get_notebooks( ): """Get all notebooks with optional filtering and ordering.""" try: + # Validate order_by against allowlist to prevent SurrealQL injection + allowed_fields = {"name", "created", "updated"} + allowed_directions = {"asc", "desc"} + + parts = order_by.strip().lower().split() + if len(parts) == 1: + if parts[0] not in allowed_fields: + raise HTTPException( + status_code=400, + detail=f"Invalid order_by field: '{order_by}'. Allowed fields: {', '.join(sorted(allowed_fields))}", + ) + validated_order_by = parts[0] + elif len(parts) == 2: + if parts[0] not in allowed_fields or parts[1] not in allowed_directions: + raise HTTPException( + status_code=400, + detail=f"Invalid order_by: '{order_by}'. Allowed fields: {', '.join(sorted(allowed_fields))}. Allowed directions: asc, desc", + ) + validated_order_by = f"{parts[0]} {parts[1]}" + else: + raise HTTPException( + status_code=400, + detail=f"Invalid order_by format: '{order_by}'. Expected 'field' or 'field direction'", + ) + # Build the query with counts query = f""" SELECT *, count(<-reference.in) as source_count, count(<-artifact.in) as note_count FROM notebook - ORDER BY {order_by} + ORDER BY {validated_order_by} """ result = await repo_query(query) @@ -52,6 +77,8 @@ async def get_notebooks( ) for nb in result ] + except HTTPException: + raise except Exception as e: logger.error(f"Error fetching notebooks: {str(e)}") raise HTTPException( diff --git a/api/routers/source_chat.py b/api/routers/source_chat.py index 5fdde19..ca304de 100644 --- a/api/routers/source_chat.py +++ b/api/routers/source_chat.py @@ -155,7 +155,9 @@ async def get_source_chat_sessions(source_id: str = Path(..., description="Sourc if session_id_raw: session_id = str(session_id_raw) - session_result = await repo_query(f"SELECT * FROM {session_id_raw}") + session_result = await repo_query( + "SELECT * FROM $id", {"id": ensure_record_id(session_id)} + ) if session_result and len(session_result) > 0: session_data = session_result[0] diff --git a/open_notebook/database/async_migrate.py b/open_notebook/database/async_migrate.py index 611fd6c..ef8cac8 100644 --- a/open_notebook/database/async_migrate.py +++ b/open_notebook/database/async_migrate.py @@ -223,7 +223,8 @@ async def bump_version() -> None: new_version = current_version + 1 await repo_query( - f"CREATE _sbl_migrations:{new_version} SET version = {new_version}, applied_at = time::now();", + "CREATE type::thing('_sbl_migrations', $version) SET version = $version, applied_at = time::now();", + {"version": new_version}, ) @@ -231,4 +232,7 @@ async def lower_version() -> None: """Lower the version by removing the latest entry from migrations table.""" current_version = await get_latest_version() if current_version > 0: - await repo_query(f"DELETE _sbl_migrations:{current_version};") + await repo_query( + "DELETE type::thing('_sbl_migrations', $version);", + {"version": current_version}, + ) diff --git a/open_notebook/domain/base.py b/open_notebook/domain/base.py index 08b1f1f..45459d4 100644 --- a/open_notebook/domain/base.py +++ b/open_notebook/domain/base.py @@ -48,7 +48,40 @@ class ObjectModel(BaseModel): "get_all() must be called from a specific model class" ) if order_by: - query = f"SELECT * FROM {table_name} ORDER BY {order_by}" + # Validate order_by to prevent SurrealQL injection + # Supports: "field", "field direction", "field1 direction, field2 direction" + import re + + allowed_field_pattern = re.compile(r"^[a-z_][a-z0-9_]*$") + allowed_directions = {"asc", "desc"} + + clauses = [c.strip() for c in order_by.split(",")] + validated_clauses = [] + for clause in clauses: + parts = clause.strip().split() + if len(parts) == 1: + if not allowed_field_pattern.match(parts[0].lower()): + raise InvalidInputError( + f"Invalid order_by field: '{parts[0]}'" + ) + validated_clauses.append(parts[0].lower()) + elif len(parts) == 2: + if not allowed_field_pattern.match( + parts[0].lower() + ) or parts[1].lower() not in allowed_directions: + raise InvalidInputError( + f"Invalid order_by clause: '{clause.strip()}'" + ) + validated_clauses.append( + f"{parts[0].lower()} {parts[1].lower()}" + ) + else: + raise InvalidInputError( + f"Invalid order_by clause: '{clause.strip()}'" + ) + + validated_order_by = ", ".join(validated_clauses) + query = f"SELECT * FROM {table_name} ORDER BY {validated_order_by}" else: query = f"SELECT * FROM {table_name}"