From 9042b08ae3395e262633f6b541d4d640a56f8ce3 Mon Sep 17 00:00:00 2001 From: LUIS NOVO Date: Tue, 22 Oct 2024 18:24:24 -0300 Subject: [PATCH] add model router and improve prompts --- .env.example | 25 ++++-- open_notebook/domain.py | 4 +- open_notebook/graphs/chat.py | 49 +++--------- open_notebook/graphs/doc_query.py | 22 ++---- open_notebook/graphs/summary.py | 47 +++++------ open_notebook/graphs/utils.py | 43 ++++++++++ open_notebook/llm_router.py | 33 ++++++++ open_notebook/{language_models.py => llms.py} | 0 open_notebook/model_configs.py | 78 ------------------- open_notebook/prompter.py | 3 +- prompts/chat.jinja | 53 ++++--------- prompts/spr.jinja | 11 --- prompts/summarize.jinja | 41 +++++----- 13 files changed, 173 insertions(+), 236 deletions(-) create mode 100644 open_notebook/graphs/utils.py create mode 100644 open_notebook/llm_router.py rename open_notebook/{language_models.py => llms.py} (100%) delete mode 100644 open_notebook/model_configs.py delete mode 100644 prompts/spr.jinja diff --git a/.env.example b/.env.example index 3f60403..3363878 100644 --- a/.env.example +++ b/.env.example @@ -1,17 +1,27 @@ -# YOUR LLM API KEYS + +# DEFAULT MODEL_CONFIGURATIONS +DEFAULT_MODEL="openai/gpt-4o-mini" +SUMMARIZATION_MODEL="openai/gpt-4o-mini" +RETRIEVAL_MODEL="openai/gpt-4o-mini" + +# OPENAI +# USE MODEL NAMES AS "openai/" +# EXAMPLE - openai/gpt-4o-mini OPENAI_API_KEY= + +# ANTHROPIC +# USE MODEL NAMES AS "anthropic/" +# EXAMPLE - anthropic/claude-3-5-sonnet-20240620 ANTHROPIC_API_KEY= -# MODEL_CONFIGURATIONS -# Only OpenAI models are supported for now -DEFAULT_MODEL="gpt-4o-mini" # This is the default model used for all the features -SUMMARIZATION_MODEL="gpt-4o-mini" # This is the model used for summarization, defaults to the DEFAULT_MODEL if empty -RETRIEVAL_MODEL="gpt-4o-mini" # This is the model used for retrieval, defaults to the DEFAULT_MODEL if empty - # OLLAMA +# USE MODEL NAMES AS "ollama/" +# EXAMPLE - ollama/gemma2 OLLAMA_API_BASE="http://10.20.30.20:11434" # OPEN ROUTER +# USE MODEL NAMES AS "openrouter/" +# EXAMPLE - openrouter/nvidia/llama-3.1-nemotron-70b-instruct OPENROUTER_BASE_URL="https://openrouter.ai/api/v1" OPENROUTER_API_KEY= @@ -21,7 +31,6 @@ OPENROUTER_API_KEY= # LANGCHAIN_API_KEY= # LANGCHAIN_PROJECT="Open Notebook" - # CONNECTION DETAILS FOR YOUR SURREAL DB SURREAL_ADDRESS="ws://localhost:8000/rpc" SURREAL_USER="root" diff --git a/open_notebook/domain.py b/open_notebook/domain.py index d94c620..5141a65 100644 --- a/open_notebook/domain.py +++ b/open_notebook/domain.py @@ -345,7 +345,7 @@ class Source(ObjectModel): try: config = RunnableConfig(configurable=dict(thread_id=self.id)) result = summarizer.invoke({"content": self.full_text}, config=config)[ - "summary" + "output" ] self._add_insight("summary", surreal_clean(result.summary)) self.title = surreal_clean(result.title) @@ -355,7 +355,7 @@ class Source(ObjectModel): except Exception as e: logger.error(f"Error summarizing source {self.id}: {str(e)}") logger.exception(e) - raise DatabaseOperationError("Failed to summarize source") + raise DatabaseOperationError(e) class Note(ObjectModel): diff --git a/open_notebook/graphs/chat.py b/open_notebook/graphs/chat.py index 8852c0d..3be8f31 100644 --- a/open_notebook/graphs/chat.py +++ b/open_notebook/graphs/chat.py @@ -1,39 +1,17 @@ import os import sqlite3 -from typing import Annotated, List, Optional +from typing import Annotated, Optional from langchain_core.runnables import ( RunnableConfig, ) -from langchain_openai import ChatOpenAI from langgraph.checkpoint.sqlite import SqliteSaver -from langgraph.graph import START, StateGraph +from langgraph.graph import END, START, StateGraph from langgraph.graph.message import add_messages -from langgraph.prebuilt import ToolNode, tools_condition -from loguru import logger -from pydantic import BaseModel, Field from typing_extensions import TypedDict from open_notebook.domain import Notebook -from open_notebook.graphs.tools import ask_the_document, get_current_timestamp -from open_notebook.prompter import Prompter - -tools = [get_current_timestamp, ask_the_document] -tool_node = ToolNode(tools) - - -class ChatResponse(BaseModel): - """Respond to the user with this""" - - title: Optional[str] = Field( - description="A title to be used if your question would become a new note on the project" - ) - message: str = Field( - description="The actual message you'd like to reply to the user" - ) - citations: Optional[List[str]] = Field( - description="The ids for the documents you used to formulate your answer" - ) +from open_notebook.graphs.utils import run_pattern class ThreadState(TypedDict): @@ -41,17 +19,16 @@ class ThreadState(TypedDict): notebook: Optional[Notebook] context: Optional[str] context_config: Optional[dict] - response: Optional[ChatResponse] def call_model_with_messages(state: ThreadState, config: RunnableConfig) -> dict: - model = ChatOpenAI(model=os.environ["DEFAULT_MODEL"], temperature=0).bind_tools( - tools + model_name = config.get("configurable", {}).get("model_name", None) + ai_message = run_pattern( + "chat", + model_name, + messages=state["messages"], + state=state, ) - messages = state["messages"] - system_prompt = Prompter(prompt_template="chat").render(data=state) - logger.warning(f"System prompt: {system_prompt}") - ai_message = model.invoke([system_prompt] + messages) return {"messages": ai_message} @@ -63,12 +40,6 @@ memory = SqliteSaver(conn) agent_state = StateGraph(ThreadState) agent_state.add_node("agent", call_model_with_messages) -agent_state.add_node("tools", tool_node) agent_state.add_edge(START, "agent") -agent_state.add_conditional_edges( - "agent", - tools_condition, -) -agent_state.add_edge("tools", "agent") - +agent_state.add_edge("agent", END) graph = agent_state.compile(checkpointer=memory) diff --git a/open_notebook/graphs/doc_query.py b/open_notebook/graphs/doc_query.py index 068d8a4..2c673db 100644 --- a/open_notebook/graphs/doc_query.py +++ b/open_notebook/graphs/doc_query.py @@ -4,12 +4,10 @@ from langchain_core.runnables import ( RunnableConfig, ) from langgraph.graph import END, START, StateGraph -from loguru import logger from typing_extensions import TypedDict from open_notebook.domain import Note, Notebook, Source -from open_notebook.model_configs import get_langchain_model -from open_notebook.prompter import Prompter +from open_notebook.graphs.utils import run_pattern class DocQueryState(TypedDict): @@ -20,17 +18,11 @@ class DocQueryState(TypedDict): notebook: Notebook -def call_model_with_messages(state: DocQueryState, config: RunnableConfig) -> dict: - if config.get("configurable", {}).get("model_name", None): - model_name = config.get("configurable", {}).get("model_name", None) - else: - model_name = os.environ.get("RETRIEVAL_MODEL", os.environ["DEFAULT_MODEL"]) - - model = get_langchain_model(model_name) - system_prompt = Prompter(prompt_template="doc_query").render(data=state) - logger.debug(f"System prompt: {system_prompt}") - ai_message = model.invoke(system_prompt) - return {"answer": ai_message} +def call_model(state: dict, config: RunnableConfig) -> dict: + model_name = config.get("configurable", {}).get( + "model_name", os.environ.get("RETRIEVAL_MODEL") + ) + return {"answer": run_pattern("doc_query", model_name, state)} # todo: there is probably a better way to do this and avoid repetition @@ -46,7 +38,7 @@ def get_content(state: DocQueryState) -> dict: agent_state = StateGraph(DocQueryState) agent_state.add_node("get_content", get_content) -agent_state.add_node("agent", call_model_with_messages) +agent_state.add_node("agent", call_model) agent_state.add_edge(START, "get_content") agent_state.add_edge("get_content", "agent") agent_state.add_edge("agent", END) diff --git a/open_notebook/graphs/summary.py b/open_notebook/graphs/summary.py index 1c771b7..27d6396 100644 --- a/open_notebook/graphs/summary.py +++ b/open_notebook/graphs/summary.py @@ -1,35 +1,30 @@ import os from typing import List, Literal +from langchain_core.output_parsers import PydanticOutputParser from langchain_core.runnables import ( RunnableConfig, ) -from langchain_openai import ChatOpenAI from langgraph.graph import END, START, StateGraph -from langgraph.prebuilt import ToolNode -from pydantic import BaseModel, Field +from pydantic import BaseModel from typing_extensions import TypedDict -from open_notebook.graphs.tools import get_current_timestamp -from open_notebook.prompter import Prompter +from open_notebook.graphs.utils import run_pattern from open_notebook.utils import split_text -tools = [get_current_timestamp] -tool_node = ToolNode(tools) - class SummaryResponse(BaseModel): - """Respond to the user with this""" + """This is schema of your response. Please provide a JSON object with the enclosed keys""" - summary: str = Field(description="The summary of the content") - topics: List[str] = Field(description="List of 4-7 topics related to this content") - title: str = Field(description="The title of the content") + summary: str + topics: List[str] + title: str class SummaryState(TypedDict): chunks: List[str] content: str - summary: SummaryResponse + output: SummaryResponse def build_chunks(state: SummaryState) -> dict: @@ -63,19 +58,19 @@ def chunk_condition(state: SummaryState) -> Literal["get_chunk", END]: # type: # todo: build a helper method for LLM communication on all graphs -def call_model_with_messages(state: SummaryState, config: RunnableConfig) -> dict: - model = ( - ChatOpenAI( - model=os.environ.get("SUMMARIZATION_MODEL", os.environ["DEFAULT_MODEL"]), - temperature=0, - ) - .bind_tools(tools) - .with_structured_output(SummaryResponse) +def call_model(state: SummaryState, config: RunnableConfig) -> dict: + model_name = config.get("configurable", {}).get( + "model_name", os.environ.get("SUMMARIZATION_MODEL") ) - - system_prompt = Prompter(prompt_template="summarize").render(data=state) - ai_message = model.invoke(system_prompt) - return {"summary": ai_message} + parser = PydanticOutputParser(pydantic_object=SummaryResponse) + return { + "output": run_pattern( + pattern_name="summarize", + model_name=model_name, + state=state, + parser=parser, + ) + } agent_state = StateGraph(SummaryState) @@ -86,7 +81,7 @@ agent_state.add_conditional_edges( chunk_condition, ) agent_state.add_node("get_chunk", setup_next_chunk) -agent_state.add_node("agent", call_model_with_messages) +agent_state.add_node("agent", call_model) agent_state.add_edge("get_chunk", "agent") agent_state.add_conditional_edges( "agent", diff --git a/open_notebook/graphs/utils.py b/open_notebook/graphs/utils.py new file mode 100644 index 0000000..8c51458 --- /dev/null +++ b/open_notebook/graphs/utils.py @@ -0,0 +1,43 @@ +import os + +from langchain.output_parsers import OutputFixingParser +from loguru import logger + +from open_notebook.llm_router import get_langchain_model +from open_notebook.prompter import Prompter + + +def run_pattern( + pattern_name: str, + model_name=None, + messages=[], + state: dict = {}, + parser=None, + output_fixing_model_name=None, +) -> dict: + if not model_name: + model_name = os.environ["DEFAULT_MODEL"] + + chain = get_langchain_model(model_name) + if parser: + chain = chain | parser + + if output_fixing_model_name and parser: + output_fix_model = get_langchain_model(output_fixing_model_name) + chain = chain | OutputFixingParser.from_llm( + parser=parser, + llm=output_fix_model, + ) + + system_prompt = Prompter(prompt_template=pattern_name, parser=parser).render( + data=state + ) + # logger.debug(f"System prompt: {system_prompt}") + + if len(messages) > 0: + logger.warning(messages) + response = chain.invoke([system_prompt] + messages) + else: + response = chain.invoke(system_prompt) + + return response diff --git a/open_notebook/llm_router.py b/open_notebook/llm_router.py new file mode 100644 index 0000000..9fdb85d --- /dev/null +++ b/open_notebook/llm_router.py @@ -0,0 +1,33 @@ +from open_notebook.llms import ( + AnthropicLanguageModel, + LiteLLMLanguageModel, + OllamaLanguageModel, + OpenAILanguageModel, + OpenRouterLanguageModel, + VertexAILanguageModel, + VertexAnthropicLanguageModel, +) + +# Map provider names to classes +PROVIDER_CLASS_MAP = { + "ollama": OllamaLanguageModel, + "openrouter": OpenRouterLanguageModel, + "vertexai-anthropic": VertexAnthropicLanguageModel, + "litellm": LiteLLMLanguageModel, + "vertexai": VertexAILanguageModel, + "anthropic": AnthropicLanguageModel, + "openai": OpenAILanguageModel, +} + + +def get_langchain_model(model_name, json=False): + parts = model_name.split("/") + provider = parts[0] + model_name_wihout_provider = "/".join(parts[1:]) + if provider not in PROVIDER_CLASS_MAP.keys(): + raise ValueError( + f"Provider {provider} not found in config. Make sure you use the correct format for model names, example: openai/gpt-4o-mini" + ) + return PROVIDER_CLASS_MAP[provider]( + model_name=model_name_wihout_provider, json=json + ).to_langchain() diff --git a/open_notebook/language_models.py b/open_notebook/llms.py similarity index 100% rename from open_notebook/language_models.py rename to open_notebook/llms.py diff --git a/open_notebook/model_configs.py b/open_notebook/model_configs.py deleted file mode 100644 index ce9fdd1..0000000 --- a/open_notebook/model_configs.py +++ /dev/null @@ -1,78 +0,0 @@ -from open_notebook.language_models import ( - AnthropicLanguageModel, - LiteLLMLanguageModel, - OllamaLanguageModel, - OpenAILanguageModel, - OpenRouterLanguageModel, - VertexAILanguageModel, - VertexAnthropicLanguageModel, -) - -LANGUAGE_MODEL_CONFIG = { - "OLLAMA": { - "class": OllamaLanguageModel, - "models": [ - "mistral-nemo:latest", - "llama3.1:8b", - "qwen2.5:32b", - "nemotron-mini:latest", - "phi3.5:latest", - "gemma2", - "bling-phi-3.gguf", - "granite3-dense:8b", - "granite3-moe:latest", - "hermes3", - "llama3.2", - "phi3.5:3.8b-mini-instruct-fp16", - "phi3:14b", - "wizardlm2", - "zephyr", - "solar-pro", - ], - }, - "OPEN_ROUTER": { - "class": OpenRouterLanguageModel, - "models": [ - "nvidia/llama-3.1-nemotron-70b-instruct", - "anthropic/claude-3.5-sonnet", - "google/gemini-flash-1.5", - ], - }, - "VERTEX_ANTHROPIC": { - "class": VertexAnthropicLanguageModel, - "models": ["claude-3-5-sonnet@20240620"], - }, - "LITELLM": { - "class": LiteLLMLanguageModel, - "models": ["ollama/mistral-nemo:latest", "ollama/llama3.1:8b"], - }, - "VERTEX_AI": { - "class": VertexAILanguageModel, - "models": ["gemini-1.5-flash-001", "gemini-1.5-pro-001"], - }, - "ANTHROPIC": { - "class": AnthropicLanguageModel, - "models": ["claude-3-5-sonnet-20240620"], - }, - "OPEN_AI": {"class": OpenAILanguageModel, "models": ["gpt-4o-mini", "gpt-4o"]}, -} - -# EMBEDDING_MODEL_CONFIG = { -# "OPEN_AI": { -# "class": OpenAIEmbeddingModel, -# "models": ["text-embedding-3-large"], -# "dimensions": [3072], -# }, -# } - - -def get_model_class(model_name): - for config in LANGUAGE_MODEL_CONFIG.values(): - if model_name in config["models"]: - return config["class"] - raise ValueError(f"Model {model_name} not found in config") - - -def get_langchain_model(model_name, json=False): - model_class = get_model_class(model_name=model_name) - return model_class(model_name=model_name, json=json).to_langchain() diff --git a/open_notebook/prompter.py b/open_notebook/prompter.py index 6782ec3..fc10679 100644 --- a/open_notebook/prompter.py +++ b/open_notebook/prompter.py @@ -30,7 +30,7 @@ class Prompter: template: Optional[Union[str, Template]] = None parser: Optional[Any] = None - def __init__(self, prompt_template=None, prompt_text=None): + def __init__(self, prompt_template=None, prompt_text=None, parser=None): """ Initialize the Prompter with either a template file or raw text. @@ -40,6 +40,7 @@ class Prompter: """ self.prompt_template = prompt_template self.prompt_text = prompt_text + self.parser = parser self.setup() def setup(self): diff --git a/prompts/chat.jinja b/prompts/chat.jinja index 5593915..0ba3471 100644 --- a/prompts/chat.jinja +++ b/prompts/chat.jinja @@ -1,45 +1,22 @@ +# SYSTEM ROLE +You are a cognitive study assistant that helps users research and learn by engaging in focused discussions about documents in their workspace. You have access to project context and can analyze documents in detail using specialized tools. -# BACKGROUND +# CAPABILITIES +- Access to project information and selected documents (CONTEXT) +- Can engage in natural dialogue while maintaining academic rigor -Your are a cognitive assistant that helps me study and research. +# FORMULATE YOUR DATA +- Generate your answer based on the CONTEXT information +- Ensure that your response is accurate and relevant to the user's query -# OUR WORKING FRAMEWORK +{% if notebook %} +# PROJECT INFORMATION -We are working within a virtual Notebook, -which is a learning workspace for a specific project. - -You have access to some information about the project, -the contents that are selected for discussion, and relevant contexts. - -Your goal is to respond to the user's commands and questions, -using purely the content in your context. - -# YOUR TOOLS - -You might find that some of the documents in the CONTEXT are worth an extra look. For that, you can use the `ask_the_document` tool. -Just ask the question as if you were talking to someone that knows the document deeply and the tool will provide you with the answer. -Use the document id to specify which source or note you'd like to ask about. - -# INSTRUCTIONS - -- You can ask tools until you are satisfied with the information -- You have a optional field in your answer called title. Only use this field if you believe your answer is important to be saved as a note. If it's just a quick chat, send an empty string to it. -- Please add to the citations list all the ids for sources and notes that you used for your anwer. - -# RESPONSE FORMAT -``` -[ANSWER] - -### CITATIONS -- id1 -- id2 -- etc -``` - -# PROJECT INFO - -{{ notebook }} +{{notebook}} +{% endif %} +{% if context %} # CONTEXT -{{ context }} +{{context}} +{% endif %} \ No newline at end of file diff --git a/prompts/spr.jinja b/prompts/spr.jinja deleted file mode 100644 index 441e0e1..0000000 --- a/prompts/spr.jinja +++ /dev/null @@ -1,11 +0,0 @@ - -# MISSION -You are a Sparse Priming Representation (SPR) writer. An SPR is a particular kind of use of language for advanced NLP, NLU, and NLG tasks, particularly useful for the latest generation of Large Language Models (LLMs). You will be given information by the USER which you are to render as an SPR. - -# THEORY -LLMs are a kind of deep neural network. They have been demonstrated to embed knowledge, abilities, and concepts, ranging from reasoning to planning, and even to theory of mind. These are called latent abilities and latent content, collectively referred to as latent space. The latent space of an LLM can be activated with the correct series of words as inputs, which will create a useful internal state of the neural network. This is not unlike how the right shorthand cues can prime a human mind to think in a certain way. Like human minds, LLMs are associative, meaning you only need to use the correct associations to "prime" another model to think in the same way. - -# METHODOLOGY -Render the input as a distilled list of succinct statements, assertions, associations, concepts, analogies, and metaphors. The idea is to capture as much, conceptually, as possible but with as few words as possible. Write it in a way that makes sense to you, as the future audience will be another language model, not a human. Use complete sentences. - -{# thanks to https://github.com/daveshap/SparsePrimingRepresentations #} \ No newline at end of file diff --git a/prompts/summarize.jinja b/prompts/summarize.jinja index c106dea..f8b65ab 100644 --- a/prompts/summarize.jinja +++ b/prompts/summarize.jinja @@ -1,28 +1,33 @@ -{% include "spr.jinja" %} +# SYSTEM ROLE +You are a content summarization assistant that creates dense, information-rich summaries optimized for machine understanding. Your summaries should capture key concepts with minimal words while maintaining complete, clear sentences. -# YOUR TASK +# TASK +Analyze the provided content and create a summary that: +- Captures the core concepts and key information +- Uses clear, direct language +- Maintains context from any previous summaries +- Includes relevant topics/tags +- Creates an appropriate title -You are part of a content summarization platform. -Sometimes, you need to summarize the content gradually since it might be very big. -Please summarize the content below in a few sentences, making it the most complete, dense and SPR compatible as you can. +# OUTPUT SCHEMA +{'summary': {'type': 'string'}, + 'topics': {'items': {'type': 'string'}, 'type': 'array'}, + 'title': {'type': 'string'}} -## INSTRUCTIONS +# OUTPUT EXAMPLE +{ + "title": "The title of the content", + "topics": ["topic1", "topic2"], + "summary": "The summary of the content" +} -- If the content already has a current summary, rewrite the summary to add the new information without losing the previous context -- Always make it dense and SPR compatible -- Do not reply with anything feedback or message other than the summary itself - -## FORMATTING INSTRUCTIONS - -{{ format_instructions }} - -## CONTENT +# CONTENT {{content}} -## PREVIOUS SUMMARY +{% if summary %} +# PREVIOUS SUMMARY {{summary}} - -## SUMMARY \ No newline at end of file +{% endif %}