From 2e2a4947b3decabc16679788a9e3572a2a99c794 Mon Sep 17 00:00:00 2001 From: LUIS NOVO Date: Sun, 10 Nov 2024 13:30:03 -0300 Subject: [PATCH] separate source and content graph --- open_notebook/domain/transformation.py | 19 ++++ .../graphs/content_processing/__init__.py | 13 ++- .../graphs/content_processing/audio.py | 4 +- .../graphs/content_processing/office.py | 4 +- .../graphs/content_processing/pdf.py | 4 +- .../graphs/content_processing/state.py | 2 +- .../graphs/content_processing/text.py | 4 +- .../graphs/content_processing/url.py | 6 +- .../graphs/content_processing/video.py | 4 +- .../graphs/content_processing/youtube.py | 4 +- open_notebook/graphs/source.py | 106 ++++++++++++++++++ pages/stream_app/source.py | 53 ++++----- 12 files changed, 167 insertions(+), 56 deletions(-) create mode 100644 open_notebook/domain/transformation.py create mode 100644 open_notebook/graphs/source.py diff --git a/open_notebook/domain/transformation.py b/open_notebook/domain/transformation.py new file mode 100644 index 0000000..6662365 --- /dev/null +++ b/open_notebook/domain/transformation.py @@ -0,0 +1,19 @@ +from typing import ClassVar, List, Optional + +import yaml +from pydantic import Field + +from open_notebook.domain.base import RecordModel + + +class Transformation: + @classmethod + def get_all(cls): + with open("transformations.yaml", "r") as file: + transformations = yaml.safe_load(file) + return transformations + + +class DefaultTransformations(RecordModel): + record_id: ClassVar[str] = "open_notebook:default_transformations" + source_insights: Optional[List[str]] = Field(default_factory=list) diff --git a/open_notebook/graphs/content_processing/__init__.py b/open_notebook/graphs/content_processing/__init__.py index 915da23..270bb4f 100644 --- a/open_notebook/graphs/content_processing/__init__.py +++ b/open_notebook/graphs/content_processing/__init__.py @@ -14,14 +14,14 @@ from open_notebook.graphs.content_processing.pdf import ( SUPPORTED_FITZ_TYPES, extract_pdf, ) -from open_notebook.graphs.content_processing.state import SourceState +from open_notebook.graphs.content_processing.state import ContentState from open_notebook.graphs.content_processing.text import extract_txt from open_notebook.graphs.content_processing.url import extract_url, url_provider from open_notebook.graphs.content_processing.video import extract_best_audio_from_video from open_notebook.graphs.content_processing.youtube import extract_youtube_transcript -def source_identification(state: SourceState): +def source_identification(state: ContentState): """ Identify the content source based on parameters """ @@ -37,7 +37,7 @@ def source_identification(state: SourceState): return {"source_type": doc_type} -def file_type(state: SourceState): +def file_type(state: ContentState): """ Identify the file using python-magic """ @@ -45,10 +45,11 @@ def file_type(state: SourceState): file_path = state.get("file_path") if file_path is not None: return_dict["identified_type"] = magic.from_file(file_path, mime=True) + return_dict["title"] = os.path.basename(file_path) return return_dict -def file_type_edge(data: SourceState): +def file_type_edge(data: ContentState): assert data.get("identified_type"), "Type not identified" identified_type = data["identified_type"] @@ -68,7 +69,7 @@ def file_type_edge(data: SourceState): ) -def delete_file(data: SourceState): +def delete_file(data: ContentState): if data.get("delete_source"): logger.debug(f"Deleting file: {data.get('file_path')}") file_path = data.get("file_path") @@ -82,7 +83,7 @@ def delete_file(data: SourceState): logger.debug("Not deleting file") -workflow = StateGraph(SourceState) +workflow = StateGraph(ContentState) workflow.add_node("source", source_identification) workflow.add_node("url_provider", url_provider) workflow.add_node("file_type", file_type) diff --git a/open_notebook/graphs/content_processing/audio.py b/open_notebook/graphs/content_processing/audio.py index 3f99277..b3d7617 100644 --- a/open_notebook/graphs/content_processing/audio.py +++ b/open_notebook/graphs/content_processing/audio.py @@ -5,7 +5,7 @@ from loguru import logger from pydub import AudioSegment from open_notebook.domain.models import model_manager -from open_notebook.graphs.content_processing.state import SourceState +from open_notebook.graphs.content_processing.state import ContentState # todo: remove reference to model_manager # future: parallelize the transcription process @@ -72,7 +72,7 @@ def split_audio(input_file, segment_length_minutes=15, output_prefix=None): return output_files -def extract_audio(data: SourceState): +def extract_audio(data: ContentState): SPEECH_TO_TEXT_MODEL = model_manager.speech_to_text input_audio_path = data.get("file_path") diff --git a/open_notebook/graphs/content_processing/office.py b/open_notebook/graphs/content_processing/office.py index 4736d8d..f7403a0 100644 --- a/open_notebook/graphs/content_processing/office.py +++ b/open_notebook/graphs/content_processing/office.py @@ -3,7 +3,7 @@ from loguru import logger from openpyxl import load_workbook from pptx import Presentation -from open_notebook.graphs.content_processing.state import SourceState +from open_notebook.graphs.content_processing.state import ContentState SUPPORTED_OFFICE_TYPES = [ "application/vnd.openxmlformats-officedocument.wordprocessingml.document", @@ -251,7 +251,7 @@ def get_xlsx_info(file_path): return None -def extract_office_content(state: SourceState): +def extract_office_content(state: ContentState): """Universal function to extract content from Office files""" assert state.get("file_path"), "No file path provided" assert ( diff --git a/open_notebook/graphs/content_processing/pdf.py b/open_notebook/graphs/content_processing/pdf.py index e842a67..610ee58 100644 --- a/open_notebook/graphs/content_processing/pdf.py +++ b/open_notebook/graphs/content_processing/pdf.py @@ -4,7 +4,7 @@ import unicodedata import fitz # type: ignore from loguru import logger -from open_notebook.graphs.content_processing.state import SourceState +from open_notebook.graphs.content_processing.state import ContentState # todo: find tables - https://pymupdf.readthedocs.io/en/latest/the-basics.html#extracting-tables-from-a-page # todo: what else can we do to make the text more readable? @@ -127,7 +127,7 @@ def _extract_text_from_pdf(pdf_path): doc.close() -def extract_pdf(state: SourceState): +def extract_pdf(state: ContentState): """ Parse the text file and print its content. """ diff --git a/open_notebook/graphs/content_processing/state.py b/open_notebook/graphs/content_processing/state.py index 37bffbf..586ee45 100644 --- a/open_notebook/graphs/content_processing/state.py +++ b/open_notebook/graphs/content_processing/state.py @@ -1,7 +1,7 @@ from typing_extensions import TypedDict -class SourceState(TypedDict): +class ContentState(TypedDict): content: str file_path: str url: str diff --git a/open_notebook/graphs/content_processing/text.py b/open_notebook/graphs/content_processing/text.py index e286e0f..b81ca6c 100644 --- a/open_notebook/graphs/content_processing/text.py +++ b/open_notebook/graphs/content_processing/text.py @@ -1,9 +1,9 @@ from loguru import logger -from open_notebook.graphs.content_processing.state import SourceState +from open_notebook.graphs.content_processing.state import ContentState -def extract_txt(state: SourceState): +def extract_txt(state: ContentState): """ Parse the text file and print its content. """ diff --git a/open_notebook/graphs/content_processing/url.py b/open_notebook/graphs/content_processing/url.py index 05a00fd..c06efbc 100644 --- a/open_notebook/graphs/content_processing/url.py +++ b/open_notebook/graphs/content_processing/url.py @@ -5,14 +5,14 @@ import requests # type: ignore from bs4 import BeautifulSoup, Comment from loguru import logger -from open_notebook.graphs.content_processing.state import SourceState +from open_notebook.graphs.content_processing.state import ContentState # future: better extraction methods # https://github.com/buriy/python-readability # also try readability: from readability import Document -def url_provider(state: SourceState): +def url_provider(state: ContentState): """ Identify the provider """ @@ -173,7 +173,7 @@ def extract_url_jina(url: str): return {"content": text} -def extract_url(state: SourceState): +def extract_url(state: ContentState): assert state.get("url"), "No URL provided" url = state["url"] try: diff --git a/open_notebook/graphs/content_processing/video.py b/open_notebook/graphs/content_processing/video.py index acd23e4..c48e540 100644 --- a/open_notebook/graphs/content_processing/video.py +++ b/open_notebook/graphs/content_processing/video.py @@ -4,7 +4,7 @@ import subprocess from loguru import logger -from open_notebook.graphs.content_processing.state import SourceState +from open_notebook.graphs.content_processing.state import ContentState def extract_audio_from_video(input_file, output_file, stream_index): @@ -102,7 +102,7 @@ def select_best_audio_stream(streams): return max(scored_streams, key=lambda x: x[0])[1] -def extract_best_audio_from_video(data: SourceState): +def extract_best_audio_from_video(data: ContentState): """ Main function to extract the best audio stream from a video file """ diff --git a/open_notebook/graphs/content_processing/youtube.py b/open_notebook/graphs/content_processing/youtube.py index 8e73c51..1e85192 100644 --- a/open_notebook/graphs/content_processing/youtube.py +++ b/open_notebook/graphs/content_processing/youtube.py @@ -9,7 +9,7 @@ from youtube_transcript_api.formatters import TextFormatter # type: ignore from open_notebook.config import CONFIG from open_notebook.exceptions import NoTranscriptFound -from open_notebook.graphs.content_processing.state import SourceState +from open_notebook.graphs.content_processing.state import ContentState ssl._create_default_https_context = ssl._create_unverified_context @@ -129,7 +129,7 @@ def get_best_transcript(video_id, preferred_langs=["en", "es", "pt"]): return None -def extract_youtube_transcript(state: SourceState): +def extract_youtube_transcript(state: ContentState): """ Parse the text file and print its content. """ diff --git a/open_notebook/graphs/source.py b/open_notebook/graphs/source.py new file mode 100644 index 0000000..39813ab --- /dev/null +++ b/open_notebook/graphs/source.py @@ -0,0 +1,106 @@ +import operator +from typing import List + +from langchain_core.runnables import ( + RunnableConfig, +) +from langgraph.graph import END, START, StateGraph +from langgraph.types import Send +from loguru import logger +from typing_extensions import Annotated, TypedDict + +from open_notebook.domain.notebook import Asset, Source +from open_notebook.domain.transformation import Transformation +from open_notebook.graphs.content_processing import ContentState +from open_notebook.graphs.content_processing import graph as content_graph +from open_notebook.graphs.multipattern import graph as transform_graph +from open_notebook.utils import surreal_clean + +# todo: we can make this more efficient + + +class SourceState(TypedDict): + content_state: ContentState + transformations: List[str] + notebook_id: str + source: Source + transformations: Annotated[list, operator.add] + + +class TransformationState(TypedDict): + source: Source + transformation: dict + + +def content_process(state: SourceState): + content_state = state["content_state"] + logger.debug("Content processing started for new content") + return {"content_state": content_graph.invoke(content_state)} + + +def run_patterns(input_text, patterns): + output = transform_graph.invoke(dict(content_stack=[input_text], patterns=patterns)) + return output["output"] + + +def save_source(state: SourceState): + logger.debug("Saving source") + content_state = state["content_state"] + source = Source( + asset=Asset( + url=content_state.get("url"), file_path=content_state.get("file_path") + ), + full_text=surreal_clean(content_state["content"]), + title=content_state.get("title"), + ) + source.save() + + if state["notebook_id"]: + logger.debug(f"Adding source to notebook {state['notebook_id']}") + source.add_to_notebook(state["notebook_id"]) + return {"source": source} + + +def trigger_transformations(state: SourceState, config: RunnableConfig): + if len(state["transformations"]) == 0: + return [] + transformations = Transformation.get_all() + to_apply = [ + t + for t in transformations["source_insights"] + if t["name"] in state["transformations"] + ] + logger.debug(f"Applying transformations {to_apply}") + return [ + Send( + "transform_content", + { + "source": state["source"], + "transformation": t, + }, + ) + for t in to_apply + ] + + +def transform_content(state: TransformationState): + source = state["source"] + content = source.full_text + transformation = state["transformation"] + logger.debug(f"Applying transformation {transformation['name']}") + result = run_patterns(content, patterns=transformation["patterns"]) + source.add_insight(transformation["name"], surreal_clean(result)) + return {"transformations": [{"name": transformation["name"], "content": result}]} + + +workflow = StateGraph(SourceState) +workflow.add_node("content_process", content_process) +workflow.add_node("save_source", save_source) +workflow.add_node("transform_content", transform_content) +workflow.add_edge(START, "content_process") +workflow.add_edge("content_process", "save_source") +workflow.add_conditional_edges( + "save_source", trigger_transformations, ["transform_content"] +) +workflow.add_edge("transform_content", END) +source_graph = workflow.compile() diff --git a/pages/stream_app/source.py b/pages/stream_app/source.py index 629e2d8..6114bd7 100644 --- a/pages/stream_app/source.py +++ b/pages/stream_app/source.py @@ -6,36 +6,15 @@ from humanize import naturaltime from loguru import logger from open_notebook.config import UPLOADS_FOLDER -from open_notebook.domain.notebook import Asset, Source +from open_notebook.domain.notebook import Source +from open_notebook.domain.transformation import DefaultTransformations, Transformation from open_notebook.exceptions import UnsupportedTypeException -from open_notebook.graphs.content_processing import graph -from open_notebook.utils import surreal_clean +from open_notebook.graphs.source import source_graph from pages.components import source_panel -from pages.stream_app.utils import run_patterns from .consts import context_icons -# moved it here to replace it with the pipeline on 0.1.0 -def generate_toc_and_title(source) -> "Source": - try: - patterns = ["patterns/default/toc"] - result = run_patterns(source.full_text, patterns=patterns) - source.add_insight("Table of Contents", surreal_clean(result)) - if not source.title: - patterns = [ - "Based on the Table of Contents below, please provide a Title for this content, with max 15 words" - ] - output = run_patterns(result, patterns=patterns) - source.title = surreal_clean(output) - source.save() - return source - except Exception as e: - logger.error(f"Error summarizing source {source.id}: {str(e)}") - logger.exception(e) - raise - - @st.dialog("Source", width="large") def source_panel_dialog(source_id): source_panel(source_id, modal=True) @@ -48,6 +27,7 @@ def add_source(notebook_id): source_text = None source_type = st.radio("Type", ["Link", "Upload", "Text"]) req = {} + transformations = Transformation.get_all() if source_type == "Link": source_link = st.text_input("Link") req["url"] = source_link @@ -58,6 +38,14 @@ def add_source(notebook_id): else: source_text = st.text_area("Text") req["content"] = source_text + + default_transformations = [t for t in DefaultTransformations().source_insights] + available_transformations = [t["name"] for t in transformations["source_insights"]] + apply_transformations = st.multiselect( + "Apply transformations", + options=available_transformations, + default=default_transformations, + ) if st.button("Process", key="add_source"): logger.debug("Adding source") with st.status("Processing...", expanded=True): @@ -82,17 +70,14 @@ def add_source(notebook_id): with open(new_path, "wb") as f: f.write(source_file.getbuffer()) - result = graph.invoke(req) - st.write("Saving..") - source = Source( - asset=Asset(url=req.get("url"), file_path=req.get("file_path")), - full_text=surreal_clean(result["content"]), - title=result.get("title"), + st.write("Processing content..") + source_graph.invoke( + { + "content_state": req, + "notebook_id": notebook_id, + "transformations": apply_transformations, + } ) - source.save() - source.add_to_notebook(notebook_id) - st.write("Summarizing...") - generate_toc_and_title(source) except UnsupportedTypeException as e: st.warning( "This type of content is not supported yet. If you think it should be, let us know on the project Issues's page"