Merge pull request #11 from lfnovo/audiovideo

Audio, Video and many fixes
This commit is contained in:
Luis Novo 2024-10-28 10:46:18 -03:00 committed by GitHub
commit fa325ed594
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 553 additions and 81 deletions

View file

@ -9,4 +9,7 @@ sqlite-db/
temp/
google-credentials.json
docker-compose*
.docker_data
.docker_data/
docs/
surreal-data/
temp/

View file

@ -2,7 +2,6 @@
# DEFAULT MODEL_CONFIGURATIONS
DEFAULT_MODEL="openai/gpt-4o-mini"
SUMMARIZATION_MODEL="openai/gpt-4o-mini"
RETRIEVAL_MODEL="openai/gpt-4o-mini"
# OPENAI
# USE MODEL NAMES AS "openai/<modelname>"
@ -12,30 +11,29 @@ OPENAI_API_KEY=
# ANTHROPIC
# USE MODEL NAMES AS "anthropic/<modelname>"
# EXAMPLE - anthropic/claude-3-5-sonnet-20240620
ANTHROPIC_API_KEY=
# ANTHROPIC_API_KEY=
# GEMINI
# USE MODEL NAMES AS "gemini/<modelname>"
# EXAMPLE - gemini/gemini-1.5-pro-002
GEMINI_API_KEY=
# GEMINI_API_KEY=
# VERTEXAI
# USE MODEL NAMES AS "vertexai/<modelname>"
# EXAMPLE - vertexai/gemini-1.5-pro-002
VERTEX_PROJECT=my-google-cloud-project-name
GOOGLE_APPLICATION_CREDENTIALS=./google-credentials.json
# VERTEX_PROJECT=my-google-cloud-project-name
# GOOGLE_APPLICATION_CREDENTIALS=./google-credentials.json
# OLLAMA
# USE MODEL NAMES AS "ollama/<modelname>"
# EXAMPLE - ollama/gemma2
OLLAMA_API_BASE="http://10.20.30.20:11434"
# OLLAMA_API_BASE="http://10.20.30.20:11434"
# OPEN ROUTER
# USE MODEL NAMES AS "openrouter/<modelname>"
# EXAMPLE - openrouter/nvidia/llama-3.1-nemotron-70b-instruct
OPENROUTER_BASE_URL="https://openrouter.ai/api/v1"
OPENROUTER_API_KEY=
# OPENROUTER_BASE_URL="https://openrouter.ai/api/v1"
# OPENROUTER_API_KEY=
# ELEVENLABS
# Used only by the podcast feature
@ -49,7 +47,8 @@ ELEVENLABS_API_KEY=
# LANGCHAIN_PROJECT="Open Notebook"
# CONNECTION DETAILS FOR YOUR SURREAL DB
SURREAL_ADDRESS="localhost"
# Use surrealdb if using docker-compose or add your server ip if using a different setup
SURREAL_ADDRESS="surrealdb"
SURREAL_PORT=8000
SURREAL_USER="root"
SURREAL_PASS="root"

View file

@ -22,7 +22,6 @@ WORKDIR /app
EXPOSE 8502
RUN mkdir -p /app/sqlite-db
RUN mkdir -p /app/data
CMD ["poetry", "run", "streamlit", "run", "app_home.py"]

View file

@ -27,7 +27,7 @@ services:
user: root
open_notebook:
image: lfnovo/open_notebook:latest
image: lfnovo/open-notebook:latest
ports:
- "8502:8502"
env_file:

View file

@ -1,10 +1,12 @@
import streamlit as st
from open_notebook.exceptions import InvalidDatabaseSchema, NoSchemaFound
from open_notebook.repository import check_version, execute_migration
from open_notebook.repository import check_database_version, execute_migration
from stream_app.utils import version_sidebar
try:
check_version()
version_sidebar()
check_database_version()
st.switch_page("pages/2_📒_Notebooks.py")
except NoSchemaFound as e:
st.warning(e)

View file

@ -12,7 +12,7 @@ services:
user: root
open_notebook:
image: lfnovo/open_notebook:latest
image: lfnovo/open-notebook:latest
ports:
- "8080:8502"
env_file:

View file

