From 8dac6dd2ace451e6964d782fb56b80cd57851250 Mon Sep 17 00:00:00 2001 From: LUIS NOVO Date: Sat, 26 Oct 2024 18:55:56 -0300 Subject: [PATCH] improve object validation --- open_notebook/domain.py | 24 ++++++++++++++++++----- open_notebook/plugins/podcasts.py | 31 +++++++++++++++++++----------- pages/3_🔍_Search.py | 2 +- pages/5_🎙️_Podcasts.py | 32 +++++++++++++++++++++++++------ 4 files changed, 66 insertions(+), 23 deletions(-) diff --git a/open_notebook/domain.py b/open_notebook/domain.py index 1e800eb..f232e4c 100644 --- a/open_notebook/domain.py +++ b/open_notebook/domain.py @@ -4,7 +4,7 @@ from typing import Any, ClassVar, Dict, List, Literal, Optional, Type, TypeVar from langchain_core.runnables.config import RunnableConfig from loguru import logger -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field, ValidationError, field_validator from open_notebook.exceptions import ( DatabaseOperationError, @@ -35,7 +35,13 @@ class ObjectModel(BaseModel): def get_all(cls: Type[T]) -> List[T]: try: result = repo_query(f"SELECT * FROM {cls.table_name}") - objects = [cls(**obj) for obj in result] + objects = [] + for obj in result: + try: + objects.append(cls(**obj)) + except Exception as e: + logger.critical(f"Error creating object: {str(e)}") + return objects except Exception as e: logger.error(f"Error fetching all {cls.table_name}: {str(e)}") @@ -64,6 +70,8 @@ class ObjectModel(BaseModel): def save(self) -> None: try: + logger.debug(f"Validating {self.__class__.__name__}") + self.model_validate(self.model_dump(), strict=True) data = self._prepare_save_data() if self.needs_embedding(): @@ -86,6 +94,13 @@ class ObjectModel(BaseModel): else: setattr(self, key, value) + except ValidationError as e: + logger.error(f"Validation failed: {e}") + raise + except Exception as e: + logger.error(f"Error saving record: {e}") + raise + except Exception as e: logger.error(f"Error saving {self.__class__.table_name}: {str(e)}") logger.exception(e) @@ -139,7 +154,7 @@ class Notebook(ObjectModel): def sources(self) -> List["Source"]: try: srcs = repo_query(f""" - select * from ( + select * OMIT full_text from ( select <- source as source from reference @@ -158,7 +173,7 @@ class Notebook(ObjectModel): def notes(self) -> List["Note"]: try: srcs = repo_query(f""" - select * from ( + select * OMIT content from ( select <- note as note from artifact @@ -322,7 +337,6 @@ class Source(ObjectModel): output = pattern_graph.invoke( dict(content_stack=[result["toc"]], transformations=transformations) ) - logger.warning(output["output"]) self.title = surreal_clean(output["output"]) self.save() return self diff --git a/open_notebook/plugins/podcasts.py b/open_notebook/plugins/podcasts.py index f1167cb..c7a5a9f 100644 --- a/open_notebook/plugins/podcasts.py +++ b/open_notebook/plugins/podcasts.py @@ -1,4 +1,5 @@ -from typing import ClassVar, List, Literal +from datetime import datetime +from typing import ClassVar, List, Literal, Optional from loguru import logger from podcastfy.client import generate_podcast @@ -27,13 +28,15 @@ class PodcastConfig(ObjectModel): conversation_style: List[str] engagement_technique: List[str] dialogue_structure: List[str] - user_instructions: str - wordcount: int = Field(gt=500, lt=10000) + user_instructions: Optional[str] = None + ending_message: Optional[str] = None + wordcount: int = Field(ge=400, le=10000) creativity: float = Field(ge=0, le=1) provider: Literal["openai", "elevenlabs", "edge"] = Field(default="openai") - voice1: str - voice2: str + voice1: Optional[str] = None + voice2: Optional[str] = None model: str + created: Optional[datetime] = Field(default_factory=datetime.now) def generate_episode(self, episode_name, text, instructions=None): self.user_instructions = ( @@ -52,7 +55,7 @@ class PodcastConfig(ObjectModel): "engagement_techniques": self.engagement_technique, "creativity": self.creativity, "text_to_speech": { - # "temp_audio_dir": "./data/audio/tmp", + # "temp_audio_dir": f"{PODCASTS_FOLDER}/tmp", "ending_message": "Thank you for listening to this episode. Don't forget to subscribe to our podcast for more interesting conversations.", "default_tts_model": self.provider, self.provider: { @@ -66,8 +69,6 @@ class PodcastConfig(ObjectModel): }, } - logger.error(conversation_config) - # conversation_config = {} logger.debug( f"Generating episode {episode_name} with config {conversation_config}" ) @@ -75,7 +76,6 @@ class PodcastConfig(ObjectModel): audio_file = generate_podcast( conversation_config=conversation_config, text=text, tts_model=self.provider ) - logger.warning(audio_file) episode = PodcastEpisode( name=episode_name, template=self.name, @@ -85,10 +85,19 @@ class PodcastConfig(ObjectModel): ) episode.save() + @field_validator( + "name", "podcast_name", "podcast_tagline", "output_language", "model" + ) + @classmethod + def validate_required_strings(cls, value: str, field) -> str: + if value is None or value.strip() == "": + raise ValueError(f"{field.field_name} cannot be None or empty string") + return value.strip() + @field_validator("wordcount") def validate_wordcount(cls, value): - if not 500 <= value <= 6000: - raise ValueError("Wordcount must be between 500 and 10000") + if not 400 <= value <= 6000: + raise ValueError("Wordcount must be between 400 and 10000") return value @field_validator("creativity") diff --git a/pages/3_🔍_Search.py b/pages/3_🔍_Search.py index 63fc8bd..39778a0 100644 --- a/pages/3_🔍_Search.py +++ b/pages/3_🔍_Search.py @@ -6,7 +6,7 @@ from stream_app.note import note_list_item from stream_app.source import source_list_item st.set_page_config( - layout="wide", page_title="🔍 Open Notebook", initial_sidebar_state="expanded" + layout="wide", page_title="🔍 Search", initial_sidebar_state="expanded" ) # search_tab, ask_tab = st.tabs(["Search", "Ask"]) diff --git a/pages/5_🎙️_Podcasts.py b/pages/5_🎙️_Podcasts.py index a5efe37..7f17d91 100644 --- a/pages/5_🎙️_Podcasts.py +++ b/pages/5_🎙️_Podcasts.py @@ -10,6 +10,11 @@ from open_notebook.plugins.podcasts import ( participant_roles, ) +st.set_page_config( + layout="wide", page_title="🎙️ Podcasts", initial_sidebar_state="expanded" +) + + episodes_tab, templates_tab = st.tabs(["Episodes", "Templates"]) with episodes_tab: @@ -66,6 +71,9 @@ with templates_tab: pd_cfg["creativity"] = st.slider( "Creativity", min_value=0.0, max_value=1.0, step=0.05 ) + pd_cfg["ending_message"] = st.text_input( + "Ending Message", placeholder="Thank you for listening!" + ) pd_cfg["provider"] = st.selectbox("Provider", ["openai", "elevenlabs", "edge"]) pd_cfg["voice1"] = st.text_input( "Voice 1", help="You can use Elevenlabs voice ID" @@ -81,10 +89,13 @@ with templates_tab: "OpenAI: tts-1 or tts-1-hd, Elevenlabs: eleven_multilingual_v2, eleven_turbo_v2_5" ) if st.button("Save"): - pd = PodcastConfig(**pd_cfg) - pd_cfg = {} - pd.save() - st.rerun() + try: + pd = PodcastConfig(**pd_cfg) + pd_cfg = {} + pd.save() + st.rerun() + except Exception as e: + st.error(e) for pd_config in PodcastConfig.get_all(): with st.expander(pd_config.name): @@ -161,6 +172,12 @@ with templates_tab: value=pd_config.creativity, key=f"creativity_{pd_config.id}", ) + pd_config.ending_message = st.text_input( + "Ending Message", + value=pd_config.ending_message, + placeholder="Thank you for listening!", + key=f"ending_message_{pd_config.id}", + ) pd_config.provider = st.selectbox( "Provider", ["openai", "elevenlabs", "edge"], @@ -190,8 +207,11 @@ with templates_tab: ) if st.button("Save Config", key=f"btn_save{pd_config.id}"): - pd_config.save() - st.rerun() + try: + pd_config.save() + st.toast("Podcast template saved") + except Exception as e: + st.error(e) if st.button("Duplicate Config", key=f"btn_duplicate{pd_config.id}"): pd_config.name = f"{pd_config.name} - Copy"