diff --git a/open_notebook/graphs/chat.py b/open_notebook/graphs/chat.py index 84a30be..88f5cb1 100644 --- a/open_notebook/graphs/chat.py +++ b/open_notebook/graphs/chat.py @@ -22,12 +22,12 @@ class ThreadState(TypedDict): def call_model_with_messages(state: ThreadState, config: RunnableConfig) -> dict: - model_name = config.get("configurable", {}).get( - "model_name", DEFAULT_MODELS.default_chat_model + model_id = config.get("configurable", {}).get( + "model_id", DEFAULT_MODELS.default_chat_model ) ai_message = run_pattern( "chat", - model_name, + model_id, messages=state["messages"], state=state, ) diff --git a/open_notebook/graphs/doc_query.py b/open_notebook/graphs/doc_query.py index 9a2ffb9..34445e7 100644 --- a/open_notebook/graphs/doc_query.py +++ b/open_notebook/graphs/doc_query.py @@ -19,10 +19,10 @@ class DocQueryState(TypedDict): def call_model(state: dict, config: RunnableConfig) -> dict: - model_name = config.get("configurable", {}).get( - "model_name", os.environ.get("RETRIEVAL_MODEL") + model_id = config.get("configurable", {}).get( + "model_id", os.environ.get("RETRIEVAL_MODEL") ) - return {"answer": run_pattern("doc_query", model_name, state)} + return {"answer": run_pattern("doc_query", model_id, state)} # todo: there is probably a better way to do this and avoid repetition diff --git a/open_notebook/graphs/multipattern.py b/open_notebook/graphs/multipattern.py index e74d7a3..b1ef19b 100644 --- a/open_notebook/graphs/multipattern.py +++ b/open_notebook/graphs/multipattern.py @@ -18,8 +18,8 @@ class PatternChainState(TypedDict): def call_model(state: dict, config: RunnableConfig) -> dict: - model_name = config.get("configurable", {}).get( - "model_name", DEFAULT_MODELS.default_transformation_model + model_id = config.get("configurable", {}).get( + "model_id", DEFAULT_MODELS.default_transformation_model ) transformations = state["transformations"] current_transformation = transformations.pop(0) @@ -34,7 +34,7 @@ def call_model(state: dict, config: RunnableConfig) -> dict: transformation_result = run_pattern( pattern_name=current_transformation, - model_name=model_name, + model_id=model_id, state=input_args, ) return { diff --git a/open_notebook/graphs/pattern.py b/open_notebook/graphs/pattern.py index c47cc14..faf15f9 100644 --- a/open_notebook/graphs/pattern.py +++ b/open_notebook/graphs/pattern.py @@ -15,13 +15,13 @@ class PatternState(TypedDict): def call_model(state: dict, config: RunnableConfig) -> dict: - model_name = config.get("configurable", {}).get( - "model_name", DEFAULT_MODELS.default_transformation_model + model_id = config.get("configurable", {}).get( + "model_id", DEFAULT_MODELS.default_transformation_model ) return { "output": run_pattern( pattern_name=state["pattern"], - model_name=model_name, + model_id=model_id, state=state, ) } diff --git a/open_notebook/graphs/recursive_toc.py b/open_notebook/graphs/recursive_toc.py index a9cb795..f52ceae 100644 --- a/open_notebook/graphs/recursive_toc.py +++ b/open_notebook/graphs/recursive_toc.py @@ -49,13 +49,13 @@ def chunk_condition(state: TocState) -> Literal["get_chunk", END]: # type: igno def call_model(state: TocState, config: RunnableConfig) -> dict: - model_name = config.get("configurable", {}).get( - "model_name", DEFAULT_MODELS.default_transformation_model + model_id = config.get("configurable", {}).get( + "model_id", DEFAULT_MODELS.default_transformation_model ) return { "toc": run_pattern( pattern_name="recursive_toc", - model_name=model_name, + model_id=model_id, state=state, ).content } diff --git a/open_notebook/graphs/summary.py b/open_notebook/graphs/summary.py index df54ff5..8cd0ec1 100644 --- a/open_notebook/graphs/summary.py +++ b/open_notebook/graphs/summary.py @@ -59,14 +59,14 @@ def chunk_condition(state: SummaryState) -> Literal["get_chunk", END]: # type: def call_model(state: dict, config: RunnableConfig) -> dict: - model_name = config.get("configurable", {}).get( - "model_name", DEFAULT_MODELS.default_transformation_model + model_id = config.get("configurable", {}).get( + "model_id", DEFAULT_MODELS.default_transformation_model ) parser = PydanticOutputParser(pydantic_object=SummaryResponse) return { "output": run_pattern( pattern_name="summarize", - model_name=model_name, + model_id=model_id, state=state, parser=parser, ) diff --git a/open_notebook/graphs/utils.py b/open_notebook/graphs/utils.py index f004e7f..cdba57a 100644 --- a/open_notebook/graphs/utils.py +++ b/open_notebook/graphs/utils.py @@ -9,36 +9,41 @@ from open_notebook.utils import token_count def run_pattern( pattern_name: str, - model_name=None, + model_id=None, messages=[], state: dict = {}, parser=None, - output_fixing_model_name=None, + output_fixing_model_id=None, ) -> dict: system_prompt = Prompter(prompt_template=pattern_name, parser=parser).render( data=state ) tokens = token_count(str(system_prompt) + str(messages)) - if tokens > 105_000 and DEFAULT_MODELS.large_context_model: - model_name = DEFAULT_MODELS.large_context_model - logger.debug( - f"Using large context model ({model_name}) because the content has {tokens} tokens" - ) - logger.warning(system_prompt) - elif tokens > 105_000 and not DEFAULT_MODELS.large_context_model: - logger.critical( - f"Content has {tokens} tokens, but no large context model is configured" - ) - elif not model_name: - model_name = DEFAULT_MODELS.default_transformation_model - chain = get_model(model_name, model_type="language") + model_id = ( + DEFAULT_MODELS.large_context_model + or DEFAULT_MODELS.default_transformation_model + or DEFAULT_MODELS.default_chat_model + ) + if tokens > 105_000: + model_id = DEFAULT_MODELS.large_context_model + logger.debug( + f"Using large context model ({model_id}) because the content has {tokens} tokens" + ) + + model_id = ( + model_id + or DEFAULT_MODELS.default_transformation_model + or DEFAULT_MODELS.default_chat_model + ) + + chain = get_model(model_id, model_type="language") if parser: chain = chain | parser - if output_fixing_model_name and parser: - output_fix_model = get_model(output_fixing_model_name, model_type="language") + if output_fixing_model_id and parser: + output_fix_model = get_model(output_fixing_model_id, model_type="language") chain = chain | OutputFixingParser.from_llm( parser=parser, llm=output_fix_model,