Merge pull request #11 from lfnovo/audiovideo
Audio, Video and many fixes
This commit is contained in:
commit
fa325ed594
25 changed files with 553 additions and 81 deletions
|
|
@ -9,4 +9,7 @@ sqlite-db/
|
|||
temp/
|
||||
google-credentials.json
|
||||
docker-compose*
|
||||
.docker_data
|
||||
.docker_data/
|
||||
docs/
|
||||
surreal-data/
|
||||
temp/
|
||||
|
|
|
|||
19
.env.example
19
.env.example
|
|
@ -2,7 +2,6 @@
|
|||
# DEFAULT MODEL_CONFIGURATIONS
|
||||
DEFAULT_MODEL="openai/gpt-4o-mini"
|
||||
SUMMARIZATION_MODEL="openai/gpt-4o-mini"
|
||||
RETRIEVAL_MODEL="openai/gpt-4o-mini"
|
||||
|
||||
# OPENAI
|
||||
# USE MODEL NAMES AS "openai/<modelname>"
|
||||
|
|
@ -12,30 +11,29 @@ OPENAI_API_KEY=
|
|||
# ANTHROPIC
|
||||
# USE MODEL NAMES AS "anthropic/<modelname>"
|
||||
# EXAMPLE - anthropic/claude-3-5-sonnet-20240620
|
||||
ANTHROPIC_API_KEY=
|
||||
|
||||
# ANTHROPIC_API_KEY=
|
||||
|
||||
# GEMINI
|
||||
# USE MODEL NAMES AS "gemini/<modelname>"
|
||||
# EXAMPLE - gemini/gemini-1.5-pro-002
|
||||
GEMINI_API_KEY=
|
||||
# GEMINI_API_KEY=
|
||||
|
||||
# VERTEXAI
|
||||
# USE MODEL NAMES AS "vertexai/<modelname>"
|
||||
# EXAMPLE - vertexai/gemini-1.5-pro-002
|
||||
VERTEX_PROJECT=my-google-cloud-project-name
|
||||
GOOGLE_APPLICATION_CREDENTIALS=./google-credentials.json
|
||||
# VERTEX_PROJECT=my-google-cloud-project-name
|
||||
# GOOGLE_APPLICATION_CREDENTIALS=./google-credentials.json
|
||||
|
||||
# OLLAMA
|
||||
# USE MODEL NAMES AS "ollama/<modelname>"
|
||||
# EXAMPLE - ollama/gemma2
|
||||
OLLAMA_API_BASE="http://10.20.30.20:11434"
|
||||
# OLLAMA_API_BASE="http://10.20.30.20:11434"
|
||||
|
||||
# OPEN ROUTER
|
||||
# USE MODEL NAMES AS "openrouter/<modelname>"
|
||||
# EXAMPLE - openrouter/nvidia/llama-3.1-nemotron-70b-instruct
|
||||
OPENROUTER_BASE_URL="https://openrouter.ai/api/v1"
|
||||
OPENROUTER_API_KEY=
|
||||
# OPENROUTER_BASE_URL="https://openrouter.ai/api/v1"
|
||||
# OPENROUTER_API_KEY=
|
||||
|
||||
# ELEVENLABS
|
||||
# Used only by the podcast feature
|
||||
|
|
@ -49,7 +47,8 @@ ELEVENLABS_API_KEY=
|
|||
# LANGCHAIN_PROJECT="Open Notebook"
|
||||
|
||||
# CONNECTION DETAILS FOR YOUR SURREAL DB
|
||||
SURREAL_ADDRESS="localhost"
|
||||
# Use surrealdb if using docker-compose or add your server ip if using a different setup
|
||||
SURREAL_ADDRESS="surrealdb"
|
||||
SURREAL_PORT=8000
|
||||
SURREAL_USER="root"
|
||||
SURREAL_PASS="root"
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@ WORKDIR /app
|
|||
|
||||
EXPOSE 8502
|
||||
|
||||
RUN mkdir -p /app/sqlite-db
|
||||
RUN mkdir -p /app/data
|
||||
|
||||
CMD ["poetry", "run", "streamlit", "run", "app_home.py"]
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ services:
|
|||
user: root
|
||||
|
||||
open_notebook:
|
||||
image: lfnovo/open_notebook:latest
|
||||
image: lfnovo/open-notebook:latest
|
||||
ports:
|
||||
- "8502:8502"
|
||||
env_file:
|
||||
|
|
|
|||
|
|
@ -1,10 +1,12 @@
|
|||
import streamlit as st
|
||||
|
||||
from open_notebook.exceptions import InvalidDatabaseSchema, NoSchemaFound
|
||||
from open_notebook.repository import check_version, execute_migration
|
||||
from open_notebook.repository import check_database_version, execute_migration
|
||||
from stream_app.utils import version_sidebar
|
||||
|
||||
try:
|
||||
check_version()
|
||||
version_sidebar()
|
||||
check_database_version()
|
||||
st.switch_page("pages/2_📒_Notebooks.py")
|
||||
except NoSchemaFound as e:
|
||||
st.warning(e)
|
||||
|
|
|
|||
|
|
@ -12,7 +12,7 @@ services:
|
|||
user: root
|
||||
|
||||
open_notebook:
|
||||
image: lfnovo/open_notebook:latest
|
||||
image: lfnovo/open-notebook:latest
|
||||
ports:
|
||||
- "8080:8502"
|
||||
env_file:
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ services:
|
|||
user: root
|
||||
|
||||
open_notebook:
|
||||
image: lfnovo/open_notebook:latest
|
||||
image: lfnovo/open-notebook:latest
|
||||
ports:
|
||||
- "8080:8502"
|
||||
env_file:
|
||||
|
|
@ -58,7 +58,7 @@ services:
|
|||
user: root
|
||||
|
||||
open_notebook:
|
||||
image: lfnovo/open_notebook:latest
|
||||
image: lfnovo/open-notebook:latest
|
||||
ports:
|
||||
- "8080:8502"
|
||||
environment:
|
||||
|
|
@ -158,3 +158,13 @@ After the app is running, you can access it at http://localhost:8080.
|
|||
The first time you connect, it will check for the database and see if the schema is ready. If not, it will create the database for you.
|
||||
|
||||
Go to the [Usage](USAGE.md) page to learn how to use all features.
|
||||
|
||||
## Upgrading Open Notebook
|
||||
|
||||
### Running from source
|
||||
|
||||
Just run `git pull` on the root project folder and then `poetry install` to update dependencies.
|
||||
|
||||
### Running from docker
|
||||
|
||||
Just pull the latest image with `docker pull lfnovo/open-notebook:latest` and restart your containers with `docker-compose up -d`
|
||||
|
|
@ -15,3 +15,20 @@ except Exception:
|
|||
logger.critical("Config file not found, using empty defaults")
|
||||
logger.debug(f"Looked in {config_path}")
|
||||
CONFIG = {}
|
||||
|
||||
# ROOT DATA FOLDER
|
||||
# todo: make this configurable once podcastfy supports it
|
||||
DATA_FOLDER = "./data"
|
||||
|
||||
# LANGGRAPH CHECKPOINT FILE
|
||||
sqlite_folder = f"{DATA_FOLDER}/sqlite-db"
|
||||
os.makedirs(sqlite_folder, exist_ok=True)
|
||||
LANGGRAPH_CHECKPOINT_FILE = f"{sqlite_folder}/checkpoints.sqlite"
|
||||
|
||||
# UPLOADS FOLDER
|
||||
UPLOADS_FOLDER = f"{DATA_FOLDER}/uploads"
|
||||
os.makedirs(UPLOADS_FOLDER, exist_ok=True)
|
||||
|
||||
# PODCASTS FOLDER
|
||||
PODCASTS_FOLDER = f"{DATA_FOLDER}/podcasts"
|
||||
os.makedirs(PODCASTS_FOLDER, exist_ok=True)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from typing import Any, ClassVar, Dict, List, Literal, Optional, Type, TypeVar
|
|||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, Field, ValidationError, field_validator
|
||||
|
||||
from open_notebook.exceptions import (
|
||||
DatabaseOperationError,
|
||||
|
|
@ -35,7 +35,13 @@ class ObjectModel(BaseModel):
|
|||
def get_all(cls: Type[T]) -> List[T]:
|
||||
try:
|
||||
result = repo_query(f"SELECT * FROM {cls.table_name}")
|
||||
objects = [cls(**obj) for obj in result]
|
||||
objects = []
|
||||
for obj in result:
|
||||
try:
|
||||
objects.append(cls(**obj))
|
||||
except Exception as e:
|
||||
logger.critical(f"Error creating object: {str(e)}")
|
||||
|
||||
return objects
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching all {cls.table_name}: {str(e)}")
|
||||
|
|
@ -64,6 +70,8 @@ class ObjectModel(BaseModel):
|
|||
|
||||
def save(self) -> None:
|
||||
try:
|
||||
logger.debug(f"Validating {self.__class__.__name__}")
|
||||
self.model_validate(self.model_dump(), strict=True)
|
||||
data = self._prepare_save_data()
|
||||
|
||||
if self.needs_embedding():
|
||||
|
|
@ -86,6 +94,13 @@ class ObjectModel(BaseModel):
|
|||
else:
|
||||
setattr(self, key, value)
|
||||
|
||||
except ValidationError as e:
|
||||
logger.error(f"Validation failed: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving record: {e}")
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving {self.__class__.table_name}: {str(e)}")
|
||||
logger.exception(e)
|
||||
|
|
@ -121,6 +136,13 @@ class ObjectModel(BaseModel):
|
|||
logger.exception(e)
|
||||
raise DatabaseOperationError(e)
|
||||
|
||||
@field_validator("created", "updated", mode="before")
|
||||
@classmethod
|
||||
def parse_datetime(cls, value):
|
||||
if isinstance(value, str):
|
||||
return datetime.fromisoformat(value.replace("Z", "+00:00"))
|
||||
return value
|
||||
|
||||
|
||||
class Notebook(ObjectModel):
|
||||
table_name: ClassVar[str] = "notebook"
|
||||
|
|
@ -139,7 +161,7 @@ class Notebook(ObjectModel):
|
|||
def sources(self) -> List["Source"]:
|
||||
try:
|
||||
srcs = repo_query(f"""
|
||||
select * from (
|
||||
select * OMIT full_text from (
|
||||
select
|
||||
<- source as source
|
||||
from reference
|
||||
|
|
@ -158,7 +180,7 @@ class Notebook(ObjectModel):
|
|||
def notes(self) -> List["Note"]:
|
||||
try:
|
||||
srcs = repo_query(f"""
|
||||
select * from (
|
||||
select * OMIT content from (
|
||||
select
|
||||
<- note as note
|
||||
from artifact
|
||||
|
|
@ -322,7 +344,6 @@ class Source(ObjectModel):
|
|||
output = pattern_graph.invoke(
|
||||
dict(content_stack=[result["toc"]], transformations=transformations)
|
||||
)
|
||||
logger.warning(output["output"])
|
||||
self.title = surreal_clean(output["output"])
|
||||
self.save()
|
||||
return self
|
||||
|
|
|
|||
|
|
@ -10,6 +10,12 @@ class DatabaseOperationError(OpenNotebookError):
|
|||
pass
|
||||
|
||||
|
||||
class UnsupportedTypeException(OpenNotebookError):
|
||||
"""Raised when an unsupported type is provided."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class NoSchemaFound(OpenNotebookError):
|
||||
"""Raised when a database schema is not found."""
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
import os
|
||||
import sqlite3
|
||||
from typing import Annotated, Optional
|
||||
|
||||
|
|
@ -10,6 +9,7 @@ from langgraph.graph import END, START, StateGraph
|
|||
from langgraph.graph.message import add_messages
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from open_notebook.config import LANGGRAPH_CHECKPOINT_FILE
|
||||
from open_notebook.domain import Notebook
|
||||
from open_notebook.graphs.utils import run_pattern
|
||||
|
||||
|
|
@ -33,7 +33,7 @@ def call_model_with_messages(state: ThreadState, config: RunnableConfig) -> dict
|
|||
|
||||
|
||||
conn = sqlite3.connect(
|
||||
os.environ.get("CHECKPOINT_DATA_PATH", "sqlite-db/checkpoints.sqlite"),
|
||||
LANGGRAPH_CHECKPOINT_FILE,
|
||||
check_same_thread=False,
|
||||
)
|
||||
memory = SqliteSaver(conn)
|
||||
|
|
|
|||
|
|
@ -1,16 +1,22 @@
|
|||
import json
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import unicodedata
|
||||
from math import ceil
|
||||
|
||||
import fitz # type: ignore
|
||||
import magic
|
||||
import requests # type: ignore
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from loguru import logger
|
||||
from pydub import AudioSegment
|
||||
from typing_extensions import TypedDict
|
||||
from youtube_transcript_api import YouTubeTranscriptApi # type: ignore
|
||||
from youtube_transcript_api.formatters import TextFormatter # type: ignore
|
||||
|
||||
from open_notebook.config import CONFIG
|
||||
from open_notebook.exceptions import UnsupportedTypeException
|
||||
|
||||
|
||||
class SourceState(TypedDict):
|
||||
|
|
@ -291,6 +297,228 @@ def should_continue(data: SourceState):
|
|||
return "end"
|
||||
|
||||
|
||||
def split_audio(input_file, segment_length_minutes=15, output_prefix=None):
|
||||
"""
|
||||
Split an audio file into segments of specified length.
|
||||
|
||||
Args:
|
||||
input_file (str): Path to the input audio file
|
||||
segment_length_minutes (int): Length of each segment in minutes
|
||||
output_dir (str): Directory to save the segments (defaults to input file's directory)
|
||||
output_prefix (str): Prefix for output files (defaults to input filename)
|
||||
|
||||
Returns:
|
||||
list: List of paths to the created segment files
|
||||
"""
|
||||
# Convert input file to absolute path
|
||||
input_file = os.path.abspath(input_file)
|
||||
|
||||
output_dir = os.path.dirname(input_file)
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
# Set up output prefix
|
||||
if output_prefix is None:
|
||||
output_prefix = os.path.splitext(os.path.basename(input_file))[0]
|
||||
|
||||
# Load the audio file
|
||||
audio = AudioSegment.from_file(input_file)
|
||||
|
||||
# Calculate segment length in milliseconds
|
||||
segment_length_ms = segment_length_minutes * 60 * 1000
|
||||
|
||||
# Calculate number of segments
|
||||
total_segments = ceil(len(audio) / segment_length_ms)
|
||||
|
||||
# List to store output file paths
|
||||
output_files = []
|
||||
|
||||
# Split the audio into segments
|
||||
for i in range(total_segments):
|
||||
# Calculate start and end times for this segment
|
||||
start_time = i * segment_length_ms
|
||||
end_time = min((i + 1) * segment_length_ms, len(audio))
|
||||
|
||||
# Extract segment
|
||||
segment = audio[start_time:end_time]
|
||||
|
||||
# Generate output filename
|
||||
# Format: prefix_001.mp3 (padding with zeros ensures correct ordering)
|
||||
output_filename = f"{output_prefix}_{str(i+1).zfill(3)}.mp3"
|
||||
output_path = os.path.join(output_dir, output_filename)
|
||||
|
||||
# Export segment
|
||||
segment.export(output_path, format="mp3")
|
||||
|
||||
output_files.append(output_path)
|
||||
|
||||
# Optional progress indication
|
||||
print(f"Exported segment {i+1}/{total_segments}: {output_filename}")
|
||||
|
||||
return output_files
|
||||
|
||||
|
||||
# todo: add a speechtotext model to the config
|
||||
def extract_audio(data: SourceState):
|
||||
input_audio_path = data.get("file_path")
|
||||
from openai import OpenAI
|
||||
|
||||
client = OpenAI()
|
||||
|
||||
audio_files = split_audio(input_audio_path)
|
||||
transcriptions = []
|
||||
for audio_file in audio_files:
|
||||
audio_file = open(audio_file, "rb")
|
||||
transcription = client.audio.transcriptions.create(
|
||||
model="whisper-1", file=audio_file
|
||||
)
|
||||
transcriptions.append(transcription.text)
|
||||
return {"content": " ".join(transcriptions)}
|
||||
|
||||
|
||||
def get_audio_streams(input_file):
|
||||
"""
|
||||
Analyze video file and return information about all audio streams
|
||||
"""
|
||||
try:
|
||||
# Get stream information in JSON format
|
||||
cmd = [
|
||||
"ffprobe",
|
||||
"-v",
|
||||
"quiet",
|
||||
"-print_format",
|
||||
"json",
|
||||
"-show_streams",
|
||||
"-select_streams",
|
||||
"a",
|
||||
input_file,
|
||||
]
|
||||
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
raise Exception(f"FFprobe failed: {result.stderr}")
|
||||
|
||||
data = json.loads(result.stdout)
|
||||
return data.get("streams", [])
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error analyzing file: {str(e)}")
|
||||
return []
|
||||
|
||||
|
||||
def select_best_audio_stream(streams):
|
||||
"""
|
||||
Select the best audio stream based on various quality metrics
|
||||
"""
|
||||
if not streams:
|
||||
return None
|
||||
|
||||
# Score each stream based on various factors
|
||||
scored_streams = []
|
||||
for stream in streams:
|
||||
score = 0
|
||||
|
||||
# Prefer higher bit rates
|
||||
bit_rate = stream.get("bit_rate")
|
||||
if bit_rate:
|
||||
score += int(bit_rate) / 1000000 # Convert to Mbps
|
||||
|
||||
# Prefer more channels (stereo over mono)
|
||||
channels = stream.get("channels", 0)
|
||||
score += channels * 10
|
||||
|
||||
# Prefer higher sample rates
|
||||
sample_rate = stream.get("sample_rate", "0")
|
||||
score += int(sample_rate) / 48000
|
||||
|
||||
scored_streams.append((score, stream))
|
||||
|
||||
# Return the stream with highest score
|
||||
return max(scored_streams, key=lambda x: x[0])[1]
|
||||
|
||||
|
||||
def extract_audio_from_video(input_file, output_file, stream_index):
|
||||
"""
|
||||
Extract the specified audio stream to MP3 format
|
||||
"""
|
||||
try:
|
||||
cmd = [
|
||||
"ffmpeg",
|
||||
"-i",
|
||||
input_file,
|
||||
"-map",
|
||||
f"0:a:{stream_index}", # Select specific audio stream
|
||||
"-codec:a",
|
||||
"libmp3lame", # Use MP3 codec
|
||||
"-q:a",
|
||||
"2", # High quality setting
|
||||
"-y", # Overwrite output file if exists
|
||||
output_file,
|
||||
]
|
||||
|
||||
result = subprocess.run(cmd, capture_output=True, text=True)
|
||||
if result.returncode != 0:
|
||||
raise Exception(f"FFmpeg failed: {result.stderr}")
|
||||
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error extracting audio: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
def extract_best_audio_from_video(data: SourceState):
|
||||
"""
|
||||
Main function to extract the best audio stream from a video file
|
||||
"""
|
||||
input_file = data.get("file_path")
|
||||
if not os.path.exists(input_file):
|
||||
print(f"Input file not found: {input_file}")
|
||||
return False
|
||||
|
||||
base_name = os.path.splitext(input_file)[0]
|
||||
output_file = f"{base_name}_audio.mp3"
|
||||
|
||||
# Get all audio streams
|
||||
streams = get_audio_streams(input_file)
|
||||
if not streams:
|
||||
print("No audio streams found in the file")
|
||||
return False
|
||||
|
||||
# Select best stream
|
||||
best_stream = select_best_audio_stream(streams)
|
||||
if not best_stream:
|
||||
print("Could not determine best audio stream")
|
||||
return False
|
||||
|
||||
# Extract the selected stream
|
||||
stream_index = streams.index(best_stream)
|
||||
success = extract_audio_from_video(input_file, output_file, stream_index)
|
||||
|
||||
if success:
|
||||
print(f"Successfully extracted audio to: {output_file}")
|
||||
print("Selected stream details:")
|
||||
print(f"- Channels: {best_stream.get('channels', 'unknown')}")
|
||||
print(f"- Sample rate: {best_stream.get('sample_rate', 'unknown')} Hz")
|
||||
print(f"- Bit rate: {best_stream.get('bit_rate', 'unknown')} bits/s")
|
||||
|
||||
return {"file_path": output_file, "identified_type": "audio/mp3"}
|
||||
|
||||
|
||||
def file_type_edge(data: SourceState):
|
||||
if data.get("identified_type") == "text/plain":
|
||||
return "extract_txt"
|
||||
elif data.get("identified_type") == "application/pdf":
|
||||
return "extract_pdf"
|
||||
elif data.get("identified_type").startswith("video"):
|
||||
return "extract_best_audio_from_video"
|
||||
elif data.get("identified_type").startswith("audio"):
|
||||
return "extract_audio"
|
||||
else:
|
||||
raise UnsupportedTypeException(
|
||||
f"Unsupported file type: {data.get('identified_type')}"
|
||||
)
|
||||
|
||||
|
||||
workflow = StateGraph(SourceState)
|
||||
workflow.add_node("source", source_identification)
|
||||
workflow.add_node("url_provider", url_provider)
|
||||
|
|
@ -298,6 +526,8 @@ workflow.add_node("file_type", file_type)
|
|||
workflow.add_node("extract_txt", extract_txt)
|
||||
workflow.add_node("extract_pdf", extract_pdf)
|
||||
workflow.add_node("extract_url", extract_url)
|
||||
workflow.add_node("extract_best_audio_from_video", extract_best_audio_from_video)
|
||||
workflow.add_node("extract_audio", extract_audio)
|
||||
workflow.add_node("extract_youtube_transcript", extract_youtube_transcript)
|
||||
|
||||
workflow.add_edge(START, "source")
|
||||
|
|
@ -312,11 +542,7 @@ workflow.add_conditional_edges(
|
|||
)
|
||||
workflow.add_conditional_edges(
|
||||
"file_type",
|
||||
lambda x: x.get("identified_type"),
|
||||
{
|
||||
"text/plain": "extract_txt",
|
||||
"application/pdf": "extract_pdf",
|
||||
},
|
||||
file_type_edge,
|
||||
)
|
||||
workflow.add_conditional_edges(
|
||||
"url_provider",
|
||||
|
|
@ -328,5 +554,7 @@ workflow.add_edge("file_type", END)
|
|||
workflow.add_edge("extract_txt", END)
|
||||
workflow.add_edge("extract_pdf", END)
|
||||
workflow.add_edge("extract_url", END)
|
||||
workflow.add_edge("extract_best_audio_from_video", "extract_audio")
|
||||
workflow.add_edge("extract_audio", END)
|
||||
workflow.add_edge("extract_youtube_transcript", END)
|
||||
graph = workflow.compile()
|
||||
|
|
|
|||
|
|
@ -33,7 +33,6 @@ def call_model(state: dict, config: RunnableConfig) -> dict:
|
|||
}
|
||||
current_transformation = "patterns/custom"
|
||||
|
||||
logger.warning(f"Processing transformation: {current_transformation}")
|
||||
logger.debug(f"Using input: {input_args}")
|
||||
transformation_result = run_pattern(
|
||||
pattern_name=current_transformation,
|
||||
|
|
|
|||
|
|
@ -32,10 +32,9 @@ def run_pattern(
|
|||
system_prompt = Prompter(prompt_template=pattern_name, parser=parser).render(
|
||||
data=state
|
||||
)
|
||||
# logger.debug(f"System prompt: {system_prompt}")
|
||||
logger.debug(f"System prompt: {system_prompt}")
|
||||
|
||||
if len(messages) > 0:
|
||||
logger.warning(messages)
|
||||
response = chain.invoke([system_prompt] + messages)
|
||||
else:
|
||||
response = chain.invoke(system_prompt)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from typing import ClassVar, List, Literal
|
||||
from datetime import datetime
|
||||
from typing import ClassVar, List, Literal, Optional
|
||||
|
||||
from loguru import logger
|
||||
from podcastfy.client import generate_podcast
|
||||
|
|
@ -27,13 +28,15 @@ class PodcastConfig(ObjectModel):
|
|||
conversation_style: List[str]
|
||||
engagement_technique: List[str]
|
||||
dialogue_structure: List[str]
|
||||
user_instructions: str
|
||||
wordcount: int = Field(gt=500, lt=10000)
|
||||
user_instructions: Optional[str] = None
|
||||
ending_message: Optional[str] = None
|
||||
wordcount: int = Field(ge=400, le=10000)
|
||||
creativity: float = Field(ge=0, le=1)
|
||||
provider: Literal["openai", "elevenlabs", "edge"] = Field(default="openai")
|
||||
voice1: str
|
||||
voice2: str
|
||||
voice1: Optional[str] = None
|
||||
voice2: Optional[str] = None
|
||||
model: str
|
||||
created: Optional[datetime] = Field(default_factory=datetime.now)
|
||||
|
||||
def generate_episode(self, episode_name, text, instructions=None):
|
||||
self.user_instructions = (
|
||||
|
|
@ -52,7 +55,7 @@ class PodcastConfig(ObjectModel):
|
|||
"engagement_techniques": self.engagement_technique,
|
||||
"creativity": self.creativity,
|
||||
"text_to_speech": {
|
||||
# "temp_audio_dir": "./data/audio/tmp",
|
||||
# "temp_audio_dir": f"{PODCASTS_FOLDER}/tmp",
|
||||
"ending_message": "Thank you for listening to this episode. Don't forget to subscribe to our podcast for more interesting conversations.",
|
||||
"default_tts_model": self.provider,
|
||||
self.provider: {
|
||||
|
|
@ -66,8 +69,6 @@ class PodcastConfig(ObjectModel):
|
|||
},
|
||||
}
|
||||
|
||||
logger.error(conversation_config)
|
||||
# conversation_config = {}
|
||||
logger.debug(
|
||||
f"Generating episode {episode_name} with config {conversation_config}"
|
||||
)
|
||||
|
|
@ -75,7 +76,6 @@ class PodcastConfig(ObjectModel):
|
|||
audio_file = generate_podcast(
|
||||
conversation_config=conversation_config, text=text, tts_model=self.provider
|
||||
)
|
||||
logger.warning(audio_file)
|
||||
episode = PodcastEpisode(
|
||||
name=episode_name,
|
||||
template=self.name,
|
||||
|
|
@ -85,10 +85,19 @@ class PodcastConfig(ObjectModel):
|
|||
)
|
||||
episode.save()
|
||||
|
||||
@field_validator(
|
||||
"name", "podcast_name", "podcast_tagline", "output_language", "model"
|
||||
)
|
||||
@classmethod
|
||||
def validate_required_strings(cls, value: str, field) -> str:
|
||||
if value is None or value.strip() == "":
|
||||
raise ValueError(f"{field.field_name} cannot be None or empty string")
|
||||
return value.strip()
|
||||
|
||||
@field_validator("wordcount")
|
||||
def validate_wordcount(cls, value):
|
||||
if not 500 <= value <= 6000:
|
||||
raise ValueError("Wordcount must be between 500 and 10000")
|
||||
if not 400 <= value <= 6000:
|
||||
raise ValueError("Wordcount must be between 400 and 10000")
|
||||
return value
|
||||
|
||||
@field_validator("creativity")
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ def db_connection():
|
|||
password=os.environ["SURREAL_PASS"],
|
||||
namespace=os.environ["SURREAL_NAMESPACE"],
|
||||
database=os.environ["SURREAL_DATABASE"],
|
||||
max_size=2.2**20,
|
||||
encrypted=False, # Set to True if using SSL
|
||||
)
|
||||
try:
|
||||
|
|
@ -38,7 +39,7 @@ def repo_query(query_str: str, vars: Optional[Dict[str, Any]] = None):
|
|||
raise
|
||||
|
||||
|
||||
def check_version():
|
||||
def check_database_version():
|
||||
try:
|
||||
result = repo_query("SELECT * FROM open_notebook:database_info;")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,13 @@
|
|||
import re
|
||||
import unicodedata
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
import tomli
|
||||
from langchain_text_splitters import CharacterTextSplitter
|
||||
from openai import OpenAI
|
||||
from packaging.version import parse as parse_version
|
||||
|
||||
client = OpenAI()
|
||||
|
||||
|
|
@ -107,3 +112,100 @@ def surreal_clean(text):
|
|||
text = text.replace(":", "\:", 1)
|
||||
|
||||
return text
|
||||
|
||||
|
||||
def get_version_from_github(repo_url: str, branch: str = "main") -> str:
|
||||
"""
|
||||
Fetch and parse the version from pyproject.toml in a public GitHub repository.
|
||||
|
||||
Args:
|
||||
repo_url (str): URL of the GitHub repository
|
||||
branch (str): Branch name to fetch from (defaults to "main")
|
||||
|
||||
Returns:
|
||||
str: Version string from pyproject.toml
|
||||
|
||||
Raises:
|
||||
ValueError: If the URL is not a valid GitHub repository URL
|
||||
requests.RequestException: If there's an error fetching the file
|
||||
KeyError: If version information is not found in pyproject.toml
|
||||
"""
|
||||
# Parse the GitHub URL
|
||||
parsed_url = urlparse(repo_url)
|
||||
if "github.com" not in parsed_url.netloc:
|
||||
raise ValueError("Not a GitHub URL")
|
||||
|
||||
# Extract owner and repo name from path
|
||||
path_parts = parsed_url.path.strip("/").split("/")
|
||||
if len(path_parts) < 2:
|
||||
raise ValueError("Invalid GitHub repository URL")
|
||||
|
||||
owner, repo = path_parts[0], path_parts[1]
|
||||
|
||||
# Construct raw content URL for pyproject.toml
|
||||
raw_url = (
|
||||
f"https://raw.githubusercontent.com/{owner}/{repo}/{branch}/pyproject.toml"
|
||||
)
|
||||
|
||||
# Fetch the file
|
||||
response = requests.get(raw_url)
|
||||
response.raise_for_status()
|
||||
|
||||
# Parse TOML content
|
||||
pyproject_data = tomli.loads(response.text)
|
||||
|
||||
# Try to find version in different possible locations
|
||||
try:
|
||||
# Check project.version first (poetry style)
|
||||
version = pyproject_data["tool"]["poetry"]["version"]
|
||||
except KeyError:
|
||||
try:
|
||||
# Check project.version (standard style)
|
||||
version = pyproject_data["project"]["version"]
|
||||
except KeyError:
|
||||
raise KeyError("Version not found in pyproject.toml")
|
||||
|
||||
return version
|
||||
|
||||
|
||||
def get_installed_version(package_name: str) -> str:
|
||||
"""
|
||||
Get the version of an installed package.
|
||||
|
||||
Args:
|
||||
package_name (str): Name of the installed package
|
||||
|
||||
Returns:
|
||||
str: Version string of the installed package
|
||||
|
||||
Raises:
|
||||
PackageNotFoundError: If the package is not installed
|
||||
"""
|
||||
try:
|
||||
return version(package_name)
|
||||
except PackageNotFoundError:
|
||||
raise PackageNotFoundError(f"Package '{package_name}' not found")
|
||||
|
||||
|
||||
def compare_versions(version1: str, version2: str) -> int:
|
||||
"""
|
||||
Compare two semantic versions.
|
||||
|
||||
Args:
|
||||
version1 (str): First version string
|
||||
version2 (str): Second version string
|
||||
|
||||
Returns:
|
||||
int: -1 if version1 < version2
|
||||
0 if version1 == version2
|
||||
1 if version1 > version2
|
||||
"""
|
||||
v1 = parse_version(version1)
|
||||
v2 = parse_version(version2)
|
||||
|
||||
if v1 < v2:
|
||||
return -1
|
||||
elif v1 > v2:
|
||||
return 1
|
||||
else:
|
||||
return 0
|
||||
|
|
|
|||
|
|
@ -5,13 +5,16 @@ from open_notebook.domain import Notebook
|
|||
from stream_app.chat import chat_sidebar
|
||||
from stream_app.note import add_note, note_card
|
||||
from stream_app.source import add_source, source_card
|
||||
from stream_app.utils import setup_stream_state
|
||||
from stream_app.utils import setup_stream_state, version_sidebar
|
||||
|
||||
st.set_page_config(
|
||||
layout="wide", page_title="📒 Open Notebook", initial_sidebar_state="expanded"
|
||||
)
|
||||
|
||||
|
||||
version_sidebar()
|
||||
|
||||
|
||||
def notebook_header(current_notebook):
|
||||
c1, c2, c3 = st.columns([8, 2, 2])
|
||||
c1.header(current_notebook.name)
|
||||
|
|
|
|||
|
|
@ -4,10 +4,12 @@ from open_notebook.domain import text_search, vector_search
|
|||
from open_notebook.utils import get_embedding
|
||||
from stream_app.note import note_list_item
|
||||
from stream_app.source import source_list_item
|
||||
from stream_app.utils import version_sidebar
|
||||
|
||||
st.set_page_config(
|
||||
layout="wide", page_title="🔍 Open Notebook", initial_sidebar_state="expanded"
|
||||
layout="wide", page_title="🔍 Search", initial_sidebar_state="expanded"
|
||||
)
|
||||
version_sidebar()
|
||||
|
||||
# search_tab, ask_tab = st.tabs(["Search", "Ask"])
|
||||
# notebooks = Notebook.get_all()
|
||||
|
|
|
|||
|
|
@ -9,6 +9,13 @@ from open_notebook.plugins.podcasts import (
|
|||
engagement_techniques,
|
||||
participant_roles,
|
||||
)
|
||||
from stream_app.utils import version_sidebar
|
||||
|
||||
st.set_page_config(
|
||||
layout="wide", page_title="🎙️ Podcasts", initial_sidebar_state="expanded"
|
||||
)
|
||||
|
||||
version_sidebar()
|
||||
|
||||
episodes_tab, templates_tab = st.tabs(["Episodes", "Templates"])
|
||||
|
||||
|
|
@ -66,6 +73,9 @@ with templates_tab:
|
|||
pd_cfg["creativity"] = st.slider(
|
||||
"Creativity", min_value=0.0, max_value=1.0, step=0.05
|
||||
)
|
||||
pd_cfg["ending_message"] = st.text_input(
|
||||
"Ending Message", placeholder="Thank you for listening!"
|
||||
)
|
||||
pd_cfg["provider"] = st.selectbox("Provider", ["openai", "elevenlabs", "edge"])
|
||||
pd_cfg["voice1"] = st.text_input(
|
||||
"Voice 1", help="You can use Elevenlabs voice ID"
|
||||
|
|
@ -81,10 +91,13 @@ with templates_tab:
|
|||
"OpenAI: tts-1 or tts-1-hd, Elevenlabs: eleven_multilingual_v2, eleven_turbo_v2_5"
|
||||
)
|
||||
if st.button("Save"):
|
||||
pd = PodcastConfig(**pd_cfg)
|
||||
pd_cfg = {}
|
||||
pd.save()
|
||||
st.rerun()
|
||||
try:
|
||||
pd = PodcastConfig(**pd_cfg)
|
||||
pd_cfg = {}
|
||||
pd.save()
|
||||
st.rerun()
|
||||
except Exception as e:
|
||||
st.error(e)
|
||||
|
||||
for pd_config in PodcastConfig.get_all():
|
||||
with st.expander(pd_config.name):
|
||||
|
|
@ -161,6 +174,12 @@ with templates_tab:
|
|||
value=pd_config.creativity,
|
||||
key=f"creativity_{pd_config.id}",
|
||||
)
|
||||
pd_config.ending_message = st.text_input(
|
||||
"Ending Message",
|
||||
value=pd_config.ending_message,
|
||||
placeholder="Thank you for listening!",
|
||||
key=f"ending_message_{pd_config.id}",
|
||||
)
|
||||
pd_config.provider = st.selectbox(
|
||||
"Provider",
|
||||
["openai", "elevenlabs", "edge"],
|
||||
|
|
@ -190,8 +209,11 @@ with templates_tab:
|
|||
)
|
||||
|
||||
if st.button("Save Config", key=f"btn_save{pd_config.id}"):
|
||||
pd_config.save()
|
||||
st.rerun()
|
||||
try:
|
||||
pd_config.save()
|
||||
st.toast("Podcast template saved")
|
||||
except Exception as e:
|
||||
st.error(e)
|
||||
|
||||
if st.button("Duplicate Config", key=f"btn_duplicate{pd_config.id}"):
|
||||
pd_config.name = f"{pd_config.name} - Copy"
|
||||
|
|
|
|||
13
poetry.lock
generated
13
poetry.lock
generated
|
|
@ -5576,6 +5576,17 @@ files = [
|
|||
{file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tomli"
|
||||
version = "2.0.2"
|
||||
description = "A lil' TOML parser"
|
||||
optional = false
|
||||
python-versions = ">=3.8"
|
||||
files = [
|
||||
{file = "tomli-2.0.2-py3-none-any.whl", hash = "sha256:2ebe24485c53d303f690b0ec092806a085f07af5a5aa1464f3931eec36caaa38"},
|
||||
{file = "tomli-2.0.2.tar.gz", hash = "sha256:d46d457a85337051c36524bc5349dd91b1877838e2979ac5ced3e710ed8a60ed"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tornado"
|
||||
version = "6.4.1"
|
||||
|
|
@ -6063,4 +6074,4 @@ type = ["pytest-mypy"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "^3.11"
|
||||
content-hash = "5f7bdea405c6c6433fa805b3321ac1550e13deee0d3a3c04e38136cd6992f5b1"
|
||||
content-hash = "f6f2373a3c5f63afd6c2746ed87bb2e84dcea36fac6006b023cc7f6b2f7221c8"
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
[tool.poetry]
|
||||
name = "open-notebook"
|
||||
version = "0.0.4"
|
||||
version = "0.0.5"
|
||||
description = "An open source implementation of a research assistant, inspired by Google Notebook LM"
|
||||
authors = ["Luis Novo <lfnovo@gmail.com>"]
|
||||
license = "MIT"
|
||||
|
|
@ -41,6 +41,7 @@ langchain-google-vertexai = "^2.0.5"
|
|||
sdblpy = "^0.3.0"
|
||||
langchain-google-genai = "^2.0.1"
|
||||
podcastfy = "^0.2.8"
|
||||
tomli = "^2.0.2"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
ipykernel = "^6.29.5"
|
||||
|
|
|
|||
|
|
@ -77,14 +77,19 @@ def chat_sidebar(session_id):
|
|||
instructions = st.text_area(
|
||||
"Instructions", value=selected_template.user_instructions
|
||||
)
|
||||
if st.button("Generate"):
|
||||
with st.spinner("Go grab a coffee, almost there..."):
|
||||
selected_template.generate_episode(
|
||||
episode_name=episode_name,
|
||||
text=context,
|
||||
instructions=instructions,
|
||||
)
|
||||
st.success("Episode generated successfully")
|
||||
if len(context.get("note", [])) + len(context.get("source", [])) == 0:
|
||||
st.warning(
|
||||
"No notes or sources found in context. You don't want a boring podcast, right? So, add some context first."
|
||||
)
|
||||
else:
|
||||
if st.button("Generate"):
|
||||
with st.spinner("Go grab a coffee, almost there..."):
|
||||
selected_template.generate_episode(
|
||||
episode_name=episode_name,
|
||||
text=context,
|
||||
instructions=instructions,
|
||||
)
|
||||
st.success("Episode generated successfully")
|
||||
st.page_link("pages/5_🎙️_Podcasts.py", label="🎙️ Go to Podcasts")
|
||||
with chat_tab:
|
||||
with st.container(border=True):
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import streamlit as st
|
||||
|
|
@ -6,16 +7,15 @@ import yaml
|
|||
from humanize import naturaltime
|
||||
from loguru import logger
|
||||
|
||||
from open_notebook.config import UPLOADS_FOLDER
|
||||
from open_notebook.domain import Asset, Source
|
||||
from open_notebook.exceptions import UnsupportedTypeException
|
||||
from open_notebook.graphs.content_process import graph
|
||||
from open_notebook.graphs.multipattern import graph as transform_graph
|
||||
from open_notebook.utils import surreal_clean
|
||||
|
||||
from .consts import context_icons
|
||||
|
||||
uploads_dir = Path("./.uploads")
|
||||
uploads_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
|
||||
def run_transformations(input_text, transformations):
|
||||
output = transform_graph.invoke(
|
||||
|
|
@ -121,10 +121,10 @@ def add_source(session_id):
|
|||
# Generate a unique file name
|
||||
base_name = Path(file_name).stem
|
||||
counter = 1
|
||||
new_path = uploads_dir / file_name
|
||||
while new_path.exists():
|
||||
new_path = os.path.join(UPLOADS_FOLDER, file_name)
|
||||
while os.path.exists(new_path):
|
||||
new_file_name = f"{base_name}_{counter}{file_extension}"
|
||||
new_path = uploads_dir / new_file_name
|
||||
new_path = os.path.join(UPLOADS_FOLDER, new_file_name)
|
||||
counter += 1
|
||||
|
||||
req["file_path"] = str(new_path)
|
||||
|
|
@ -139,17 +139,31 @@ def add_source(session_id):
|
|||
logger.debug("Adding source")
|
||||
with st.status("Processing...", expanded=True):
|
||||
st.write("Processing document...")
|
||||
result = graph.invoke(req)
|
||||
st.write("Saving..")
|
||||
source = Source(
|
||||
asset=Asset(url=req.get("url"), file_path=req.get("file_path")),
|
||||
full_text=surreal_clean(result["content"]),
|
||||
title=result.get("title"),
|
||||
)
|
||||
source.save()
|
||||
source.add_to_notebook(st.session_state[session_id]["notebook"].id)
|
||||
st.write("Summarizing...")
|
||||
source.generate_toc_and_title()
|
||||
try:
|
||||
result = graph.invoke(req)
|
||||
st.write("Saving..")
|
||||
source = Source(
|
||||
asset=Asset(url=req.get("url"), file_path=req.get("file_path")),
|
||||
full_text=surreal_clean(result["content"]),
|
||||
title=result.get("title"),
|
||||
)
|
||||
source.save()
|
||||
source.add_to_notebook(st.session_state[session_id]["notebook"].id)
|
||||
st.write("Summarizing...")
|
||||
source.generate_toc_and_title()
|
||||
except UnsupportedTypeException:
|
||||
st.warning(
|
||||
"This type of content is not supported yet. If you think it should be, let us know on the project Issues's page"
|
||||
)
|
||||
st.link_button(
|
||||
"Go to Github Issues",
|
||||
url="https://www.github.com/lfnovo/open-notebook/issues",
|
||||
)
|
||||
st.stop()
|
||||
|
||||
except Exception as e:
|
||||
st.error(e)
|
||||
return
|
||||
|
||||
st.rerun()
|
||||
|
||||
|
|
@ -159,7 +173,8 @@ def source_card(session_id, source):
|
|||
icon = "🔗"
|
||||
|
||||
with st.container(border=True):
|
||||
st.markdown((f"{icon} **{source.title if source.title else 'No Title'}**"))
|
||||
title = (source.title if source.title else "No Title").strip()
|
||||
st.markdown((f"{icon}**{title}**"))
|
||||
context_state = st.selectbox(
|
||||
"Context",
|
||||
label_visibility="collapsed",
|
||||
|
|
|
|||
|
|
@ -1,6 +1,24 @@
|
|||
import streamlit as st
|
||||
|
||||
from open_notebook.graphs.chat import ThreadState, graph
|
||||
from open_notebook.utils import (
|
||||
compare_versions,
|
||||
get_installed_version,
|
||||
get_version_from_github,
|
||||
)
|
||||
|
||||
|
||||
def version_sidebar():
|
||||
with st.sidebar:
|
||||
current_version = get_installed_version("open-notebook")
|
||||
latest_version = get_version_from_github(
|
||||
"https://www.github.com/lfnovo/open-notebook", "main"
|
||||
)
|
||||
st.write(f"Open Notebook: {current_version}")
|
||||
if compare_versions(current_version, latest_version) < 0:
|
||||
st.warning(
|
||||
f"New version {latest_version} available. [Click here for upgrade instructions](https://github.com/lfnovo/open-notebook/blob/main/docs/SETUP.md#upgrading-open-notebook)"
|
||||
)
|
||||
|
||||
|
||||
def setup_stream_state(session_id) -> None:
|
||||
|
|
|
|||
Loading…
Reference in a new issue