diff --git a/open_notebook/graphs/chat.py b/open_notebook/graphs/chat.py index 5e3b4ca..7342ca3 100644 --- a/open_notebook/graphs/chat.py +++ b/open_notebook/graphs/chat.py @@ -25,7 +25,12 @@ class ThreadState(TypedDict): def call_model_with_messages(state: ThreadState, config: RunnableConfig) -> dict: system_prompt = Prompter(prompt_template="chat").render(data=state) payload = [system_prompt] + state.get("messages", []) - model = provision_langchain_model(str(payload), config, "chat", max_tokens=2000) + model = provision_langchain_model( + str(payload), + config.get("configurable", {}).get("model_id"), + "chat", + max_tokens=2000, + ) ai_message = model.invoke(payload) return {"messages": ai_message} diff --git a/open_notebook/graphs/tools.py b/open_notebook/graphs/tools.py index 620fac4..9c3df13 100644 --- a/open_notebook/graphs/tools.py +++ b/open_notebook/graphs/tools.py @@ -1,10 +1,7 @@ from datetime import datetime -from typing import List from langchain.tools import tool -from open_notebook.domain.notebook import hybrid_search - # todo: turn this into a system prompt variable @tool @@ -14,14 +11,3 @@ def get_current_timestamp() -> str: Returns the current timestamp in the format YYYYMMDDHHmmss. """ return datetime.now().strftime("%Y%m%d%H%M%S") - - -@tool -def repository_search(keyword_searches: List[str], vector_searches: List[str]) -> str: - """ - name: repository_search - Makes a search in the content repository for the given query. - keyword_searches: List[str] - A list of search terms to search for using keyword search. - vector_searches: List[str] - A list of search terms to search for using vector search. - """ - return hybrid_search(keyword_searches, vector_searches, 20) diff --git a/open_notebook/graphs/utils.py b/open_notebook/graphs/utils.py index 07365ea..3c79d85 100644 --- a/open_notebook/graphs/utils.py +++ b/open_notebook/graphs/utils.py @@ -8,7 +8,9 @@ from open_notebook.prompter import Prompter from open_notebook.utils import token_count -def provision_langchain_model(content, config, default_type, **kwargs) -> BaseChatModel: +def provision_langchain_model( + content, model_id, default_type, **kwargs +) -> BaseChatModel: """ Returns the best model to use based on the context size and on whether there is a specific model being requested in Config. If context > 105_000, returns the large_context_model @@ -22,10 +24,8 @@ def provision_langchain_model(content, config, default_type, **kwargs) -> BaseCh f"Using large context model because the content has {tokens} tokens" ) model = model_manager.get_default_model("large_context", **kwargs) - elif config.get("configurable", {}).get("model_id"): - model = model_manager.get_model( - config.get("configurable", {}).get("model_id"), **kwargs - ) + elif model_id: + model = model_manager.get_model(model_id, **kwargs) else: model = model_manager.get_default_model(default_type, **kwargs) @@ -45,7 +45,9 @@ def run_pattern( data=state ) payload = [system_prompt] + messages - chain = provision_langchain_model(str(payload), config, "transformation") + chain = provision_langchain_model( + str(payload), config.get("configurable", {}).get("model_id"), "transformation" + ) response = chain.invoke(payload)