improve search
This commit is contained in:
parent
0524eddb0b
commit
53da255801
3 changed files with 54 additions and 138 deletions
|
|
@ -309,7 +309,8 @@ def text_search(keyword: str, results: int, source: bool = True, note: bool = Tr
|
|||
try:
|
||||
results = repo_query(
|
||||
"""
|
||||
SELECT * FROM fn::text_search($keyword, $results, $source, $note);
|
||||
select *
|
||||
from fn::text_search($keyword, $results, $source, $note)
|
||||
""",
|
||||
{"keyword": keyword, "results": results, "source": source, "note": note},
|
||||
)
|
||||
|
|
@ -320,7 +321,13 @@ def text_search(keyword: str, results: int, source: bool = True, note: bool = Tr
|
|||
raise DatabaseOperationError(e)
|
||||
|
||||
|
||||
def vector_search(keyword: str, results: int, source: bool = True, note: bool = True):
|
||||
def vector_search(
|
||||
keyword: str,
|
||||
results: int,
|
||||
source: bool = True,
|
||||
note: bool = True,
|
||||
minimum_score=0.2,
|
||||
):
|
||||
if not keyword:
|
||||
raise InvalidInputError("Search keyword cannot be empty")
|
||||
try:
|
||||
|
|
@ -328,131 +335,18 @@ def vector_search(keyword: str, results: int, source: bool = True, note: bool =
|
|||
embed = EMBEDDING_MODEL.embed(keyword)
|
||||
results = repo_query(
|
||||
"""
|
||||
SELECT * FROM fn::vector_search($embed, $results, $source, $note, 0.15);
|
||||
SELECT * FROM fn::vector_search($embed, $results, $source, $note, $minimum_score);
|
||||
""",
|
||||
{"embed": embed, "results": results, "source": source, "note": note},
|
||||
{
|
||||
"embed": embed,
|
||||
"results": results,
|
||||
"source": source,
|
||||
"note": note,
|
||||
"minimum_score": minimum_score,
|
||||
},
|
||||
)
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.error(f"Error performing vector search: {str(e)}")
|
||||
logger.exception(e)
|
||||
raise DatabaseOperationError(e)
|
||||
|
||||
|
||||
def hybrid_search(
|
||||
keyword_search: List[str],
|
||||
embed_search: List[str],
|
||||
results: int = 50,
|
||||
source: bool = True,
|
||||
note: bool = True,
|
||||
max_chunks_per_doc: int = 3,
|
||||
min_results_per_query: int = 3,
|
||||
) -> Dict[str, List[Dict]]:
|
||||
if not keyword_search and not embed_search:
|
||||
raise InvalidInputError("At least one search term required")
|
||||
|
||||
# Process keyword searches
|
||||
all_keyword_results = {} # Dictionary to store results per keyword
|
||||
for keyword in keyword_search:
|
||||
try:
|
||||
search_results = text_search(keyword, results, source, note)
|
||||
# Sort results by relevance
|
||||
sorted_results = sorted(
|
||||
search_results, key=lambda x: x.get("relevance", 0), reverse=True
|
||||
)
|
||||
# Group by parent_id and limit chunks per document
|
||||
seen_parent_ids = {}
|
||||
filtered_results = []
|
||||
for result in sorted_results:
|
||||
parent_id = result["parent_id"]
|
||||
if parent_id not in seen_parent_ids:
|
||||
seen_parent_ids[parent_id] = 1
|
||||
filtered_results.append(result)
|
||||
elif seen_parent_ids[parent_id] < max_chunks_per_doc:
|
||||
seen_parent_ids[parent_id] += 1
|
||||
filtered_results.append(result)
|
||||
all_keyword_results[keyword] = filtered_results
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in keyword search for term '{keyword}': {str(e)}")
|
||||
continue
|
||||
|
||||
# Ensure minimum results from each keyword query
|
||||
keyword_results = []
|
||||
remaining_slots = results
|
||||
|
||||
# First pass: add minimum results from each query
|
||||
for keyword, query_results in all_keyword_results.items():
|
||||
keyword_results.extend(query_results[:min_results_per_query])
|
||||
remaining_slots -= min(len(query_results), min_results_per_query)
|
||||
|
||||
# Second pass: fill remaining slots with best results
|
||||
all_remaining = []
|
||||
for keyword, query_results in all_keyword_results.items():
|
||||
all_remaining.extend(query_results[min_results_per_query:])
|
||||
|
||||
# Sort remaining by relevance and add until we hit the limit
|
||||
all_remaining = sorted(
|
||||
all_remaining, key=lambda x: x.get("relevance", 0), reverse=True
|
||||
)
|
||||
seen_ids = {r["id"] for r in keyword_results}
|
||||
for result in all_remaining:
|
||||
if remaining_slots <= 0:
|
||||
break
|
||||
if result["id"] not in seen_ids:
|
||||
keyword_results.append(result)
|
||||
seen_ids.add(result["id"])
|
||||
remaining_slots -= 1
|
||||
|
||||
# Process vector searches with the same approach
|
||||
all_vector_results = {} # Dictionary to store results per embedding
|
||||
for embed in embed_search:
|
||||
try:
|
||||
search_results = vector_search(embed, results, source, note)
|
||||
# Sort results by similarity
|
||||
sorted_results = sorted(
|
||||
search_results, key=lambda x: x.get("similarity", 0), reverse=True
|
||||
)
|
||||
# Group by parent_id and limit chunks per document
|
||||
seen_parent_ids = {}
|
||||
filtered_results = []
|
||||
for result in sorted_results:
|
||||
parent_id = result["parent_id"]
|
||||
if parent_id not in seen_parent_ids:
|
||||
seen_parent_ids[parent_id] = 1
|
||||
filtered_results.append(result)
|
||||
elif seen_parent_ids[parent_id] < max_chunks_per_doc:
|
||||
seen_parent_ids[parent_id] += 1
|
||||
filtered_results.append(result)
|
||||
all_vector_results[embed] = filtered_results
|
||||
except Exception as e:
|
||||
logger.warning(f"Error in vector search for term '{embed}': {str(e)}")
|
||||
continue
|
||||
|
||||
# Ensure minimum results from each vector query
|
||||
vector_results = []
|
||||
remaining_slots = results
|
||||
|
||||
# First pass: add minimum results from each query
|
||||
for embed, query_results in all_vector_results.items():
|
||||
vector_results.extend(query_results[:min_results_per_query])
|
||||
remaining_slots -= min(len(query_results), min_results_per_query)
|
||||
|
||||
# Second pass: fill remaining slots with best results
|
||||
all_remaining = []
|
||||
for embed, query_results in all_vector_results.items():
|
||||
all_remaining.extend(query_results[min_results_per_query:])
|
||||
|
||||
# Sort remaining by similarity and add until we hit the limit
|
||||
all_remaining = sorted(
|
||||
all_remaining, key=lambda x: x.get("similarity", 0), reverse=True
|
||||
)
|
||||
seen_ids = {r["id"] for r in vector_results}
|
||||
for result in all_remaining:
|
||||
if remaining_slots <= 0:
|
||||
break
|
||||
if result["id"] not in seen_ids:
|
||||
vector_results.append(result)
|
||||
seen_ids.add(result["id"])
|
||||
remaining_slots -= 1
|
||||
|
||||
return {"keyword_results": keyword_results, "vector_results": vector_results}
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import streamlit as st
|
|||
|
||||
from open_notebook.domain.models import Model
|
||||
from open_notebook.domain.notebook import text_search, vector_search
|
||||
from open_notebook.graphs.rag import graph as rag_graph
|
||||
from open_notebook.graphs.ask import graph as ask_graph
|
||||
from pages.stream_app.utils import convert_source_references, setup_page
|
||||
|
||||
setup_page("🔍 Search")
|
||||
|
|
@ -15,10 +15,13 @@ if "search_results" not in st.session_state:
|
|||
|
||||
def results_card(item):
|
||||
score = item.get("relevance", item.get("similarity", item.get("score", 0)))
|
||||
with st.expander(f"[{score:.2f}] **{item['title']}**"):
|
||||
st.markdown(f"**{item['content']}**")
|
||||
st.write(item["id"])
|
||||
st.write(item["parent_id"])
|
||||
with st.container(border=True):
|
||||
st.markdown(
|
||||
f"[{score:.2f}] **[{item['title']}](/?object_id={item['parent_id']})**"
|
||||
)
|
||||
with st.expander("Matches"):
|
||||
for match in item["matches"]:
|
||||
st.markdown(match)
|
||||
|
||||
|
||||
with ask_tab:
|
||||
|
|
@ -26,22 +29,41 @@ with ask_tab:
|
|||
st.caption(
|
||||
"The LLM will answer your query based on the documents in your knowledge base. "
|
||||
)
|
||||
st.warning(
|
||||
"This functionality requires the use of Tools and, at this moment, works well with Open AI and Anthropic models only."
|
||||
)
|
||||
question = st.text_input("Question", "")
|
||||
models = Model.get_models_by_type("language")
|
||||
model: Model = st.selectbox("Model", models, format_func=lambda x: x.name)
|
||||
strategy_model: Model = st.selectbox(
|
||||
"Query Strategy Model",
|
||||
models,
|
||||
format_func=lambda x: x.name,
|
||||
help="This is the LLM that will be responsible for strategizing the search",
|
||||
)
|
||||
answer_model: Model = st.selectbox(
|
||||
"Indivual Answer Model",
|
||||
models,
|
||||
format_func=lambda x: x.name,
|
||||
help="This is the LLM that will be responsible for processing individual subqueries",
|
||||
)
|
||||
final_answer_model: Model = st.selectbox(
|
||||
"Final Answer Model",
|
||||
models,
|
||||
format_func=lambda x: x.name,
|
||||
help="This is the LLM that will be responsible for processing the final answer",
|
||||
)
|
||||
if st.button("Ask"):
|
||||
st.write(f"Searching for {question}")
|
||||
messages = [question]
|
||||
rag_results = rag_graph.invoke(
|
||||
rag_results = ask_graph.invoke(
|
||||
dict(
|
||||
messages=messages,
|
||||
question=question,
|
||||
),
|
||||
config=dict(
|
||||
configurable=dict(
|
||||
strategy_model=strategy_model.id,
|
||||
answer_model=answer_model.id,
|
||||
final_answer_model=final_answer_model.id,
|
||||
)
|
||||
),
|
||||
config=dict(configurable=dict(model_id=model.id)),
|
||||
)
|
||||
st.markdown(convert_source_references(rag_results["messages"][-1].content))
|
||||
st.markdown(convert_source_references(rag_results["final_answer"]))
|
||||
with st.expander("Details (for debugging)"):
|
||||
st.json(rag_results)
|
||||
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ def generate_toc_and_title(source) -> "Source":
|
|||
|
||||
@st.dialog("Source", width="large")
|
||||
def source_panel_dialog(source_id):
|
||||
source_panel(source_id)
|
||||
source_panel(source_id, modal=True)
|
||||
|
||||
|
||||
@st.dialog("Add a Source", width="large")
|
||||
|
|
|
|||
Loading…
Reference in a new issue