improve object validation

This commit is contained in:
LUIS NOVO 2024-10-26 18:55:56 -03:00
parent 9719146210
commit 8dac6dd2ac
4 changed files with 66 additions and 23 deletions

View file

@ -4,7 +4,7 @@ from typing import Any, ClassVar, Dict, List, Literal, Optional, Type, TypeVar
from langchain_core.runnables.config import RunnableConfig
from loguru import logger
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field, ValidationError, field_validator
from open_notebook.exceptions import (
DatabaseOperationError,
@ -35,7 +35,13 @@ class ObjectModel(BaseModel):
def get_all(cls: Type[T]) -> List[T]:
try:
result = repo_query(f"SELECT * FROM {cls.table_name}")
objects = [cls(**obj) for obj in result]
objects = []
for obj in result:
try:
objects.append(cls(**obj))
except Exception as e:
logger.critical(f"Error creating object: {str(e)}")
return objects
except Exception as e:
logger.error(f"Error fetching all {cls.table_name}: {str(e)}")
@ -64,6 +70,8 @@ class ObjectModel(BaseModel):
def save(self) -> None:
try:
logger.debug(f"Validating {self.__class__.__name__}")
self.model_validate(self.model_dump(), strict=True)
data = self._prepare_save_data()
if self.needs_embedding():
@ -86,6 +94,13 @@ class ObjectModel(BaseModel):
else:
setattr(self, key, value)
except ValidationError as e:
logger.error(f"Validation failed: {e}")
raise
except Exception as e:
logger.error(f"Error saving record: {e}")
raise
except Exception as e:
logger.error(f"Error saving {self.__class__.table_name}: {str(e)}")
logger.exception(e)
@ -139,7 +154,7 @@ class Notebook(ObjectModel):
def sources(self) -> List["Source"]:
try:
srcs = repo_query(f"""
select * from (
select * OMIT full_text from (
select
<- source as source
from reference
@ -158,7 +173,7 @@ class Notebook(ObjectModel):
def notes(self) -> List["Note"]:
try:
srcs = repo_query(f"""
select * from (
select * OMIT content from (
select
<- note as note
from artifact
@ -322,7 +337,6 @@ class Source(ObjectModel):
output = pattern_graph.invoke(
dict(content_stack=[result["toc"]], transformations=transformations)
)
logger.warning(output["output"])
self.title = surreal_clean(output["output"])
self.save()
return self

View file

@ -1,4 +1,5 @@
from typing import ClassVar, List, Literal
from datetime import datetime
from typing import ClassVar, List, Literal, Optional
from loguru import logger
from podcastfy.client import generate_podcast
@ -27,13 +28,15 @@ class PodcastConfig(ObjectModel):
conversation_style: List[str]
engagement_technique: List[str]
dialogue_structure: List[str]
user_instructions: str
wordcount: int = Field(gt=500, lt=10000)
user_instructions: Optional[str] = None
ending_message: Optional[str] = None
wordcount: int = Field(ge=400, le=10000)
creativity: float = Field(ge=0, le=1)
provider: Literal["openai", "elevenlabs", "edge"] = Field(default="openai")
voice1: str
voice2: str
voice1: Optional[str] = None
voice2: Optional[str] = None
model: str
created: Optional[datetime] = Field(default_factory=datetime.now)
def generate_episode(self, episode_name, text, instructions=None):
self.user_instructions = (
@ -52,7 +55,7 @@ class PodcastConfig(ObjectModel):
"engagement_techniques": self.engagement_technique,
"creativity": self.creativity,
"text_to_speech": {
# "temp_audio_dir": "./data/audio/tmp",
# "temp_audio_dir": f"{PODCASTS_FOLDER}/tmp",
"ending_message": "Thank you for listening to this episode. Don't forget to subscribe to our podcast for more interesting conversations.",
"default_tts_model": self.provider,
self.provider: {
@ -66,8 +69,6 @@ class PodcastConfig(ObjectModel):
},
}
logger.error(conversation_config)
# conversation_config = {}
logger.debug(
f"Generating episode {episode_name} with config {conversation_config}"
)
@ -75,7 +76,6 @@ class PodcastConfig(ObjectModel):
audio_file = generate_podcast(
conversation_config=conversation_config, text=text, tts_model=self.provider
)
logger.warning(audio_file)
episode = PodcastEpisode(
name=episode_name,
template=self.name,
@ -85,10 +85,19 @@ class PodcastConfig(ObjectModel):
)
episode.save()
@field_validator(
"name", "podcast_name", "podcast_tagline", "output_language", "model"
)
@classmethod
def validate_required_strings(cls, value: str, field) -> str:
if value is None or value.strip() == "":
raise ValueError(f"{field.field_name} cannot be None or empty string")
return value.strip()
@field_validator("wordcount")
def validate_wordcount(cls, value):
if not 500 <= value <= 6000:
raise ValueError("Wordcount must be between 500 and 10000")
if not 400 <= value <= 6000:
raise ValueError("Wordcount must be between 400 and 10000")
return value
@field_validator("creativity")

View file

@ -6,7 +6,7 @@ from stream_app.note import note_list_item
from stream_app.source import source_list_item
st.set_page_config(
layout="wide", page_title="🔍 Open Notebook", initial_sidebar_state="expanded"
layout="wide", page_title="🔍 Search", initial_sidebar_state="expanded"
)
# search_tab, ask_tab = st.tabs(["Search", "Ask"])

View file

@ -10,6 +10,11 @@ from open_notebook.plugins.podcasts import (
participant_roles,
)
st.set_page_config(
layout="wide", page_title="🎙️ Podcasts", initial_sidebar_state="expanded"
)
episodes_tab, templates_tab = st.tabs(["Episodes", "Templates"])
with episodes_tab:
@ -66,6 +71,9 @@ with templates_tab:
pd_cfg["creativity"] = st.slider(
"Creativity", min_value=0.0, max_value=1.0, step=0.05
)
pd_cfg["ending_message"] = st.text_input(
"Ending Message", placeholder="Thank you for listening!"
)
pd_cfg["provider"] = st.selectbox("Provider", ["openai", "elevenlabs", "edge"])
pd_cfg["voice1"] = st.text_input(
"Voice 1", help="You can use Elevenlabs voice ID"
@ -81,10 +89,13 @@ with templates_tab:
"OpenAI: tts-1 or tts-1-hd, Elevenlabs: eleven_multilingual_v2, eleven_turbo_v2_5"
)
if st.button("Save"):
pd = PodcastConfig(**pd_cfg)
pd_cfg = {}
pd.save()
st.rerun()
try:
pd = PodcastConfig(**pd_cfg)
pd_cfg = {}
pd.save()
st.rerun()
except Exception as e:
st.error(e)
for pd_config in PodcastConfig.get_all():
with st.expander(pd_config.name):
@ -161,6 +172,12 @@ with templates_tab:
value=pd_config.creativity,
key=f"creativity_{pd_config.id}",
)
pd_config.ending_message = st.text_input(
"Ending Message",
value=pd_config.ending_message,
placeholder="Thank you for listening!",
key=f"ending_message_{pd_config.id}",
)
pd_config.provider = st.selectbox(
"Provider",
["openai", "elevenlabs", "edge"],
@ -190,8 +207,11 @@ with templates_tab:
)
if st.button("Save Config", key=f"btn_save{pd_config.id}"):
pd_config.save()
st.rerun()
try:
pd_config.save()
st.toast("Podcast template saved")
except Exception as e:
st.error(e)
if st.button("Duplicate Config", key=f"btn_duplicate{pd_config.id}"):
pd_config.name = f"{pd_config.name} - Copy"