@ -25,7 +25,7 @@ services:
user: root
open_notebook:
image: lfnovo/open_notebook:latest
image: lfnovo/open-notebook:latest
ports:
- "8080:8502"
env_file:
@ -58,7 +58,7 @@ services:
user: root
open_notebook:
image: lfnovo/open_notebook:latest
image: lfnovo/open-notebook:latest
ports:
- "8080:8502"
environment:
@ -158,3 +158,13 @@ After the app is running, you can access it at http://localhost:8080.
The first time you connect, it will check for the database and see if the schema is ready. If not, it will create the database for you.
Go to the [Usage](USAGE.md) page to learn how to use all features.
## Upgrading Open Notebook
### Running from source
Just run `git pull` on the root project folder and then `poetry install` to update dependencies.
### Running from docker
Just pull the latest image with `docker pull lfnovo/open-notebook:latest` and restart your containers with `docker-compose up -d`

View file

@ -15,3 +15,20 @@ except Exception:
logger.critical("Config file not found, using empty defaults")
logger.debug(f"Looked in {config_path}")
CONFIG = {}
# ROOT DATA FOLDER
# todo: make this configurable once podcastfy supports it
DATA_FOLDER = "./data"
# LANGGRAPH CHECKPOINT FILE
sqlite_folder = f"{DATA_FOLDER}/sqlite-db"
os.makedirs(sqlite_folder, exist_ok=True)
LANGGRAPH_CHECKPOINT_FILE = f"{sqlite_folder}/checkpoints.sqlite"
# UPLOADS FOLDER
UPLOADS_FOLDER = f"{DATA_FOLDER}/uploads"
os.makedirs(UPLOADS_FOLDER, exist_ok=True)
# PODCASTS FOLDER
PODCASTS_FOLDER = f"{DATA_FOLDER}/podcasts"
os.makedirs(PODCASTS_FOLDER, exist_ok=True)

View file

