change model provisioning parameters

This commit is contained in:
LUIS NOVO 2024-11-08 16:08:54 -03:00
parent 99b8ada280
commit 183149014e
3 changed files with 14 additions and 21 deletions

View file

@ -25,7 +25,12 @@ class ThreadState(TypedDict):
def call_model_with_messages(state: ThreadState, config: RunnableConfig) -> dict:
system_prompt = Prompter(prompt_template="chat").render(data=state)
payload = [system_prompt] + state.get("messages", [])
model = provision_langchain_model(str(payload), config, "chat", max_tokens=2000)
model = provision_langchain_model(
str(payload),
config.get("configurable", {}).get("model_id"),
"chat",
max_tokens=2000,
)
ai_message = model.invoke(payload)
return {"messages": ai_message}

View file

@ -1,10 +1,7 @@
from datetime import datetime
from typing import List
from langchain.tools import tool
from open_notebook.domain.notebook import hybrid_search
# todo: turn this into a system prompt variable
@tool
@ -14,14 +11,3 @@ def get_current_timestamp() -> str:
Returns the current timestamp in the format YYYYMMDDHHmmss.
"""
return datetime.now().strftime("%Y%m%d%H%M%S")
@tool
def repository_search(keyword_searches: List[str], vector_searches: List[str]) -> str:
"""
name: repository_search
Makes a search in the content repository for the given query.
keyword_searches: List[str] - A list of search terms to search for using keyword search.
vector_searches: List[str] - A list of search terms to search for using vector search.
"""
return hybrid_search(keyword_searches, vector_searches, 20)

View file

@ -8,7 +8,9 @@ from open_notebook.prompter import Prompter
from open_notebook.utils import token_count
def provision_langchain_model(content, config, default_type, **kwargs) -> BaseChatModel:
def provision_langchain_model(
content, model_id, default_type, **kwargs
) -> BaseChatModel:
"""
Returns the best model to use based on the context size and on whether there is a specific model being requested in Config.
If context > 105_000, returns the large_context_model
@ -22,10 +24,8 @@ def provision_langchain_model(content, config, default_type, **kwargs) -> BaseCh
f"Using large context model because the content has {tokens} tokens"
)
model = model_manager.get_default_model("large_context", **kwargs)
elif config.get("configurable", {}).get("model_id"):
model = model_manager.get_model(
config.get("configurable", {}).get("model_id"), **kwargs
)
elif model_id:
model = model_manager.get_model(model_id, **kwargs)
else:
model = model_manager.get_default_model(default_type, **kwargs)
@ -45,7 +45,9 @@ def run_pattern(
data=state
)
payload = [system_prompt] + messages
chain = provision_langchain_model(str(payload), config, "transformation")
chain = provision_langchain_model(
str(payload), config.get("configurable", {}).get("model_id"), "transformation"
)
response = chain.invoke(payload)