change model provisioning strategy
This commit is contained in:
parent
d9c0c93deb
commit
b4ba3ef4c8
2 changed files with 7 additions and 5 deletions
|
|
@ -25,7 +25,7 @@ class ThreadState(TypedDict):
|
|||
def call_model_with_messages(state: ThreadState, config: RunnableConfig) -> dict:
|
||||
system_prompt = Prompter(prompt_template="chat").render(data=state)
|
||||
payload = [system_prompt] + state.get("messages", [])
|
||||
model = provision_langchain_model(str(payload), config, "chat")
|
||||
model = provision_langchain_model(str(payload), config, "chat", max_tokens=2000)
|
||||
ai_message = model.invoke(payload)
|
||||
return {"messages": ai_message}
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from open_notebook.prompter import Prompter
|
|||
from open_notebook.utils import token_count
|
||||
|
||||
|
||||
def provision_langchain_model(content, config, default_type) -> BaseChatModel:
|
||||
def provision_langchain_model(content, config, default_type, **kwargs) -> BaseChatModel:
|
||||
"""
|
||||
Returns the best model to use based on the context size and on whether there is a specific model being requested in Config.
|
||||
If context > 105_000, returns the large_context_model
|
||||
|
|
@ -21,11 +21,13 @@ def provision_langchain_model(content, config, default_type) -> BaseChatModel:
|
|||
logger.debug(
|
||||
f"Using large context model because the content has {tokens} tokens"
|
||||
)
|
||||
model = model_manager.get_default_model("large_context")
|
||||
model = model_manager.get_default_model("large_context", **kwargs)
|
||||
elif config.get("configurable", {}).get("model_id"):
|
||||
model = model_manager.get_model(config.get("configurable", {}).get("model_id"))
|
||||
model = model_manager.get_model(
|
||||
config.get("configurable", {}).get("model_id"), **kwargs
|
||||
)
|
||||
else:
|
||||
model = model_manager.get_default_model(default_type)
|
||||
model = model_manager.get_default_model(default_type, **kwargs)
|
||||
|
||||
assert isinstance(model, LanguageModel), f"Model is not a LanguageModel: {model}"
|
||||
return model.to_langchain()
|
||||
|
|
|
|||
Loading…
Reference in a new issue