diff --git a/open_notebook/graphs/chat.py b/open_notebook/graphs/chat.py index 52041ea..810d29f 100644 --- a/open_notebook/graphs/chat.py +++ b/open_notebook/graphs/chat.py @@ -3,8 +3,10 @@ import sqlite3 from typing import Annotated, Optional from ai_prompter import Prompter -from langchain_core.messages import SystemMessage +from langchain_core.messages import AIMessage, SystemMessage from langchain_core.runnables import RunnableConfig + +from open_notebook.utils import clean_thinking_content from langgraph.checkpoint.sqlite import SqliteSaver from langgraph.graph import END, START, StateGraph from langgraph.graph.message import add_messages @@ -66,7 +68,13 @@ def call_model_with_messages(state: ThreadState, config: RunnableConfig) -> dict ) ai_message = model.invoke(payload) - return {"messages": ai_message} + + # Clean thinking content from AI response (e.g., ... tags) + content = ai_message.content if isinstance(ai_message.content, str) else str(ai_message.content) + cleaned_content = clean_thinking_content(content) + cleaned_message = AIMessage(content=cleaned_content) + + return {"messages": cleaned_message} conn = sqlite3.connect( diff --git a/open_notebook/graphs/source_chat.py b/open_notebook/graphs/source_chat.py index 868164d..a173a56 100644 --- a/open_notebook/graphs/source_chat.py +++ b/open_notebook/graphs/source_chat.py @@ -3,8 +3,10 @@ import sqlite3 from typing import Annotated, Dict, List, Optional from ai_prompter import Prompter -from langchain_core.messages import SystemMessage +from langchain_core.messages import AIMessage, SystemMessage from langchain_core.runnables import RunnableConfig + +from open_notebook.utils import clean_thinking_content from langgraph.checkpoint.sqlite import SqliteSaver from langgraph.graph import END, START, StateGraph from langgraph.graph.message import add_messages @@ -154,9 +156,14 @@ def call_model_with_source_context( ai_message = model.invoke(payload) + # Clean thinking content from AI response (e.g., ... tags) + content = ai_message.content if isinstance(ai_message.content, str) else str(ai_message.content) + cleaned_content = clean_thinking_content(content) + cleaned_message = AIMessage(content=cleaned_content) + # Update state with context information return { - "messages": ai_message, + "messages": cleaned_message, "source": source, "insights": insights, "context": formatted_context, diff --git a/open_notebook/utils/text_utils.py b/open_notebook/utils/text_utils.py index 0024ca3..b2a7720 100644 --- a/open_notebook/utils/text_utils.py +++ b/open_notebook/utils/text_utils.py @@ -11,8 +11,11 @@ from langchain_text_splitters import RecursiveCharacterTextSplitter from .token_utils import token_count -# Pattern for matching thinking content in AI responses +# Patterns for matching thinking content in AI responses +# Standard pattern: ... THINK_PATTERN = re.compile(r"(.*?)", re.DOTALL) +# Pattern for malformed output: content (missing opening tag) +THINK_PATTERN_NO_OPEN = re.compile(r"^(.*?)", re.DOTALL) def split_text(txt: str, chunk_size=500): @@ -77,6 +80,9 @@ def parse_thinking_content(content: str) -> Tuple[str, str]: """ Parse message content to extract thinking content from tags. + Handles both well-formed tags and malformed output where the opening + tag is missing but is present. + Args: content (str): The original message content @@ -101,22 +107,31 @@ def parse_thinking_content(content: str) -> Tuple[str, str]: if len(content) > 100000: return "", content - # Find all thinking blocks + # Find all well-formed thinking blocks thinking_matches = THINK_PATTERN.findall(content) - if not thinking_matches: - return "", content + if thinking_matches: + # Join all thinking content with double newlines + thinking_content = "\n\n".join(match.strip() for match in thinking_matches) - # Join all thinking content with double newlines - thinking_content = "\n\n".join(match.strip() for match in thinking_matches) + # Remove all ... blocks from the original content + cleaned_content = THINK_PATTERN.sub("", content) - # Remove all ... blocks from the original content - cleaned_content = THINK_PATTERN.sub("", content) + # Clean up extra whitespace + cleaned_content = re.sub(r"\n\s*\n\s*\n", "\n\n", cleaned_content).strip() - # Clean up extra whitespace - cleaned_content = re.sub(r"\n\s*\n\s*\n", "\n\n", cleaned_content).strip() + return thinking_content, cleaned_content - return thinking_content, cleaned_content + # Handle malformed output: content (missing opening tag) + # Some models like Nemotron output thinking without the opening tag + malformed_match = THINK_PATTERN_NO_OPEN.match(content) + if malformed_match: + thinking_content = malformed_match.group(1).strip() + # Remove the thinking content and tag + cleaned_content = content[malformed_match.end():].strip() + return thinking_content, cleaned_content + + return "", content def clean_thinking_content(content: str) -> str: