feat: extract think tags from reasoning models

This commit is contained in:
LUIS NOVO 2025-06-26 11:41:15 -03:00
parent 01dc2240a2
commit 7eee271232
4 changed files with 104 additions and 10 deletions

View file

@ -3,9 +3,7 @@ from typing import Annotated, List
from ai_prompter import Prompter
from langchain_core.output_parsers.pydantic import PydanticOutputParser
from langchain_core.runnables import (
RunnableConfig,
)
from langchain_core.runnables import RunnableConfig
from langgraph.graph import END, START, StateGraph
from langgraph.types import Send
from pydantic import BaseModel, Field
@ -13,6 +11,7 @@ from typing_extensions import TypedDict
from open_notebook.domain.notebook import vector_search
from open_notebook.graphs.utils import provision_langchain_model
from open_notebook.utils import clean_thinking_content
class SubGraphState(TypedDict):
@ -59,10 +58,19 @@ async def call_model_with_messages(state: ThreadState, config: RunnableConfig) -
config.get("configurable", {}).get("strategy_model"),
"tools",
max_tokens=2000,
structured=dict(type="json"),
)
# model = model.bind_tools(tools)
ai_message = (model | parser).invoke(system_prompt)
return {"strategy": ai_message}
# First get the raw response from the model
ai_message = model.invoke(system_prompt)
# Clean the thinking content from the response
cleaned_content = clean_thinking_content(ai_message.content)
# Parse the cleaned JSON content
strategy = parser.parse(cleaned_content)
return {"strategy": strategy}
async def trigger_queries(state: ThreadState, config: RunnableConfig):
@ -99,7 +107,7 @@ async def provide_answer(state: SubGraphState, config: RunnableConfig) -> dict:
max_tokens=2000,
)
ai_message = model.invoke(system_prompt)
return {"answers": [ai_message.content]}
return {"answers": [clean_thinking_content(ai_message.content)]}
async def write_final_answer(state: ThreadState, config: RunnableConfig) -> dict:
@ -111,7 +119,7 @@ async def write_final_answer(state: ThreadState, config: RunnableConfig) -> dict
max_tokens=2000,
)
ai_message = model.invoke(system_prompt)
return {"final_answer": ai_message.content}
return {"final_answer": clean_thinking_content(ai_message.content)}
agent_state = StateGraph(ThreadState)

View file

@ -7,6 +7,7 @@ from typing_extensions import TypedDict
from open_notebook.domain.notebook import Source
from open_notebook.domain.transformation import DefaultPrompts, Transformation
from open_notebook.graphs.utils import provision_langchain_model
from open_notebook.utils import clean_thinking_content
class TransformationState(TypedDict):
@ -42,11 +43,15 @@ def run_transformation(state: dict, config: RunnableConfig) -> dict:
)
response = chain.invoke(payload)
# Clean thinking content from the response
cleaned_content = clean_thinking_content(response.content)
if source:
source.add_insight(transformation.title, response.content)
source.add_insight(transformation.title, cleaned_content)
return {
"output": response.content,
"output": cleaned_content,
}

View file

@ -1,6 +1,7 @@
import re
import unicodedata
from importlib.metadata import PackageNotFoundError, version
from typing import Tuple
from urllib.parse import urlparse
import requests
@ -217,3 +218,66 @@ def compare_versions(version1: str, version2: str) -> int:
return 1
else:
return 0
def parse_thinking_content(content: str) -> Tuple[str, str]:
"""
Parse message content to extract thinking content from <think> tags.
Args:
content (str): The original message content
Returns:
Tuple[str, str]: (thinking_content, cleaned_content)
- thinking_content: Content from within <think> tags
- cleaned_content: Original content with <think> blocks removed
Example:
>>> content = "<think>Let me analyze this</think>Here's my answer"
>>> thinking, cleaned = parse_thinking_content(content)
>>> print(thinking)
"Let me analyze this"
>>> print(cleaned)
"Here's my answer"
"""
# Pattern to match <think>...</think> blocks (including multiline)
think_pattern = r'<think>(.*?)</think>'
# Find all thinking blocks
thinking_matches = re.findall(think_pattern, content, re.DOTALL)
if not thinking_matches:
return "", content
# Join all thinking content with double newlines
thinking_content = "\n\n".join(match.strip() for match in thinking_matches)
# Remove all <think>...</think> blocks from the original content
cleaned_content = re.sub(think_pattern, "", content, flags=re.DOTALL)
# Clean up extra whitespace
cleaned_content = re.sub(r'\n\s*\n\s*\n', '\n\n', cleaned_content).strip()
return thinking_content, cleaned_content
def clean_thinking_content(content: str) -> str:
"""
Remove thinking content from AI responses, returning only the cleaned content.
This is a convenience function for cases where you only need the cleaned
content and don't need access to the thinking process.
Args:
content (str): The original message content with potential <think> tags
Returns:
str: Content with <think> blocks removed and whitespace cleaned
Example:
>>> content = "<think>Let me think...</think>Here's the answer"
>>> clean_thinking_content(content)
"Here's the answer"
"""
_, cleaned_content = parse_thinking_content(content)
return cleaned_content

View file

@ -14,6 +14,8 @@ from pages.stream_app.utils import (
create_session_for_notebook,
)
from open_notebook.utils import parse_thinking_content
from .note import make_note_from_chat
@ -186,11 +188,26 @@ def chat_sidebar(current_notebook: Notebook, current_session: ChatSession):
continue
with st.chat_message(name=msg.type):
st.markdown(convert_source_references(msg.content))
if msg.type == "ai":
# Parse thinking content for AI messages
thinking_content, cleaned_content = parse_thinking_content(msg.content)
# Show thinking content in expander if present
if thinking_content:
with st.expander("🤔 AI Reasoning", expanded=False):
st.markdown(thinking_content)
# Show the cleaned regular content
if cleaned_content:
st.markdown(convert_source_references(cleaned_content))
# New Note button for AI messages
if st.button("💾 New Note", key=f"render_save_{msg.id}"):
make_note_from_chat(
content=msg.content,
notebook_id=current_notebook.id,
)
st.rerun()
else:
# Human messages - display normally
st.markdown(convert_source_references(msg.content))