Merge pull request #544 from danrush777/fix/gemini-envelope-parsing
fix: handle structured content format in LLM response parsing
This commit is contained in:
commit
96525b4457
8 changed files with 43 additions and 35 deletions
|
|
@ -14,6 +14,7 @@ from open_notebook.domain.notebook import vector_search
|
|||
from open_notebook.exceptions import OpenNotebookError
|
||||
from open_notebook.utils import clean_thinking_content
|
||||
from open_notebook.utils.error_classifier import classify_error
|
||||
from open_notebook.utils.text_utils import extract_text_content
|
||||
|
||||
|
||||
class SubGraphState(TypedDict):
|
||||
|
|
@ -65,11 +66,7 @@ async def call_model_with_messages(state: ThreadState, config: RunnableConfig) -
|
|||
ai_message = await model.ainvoke(system_prompt)
|
||||
|
||||
# Clean the thinking content from the response
|
||||
message_content = (
|
||||
ai_message.content
|
||||
if isinstance(ai_message.content, str)
|
||||
else str(ai_message.content)
|
||||
)
|
||||
message_content = extract_text_content(ai_message.content)
|
||||
cleaned_content = clean_thinking_content(message_content)
|
||||
|
||||
# Parse the cleaned JSON content
|
||||
|
|
@ -118,11 +115,7 @@ async def provide_answer(state: SubGraphState, config: RunnableConfig) -> dict:
|
|||
max_tokens=2000,
|
||||
)
|
||||
ai_message = await model.ainvoke(system_prompt)
|
||||
ai_content = (
|
||||
ai_message.content
|
||||
if isinstance(ai_message.content, str)
|
||||
else str(ai_message.content)
|
||||
)
|
||||
ai_content = extract_text_content(ai_message.content)
|
||||
return {"answers": [clean_thinking_content(ai_content)]}
|
||||
except OpenNotebookError:
|
||||
raise
|
||||
|
|
@ -141,11 +134,7 @@ async def write_final_answer(state: ThreadState, config: RunnableConfig) -> dict
|
|||
max_tokens=2000,
|
||||
)
|
||||
ai_message = await model.ainvoke(system_prompt)
|
||||
final_content = (
|
||||
ai_message.content
|
||||
if isinstance(ai_message.content, str)
|
||||
else str(ai_message.content)
|
||||
)
|
||||
final_content = extract_text_content(ai_message.content)
|
||||
return {"final_answer": clean_thinking_content(final_content)}
|
||||
except OpenNotebookError:
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ from open_notebook.domain.notebook import Notebook
|
|||
from open_notebook.exceptions import OpenNotebookError
|
||||
from open_notebook.utils import clean_thinking_content
|
||||
from open_notebook.utils.error_classifier import classify_error
|
||||
from open_notebook.utils.text_utils import extract_text_content
|
||||
|
||||
|
||||
class ThreadState(TypedDict):
|
||||
|
|
@ -72,11 +73,7 @@ def call_model_with_messages(state: ThreadState, config: RunnableConfig) -> dict
|
|||
ai_message = model.invoke(payload)
|
||||
|
||||
# Clean thinking content from AI response (e.g., <think>...</think> tags)
|
||||
content = (
|
||||
ai_message.content
|
||||
if isinstance(ai_message.content, str)
|
||||
else str(ai_message.content)
|
||||
)
|
||||
content = extract_text_content(ai_message.content)
|
||||
cleaned_content = clean_thinking_content(content)
|
||||
cleaned_message = ai_message.model_copy(update={"content": cleaned_content})
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from langgraph.graph import END, START, StateGraph
|
|||
from typing_extensions import TypedDict
|
||||
|
||||
from open_notebook.ai.provision import provision_langchain_model
|
||||
from open_notebook.utils.text_utils import clean_thinking_content
|
||||
from open_notebook.utils.text_utils import clean_thinking_content, extract_text_content
|
||||
|
||||
|
||||
class PatternChainState(TypedDict):
|
||||
|
|
@ -33,7 +33,7 @@ async def call_model(state: dict, config: RunnableConfig) -> dict:
|
|||
response = await chain.ainvoke(payload)
|
||||
|
||||
# Clean thinking tags from response (handles extended thinking models)
|
||||
output = clean_thinking_content(str(response.content))
|
||||
output = clean_thinking_content(extract_text_content(response.content))
|
||||
return {"output": output}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from open_notebook.exceptions import OpenNotebookError
|
|||
from open_notebook.utils import clean_thinking_content
|
||||
from open_notebook.utils.context_builder import ContextBuilder
|
||||
from open_notebook.utils.error_classifier import classify_error
|
||||
from open_notebook.utils.text_utils import extract_text_content
|
||||
|
||||
|
||||
class SourceChatState(TypedDict):
|
||||
|
|
@ -172,11 +173,7 @@ def _call_model_with_source_context_inner(
|
|||
ai_message = model.invoke(payload)
|
||||
|
||||
# Clean thinking content from AI response (e.g., <think>...</think> tags)
|
||||
content = (
|
||||
ai_message.content
|
||||
if isinstance(ai_message.content, str)
|
||||
else str(ai_message.content)
|
||||
)
|
||||
content = extract_text_content(ai_message.content)
|
||||
cleaned_content = clean_thinking_content(content)
|
||||
cleaned_message = ai_message.model_copy(update={"content": cleaned_content})
|
||||
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from open_notebook.domain.transformation import DefaultPrompts, Transformation
|
|||
from open_notebook.exceptions import OpenNotebookError
|
||||
from open_notebook.utils import clean_thinking_content
|
||||
from open_notebook.utils.error_classifier import classify_error
|
||||
from open_notebook.utils.text_utils import extract_text_content
|
||||
|
||||
|
||||
class TransformationState(TypedDict):
|
||||
|
|
@ -51,9 +52,7 @@ async def run_transformation(state: dict, config: RunnableConfig) -> dict:
|
|||
response = await chain.ainvoke(payload)
|
||||
|
||||
# Clean thinking content from the response
|
||||
response_content = (
|
||||
response.content if isinstance(response.content, str) else str(response.content)
|
||||
)
|
||||
response_content = extract_text_content(response.content)
|
||||
cleaned_content = clean_thinking_content(response_content)
|
||||
|
||||
if source:
|
||||
|
|
|
|||
|
|
@ -117,3 +117,29 @@ def clean_thinking_content(content: str) -> str:
|
|||
"""
|
||||
_, cleaned_content = parse_thinking_content(content)
|
||||
return cleaned_content
|
||||
|
||||
|
||||
def extract_text_content(content) -> str:
|
||||
"""Extract text from LLM response content.
|
||||
|
||||
Handles both plain string responses and structured content formats
|
||||
(e.g. Gemini's envelope format):
|
||||
[{'type': 'text', 'text': '...', 'extras': {...}}]
|
||||
|
||||
Args:
|
||||
content: The content from an AI message, either a string or a list of parts.
|
||||
|
||||
Returns:
|
||||
The extracted text content as a string.
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
text_parts = []
|
||||
for part in content:
|
||||
if isinstance(part, dict) and "text" in part:
|
||||
text_parts.append(part["text"])
|
||||
elif isinstance(part, str):
|
||||
text_parts.append(part)
|
||||
return "".join(text_parts)
|
||||
return str(content)
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ dependencies = [
|
|||
"ai-prompter>=0.3,<1",
|
||||
"esperanto>=2.19.3,<3",
|
||||
"surrealdb>=1.0.4",
|
||||
"podcast-creator>=0.9.1,<1",
|
||||
"podcast-creator>=0.9.4,<1",
|
||||
"surreal-commands>=1.3.1,<2",
|
||||
"numpy>=2.4.1",
|
||||
]
|
||||
|
|
|
|||
8
uv.lock
8
uv.lock
|
|
@ -2168,7 +2168,7 @@ requires-dist = [
|
|||
{ name = "loguru", specifier = ">=0.7.2" },
|
||||
{ name = "mypy", marker = "extra == 'dev'", specifier = ">=1.11.1" },
|
||||
{ name = "numpy", specifier = ">=2.4.1" },
|
||||
{ name = "podcast-creator", specifier = ">=0.9.1,<1" },
|
||||
{ name = "podcast-creator", specifier = ">=0.9.4,<1" },
|
||||
{ name = "pre-commit", marker = "extra == 'dev'", specifier = ">=4.0.1" },
|
||||
{ name = "pydantic", specifier = ">=2.9.2" },
|
||||
{ name = "pytest", marker = "extra == 'dev'", specifier = ">=8.0.0" },
|
||||
|
|
@ -2519,7 +2519,7 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "podcast-creator"
|
||||
version = "0.9.1"
|
||||
version = "0.9.4"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "ai-prompter" },
|
||||
|
|
@ -2535,9 +2535,9 @@ dependencies = [
|
|||
{ name = "requests" },
|
||||
{ name = "tiktoken" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/7d/de/f7ee60b502dad23b724d669be31fdeb6a790e306968c2cd6a079388262be/podcast_creator-0.9.1.tar.gz", hash = "sha256:177ae68b18c7efd815e555dcec3c644e541bd053e2c63669fd0a18a008b2f374", size = 470751, upload-time = "2026-02-16T17:58:44.275Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/97/4a/9f23b55659d7d236645593a4b75141837ed88568ba6a6a370b01d97827e6/podcast_creator-0.9.4.tar.gz", hash = "sha256:9e40a77c105d0b02f04a3eef7881a34454ef556fabd8297fe68d50307ca5f926", size = 472357, upload-time = "2026-02-17T20:21:57.257Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/e0/d7/687284d059fc490a19d60af8f07a66b19895e15946e7ced143096d3c5ea0/podcast_creator-0.9.1-py3-none-any.whl", hash = "sha256:e3e513f2aacccd96c15bcab891216ff447568551c4392b3f12575aa0cf0cbeee", size = 74421, upload-time = "2026-02-16T17:58:42.818Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/0c/ac/b331aae683771964f0574189c8dbc1bc0c7b22aca9a376d61c3248180848/podcast_creator-0.9.4-py3-none-any.whl", hash = "sha256:2bd1138cbd1a4deda9da657e7e2b9c8a7d8c0cc43c649506af4837aeb708d46f", size = 74844, upload-time = "2026-02-17T20:21:58.271Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
|
|||
Loading…
Reference in a new issue