change model provisioning parameters
This commit is contained in:
parent
99b8ada280
commit
183149014e
3 changed files with 14 additions and 21 deletions
|
|
@ -25,7 +25,12 @@ class ThreadState(TypedDict):
|
|||
def call_model_with_messages(state: ThreadState, config: RunnableConfig) -> dict:
|
||||
system_prompt = Prompter(prompt_template="chat").render(data=state)
|
||||
payload = [system_prompt] + state.get("messages", [])
|
||||
model = provision_langchain_model(str(payload), config, "chat", max_tokens=2000)
|
||||
model = provision_langchain_model(
|
||||
str(payload),
|
||||
config.get("configurable", {}).get("model_id"),
|
||||
"chat",
|
||||
max_tokens=2000,
|
||||
)
|
||||
ai_message = model.invoke(payload)
|
||||
return {"messages": ai_message}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,7 @@
|
|||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
from langchain.tools import tool
|
||||
|
||||
from open_notebook.domain.notebook import hybrid_search
|
||||
|
||||
|
||||
# todo: turn this into a system prompt variable
|
||||
@tool
|
||||
|
|
@ -14,14 +11,3 @@ def get_current_timestamp() -> str:
|
|||
Returns the current timestamp in the format YYYYMMDDHHmmss.
|
||||
"""
|
||||
return datetime.now().strftime("%Y%m%d%H%M%S")
|
||||
|
||||
|
||||
@tool
|
||||
def repository_search(keyword_searches: List[str], vector_searches: List[str]) -> str:
|
||||
"""
|
||||
name: repository_search
|
||||
Makes a search in the content repository for the given query.
|
||||
keyword_searches: List[str] - A list of search terms to search for using keyword search.
|
||||
vector_searches: List[str] - A list of search terms to search for using vector search.
|
||||
"""
|
||||
return hybrid_search(keyword_searches, vector_searches, 20)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,9 @@ from open_notebook.prompter import Prompter
|
|||
from open_notebook.utils import token_count
|
||||
|
||||
|
||||
def provision_langchain_model(content, config, default_type, **kwargs) -> BaseChatModel:
|
||||
def provision_langchain_model(
|
||||
content, model_id, default_type, **kwargs
|
||||
) -> BaseChatModel:
|
||||
"""
|
||||
Returns the best model to use based on the context size and on whether there is a specific model being requested in Config.
|
||||
If context > 105_000, returns the large_context_model
|
||||
|
|
@ -22,10 +24,8 @@ def provision_langchain_model(content, config, default_type, **kwargs) -> BaseCh
|
|||
f"Using large context model because the content has {tokens} tokens"
|
||||
)
|
||||
model = model_manager.get_default_model("large_context", **kwargs)
|
||||
elif config.get("configurable", {}).get("model_id"):
|
||||
model = model_manager.get_model(
|
||||
config.get("configurable", {}).get("model_id"), **kwargs
|
||||
)
|
||||
elif model_id:
|
||||
model = model_manager.get_model(model_id, **kwargs)
|
||||
else:
|
||||
model = model_manager.get_default_model(default_type, **kwargs)
|
||||
|
||||
|
|
@ -45,7 +45,9 @@ def run_pattern(
|
|||
data=state
|
||||
)
|
||||
payload = [system_prompt] + messages
|
||||
chain = provision_langchain_model(str(payload), config, "transformation")
|
||||
chain = provision_langchain_model(
|
||||
str(payload), config.get("configurable", {}).get("model_id"), "transformation"
|
||||
)
|
||||
|
||||
response = chain.invoke(payload)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue