simplify model provisioning
This commit is contained in:
parent
fcd883f393
commit
15048b0839
3 changed files with 58 additions and 14 deletions
|
|
@ -24,10 +24,9 @@ class ThreadState(TypedDict):
|
|||
|
||||
def call_model_with_messages(state: ThreadState, config: RunnableConfig) -> dict:
|
||||
system_prompt = Prompter(prompt_template="chat").render(data=state)
|
||||
model = provision_model(
|
||||
str(system_prompt) + str(state.get("messages", [])), config, "chat"
|
||||
)
|
||||
ai_message = model.invoke([system_prompt] + state.get("messages", []))
|
||||
payload = [system_prompt] + state.get("messages", [])
|
||||
model = provision_model(str(payload), config, "chat")
|
||||
ai_message = model.invoke(payload, [])
|
||||
return {"messages": ai_message}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -39,9 +39,8 @@ def run_pattern(
|
|||
system_prompt = Prompter(prompt_template=pattern_name, parser=parser).render(
|
||||
data=state
|
||||
)
|
||||
chain = provision_model(
|
||||
str(system_prompt) + str(messages), config, "transformation"
|
||||
)
|
||||
payload = [system_prompt] + messages
|
||||
chain = provision_model(str(payload), config, "transformation")
|
||||
|
||||
if parser:
|
||||
chain = chain | parser
|
||||
|
|
@ -53,10 +52,6 @@ def run_pattern(
|
|||
llm=output_fix_model,
|
||||
)
|
||||
|
||||
# todo: precisa deste if?
|
||||
if len(messages) > 0:
|
||||
response = chain.invoke([system_prompt] + messages)
|
||||
else:
|
||||
response = chain.invoke(system_prompt)
|
||||
response = chain.invoke(payload)
|
||||
|
||||
return response
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
import streamlit as st
|
||||
|
||||
from open_notebook.database.migrate import MigrationManager
|
||||
from open_notebook.graphs.chat import ThreadState, graph
|
||||
from open_notebook.models import model_manager
|
||||
from open_notebook.utils import (
|
||||
compare_versions,
|
||||
get_installed_version,
|
||||
|
|
@ -40,9 +42,57 @@ def setup_stream_state(session_id) -> None:
|
|||
existing_state = graph.get_state({"configurable": {"thread_id": session_id}}).values
|
||||
if len(existing_state.keys()) == 0:
|
||||
st.session_state[session_id] = ThreadState(
|
||||
messages=[], context=None, notebook=None, context_config={}, response=None
|
||||
messages=[], context=None, notebook=None, context_config={}
|
||||
)
|
||||
else:
|
||||
st.session_state[session_id] = existing_state
|
||||
st.session_state["active_session"] = session_id
|
||||
st.session_state["active_session"] = session_id
|
||||
|
||||
|
||||
def check_migration():
|
||||
mm = MigrationManager()
|
||||
if mm.needs_migration:
|
||||
st.warning("The Open Notebook database needs a migration to run properly.")
|
||||
if st.button("Run Migration"):
|
||||
mm.run_migration_up()
|
||||
st.success("Migration successful")
|
||||
st.rerun()
|
||||
st.stop()
|
||||
|
||||
|
||||
def check_models():
|
||||
default_models = model_manager.defaults
|
||||
if (
|
||||
not default_models.default_chat_model
|
||||
or not default_models.default_transformation_model
|
||||
):
|
||||
st.warning(
|
||||
"You don't have default chat and transformation models selected. Please, select them on the settings page."
|
||||
)
|
||||
st.stop()
|
||||
elif not default_models.default_embedding_model:
|
||||
st.warning(
|
||||
"You don't have a default embedding model selected. Vector search will not be possible and your assistant will be less able to answer your queries. Please, select one on the settings page."
|
||||
)
|
||||
st.stop()
|
||||
elif not default_models.default_speech_to_text_model:
|
||||
st.warning(
|
||||
"You don't have a default speech to text model selected. Your assistant will not be able to transcribe audio. Please, select one on the settings page."
|
||||
)
|
||||
st.stop()
|
||||
elif not default_models.default_text_to_speech_model:
|
||||
st.warning(
|
||||
"You don't have a default text to speech model selected. Your assistant will not be able to generate audio and podcasts. Please, select one on the settings page."
|
||||
)
|
||||
st.stop()
|
||||
elif not default_models.large_context_model:
|
||||
st.warning(
|
||||
"You don't have a large context model selected. Your assistant will not be able to process large documents. Please, select one on the settings page."
|
||||
)
|
||||
st.stop()
|
||||
|
||||
|
||||
def page_commons():
|
||||
version_sidebar()
|
||||
check_migration()
|
||||
check_models()
|
||||
|
|
|
|||
Loading…
Reference in a new issue