refactor: implement ai_prompter library

This commit is contained in:
LUIS NOVO 2025-06-01 08:09:33 -03:00
parent 858b5e0d6e
commit 2afbd36cb4
7 changed files with 21 additions and 122 deletions

View file

@ -1,6 +1,7 @@
import operator
from typing import Annotated, List
from ai_prompter import Prompter
from langchain_core.output_parsers.pydantic import PydanticOutputParser
from langchain_core.runnables import (
RunnableConfig,
@ -12,7 +13,6 @@ from typing_extensions import TypedDict
from open_notebook.domain.notebook import vector_search
from open_notebook.graphs.utils import provision_langchain_model
from open_notebook.prompter import Prompter
class SubGraphState(TypedDict):

View file

@ -1,6 +1,7 @@
import sqlite3
from typing import Annotated, Optional
from ai_prompter import Prompter
from langchain_core.messages import SystemMessage
from langchain_core.runnables import (
RunnableConfig,
@ -13,7 +14,6 @@ from typing_extensions import TypedDict
from open_notebook.config import LANGGRAPH_CHECKPOINT_FILE
from open_notebook.domain.notebook import Notebook
from open_notebook.graphs.utils import provision_langchain_model
from open_notebook.prompter import Prompter
class ThreadState(TypedDict):

View file

@ -1,15 +1,13 @@
from typing import Any, Optional
from ai_prompter import Prompter
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.runnables import (
RunnableConfig,
)
from langchain_core.runnables import RunnableConfig
from langgraph.graph import END, START, StateGraph
from loguru import logger
from typing_extensions import TypedDict
from open_notebook.graphs.utils import provision_langchain_model
from open_notebook.prompter import Prompter
class PatternChainState(TypedDict):
@ -22,7 +20,7 @@ class PatternChainState(TypedDict):
def call_model(state: dict, config: RunnableConfig) -> dict:
content = state["input_text"]
system_prompt = Prompter(
prompt_text=state["prompt"], parser=state.get("parser")
template_text=state["prompt"], parser=state.get("parser")
).render(data=state)
logger.warning(content)
payload = [SystemMessage(content=system_prompt)] + [HumanMessage(content=content)]

View file

@ -1,14 +1,12 @@
from ai_prompter import Prompter
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.runnables import (
RunnableConfig,
)
from langchain_core.runnables import RunnableConfig
from langgraph.graph import END, START, StateGraph
from typing_extensions import TypedDict
from open_notebook.domain.notebook import Source
from open_notebook.domain.transformation import DefaultPrompts, Transformation
from open_notebook.graphs.utils import provision_langchain_model
from open_notebook.prompter import Prompter
class TransformationState(TypedDict):
@ -25,14 +23,16 @@ def run_transformation(state: dict, config: RunnableConfig) -> dict:
transformation: Transformation = state["transformation"]
if not content:
content = source.full_text
transformation_prompt_text = transformation.prompt
transformation_template_text = transformation.prompt
default_prompts: DefaultPrompts = DefaultPrompts()
if default_prompts.transformation_instructions:
transformation_prompt_text = f"{default_prompts.transformation_instructions}\n\n{transformation_prompt_text}"
transformation_template_text = f"{default_prompts.transformation_instructions}\n\n{transformation_template_text}"
transformation_prompt_text = f"{transformation_prompt_text}\n\n# INPUT"
transformation_template_text = f"{transformation_template_text}\n\n# INPUT"
system_prompt = Prompter(prompt_text=transformation_prompt_text).render(data=state)
system_prompt = Prompter(template_text=transformation_template_text).render(
data=state
)
payload = [SystemMessage(content=system_prompt)] + [HumanMessage(content=content)]
chain = provision_langchain_model(
str(payload),

View file

@ -1,102 +0,0 @@
"""
A prompt management module using Jinja to generate complex prompts with simple templates.
"""
import os
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Optional, Union
from jinja2 import Environment, FileSystemLoader, Template
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(current_dir)
env = Environment(
loader=FileSystemLoader(
os.path.join(project_root, os.environ.get("PROMPT_PATH", "prompts"))
)
)
@dataclass
class Prompter:
"""
A class for managing and rendering prompt templates.
Attributes:
prompt_template (str, optional): The name of the prompt template file.
prompt_variation (str, optional): The variation of the prompt template.
prompt_text (str, optional): The raw prompt text.
template (Union[str, Template], optional): The Jinja2 template object.
"""
prompt_template: Optional[str] = None
prompt_variation: Optional[str] = "default"
prompt_text: Optional[str] = None
template: Optional[Union[str, Template]] = None
parser: Optional[Any] = None
def __init__(self, prompt_template=None, prompt_text=None, parser=None):
"""
Initialize the Prompter with either a template file or raw text.
Args:
prompt_template (str, optional): The name of the prompt template file.
prompt_text (str, optional): The raw prompt text.
"""
self.prompt_template = prompt_template
self.prompt_text = prompt_text
self.parser = parser
self.setup()
def setup(self):
"""
Set up the Jinja2 template based on the provided template file or text.
Raises:
ValueError: If neither prompt_template nor prompt_text is provided.
"""
if self.prompt_template:
self.template = env.get_template(f"{self.prompt_template}.jinja")
elif self.prompt_text:
self.template = Template(self.prompt_text)
else:
raise ValueError("Prompter must have a prompt_template or prompt_text")
assert self.prompt_template or self.prompt_text, "Prompt is required"
@classmethod
def from_text(cls, text: str):
"""
Create a Prompter instance from raw text, which can contain Jinja code.
Args:
text (str): The raw prompt text.
Returns:
Prompter: A new Prompter instance.
"""
return cls(prompt_text=text)
def render(self, data) -> str:
"""
Render the prompt template with the given data.
Args:
data (dict): The data to be used in rendering the template.
Returns:
str: The rendered prompt text.
Raises:
AssertionError: If the template is not defined or not a Jinja2 Template.
"""
data["current_time"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
if self.parser:
data["format_instructions"] = self.parser.get_format_instructions()
assert self.template, "Prompter template is not defined"
assert isinstance(
self.template, Template
), "Prompter template is not a Jinja2 Template"
return self.template.render(data)

View file

@ -1,6 +1,6 @@
[project]
name = "open-notebook"
version = "0.2.0"
version = "0.2.1"
description = "An open source implementation of a research assistant, inspired by Google Notebook LM"
authors = [
{name = "Luis Novo", email = "lfnovo@gmail.com"}
@ -41,6 +41,7 @@ dependencies = [
"podcastfy",
"nest-asyncio>=1.6.0",
"content-core>=1.0.0",
"ai-prompter>=0.3",
]
[tool.setuptools]

10
uv.lock
View file

@ -20,16 +20,16 @@ resolution-markers = [
[[package]]
name = "ai-prompter"
version = "0.2.3"
version = "0.3.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "jinja2" },
{ name = "pip" },
{ name = "pydantic" },
]
sdist = { url = "https://files.pythonhosted.org/packages/39/ff/cf13c31b88c06e11a1ffeed505601c167293b23d3e2e4e02adac93cc9300/ai_prompter-0.2.3.tar.gz", hash = "sha256:40f55c18f87df250a13f84d0cf7a4e8b31815a01f27666039386d6592849694b", size = 72955 }
sdist = { url = "https://files.pythonhosted.org/packages/88/1a/263b2fb49a485d1b394ead887361cb8855ab28daa20a184cef0d2a0f8f2c/ai_prompter-0.3.0.tar.gz", hash = "sha256:3369555345386c6b9eebb7edbbb96df268977ab2657acb2890c217290bf92569", size = 74091 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/5f/11/9e3712b8393dbef152258c68617baec343040c3d08b372d77b57e51d8e5d/ai_prompter-0.2.3-py3-none-any.whl", hash = "sha256:e8c0becbb3c8bdff399e372830e2c0a3cc3292e02d67921e2b255871329ee477", size = 7345 },
{ url = "https://files.pythonhosted.org/packages/90/ae/cc493d9d37cd1501e442154aa7265fa05814d0e8519ddf549ebd2f5fcb1b/ai_prompter-0.3.0-py3-none-any.whl", hash = "sha256:b70569bf6a64258ab3453e1ff99a7a4cd1c7709296093dc2a35127230d408e7b", size = 8419 },
]
[[package]]
@ -2701,9 +2701,10 @@ wheels = [
[[package]]
name = "open-notebook"
version = "0.2.0"
version = "0.2.1"
source = { editable = "." }
dependencies = [
{ name = "ai-prompter" },
{ name = "content-core" },
{ name = "google-generativeai" },
{ name = "groq" },
@ -2752,6 +2753,7 @@ dev = [
[package.metadata]
requires-dist = [
{ name = "ai-prompter", specifier = ">=0.3" },
{ name = "content-core", specifier = ">=1.0.0" },
{ name = "google-generativeai", specifier = ">=0.8.3" },
{ name = "groq", specifier = ">=0.12.0" },