From 8398539df86118ded48fdab7ab1ef49b1e6ea9b1 Mon Sep 17 00:00:00 2001 From: LUIS NOVO Date: Mon, 4 Nov 2024 09:51:20 -0300 Subject: [PATCH] add hybrid search to combine text and vector searches --- open_notebook/domain/notebook.py | 145 +++++++++++++++++++++++++++++-- 1 file changed, 138 insertions(+), 7 deletions(-) diff --git a/open_notebook/domain/notebook.py b/open_notebook/domain/notebook.py index 1eb2edd..ee2035e 100644 --- a/open_notebook/domain/notebook.py +++ b/open_notebook/domain/notebook.py @@ -260,23 +260,154 @@ def text_search(keyword: str, results: int, source: bool = True, note: bool = Tr except Exception as e: logger.error(f"Error performing text search: {str(e)}") logger.exception(e) - raise DatabaseOperationError("Failed to perform text search") + raise DatabaseOperationError(e) -def vector_search( - keyword: List[float], results: int, source: bool = True, note: bool = True -): +# def hybrid_search( +# keyword_search: List[str], +# embed_search: List[str], +# results: int = 50, +# source: bool = True, +# note: bool = True, +# ): +# EMBEDDING_MODEL = model_manager.embedding_model +# embed1_vector = EMBEDDING_MODEL.embed(embed1) if embed1 else None + + +# todo: mover o embedding pra ca +def vector_search(keyword: str, results: int, source: bool = True, note: bool = True): if not keyword: raise InvalidInputError("Search keyword cannot be empty") try: + EMBEDDING_MODEL = model_manager.embedding_model + embed = EMBEDDING_MODEL.embed(keyword) results = repo_query( """ - SELECT * FROM fn::vector_search($keyword, $results, $source, $note); + SELECT * FROM fn::vector_search($embed, $results, $source, $note, 0.15); """, - {"keyword": keyword, "results": results, "source": source, "note": note}, + {"embed": embed, "results": results, "source": source, "note": note}, ) return results except Exception as e: logger.error(f"Error performing vector search: {str(e)}") logger.exception(e) - raise DatabaseOperationError("Failed to perform vector search") + raise DatabaseOperationError(e) + + +def hybrid_search( + keyword_search: List[str], + embed_search: List[str], + results: int = 50, + source: bool = True, + note: bool = True, + max_chunks_per_doc: int = 3, + min_results_per_query: int = 3, +) -> Dict[str, List[Dict]]: + if not keyword_search and not embed_search: + raise InvalidInputError("At least one search term required") + + # Process keyword searches + all_keyword_results = {} # Dictionary to store results per keyword + for keyword in keyword_search: + try: + search_results = text_search(keyword, results, source, note) + # Sort results by relevance + sorted_results = sorted( + search_results, key=lambda x: x.get("relevance", 0), reverse=True + ) + # Group by parent_id and limit chunks per document + seen_parent_ids = {} + filtered_results = [] + for result in sorted_results: + parent_id = result["parent_id"] + if parent_id not in seen_parent_ids: + seen_parent_ids[parent_id] = 1 + filtered_results.append(result) + elif seen_parent_ids[parent_id] < max_chunks_per_doc: + seen_parent_ids[parent_id] += 1 + filtered_results.append(result) + all_keyword_results[keyword] = filtered_results + except Exception as e: + logger.warning(f"Error in keyword search for term '{keyword}': {str(e)}") + continue + + # Ensure minimum results from each keyword query + keyword_results = [] + remaining_slots = results + + # First pass: add minimum results from each query + for keyword, query_results in all_keyword_results.items(): + keyword_results.extend(query_results[:min_results_per_query]) + remaining_slots -= min(len(query_results), min_results_per_query) + + # Second pass: fill remaining slots with best results + all_remaining = [] + for keyword, query_results in all_keyword_results.items(): + all_remaining.extend(query_results[min_results_per_query:]) + + # Sort remaining by relevance and add until we hit the limit + all_remaining = sorted( + all_remaining, key=lambda x: x.get("relevance", 0), reverse=True + ) + seen_ids = {r["id"] for r in keyword_results} + for result in all_remaining: + if remaining_slots <= 0: + break + if result["id"] not in seen_ids: + keyword_results.append(result) + seen_ids.add(result["id"]) + remaining_slots -= 1 + + # Process vector searches with the same approach + all_vector_results = {} # Dictionary to store results per embedding + for embed in embed_search: + try: + search_results = vector_search(embed, results, source, note) + # Sort results by similarity + sorted_results = sorted( + search_results, key=lambda x: x.get("similarity", 0), reverse=True + ) + # Group by parent_id and limit chunks per document + seen_parent_ids = {} + filtered_results = [] + for result in sorted_results: + parent_id = result["parent_id"] + if parent_id not in seen_parent_ids: + seen_parent_ids[parent_id] = 1 + filtered_results.append(result) + elif seen_parent_ids[parent_id] < max_chunks_per_doc: + seen_parent_ids[parent_id] += 1 + filtered_results.append(result) + all_vector_results[embed] = filtered_results + except Exception as e: + logger.warning(f"Error in vector search for term '{embed}': {str(e)}") + continue + + # Ensure minimum results from each vector query + vector_results = [] + remaining_slots = results + + # First pass: add minimum results from each query + for embed, query_results in all_vector_results.items(): + vector_results.extend(query_results[:min_results_per_query]) + remaining_slots -= min(len(query_results), min_results_per_query) + + # Second pass: fill remaining slots with best results + all_remaining = [] + for embed, query_results in all_vector_results.items(): + all_remaining.extend(query_results[min_results_per_query:]) + + # Sort remaining by similarity and add until we hit the limit + all_remaining = sorted( + all_remaining, key=lambda x: x.get("similarity", 0), reverse=True + ) + seen_ids = {r["id"] for r in vector_results} + for result in all_remaining: + if remaining_slots <= 0: + break + if result["id"] not in seen_ids: + vector_results.append(result) + seen_ids.add(result["id"]) + remaining_slots -= 1 + + return {"keyword_results": keyword_results, "vector_results": vector_results}