diff --git a/open_notebook/domain/notebook.py b/open_notebook/domain/notebook.py index 21c9965..2b8e422 100644 --- a/open_notebook/domain/notebook.py +++ b/open_notebook/domain/notebook.py @@ -213,27 +213,6 @@ class Source(ObjectModel): logger.exception(e) raise DatabaseOperationError(e) - # @classmethod - # def search(cls, query: str) -> List[Dict[str, Any]]: - # if not query: - # raise InvalidInputError("Search query cannot be empty") - # try: - # result = repo_query( - # """ - # SELECT * omit full_text - # FROM source - # WHERE string::lowercase(title) CONTAINS $query or title @@ $query - # OR string::lowercase(summary) CONTAINS $query or summary @@ $query - # OR string::lowercase(full_text) CONTAINS $query or full_text @@ $query - # """, - # {"query": query}, - # ) - # return result - # except Exception as e: - # logger.error(f"Error searching sources: {str(e)}") - # logger.exception(e) - # raise DatabaseOperationError("Failed to search sources") - def add_insight(self, insight_type: str, content: str) -> Any: EMBEDDING_MODEL = model_manager.embedding_model diff --git a/open_notebook/graphs/source.py b/open_notebook/graphs/source.py index 6582553..dbce6da 100644 --- a/open_notebook/graphs/source.py +++ b/open_notebook/graphs/source.py @@ -23,6 +23,7 @@ class SourceState(TypedDict): notebook_id: str source: Source transformations: Annotated[list, operator.add] + embed: bool = False class TransformationState(TypedDict): @@ -102,6 +103,14 @@ async def transform_content(state: TransformationState) -> dict: return {"transformations": [{"name": transformation["name"], "content": result}]} +async def embed_content(state: SourceState) -> dict: + source: Source = state["source"] + if state["embed"]: + logger.debug("Embedding content for vector search") + source.vectorize() + return {"source": source} + + # Create and compile the workflow workflow = StateGraph(SourceState) @@ -109,14 +118,15 @@ workflow = StateGraph(SourceState) workflow.add_node("content_process", content_process) workflow.add_node("save_source", save_source) workflow.add_node("transform_content", transform_content) - +workflow.add_node("embed_content", embed_content) # Define the graph edges workflow.add_edge(START, "content_process") workflow.add_edge("content_process", "save_source") workflow.add_conditional_edges( "save_source", trigger_transformations, ["transform_content"] ) -workflow.add_edge("transform_content", END) +workflow.add_edge("transform_content", "embed_content") +workflow.add_edge("embed_content", END) # Compile the graph source_graph = workflow.compile() diff --git a/pages/stream_app/source.py b/pages/stream_app/source.py index 550553a..d37efcc 100644 --- a/pages/stream_app/source.py +++ b/pages/stream_app/source.py @@ -47,6 +47,7 @@ def add_source(notebook_id): options=available_transformations, default=default_transformations, ) + embed = st.checkbox("Embed content for vector search", value=False) if st.button("Process", key="add_source"): logger.debug("Adding source") with st.status("Processing...", expanded=True): @@ -71,13 +72,13 @@ def add_source(notebook_id): with open(new_path, "wb") as f: f.write(source_file.getbuffer()) - st.write("Processing content..") asyncio.run( source_graph.ainvoke( { "content_state": req, "notebook_id": notebook_id, "transformations": apply_transformations, + "embed": embed, } ) )