@ -4,7 +4,7 @@ from typing import Any, ClassVar, Dict, List, Literal, Optional, Type, TypeVar
from langchain_core.runnables.config import RunnableConfig
from loguru import logger
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field, ValidationError, field_validator
from open_notebook.exceptions import (
DatabaseOperationError,
@ -35,7 +35,13 @@ class ObjectModel(BaseModel):
def get_all(cls: Type[T]) -> List[T]:
try:
result = repo_query(f"SELECT * FROM {cls.table_name}")
objects = [cls(**obj) for obj in result]
objects = []
for obj in result:
try:
objects.append(cls(**obj))
except Exception as e:
logger.critical(f"Error creating object: {str(e)}")
return objects
except Exception as e:
logger.error(f"Error fetching all {cls.table_name}: {str(e)}")
@ -64,6 +70,8 @@ class ObjectModel(BaseModel):
def save(self) -> None:
try:
logger.debug(f"Validating {self.__class__.__name__}")
self.model_validate(self.model_dump(), strict=True)
data = self._prepare_save_data()
if self.needs_embedding():
@ -86,6 +94,13 @@ class ObjectModel(BaseModel):
else:
setattr(self, key, value)
except ValidationError as e:
logger.error(f"Validation failed: {e}")
raise
except Exception as e:
logger.error(f"Error saving record: {e}")
raise
except Exception as e:
logger.error(f"Error saving {self.__class__.table_name}: {str(e)}")
logger.exception(e)
@ -121,6 +136,13 @@ class ObjectModel(BaseModel):
logger.exception(e)
raise DatabaseOperationError(e)
@field_validator("created", "updated", mode="before")
@classmethod
def parse_datetime(cls, value):
if isinstance(value, str):
return datetime.fromisoformat(value.replace("Z", "+00:00"))
return value
class Notebook(ObjectModel):
table_name: ClassVar[str] = "notebook"
@ -139,7 +161,7 @@ class Notebook(ObjectModel):
def sources(self) -> List["Source"]:
try:
srcs = repo_query(f"""
select * from (
select * OMIT full_text from (
select
<- source as source
from reference
@ -158,7 +180,7 @@ class Notebook(ObjectModel):
def notes(self) -> List["Note"]:
try:
srcs = repo_query(f"""
select * from (
select * OMIT content from (
select
<- note as note
from artifact
@ -322,7 +344,6 @@ class Source(ObjectModel):
output = pattern_graph.invoke(
dict(content_stack=[result["toc"]], transformations=transformations)
)
logger.warning(output["output"])
self.title = surreal_clean(output["output"])
self.save()
return self

View file

@ -10,6 +10,12 @@ class DatabaseOperationError(OpenNotebookError):
pass
class UnsupportedTypeException(OpenNotebookError):
"""Raised when an unsupported type is provided."""
pass
class NoSchemaFound(OpenNotebookError):
"""Raised when a database schema is not found."""

View file

@ -1,4 +1,3 @@
import os
import sqlite3
from typing import Annotated, Optional
@ -10,6 +9,7 @@ from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from typing_extensions import TypedDict
from open_notebook.config import LANGGRAPH_CHECKPOINT_FILE
from open_notebook.domain import Notebook
from open_notebook.graphs.utils import run_pattern
@ -33,7 +33,7 @@ def call_model_with_messages(state: ThreadState, config: RunnableConfig) -> dict
conn = sqlite3.connect(
os.environ.get("CHECKPOINT_DATA_PATH", "sqlite-db/checkpoints.sqlite"),
LANGGRAPH_CHECKPOINT_FILE,
check_same_thread=False,
)
memory = SqliteSaver(conn)

View file

@ -1,16 +1,22 @@
import json
import os
import re
import subprocess
import unicodedata
from math import ceil
import fitz # type: ignore
import magic
import requests # type: ignore
from langgraph.graph import END, START, StateGraph
from loguru import logger
from pydub import AudioSegment
from typing_extensions import TypedDict
from youtube_transcript_api import YouTubeTranscriptApi # type: ignore
from youtube_transcript_api.formatters import TextFormatter # type: ignore
from open_notebook.config import CONFIG
from open_notebook.exceptions import UnsupportedTypeException
class SourceState(TypedDict):
@ -291,6 +297,228 @@ def should_continue(data: SourceState):
return "end"
def split_audio(input_file, segment_length_minutes=15, output_prefix=None):
"""
Split an audio file into segments of specified length.
Args:
input_file (str): Path to the input audio file
segment_length_minutes (int): Length of each segment in minutes
output_dir (str): Directory to save the segments (defaults to input file's directory)
output_prefix (str): Prefix for output files (defaults to input filename)
Returns:
list: List of paths to the created segment files
"""
# Convert input file to absolute path
input_file = os.path.abspath(input_file)
output_dir = os.path.dirname(input_file)
os.makedirs(output_dir, exist_ok=True)
# Set up output prefix
if output_prefix is None:
output_prefix = os.path.splitext(os.path.basename(input_file))[0]
# Load the audio file
audio = AudioSegment.from_file(input_file)
# Calculate segment length in milliseconds
segment_length_ms = segment_length_minutes * 60 * 1000
# Calculate number of segments
total_segments = ceil(len(audio) / segment_length_ms)
# List to store output file paths
output_files = []
# Split the audio into segments
for i in range(total_segments):
# Calculate start and end times for this segment
start_time = i * segment_length_ms
end_time = min((i + 1) * segment_length_ms, len(audio))
# Extract segment
segment = audio[start_time:end_time]
# Generate output filename
# Format: prefix_001.mp3 (padding with zeros ensures correct ordering)
output_filename = f"{output_prefix}_{str(i+1).zfill(3)}.mp3"
output_path = os.path.join(output_dir, output_filename)
# Export segment
segment.export(output_path, format="mp3")
output_files.append(output_path)
# Optional progress indication
print(f"Exported segment {i+1}/{total_segments}: {output_filename}")
return output_files
# todo: add a speechtotext model to the config
def extract_audio(data: SourceState):
input_audio_path = data.get("file_path")
from openai import OpenAI
client = OpenAI()
audio_files = split_audio(input_audio_path)
transcriptions = []
for audio_file in audio_files:
audio_file = open(audio_file, "rb")
transcription = client.audio.transcriptions.create(
model="whisper-1", file=audio_file
)
transcriptions.append(transcription.text)
return {"content": " ".join(transcriptions)}
def get_audio_streams(input_file):
"""
Analyze video file and return information about all audio streams
"""
try:
# Get stream information in JSON format
cmd = [
"ffprobe",
"-v",
"quiet",
"-print_format",
"json",
"-show_streams",
"-select_streams",
"a",
input_file,
]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
raise Exception(f"FFprobe failed: {result.stderr}")
data = json.loads(result.stdout)
return data.get("streams", [])
except Exception as e:
print(f"Error analyzing file: {str(e)}")
return []
def select_best_audio_stream(streams):
"""
Select the best audio stream based on various quality metrics
"""
if not streams:
return None
# Score each stream based on various factors
scored_streams = []
for stream in streams:
score = 0
# Prefer higher bit rates
bit_rate = stream.get("bit_rate")
if bit_rate:
score += int(bit_rate) / 1000000 # Convert to Mbps
# Prefer more channels (stereo over mono)
channels = stream.get("channels", 0)
score += channels * 10
# Prefer higher sample rates
sample_rate = stream.get("sample_rate", "0")
score += int(sample_rate) / 48000
scored_streams.append((score, stream))
# Return the stream with highest score
return max(scored_streams, key=lambda x: x[0])[1]
def extract_audio_from_video(input_file, output_file, stream_index):
"""
Extract the specified audio stream to MP3 format
"""
try:
cmd = [
"ffmpeg",
"-i",
input_file,
"-map",
f"0:a:{stream_index}", # Select specific audio stream
"-codec:a",
"libmp3lame", # Use MP3 codec
"-q:a",
"2", # High quality setting
"-y", # Overwrite output file if exists
output_file,
]
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
raise Exception(f"FFmpeg failed: {result.stderr}")
return True
except Exception as e:
print(f"Error extracting audio: {str(e)}")
return False
def extract_best_audio_from_video(data: SourceState):
"""
Main function to extract the best audio stream from a video file
"""
input_file = data.get("file_path")
if not os.path.exists(input_file):
print(f"Input file not found: {input_file}")
return False
base_name = os.path.splitext(input_file)[0]
output_file = f"{base_name}_audio.mp3"
# Get all audio streams
streams = get_audio_streams(input_file)
if not streams:
print("No audio streams found in the file")
return False
# Select best stream
best_stream = select_best_audio_stream(streams)
if not best_stream:
print("Could not determine best audio stream")
return False
# Extract the selected stream
stream_index = streams.index(best_stream)
success = extract_audio_from_video(input_file, output_file, stream_index)
if success:
print(f"Successfully extracted audio to: {output_file}")
print("Selected stream details:")
print(f"- Channels: {best_stream.get('channels', 'unknown')}")
print(f"- Sample rate: {best_stream.get('sample_rate', 'unknown')} Hz")
print(f"- Bit rate: {best_stream.get('bit_rate', 'unknown')} bits/s")
return {"file_path": output_file, "identified_type": "audio/mp3"}
def file_type_edge(data: SourceState):
if data.get("identified_type") == "text/plain":
return "extract_txt"
elif data.get("identified_type") == "application/pdf":
return "extract_pdf"
elif data.get("identified_type").startswith("video"):
return "extract_best_audio_from_video"
elif data.get("identified_type").startswith("audio"):
return "extract_audio"
else:
raise UnsupportedTypeException(
f"Unsupported file type: {data.get('identified_type')}"
)
workflow = StateGraph(SourceState)
workflow.add_node("source", source_identification)
workflow.add_node("url_provider", url_provider)
@ -298,6 +526,8 @@ workflow.add_node("file_type", file_type)
workflow.add_node("extract_txt", extract_txt)
workflow.add_node("extract_pdf", extract_pdf)
workflow.add_node("extract_url", extract_url)
workflow.add_node("extract_best_audio_from_video", extract_best_audio_from_video)
workflow.add_node("extract_audio", extract_audio)
workflow.add_node("extract_youtube_transcript", extract_youtube_transcript)
workflow.add_edge(START, "source")
@ -312,11 +542,7 @@ workflow.add_conditional_edges(
)
workflow.add_conditional_edges(
"file_type",
lambda x: x.get("identified_type"),
{
"text/plain": "extract_txt",
"application/pdf": "extract_pdf",
},
file_type_edge,
)
workflow.add_conditional_edges(
"url_provider",
@ -328,5 +554,7 @@ workflow.add_edge("file_type", END)
workflow.add_edge("extract_txt", END)
workflow.add_edge("extract_pdf", END)
workflow.add_edge("extract_url", END)
workflow.add_edge("extract_best_audio_from_video", "extract_audio")
workflow.add_edge("extract_audio", END)
workflow.add_edge("extract_youtube_transcript", END)
graph = workflow.compile()

View file

@ -33,7 +33,6 @@ def call_model(state: dict, config: RunnableConfig) -> dict:
}
current_transformation = "patterns/custom"
logger.warning(f"Processing transformation: {current_transformation}")
logger.debug(f"Using input: {input_args}")
transformation_result = run_pattern(
pattern_name=current_transformation,

View file

@ -32,10 +32,9 @@ def run_pattern(
system_prompt = Prompter(prompt_template=pattern_name, parser=parser).render(
data=state
)
# logger.debug(f"System prompt: {system_prompt}")
logger.debug(f"System prompt: {system_prompt}")
if len(messages) > 0:
logger.warning(messages)
response = chain.invoke([system_prompt] + messages)
else:
response = chain.invoke(system_prompt)

View file

@ -1,4 +1,5 @@
from typing import ClassVar, List, Literal
from datetime import datetime
from typing import ClassVar, List, Literal, Optional
from loguru import logger
from podcastfy.client import generate_podcast
@ -27,13 +28,15 @@ class PodcastConfig(ObjectModel):
conversation_style: List[str]
engagement_technique: List[str]
dialogue_structure: List[str]
user_instructions: str
wordcount: int = Field(gt=500, lt=10000)
user_instructions: Optional[str] = None
ending_message: Optional[str] = None
wordcount: int = Field(ge=400, le=10000)
creativity: float = Field(ge=0, le=1)
provider: Literal["openai", "elevenlabs", "edge"] = Field(default="openai")
voice1: str
voice2: str
voice1: Optional[str] = None
voice2: Optional[str] = None
model: str
created: Optional[datetime] = Field(default_factory=datetime.now)
def generate_episode(self, episode_name, text, instructions=None):
self.user_instructions = (
@ -52,7 +55,7 @@ class PodcastConfig(ObjectModel):
"engagement_techniques": self.engagement_technique,
"creativity": self.creativity,
"text_to_speech": {
# "temp_audio_dir": "./data/audio/tmp",
# "temp_audio_dir": f"{PODCASTS_FOLDER}/tmp",
"ending_message": "Thank you for listening to this episode. Don't forget to subscribe to our podcast for more interesting conversations.",
"default_tts_model": self.provider,
self.provider: {
@ -66,8 +69,6 @@ class PodcastConfig(ObjectModel):
},
}
logger.error(conversation_config)
# conversation_config = {}
logger.debug(
f"Generating episode {episode_name} with config {conversation_config}"
)
@ -75,7 +76,6 @@ class PodcastConfig(ObjectModel):
audio_file = generate_podcast(
conversation_config=conversation_config, text=text, tts_model=self.provider
)
logger.warning(audio_file)
episode = PodcastEpisode(
name=episode_name,
template=self.name,
@ -85,10 +85,19 @@ class PodcastConfig(ObjectModel):
)
episode.save()
@field_validator(
"name", "podcast_name", "podcast_tagline", "output_language", "model"
)
@classmethod
def validate_required_strings(cls, value: str, field) -> str:
if value is None or value.strip() == "":
raise ValueError(f"{field.field_name} cannot be None or empty string")
return value.strip()
@field_validator("wordcount")
def validate_wordcount(cls, value):
if not 500 <= value <= 6000:
raise ValueError("Wordcount must be between 500 and 10000")
if not 400 <= value <= 6000:
raise ValueError("Wordcount must be between 400 and 10000")
return value
@field_validator("creativity")

View file

@ -19,6 +19,7 @@ def db_connection():
password=os.environ["SURREAL_PASS"],
namespace=os.environ["SURREAL_NAMESPACE"],
database=os.environ["SURREAL_DATABASE"],
max_size=2.2**20,
encrypted=False, # Set to True if using SSL
)
try:
@ -38,7 +39,7 @@ def repo_query(query_str: str, vars: Optional[Dict[str, Any]] = None):
raise
def check_version():
def check_database_version():
try:
result = repo_query("SELECT * FROM open_notebook:database_info;")

View file

@ -1,8 +1,13 @@
import re
import unicodedata
from importlib.metadata import PackageNotFoundError, version
from urllib.parse import urlparse
import requests
import tomli
from langchain_text_splitters import CharacterTextSplitter
from openai import OpenAI
from packaging.version import parse as parse_version
client = OpenAI()
@ -107,3 +112,100 @@ def surreal_clean(text):
text = text.replace(":", "\:", 1)
return text
def get_version_from_github(repo_url: str, branch: str = "main") -> str:
"""
Fetch and parse the version from pyproject.toml in a public GitHub repository.
Args:
repo_url (str): URL of the GitHub repository
branch (str): Branch name to fetch from (defaults to "main")
Returns:
str: Version string from pyproject.toml
Raises:
ValueError: If the URL is not a valid GitHub repository URL
requests.RequestException: If there's an error fetching the file
KeyError: If version information is not found in pyproject.toml
"""
# Parse the GitHub URL
parsed_url = urlparse(repo_url)
if "github.com" not in parsed_url.netloc:
raise ValueError("Not a GitHub URL")
# Extract owner and repo name from path
path_parts = parsed_url.path.strip("/").split("/")
if len(path_parts) < 2:
raise ValueError("Invalid GitHub repository URL")
owner, repo = path_parts[0], path_parts[1]
# Construct raw content URL for pyproject.toml
raw_url = (
f"https://raw.githubusercontent.com/{owner}/{repo}/{branch}/pyproject.toml"
)
# Fetch the file
response = requests.get(raw_url)
response.raise_for_status()
# Parse TOML content
pyproject_data = tomli.loads(response.text)
# Try to find version in different possible locations
try:
# Check project.version first (poetry style)
version = pyproject_data["tool"]["poetry"]["version"]
except KeyError:
try:
# Check project.version (standard style)
version = pyproject_data["project"]["version"]
except KeyError:
raise KeyError("Version not found in pyproject.toml")
return version
def get_installed_version(package_name: str) -> str:
"""
Get the version of an installed package.
Args:
package_name (str): Name of the installed package
Returns:
str: Version string of the installed package
Raises:
PackageNotFoundError: If the package is not installed
"""
try:
return version(package_name)
except PackageNotFoundError:
raise PackageNotFoundError(f"Package '{package_name}' not found")
def compare_versions(version1: str, version2: str) -> int:
"""
Compare two semantic versions.
Args:
version1 (str): First version string
version2 (str): Second version string
Returns:
int: -1 if version1 < version2
0 if version1 == version2
1 if version1 > version2
"""
v1 = parse_version(version1)
v2 = parse_version(version2)
if v1 < v2:
return -1
elif v1 > v2:
return 1
else:
return 0

View file

@ -5,13 +5,16 @@ from open_notebook.domain import Notebook
from stream_app.chat import chat_sidebar
from stream_app.note import add_note, note_card
from stream_app.source import add_source, source_card
from stream_app.utils import setup_stream_state
from stream_app.utils import setup_stream_state, version_sidebar
st.set_page_config(
layout="wide", page_title="📒 Open Notebook", initial_sidebar_state="expanded"
)
version_sidebar()
def notebook_header(current_notebook):
c1, c2, c3 = st.columns([8, 2, 2])
c1.header(current_notebook.name)

View file

@ -4,10 +4,12 @@ from open_notebook.domain import text_search, vector_search
from open_notebook.utils import get_embedding
from stream_app.note import note_list_item
from stream_app.source import source_list_item
from stream_app.utils import version_sidebar
st.set_page_config(
layout="wide", page_title="🔍 Open Notebook", initial_sidebar_state="expanded"
layout="wide", page_title="🔍 Search", initial_sidebar_state="expanded"
)
version_sidebar()
# search_tab, ask_tab = st.tabs(["Search", "Ask"])
# notebooks = Notebook.get_all()

View file

@ -9,6 +9,13 @@ from open_notebook.plugins.podcasts import (
engagement_techniques,
participant_roles,
)
from stream_app.utils import version_sidebar
st.set_page_config(
layout="wide", page_title="🎙️ Podcasts", initial_sidebar_state="expanded"
)
version_sidebar()
episodes_tab, templates_tab = st.tabs(["Episodes", "Templates"])
@ -66,6 +73,9 @@ with templates_tab:
pd_cfg["creativity"] = st.slider(
"Creativity", min_value=0.0, max_value=1.0, step=0.05
)
pd_cfg["ending_message"] = st.text_input(
"Ending Message", placeholder="Thank you for listening!"
)
pd_cfg["provider"] = st.selectbox("Provider", ["openai", "elevenlabs", "edge"])
pd_cfg["voice1"] = st.text_input(
"Voice 1", help="You can use Elevenlabs voice ID"
@ -81,10 +91,13 @@ with templates_tab:
"OpenAI: tts-1 or tts-1-hd, Elevenlabs: eleven_multilingual_v2, eleven_turbo_v2_5"
)
if st.button("Save"):
pd = PodcastConfig(**pd_cfg)
pd_cfg = {}
pd.save()
st.rerun()
try:
pd = PodcastConfig(**pd_cfg)
pd_cfg = {}
pd.save()
st.rerun()
except Exception as e:
st.error(e)
for pd_config in PodcastConfig.get_all():
with st.expander(pd_config.name):
@ -161,6 +174,12 @@ with templates_tab:
value=pd_config.creativity,
key=f"creativity_{pd_config.id}",
)
pd_config.ending_message = st.text_input(
"Ending Message",
value=pd_config.ending_message,
placeholder="Thank you for listening!",
key=f"ending_message_{pd_config.id}",
)
pd_config.provider = st.selectbox(
"Provider",
["openai", "elevenlabs", "edge"],
@ -190,8 +209,11 @@ with templates_tab:
)
if st.button("Save Config", key=f"btn_save{pd_config.id}"):
pd_config.save()
st.rerun()
try:
pd_config.save()
st.toast("Podcast template saved")
except Exception as e:
st.error(e)
if st.button("Duplicate Config", key=f"btn_duplicate{pd_config.id}"):
pd_config.name = f"{pd_config.name} - Copy"

13
poetry.lock generated
View file

@ -5576,6 +5576,17 @@ files = [
{file = "toml-0.10.2.tar.gz", hash = "sha256:b3bda1d108d5dd99f4a20d24d9c348e91c4db7ab1b749200bded2f839ccbe68f"},
]
[[package]]
name = "tomli"
version = "2.0.2"
description = "A lil' TOML parser"
optional = false
python-versions = ">=3.8"
files = [
{file = "tomli-2.0.2-py3-none-any.whl", hash = "sha256:2ebe24485c53d303f690b0ec092806a085f07af5a5aa1464f3931eec36caaa38"},
{file = "tomli-2.0.2.tar.gz", hash = "sha256:d46d457a85337051c36524bc5349dd91b1877838e2979ac5ced3e710ed8a60ed"},
]
[[package]]
name = "tornado"
version = "6.4.1"
@ -6063,4 +6074,4 @@ type = ["pytest-mypy"]
[metadata]
lock-version = "2.0"
python-versions = "^3.11"
content-hash = "5f7bdea405c6c6433fa805b3321ac1550e13deee0d3a3c04e38136cd6992f5b1"
content-hash = "f6f2373a3c5f63afd6c2746ed87bb2e84dcea36fac6006b023cc7f6b2f7221c8"

View file

@ -1,6 +1,6 @@
[tool.poetry]
name = "open-notebook"
version = "0.0.4"
version = "0.0.5"
description = "An open source implementation of a research assistant, inspired by Google Notebook LM"
authors = ["Luis Novo <lfnovo@gmail.com>"]
license = "MIT"
@ -41,6 +41,7 @@ langchain-google-vertexai = "^2.0.5"
sdblpy = "^0.3.0"
langchain-google-genai = "^2.0.1"
podcastfy = "^0.2.8"
tomli = "^2.0.2"
[tool.poetry.group.dev.dependencies]
ipykernel = "^6.29.5"

View file

@ -77,14 +77,19 @@ def chat_sidebar(session_id):
instructions = st.text_area(
"Instructions", value=selected_template.user_instructions
)
if st.button("Generate"):
with st.spinner("Go grab a coffee, almost there..."):
selected_template.generate_episode(
episode_name=episode_name,
text=context,
instructions=instructions,
)
st.success("Episode generated successfully")
if len(context.get("note", [])) + len(context.get("source", [])) == 0:
st.warning(
"No notes or sources found in context. You don't want a boring podcast, right? So, add some context first."
)
else:
if st.button("Generate"):
with st.spinner("Go grab a coffee, almost there..."):
selected_template.generate_episode(
episode_name=episode_name,
text=context,
instructions=instructions,
)
st.success("Episode generated successfully")
st.page_link("pages/5_🎙_Podcasts.py", label="🎙️ Go to Podcasts")
with chat_tab:
with st.container(border=True):

View file

@ -1,3 +1,4 @@
import os
from pathlib import Path
import streamlit as st
@ -6,16 +7,15 @@ import yaml
from humanize import naturaltime
from loguru import logger
from open_notebook.config import UPLOADS_FOLDER
from open_notebook.domain import Asset, Source
from open_notebook.exceptions import UnsupportedTypeException
from open_notebook.graphs.content_process import graph
from open_notebook.graphs.multipattern import graph as transform_graph
from open_notebook.utils import surreal_clean
from .consts import context_icons
uploads_dir = Path("./.uploads")
uploads_dir.mkdir(parents=True, exist_ok=True)
def run_transformations(input_text, transformations):
output = transform_graph.invoke(
@ -121,10 +121,10 @@ def add_source(session_id):
# Generate a unique file name
base_name = Path(file_name).stem
counter = 1
new_path = uploads_dir / file_name
while new_path.exists():
new_path = os.path.join(UPLOADS_FOLDER, file_name)
while os.path.exists(new_path):
new_file_name = f"{base_name}_{counter}{file_extension}"
new_path = uploads_dir / new_file_name
new_path = os.path.join(UPLOADS_FOLDER, new_file_name)
counter += 1
req["file_path"] = str(new_path)
@ -139,17 +139,31 @@ def add_source(session_id):
logger.debug("Adding source")
with st.status("Processing...", expanded=True):
st.write("Processing document...")
result = graph.invoke(req)
st.write("Saving..")
source = Source(
asset=Asset(url=req.get("url"), file_path=req.get("file_path")),
full_text=surreal_clean(result["content"]),
title=result.get("title"),
)
source.save()
source.add_to_notebook(st.session_state[session_id]["notebook"].id)
st.write("Summarizing...")
source.generate_toc_and_title()
try:
result = graph.invoke(req)
st.write("Saving..")
source = Source(
asset=Asset(url=req.get("url"), file_path=req.get("file_path")),
full_text=surreal_clean(result["content"]),
title=result.get("title"),
)
source.save()
source.add_to_notebook(st.session_state[session_id]["notebook"].id)
st.write("Summarizing...")
source.generate_toc_and_title()
except UnsupportedTypeException:
st.warning(
"This type of content is not supported yet. If you think it should be, let us know on the project Issues's page"
)
st.link_button(
"Go to Github Issues",
url="https://www.github.com/lfnovo/open-notebook/issues",
)
st.stop()
except Exception as e:
st.error(e)
return
st.rerun()
@ -159,7 +173,8 @@ def source_card(session_id, source):
icon = "🔗"
with st.container(border=True):
st.markdown((f"{icon} **{source.title if source.title else 'No Title'}**"))
title = (source.title if source.title else "No Title").strip()
st.markdown((f"{icon}**{title}**"))
context_state = st.selectbox(
"Context",
label_visibility="collapsed",

View file

@ -1,6 +1,24 @@
import streamlit as st
from open_notebook.graphs.chat import ThreadState, graph
from open_notebook.utils import (
compare_versions,
get_installed_version,
get_version_from_github,
)
def version_sidebar():
with st.sidebar:
current_version = get_installed_version("open-notebook")
latest_version = get_version_from_github(
"https://www.github.com/lfnovo/open-notebook", "main"
)
st.write(f"Open Notebook: {current_version}")
if compare_versions(current_version, latest_version) < 0:
st.warning(
f"New version {latest_version} available. [Click here for upgrade instructions](https://github.com/lfnovo/open-notebook/blob/main/docs/SETUP.md#upgrading-open-notebook)"
)
def setup_stream_state(session_id) -> None: