add model router and improve prompts

This commit is contained in:
LUIS NOVO 2024-10-22 18:24:24 -03:00
parent f96fc580b3
commit 9042b08ae3
13 changed files with 173 additions and 236 deletions

View file

@ -1,17 +1,27 @@
# YOUR LLM API KEYS
# DEFAULT MODEL_CONFIGURATIONS
DEFAULT_MODEL="openai/gpt-4o-mini"
SUMMARIZATION_MODEL="openai/gpt-4o-mini"
RETRIEVAL_MODEL="openai/gpt-4o-mini"
# OPENAI
# USE MODEL NAMES AS "openai/<modelname>"
# EXAMPLE - openai/gpt-4o-mini
OPENAI_API_KEY=
# ANTHROPIC
# USE MODEL NAMES AS "anthropic/<modelname>"
# EXAMPLE - anthropic/claude-3-5-sonnet-20240620
ANTHROPIC_API_KEY=
# MODEL_CONFIGURATIONS
# Only OpenAI models are supported for now
DEFAULT_MODEL="gpt-4o-mini" # This is the default model used for all the features
SUMMARIZATION_MODEL="gpt-4o-mini" # This is the model used for summarization, defaults to the DEFAULT_MODEL if empty
RETRIEVAL_MODEL="gpt-4o-mini" # This is the model used for retrieval, defaults to the DEFAULT_MODEL if empty
# OLLAMA
# USE MODEL NAMES AS "ollama/<modelname>"
# EXAMPLE - ollama/gemma2
OLLAMA_API_BASE="http://10.20.30.20:11434"
# OPEN ROUTER
# USE MODEL NAMES AS "openrouter/<modelname>"
# EXAMPLE - openrouter/nvidia/llama-3.1-nemotron-70b-instruct
OPENROUTER_BASE_URL="https://openrouter.ai/api/v1"
OPENROUTER_API_KEY=
@ -21,7 +31,6 @@ OPENROUTER_API_KEY=
# LANGCHAIN_API_KEY=
# LANGCHAIN_PROJECT="Open Notebook"
# CONNECTION DETAILS FOR YOUR SURREAL DB
SURREAL_ADDRESS="ws://localhost:8000/rpc"
SURREAL_USER="root"

View file

@ -345,7 +345,7 @@ class Source(ObjectModel):
try:
config = RunnableConfig(configurable=dict(thread_id=self.id))
result = summarizer.invoke({"content": self.full_text}, config=config)[
"summary"
"output"
]
self._add_insight("summary", surreal_clean(result.summary))
self.title = surreal_clean(result.title)
@ -355,7 +355,7 @@ class Source(ObjectModel):
except Exception as e:
logger.error(f"Error summarizing source {self.id}: {str(e)}")
logger.exception(e)
raise DatabaseOperationError("Failed to summarize source")
raise DatabaseOperationError(e)
class Note(ObjectModel):

View file

@ -1,39 +1,17 @@
import os
import sqlite3
from typing import Annotated, List, Optional
from typing import Annotated, Optional
from langchain_core.runnables import (
RunnableConfig,
)
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.sqlite import SqliteSaver
from langgraph.graph import START, StateGraph
from langgraph.graph import END, START, StateGraph
from langgraph.graph.message import add_messages
from langgraph.prebuilt import ToolNode, tools_condition
from loguru import logger
from pydantic import BaseModel, Field
from typing_extensions import TypedDict
from open_notebook.domain import Notebook
from open_notebook.graphs.tools import ask_the_document, get_current_timestamp
from open_notebook.prompter import Prompter
tools = [get_current_timestamp, ask_the_document]
tool_node = ToolNode(tools)
class ChatResponse(BaseModel):
"""Respond to the user with this"""
title: Optional[str] = Field(
description="A title to be used if your question would become a new note on the project"
)
message: str = Field(
description="The actual message you'd like to reply to the user"
)
citations: Optional[List[str]] = Field(
description="The ids for the documents you used to formulate your answer"
)
from open_notebook.graphs.utils import run_pattern
class ThreadState(TypedDict):
@ -41,17 +19,16 @@ class ThreadState(TypedDict):
notebook: Optional[Notebook]
context: Optional[str]
context_config: Optional[dict]
response: Optional[ChatResponse]
def call_model_with_messages(state: ThreadState, config: RunnableConfig) -> dict:
model = ChatOpenAI(model=os.environ["DEFAULT_MODEL"], temperature=0).bind_tools(
tools
model_name = config.get("configurable", {}).get("model_name", None)
ai_message = run_pattern(
"chat",
model_name,
messages=state["messages"],
state=state,
)
messages = state["messages"]
system_prompt = Prompter(prompt_template="chat").render(data=state)
logger.warning(f"System prompt: {system_prompt}")
ai_message = model.invoke([system_prompt] + messages)
return {"messages": ai_message}
@ -63,12 +40,6 @@ memory = SqliteSaver(conn)
agent_state = StateGraph(ThreadState)
agent_state.add_node("agent", call_model_with_messages)
agent_state.add_node("tools", tool_node)
agent_state.add_edge(START, "agent")
agent_state.add_conditional_edges(
"agent",
tools_condition,
)
agent_state.add_edge("tools", "agent")
agent_state.add_edge("agent", END)
graph = agent_state.compile(checkpointer=memory)

View file

@ -4,12 +4,10 @@ from langchain_core.runnables import (
RunnableConfig,
)
from langgraph.graph import END, START, StateGraph
from loguru import logger
from typing_extensions import TypedDict
from open_notebook.domain import Note, Notebook, Source
from open_notebook.model_configs import get_langchain_model
from open_notebook.prompter import Prompter
from open_notebook.graphs.utils import run_pattern
class DocQueryState(TypedDict):
@ -20,17 +18,11 @@ class DocQueryState(TypedDict):
notebook: Notebook
def call_model_with_messages(state: DocQueryState, config: RunnableConfig) -> dict:
if config.get("configurable", {}).get("model_name", None):
model_name = config.get("configurable", {}).get("model_name", None)
else:
model_name = os.environ.get("RETRIEVAL_MODEL", os.environ["DEFAULT_MODEL"])
model = get_langchain_model(model_name)
system_prompt = Prompter(prompt_template="doc_query").render(data=state)
logger.debug(f"System prompt: {system_prompt}")
ai_message = model.invoke(system_prompt)
return {"answer": ai_message}
def call_model(state: dict, config: RunnableConfig) -> dict:
model_name = config.get("configurable", {}).get(
"model_name", os.environ.get("RETRIEVAL_MODEL")
)
return {"answer": run_pattern("doc_query", model_name, state)}
# todo: there is probably a better way to do this and avoid repetition
@ -46,7 +38,7 @@ def get_content(state: DocQueryState) -> dict:
agent_state = StateGraph(DocQueryState)
agent_state.add_node("get_content", get_content)
agent_state.add_node("agent", call_model_with_messages)
agent_state.add_node("agent", call_model)
agent_state.add_edge(START, "get_content")
agent_state.add_edge("get_content", "agent")
agent_state.add_edge("agent", END)

View file

@ -1,35 +1,30 @@
import os
from typing import List, Literal
from langchain_core.output_parsers import PydanticOutputParser
from langchain_core.runnables import (
RunnableConfig,
)
from langchain_openai import ChatOpenAI
from langgraph.graph import END, START, StateGraph
from langgraph.prebuilt import ToolNode
from pydantic import BaseModel, Field
from pydantic import BaseModel
from typing_extensions import TypedDict
from open_notebook.graphs.tools import get_current_timestamp
from open_notebook.prompter import Prompter
from open_notebook.graphs.utils import run_pattern
from open_notebook.utils import split_text
tools = [get_current_timestamp]
tool_node = ToolNode(tools)
class SummaryResponse(BaseModel):
"""Respond to the user with this"""
"""This is schema of your response. Please provide a JSON object with the enclosed keys"""
summary: str = Field(description="The summary of the content")
topics: List[str] = Field(description="List of 4-7 topics related to this content")
title: str = Field(description="The title of the content")
summary: str
topics: List[str]
title: str
class SummaryState(TypedDict):
chunks: List[str]
content: str
summary: SummaryResponse
output: SummaryResponse
def build_chunks(state: SummaryState) -> dict:
@ -63,19 +58,19 @@ def chunk_condition(state: SummaryState) -> Literal["get_chunk", END]: # type:
# todo: build a helper method for LLM communication on all graphs
def call_model_with_messages(state: SummaryState, config: RunnableConfig) -> dict:
model = (
ChatOpenAI(
model=os.environ.get("SUMMARIZATION_MODEL", os.environ["DEFAULT_MODEL"]),
temperature=0,
)
.bind_tools(tools)
.with_structured_output(SummaryResponse)
def call_model(state: SummaryState, config: RunnableConfig) -> dict:
model_name = config.get("configurable", {}).get(
"model_name", os.environ.get("SUMMARIZATION_MODEL")
)
system_prompt = Prompter(prompt_template="summarize").render(data=state)
ai_message = model.invoke(system_prompt)
return {"summary": ai_message}
parser = PydanticOutputParser(pydantic_object=SummaryResponse)
return {
"output": run_pattern(
pattern_name="summarize",
model_name=model_name,
state=state,
parser=parser,
)
}
agent_state = StateGraph(SummaryState)
@ -86,7 +81,7 @@ agent_state.add_conditional_edges(
chunk_condition,
)
agent_state.add_node("get_chunk", setup_next_chunk)
agent_state.add_node("agent", call_model_with_messages)
agent_state.add_node("agent", call_model)
agent_state.add_edge("get_chunk", "agent")
agent_state.add_conditional_edges(
"agent",

View file

@ -0,0 +1,43 @@
import os
from langchain.output_parsers import OutputFixingParser
from loguru import logger
from open_notebook.llm_router import get_langchain_model
from open_notebook.prompter import Prompter
def run_pattern(
pattern_name: str,
model_name=None,
messages=[],
state: dict = {},
parser=None,
output_fixing_model_name=None,
) -> dict:
if not model_name:
model_name = os.environ["DEFAULT_MODEL"]
chain = get_langchain_model(model_name)
if parser:
chain = chain | parser
if output_fixing_model_name and parser:
output_fix_model = get_langchain_model(output_fixing_model_name)
chain = chain | OutputFixingParser.from_llm(
parser=parser,
llm=output_fix_model,
)
system_prompt = Prompter(prompt_template=pattern_name, parser=parser).render(
data=state
)
# logger.debug(f"System prompt: {system_prompt}")
if len(messages) > 0:
logger.warning(messages)
response = chain.invoke([system_prompt] + messages)
else:
response = chain.invoke(system_prompt)
return response

View file

@ -0,0 +1,33 @@
from open_notebook.llms import (
AnthropicLanguageModel,
LiteLLMLanguageModel,
OllamaLanguageModel,
OpenAILanguageModel,
OpenRouterLanguageModel,
VertexAILanguageModel,
VertexAnthropicLanguageModel,
)
# Map provider names to classes
PROVIDER_CLASS_MAP = {
"ollama": OllamaLanguageModel,
"openrouter": OpenRouterLanguageModel,
"vertexai-anthropic": VertexAnthropicLanguageModel,
"litellm": LiteLLMLanguageModel,
"vertexai": VertexAILanguageModel,
"anthropic": AnthropicLanguageModel,
"openai": OpenAILanguageModel,
}
def get_langchain_model(model_name, json=False):
parts = model_name.split("/")
provider = parts[0]
model_name_wihout_provider = "/".join(parts[1:])
if provider not in PROVIDER_CLASS_MAP.keys():
raise ValueError(
f"Provider {provider} not found in config. Make sure you use the correct format for model names, example: openai/gpt-4o-mini"
)
return PROVIDER_CLASS_MAP[provider](
model_name=model_name_wihout_provider, json=json
).to_langchain()

View file

@ -1,78 +0,0 @@
from open_notebook.language_models import (
AnthropicLanguageModel,
LiteLLMLanguageModel,
OllamaLanguageModel,
OpenAILanguageModel,
OpenRouterLanguageModel,
VertexAILanguageModel,
VertexAnthropicLanguageModel,
)
LANGUAGE_MODEL_CONFIG = {
"OLLAMA": {
"class": OllamaLanguageModel,
"models": [
"mistral-nemo:latest",
"llama3.1:8b",
"qwen2.5:32b",
"nemotron-mini:latest",
"phi3.5:latest",
"gemma2",
"bling-phi-3.gguf",
"granite3-dense:8b",
"granite3-moe:latest",
"hermes3",
"llama3.2",
"phi3.5:3.8b-mini-instruct-fp16",
"phi3:14b",
"wizardlm2",
"zephyr",
"solar-pro",
],
},
"OPEN_ROUTER": {
"class": OpenRouterLanguageModel,
"models": [
"nvidia/llama-3.1-nemotron-70b-instruct",
"anthropic/claude-3.5-sonnet",
"google/gemini-flash-1.5",
],
},
"VERTEX_ANTHROPIC": {
"class": VertexAnthropicLanguageModel,
"models": ["claude-3-5-sonnet@20240620"],
},
"LITELLM": {
"class": LiteLLMLanguageModel,
"models": ["ollama/mistral-nemo:latest", "ollama/llama3.1:8b"],
},
"VERTEX_AI": {
"class": VertexAILanguageModel,
"models": ["gemini-1.5-flash-001", "gemini-1.5-pro-001"],
},
"ANTHROPIC": {
"class": AnthropicLanguageModel,
"models": ["claude-3-5-sonnet-20240620"],
},
"OPEN_AI": {"class": OpenAILanguageModel, "models": ["gpt-4o-mini", "gpt-4o"]},
}
# EMBEDDING_MODEL_CONFIG = {
# "OPEN_AI": {
# "class": OpenAIEmbeddingModel,
# "models": ["text-embedding-3-large"],
# "dimensions": [3072],
# },
# }
def get_model_class(model_name):
for config in LANGUAGE_MODEL_CONFIG.values():
if model_name in config["models"]:
return config["class"]
raise ValueError(f"Model {model_name} not found in config")
def get_langchain_model(model_name, json=False):
model_class = get_model_class(model_name=model_name)
return model_class(model_name=model_name, json=json).to_langchain()

View file

@ -30,7 +30,7 @@ class Prompter:
template: Optional[Union[str, Template]] = None
parser: Optional[Any] = None
def __init__(self, prompt_template=None, prompt_text=None):
def __init__(self, prompt_template=None, prompt_text=None, parser=None):
"""
Initialize the Prompter with either a template file or raw text.
@ -40,6 +40,7 @@ class Prompter:
"""
self.prompt_template = prompt_template
self.prompt_text = prompt_text
self.parser = parser
self.setup()
def setup(self):

View file

@ -1,45 +1,22 @@
# SYSTEM ROLE
You are a cognitive study assistant that helps users research and learn by engaging in focused discussions about documents in their workspace. You have access to project context and can analyze documents in detail using specialized tools.
# BACKGROUND
# CAPABILITIES
- Access to project information and selected documents (CONTEXT)
- Can engage in natural dialogue while maintaining academic rigor
Your are a cognitive assistant that helps me study and research.
# FORMULATE YOUR DATA
- Generate your answer based on the CONTEXT information
- Ensure that your response is accurate and relevant to the user's query
# OUR WORKING FRAMEWORK
{% if notebook %}
# PROJECT INFORMATION
We are working within a virtual Notebook,
which is a learning workspace for a specific project.
You have access to some information about the project,
the contents that are selected for discussion, and relevant contexts.
Your goal is to respond to the user's commands and questions,
using purely the content in your context.
# YOUR TOOLS
You might find that some of the documents in the CONTEXT are worth an extra look. For that, you can use the `ask_the_document` tool.
Just ask the question as if you were talking to someone that knows the document deeply and the tool will provide you with the answer.
Use the document id to specify which source or note you'd like to ask about.
# INSTRUCTIONS
- You can ask tools until you are satisfied with the information
- You have a optional field in your answer called title. Only use this field if you believe your answer is important to be saved as a note. If it's just a quick chat, send an empty string to it.
- Please add to the citations list all the ids for sources and notes that you used for your anwer.
# RESPONSE FORMAT
```
[ANSWER]
### CITATIONS
- id1
- id2
- etc
```
# PROJECT INFO
{{ notebook }}
{{notebook}}
{% endif %}
{% if context %}
# CONTEXT
{{ context }}
{{context}}
{% endif %}

View file

@ -1,11 +0,0 @@
# MISSION
You are a Sparse Priming Representation (SPR) writer. An SPR is a particular kind of use of language for advanced NLP, NLU, and NLG tasks, particularly useful for the latest generation of Large Language Models (LLMs). You will be given information by the USER which you are to render as an SPR.
# THEORY
LLMs are a kind of deep neural network. They have been demonstrated to embed knowledge, abilities, and concepts, ranging from reasoning to planning, and even to theory of mind. These are called latent abilities and latent content, collectively referred to as latent space. The latent space of an LLM can be activated with the correct series of words as inputs, which will create a useful internal state of the neural network. This is not unlike how the right shorthand cues can prime a human mind to think in a certain way. Like human minds, LLMs are associative, meaning you only need to use the correct associations to "prime" another model to think in the same way.
# METHODOLOGY
Render the input as a distilled list of succinct statements, assertions, associations, concepts, analogies, and metaphors. The idea is to capture as much, conceptually, as possible but with as few words as possible. Write it in a way that makes sense to you, as the future audience will be another language model, not a human. Use complete sentences.
{# thanks to https://github.com/daveshap/SparsePrimingRepresentations #}

View file

@ -1,28 +1,33 @@
{% include "spr.jinja" %}
# SYSTEM ROLE
You are a content summarization assistant that creates dense, information-rich summaries optimized for machine understanding. Your summaries should capture key concepts with minimal words while maintaining complete, clear sentences.
# YOUR TASK
# TASK
Analyze the provided content and create a summary that:
- Captures the core concepts and key information
- Uses clear, direct language
- Maintains context from any previous summaries
- Includes relevant topics/tags
- Creates an appropriate title
You are part of a content summarization platform.
Sometimes, you need to summarize the content gradually since it might be very big.
Please summarize the content below in a few sentences, making it the most complete, dense and SPR compatible as you can.
# OUTPUT SCHEMA
{'summary': {'type': 'string'},
'topics': {'items': {'type': 'string'}, 'type': 'array'},
'title': {'type': 'string'}}
## INSTRUCTIONS
# OUTPUT EXAMPLE
{
"title": "The title of the content",
"topics": ["topic1", "topic2"],
"summary": "The summary of the content"
}
- If the content already has a current summary, rewrite the summary to add the new information without losing the previous context
- Always make it dense and SPR compatible
- Do not reply with anything feedback or message other than the summary itself
## FORMATTING INSTRUCTIONS
{{ format_instructions }}
## CONTENT
# CONTENT
{{content}}
## PREVIOUS SUMMARY
{% if summary %}
# PREVIOUS SUMMARY
{{summary}}
## SUMMARY
{% endif %}