improve object validation
This commit is contained in:
parent
9719146210
commit
8dac6dd2ac
4 changed files with 66 additions and 23 deletions
|
|
@ -4,7 +4,7 @@ from typing import Any, ClassVar, Dict, List, Literal, Optional, Type, TypeVar
|
|||
|
||||
from langchain_core.runnables.config import RunnableConfig
|
||||
from loguru import logger
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, Field, ValidationError, field_validator
|
||||
|
||||
from open_notebook.exceptions import (
|
||||
DatabaseOperationError,
|
||||
|
|
@ -35,7 +35,13 @@ class ObjectModel(BaseModel):
|
|||
def get_all(cls: Type[T]) -> List[T]:
|
||||
try:
|
||||
result = repo_query(f"SELECT * FROM {cls.table_name}")
|
||||
objects = [cls(**obj) for obj in result]
|
||||
objects = []
|
||||
for obj in result:
|
||||
try:
|
||||
objects.append(cls(**obj))
|
||||
except Exception as e:
|
||||
logger.critical(f"Error creating object: {str(e)}")
|
||||
|
||||
return objects
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching all {cls.table_name}: {str(e)}")
|
||||
|
|
@ -64,6 +70,8 @@ class ObjectModel(BaseModel):
|
|||
|
||||
def save(self) -> None:
|
||||
try:
|
||||
logger.debug(f"Validating {self.__class__.__name__}")
|
||||
self.model_validate(self.model_dump(), strict=True)
|
||||
data = self._prepare_save_data()
|
||||
|
||||
if self.needs_embedding():
|
||||
|
|
@ -86,6 +94,13 @@ class ObjectModel(BaseModel):
|
|||
else:
|
||||
setattr(self, key, value)
|
||||
|
||||
except ValidationError as e:
|
||||
logger.error(f"Validation failed: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving record: {e}")
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving {self.__class__.table_name}: {str(e)}")
|
||||
logger.exception(e)
|
||||
|
|
@ -139,7 +154,7 @@ class Notebook(ObjectModel):
|
|||
def sources(self) -> List["Source"]:
|
||||
try:
|
||||
srcs = repo_query(f"""
|
||||
select * from (
|
||||
select * OMIT full_text from (
|
||||
select
|
||||
<- source as source
|
||||
from reference
|
||||
|
|
@ -158,7 +173,7 @@ class Notebook(ObjectModel):
|
|||
def notes(self) -> List["Note"]:
|
||||
try:
|
||||
srcs = repo_query(f"""
|
||||
select * from (
|
||||
select * OMIT content from (
|
||||
select
|
||||
<- note as note
|
||||
from artifact
|
||||
|
|
@ -322,7 +337,6 @@ class Source(ObjectModel):
|
|||
output = pattern_graph.invoke(
|
||||
dict(content_stack=[result["toc"]], transformations=transformations)
|
||||
)
|
||||
logger.warning(output["output"])
|
||||
self.title = surreal_clean(output["output"])
|
||||
self.save()
|
||||
return self
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from typing import ClassVar, List, Literal
|
||||
from datetime import datetime
|
||||
from typing import ClassVar, List, Literal, Optional
|
||||
|
||||
from loguru import logger
|
||||
from podcastfy.client import generate_podcast
|
||||
|
|
@ -27,13 +28,15 @@ class PodcastConfig(ObjectModel):
|
|||
conversation_style: List[str]
|
||||
engagement_technique: List[str]
|
||||
dialogue_structure: List[str]
|
||||
user_instructions: str
|
||||
wordcount: int = Field(gt=500, lt=10000)
|
||||
user_instructions: Optional[str] = None
|
||||
ending_message: Optional[str] = None
|
||||
wordcount: int = Field(ge=400, le=10000)
|
||||
creativity: float = Field(ge=0, le=1)
|
||||
provider: Literal["openai", "elevenlabs", "edge"] = Field(default="openai")
|
||||
voice1: str
|
||||
voice2: str
|
||||
voice1: Optional[str] = None
|
||||
voice2: Optional[str] = None
|
||||
model: str
|
||||
created: Optional[datetime] = Field(default_factory=datetime.now)
|
||||
|
||||
def generate_episode(self, episode_name, text, instructions=None):
|
||||
self.user_instructions = (
|
||||
|
|
@ -52,7 +55,7 @@ class PodcastConfig(ObjectModel):
|
|||
"engagement_techniques": self.engagement_technique,
|
||||
"creativity": self.creativity,
|
||||
"text_to_speech": {
|
||||
# "temp_audio_dir": "./data/audio/tmp",
|
||||
# "temp_audio_dir": f"{PODCASTS_FOLDER}/tmp",
|
||||
"ending_message": "Thank you for listening to this episode. Don't forget to subscribe to our podcast for more interesting conversations.",
|
||||
"default_tts_model": self.provider,
|
||||
self.provider: {
|
||||
|
|
@ -66,8 +69,6 @@ class PodcastConfig(ObjectModel):
|
|||
},
|
||||
}
|
||||
|
||||
logger.error(conversation_config)
|
||||
# conversation_config = {}
|
||||
logger.debug(
|
||||
f"Generating episode {episode_name} with config {conversation_config}"
|
||||
)
|
||||
|
|
@ -75,7 +76,6 @@ class PodcastConfig(ObjectModel):
|
|||
audio_file = generate_podcast(
|
||||
conversation_config=conversation_config, text=text, tts_model=self.provider
|
||||
)
|
||||
logger.warning(audio_file)
|
||||
episode = PodcastEpisode(
|
||||
name=episode_name,
|
||||
template=self.name,
|
||||
|
|
@ -85,10 +85,19 @@ class PodcastConfig(ObjectModel):
|
|||
)
|
||||
episode.save()
|
||||
|
||||
@field_validator(
|
||||
"name", "podcast_name", "podcast_tagline", "output_language", "model"
|
||||
)
|
||||
@classmethod
|
||||
def validate_required_strings(cls, value: str, field) -> str:
|
||||
if value is None or value.strip() == "":
|
||||
raise ValueError(f"{field.field_name} cannot be None or empty string")
|
||||
return value.strip()
|
||||
|
||||
@field_validator("wordcount")
|
||||
def validate_wordcount(cls, value):
|
||||
if not 500 <= value <= 6000:
|
||||
raise ValueError("Wordcount must be between 500 and 10000")
|
||||
if not 400 <= value <= 6000:
|
||||
raise ValueError("Wordcount must be between 400 and 10000")
|
||||
return value
|
||||
|
||||
@field_validator("creativity")
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from stream_app.note import note_list_item
|
|||
from stream_app.source import source_list_item
|
||||
|
||||
st.set_page_config(
|
||||
layout="wide", page_title="🔍 Open Notebook", initial_sidebar_state="expanded"
|
||||
layout="wide", page_title="🔍 Search", initial_sidebar_state="expanded"
|
||||
)
|
||||
|
||||
# search_tab, ask_tab = st.tabs(["Search", "Ask"])
|
||||
|
|
|
|||
|
|
@ -10,6 +10,11 @@ from open_notebook.plugins.podcasts import (
|
|||
participant_roles,
|
||||
)
|
||||
|
||||
st.set_page_config(
|
||||
layout="wide", page_title="🎙️ Podcasts", initial_sidebar_state="expanded"
|
||||
)
|
||||
|
||||
|
||||
episodes_tab, templates_tab = st.tabs(["Episodes", "Templates"])
|
||||
|
||||
with episodes_tab:
|
||||
|
|
@ -66,6 +71,9 @@ with templates_tab:
|
|||
pd_cfg["creativity"] = st.slider(
|
||||
"Creativity", min_value=0.0, max_value=1.0, step=0.05
|
||||
)
|
||||
pd_cfg["ending_message"] = st.text_input(
|
||||
"Ending Message", placeholder="Thank you for listening!"
|
||||
)
|
||||
pd_cfg["provider"] = st.selectbox("Provider", ["openai", "elevenlabs", "edge"])
|
||||
pd_cfg["voice1"] = st.text_input(
|
||||
"Voice 1", help="You can use Elevenlabs voice ID"
|
||||
|
|
@ -81,10 +89,13 @@ with templates_tab:
|
|||
"OpenAI: tts-1 or tts-1-hd, Elevenlabs: eleven_multilingual_v2, eleven_turbo_v2_5"
|
||||
)
|
||||
if st.button("Save"):
|
||||
pd = PodcastConfig(**pd_cfg)
|
||||
pd_cfg = {}
|
||||
pd.save()
|
||||
st.rerun()
|
||||
try:
|
||||
pd = PodcastConfig(**pd_cfg)
|
||||
pd_cfg = {}
|
||||
pd.save()
|
||||
st.rerun()
|
||||
except Exception as e:
|
||||
st.error(e)
|
||||
|
||||
for pd_config in PodcastConfig.get_all():
|
||||
with st.expander(pd_config.name):
|
||||
|
|
@ -161,6 +172,12 @@ with templates_tab:
|
|||
value=pd_config.creativity,
|
||||
key=f"creativity_{pd_config.id}",
|
||||
)
|
||||
pd_config.ending_message = st.text_input(
|
||||
"Ending Message",
|
||||
value=pd_config.ending_message,
|
||||
placeholder="Thank you for listening!",
|
||||
key=f"ending_message_{pd_config.id}",
|
||||
)
|
||||
pd_config.provider = st.selectbox(
|
||||
"Provider",
|
||||
["openai", "elevenlabs", "edge"],
|
||||
|
|
@ -190,8 +207,11 @@ with templates_tab:
|
|||
)
|
||||
|
||||
if st.button("Save Config", key=f"btn_save{pd_config.id}"):
|
||||
pd_config.save()
|
||||
st.rerun()
|
||||
try:
|
||||
pd_config.save()
|
||||
st.toast("Podcast template saved")
|
||||
except Exception as e:
|
||||
st.error(e)
|
||||
|
||||
if st.button("Duplicate Config", key=f"btn_duplicate{pd_config.id}"):
|
||||
pd_config.name = f"{pd_config.name} - Copy"
|
||||
|
|
|
|||
Loading…
Reference in a new issue