From 15048b08392b4da3489bb4e69cbe1ad1730e9a7b Mon Sep 17 00:00:00 2001 From: LUIS NOVO Date: Fri, 1 Nov 2024 21:32:40 -0300 Subject: [PATCH] simplify model provisioning --- open_notebook/graphs/chat.py | 7 ++--- open_notebook/graphs/utils.py | 11 ++----- stream_app/utils.py | 54 +++++++++++++++++++++++++++++++++-- 3 files changed, 58 insertions(+), 14 deletions(-) diff --git a/open_notebook/graphs/chat.py b/open_notebook/graphs/chat.py index 5d0939f..c87af21 100644 --- a/open_notebook/graphs/chat.py +++ b/open_notebook/graphs/chat.py @@ -24,10 +24,9 @@ class ThreadState(TypedDict): def call_model_with_messages(state: ThreadState, config: RunnableConfig) -> dict: system_prompt = Prompter(prompt_template="chat").render(data=state) - model = provision_model( - str(system_prompt) + str(state.get("messages", [])), config, "chat" - ) - ai_message = model.invoke([system_prompt] + state.get("messages", [])) + payload = [system_prompt] + state.get("messages", []) + model = provision_model(str(payload), config, "chat") + ai_message = model.invoke(payload, []) return {"messages": ai_message} diff --git a/open_notebook/graphs/utils.py b/open_notebook/graphs/utils.py index ab78147..5d8339c 100644 --- a/open_notebook/graphs/utils.py +++ b/open_notebook/graphs/utils.py @@ -39,9 +39,8 @@ def run_pattern( system_prompt = Prompter(prompt_template=pattern_name, parser=parser).render( data=state ) - chain = provision_model( - str(system_prompt) + str(messages), config, "transformation" - ) + payload = [system_prompt] + messages + chain = provision_model(str(payload), config, "transformation") if parser: chain = chain | parser @@ -53,10 +52,6 @@ def run_pattern( llm=output_fix_model, ) - # todo: precisa deste if? - if len(messages) > 0: - response = chain.invoke([system_prompt] + messages) - else: - response = chain.invoke(system_prompt) + response = chain.invoke(payload) return response diff --git a/stream_app/utils.py b/stream_app/utils.py index 55d9db1..911305c 100644 --- a/stream_app/utils.py +++ b/stream_app/utils.py @@ -1,6 +1,8 @@ import streamlit as st +from open_notebook.database.migrate import MigrationManager from open_notebook.graphs.chat import ThreadState, graph +from open_notebook.models import model_manager from open_notebook.utils import ( compare_versions, get_installed_version, @@ -40,9 +42,57 @@ def setup_stream_state(session_id) -> None: existing_state = graph.get_state({"configurable": {"thread_id": session_id}}).values if len(existing_state.keys()) == 0: st.session_state[session_id] = ThreadState( - messages=[], context=None, notebook=None, context_config={}, response=None + messages=[], context=None, notebook=None, context_config={} ) else: st.session_state[session_id] = existing_state st.session_state["active_session"] = session_id - st.session_state["active_session"] = session_id + + +def check_migration(): + mm = MigrationManager() + if mm.needs_migration: + st.warning("The Open Notebook database needs a migration to run properly.") + if st.button("Run Migration"): + mm.run_migration_up() + st.success("Migration successful") + st.rerun() + st.stop() + + +def check_models(): + default_models = model_manager.defaults + if ( + not default_models.default_chat_model + or not default_models.default_transformation_model + ): + st.warning( + "You don't have default chat and transformation models selected. Please, select them on the settings page." + ) + st.stop() + elif not default_models.default_embedding_model: + st.warning( + "You don't have a default embedding model selected. Vector search will not be possible and your assistant will be less able to answer your queries. Please, select one on the settings page." + ) + st.stop() + elif not default_models.default_speech_to_text_model: + st.warning( + "You don't have a default speech to text model selected. Your assistant will not be able to transcribe audio. Please, select one on the settings page." + ) + st.stop() + elif not default_models.default_text_to_speech_model: + st.warning( + "You don't have a default text to speech model selected. Your assistant will not be able to generate audio and podcasts. Please, select one on the settings page." + ) + st.stop() + elif not default_models.large_context_model: + st.warning( + "You don't have a large context model selected. Your assistant will not be able to process large documents. Please, select one on the settings page." + ) + st.stop() + + +def page_commons(): + version_sidebar() + check_migration() + check_models()