add initial embedding to the content graph
This commit is contained in:
parent
01cf15e7d1
commit
817b1bc7f9
3 changed files with 14 additions and 24 deletions
|
|
@ -213,27 +213,6 @@ class Source(ObjectModel):
|
|||
logger.exception(e)
|
||||
raise DatabaseOperationError(e)
|
||||
|
||||
# @classmethod
|
||||
# def search(cls, query: str) -> List[Dict[str, Any]]:
|
||||
# if not query:
|
||||
# raise InvalidInputError("Search query cannot be empty")
|
||||
# try:
|
||||
# result = repo_query(
|
||||
# """
|
||||
# SELECT * omit full_text
|
||||
# FROM source
|
||||
# WHERE string::lowercase(title) CONTAINS $query or title @@ $query
|
||||
# OR string::lowercase(summary) CONTAINS $query or summary @@ $query
|
||||
# OR string::lowercase(full_text) CONTAINS $query or full_text @@ $query
|
||||
# """,
|
||||
# {"query": query},
|
||||
# )
|
||||
# return result
|
||||
# except Exception as e:
|
||||
# logger.error(f"Error searching sources: {str(e)}")
|
||||
# logger.exception(e)
|
||||
# raise DatabaseOperationError("Failed to search sources")
|
||||
|
||||
def add_insight(self, insight_type: str, content: str) -> Any:
|
||||
EMBEDDING_MODEL = model_manager.embedding_model
|
||||
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ class SourceState(TypedDict):
|
|||
notebook_id: str
|
||||
source: Source
|
||||
transformations: Annotated[list, operator.add]
|
||||
embed: bool = False
|
||||
|
||||
|
||||
class TransformationState(TypedDict):
|
||||
|
|
@ -102,6 +103,14 @@ async def transform_content(state: TransformationState) -> dict:
|
|||
return {"transformations": [{"name": transformation["name"], "content": result}]}
|
||||
|
||||
|
||||
async def embed_content(state: SourceState) -> dict:
|
||||
source: Source = state["source"]
|
||||
if state["embed"]:
|
||||
logger.debug("Embedding content for vector search")
|
||||
source.vectorize()
|
||||
return {"source": source}
|
||||
|
||||
|
||||
# Create and compile the workflow
|
||||
workflow = StateGraph(SourceState)
|
||||
|
||||
|
|
@ -109,14 +118,15 @@ workflow = StateGraph(SourceState)
|
|||
workflow.add_node("content_process", content_process)
|
||||
workflow.add_node("save_source", save_source)
|
||||
workflow.add_node("transform_content", transform_content)
|
||||
|
||||
workflow.add_node("embed_content", embed_content)
|
||||
# Define the graph edges
|
||||
workflow.add_edge(START, "content_process")
|
||||
workflow.add_edge("content_process", "save_source")
|
||||
workflow.add_conditional_edges(
|
||||
"save_source", trigger_transformations, ["transform_content"]
|
||||
)
|
||||
workflow.add_edge("transform_content", END)
|
||||
workflow.add_edge("transform_content", "embed_content")
|
||||
workflow.add_edge("embed_content", END)
|
||||
|
||||
# Compile the graph
|
||||
source_graph = workflow.compile()
|
||||
|
|
|
|||
|
|
@ -47,6 +47,7 @@ def add_source(notebook_id):
|
|||
options=available_transformations,
|
||||
default=default_transformations,
|
||||
)
|
||||
embed = st.checkbox("Embed content for vector search", value=False)
|
||||
if st.button("Process", key="add_source"):
|
||||
logger.debug("Adding source")
|
||||
with st.status("Processing...", expanded=True):
|
||||
|
|
@ -71,13 +72,13 @@ def add_source(notebook_id):
|
|||
with open(new_path, "wb") as f:
|
||||
f.write(source_file.getbuffer())
|
||||
|
||||
st.write("Processing content..")
|
||||
asyncio.run(
|
||||
source_graph.ainvoke(
|
||||
{
|
||||
"content_state": req,
|
||||
"notebook_id": notebook_id,
|
||||
"transformations": apply_transformations,
|
||||
"embed": embed,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in a new issue