refactor: implement ai_prompter library
This commit is contained in:
parent
858b5e0d6e
commit
2afbd36cb4
7 changed files with 21 additions and 122 deletions
|
|
@ -1,6 +1,7 @@
|
|||
import operator
|
||||
from typing import Annotated, List
|
||||
|
||||
from ai_prompter import Prompter
|
||||
from langchain_core.output_parsers.pydantic import PydanticOutputParser
|
||||
from langchain_core.runnables import (
|
||||
RunnableConfig,
|
||||
|
|
@ -12,7 +13,6 @@ from typing_extensions import TypedDict
|
|||
|
||||
from open_notebook.domain.notebook import vector_search
|
||||
from open_notebook.graphs.utils import provision_langchain_model
|
||||
from open_notebook.prompter import Prompter
|
||||
|
||||
|
||||
class SubGraphState(TypedDict):
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
import sqlite3
|
||||
from typing import Annotated, Optional
|
||||
|
||||
from ai_prompter import Prompter
|
||||
from langchain_core.messages import SystemMessage
|
||||
from langchain_core.runnables import (
|
||||
RunnableConfig,
|
||||
|
|
@ -13,7 +14,6 @@ from typing_extensions import TypedDict
|
|||
from open_notebook.config import LANGGRAPH_CHECKPOINT_FILE
|
||||
from open_notebook.domain.notebook import Notebook
|
||||
from open_notebook.graphs.utils import provision_langchain_model
|
||||
from open_notebook.prompter import Prompter
|
||||
|
||||
|
||||
class ThreadState(TypedDict):
|
||||
|
|
|
|||
|
|
@ -1,15 +1,13 @@
|
|||
from typing import Any, Optional
|
||||
|
||||
from ai_prompter import Prompter
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_core.runnables import (
|
||||
RunnableConfig,
|
||||
)
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from loguru import logger
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from open_notebook.graphs.utils import provision_langchain_model
|
||||
from open_notebook.prompter import Prompter
|
||||
|
||||
|
||||
class PatternChainState(TypedDict):
|
||||
|
|
@ -22,7 +20,7 @@ class PatternChainState(TypedDict):
|
|||
def call_model(state: dict, config: RunnableConfig) -> dict:
|
||||
content = state["input_text"]
|
||||
system_prompt = Prompter(
|
||||
prompt_text=state["prompt"], parser=state.get("parser")
|
||||
template_text=state["prompt"], parser=state.get("parser")
|
||||
).render(data=state)
|
||||
logger.warning(content)
|
||||
payload = [SystemMessage(content=system_prompt)] + [HumanMessage(content=content)]
|
||||
|
|
|
|||
|
|
@ -1,14 +1,12 @@
|
|||
from ai_prompter import Prompter
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_core.runnables import (
|
||||
RunnableConfig,
|
||||
)
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.graph import END, START, StateGraph
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from open_notebook.domain.notebook import Source
|
||||
from open_notebook.domain.transformation import DefaultPrompts, Transformation
|
||||
from open_notebook.graphs.utils import provision_langchain_model
|
||||
from open_notebook.prompter import Prompter
|
||||
|
||||
|
||||
class TransformationState(TypedDict):
|
||||
|
|
@ -25,14 +23,16 @@ def run_transformation(state: dict, config: RunnableConfig) -> dict:
|
|||
transformation: Transformation = state["transformation"]
|
||||
if not content:
|
||||
content = source.full_text
|
||||
transformation_prompt_text = transformation.prompt
|
||||
transformation_template_text = transformation.prompt
|
||||
default_prompts: DefaultPrompts = DefaultPrompts()
|
||||
if default_prompts.transformation_instructions:
|
||||
transformation_prompt_text = f"{default_prompts.transformation_instructions}\n\n{transformation_prompt_text}"
|
||||
transformation_template_text = f"{default_prompts.transformation_instructions}\n\n{transformation_template_text}"
|
||||
|
||||
transformation_prompt_text = f"{transformation_prompt_text}\n\n# INPUT"
|
||||
transformation_template_text = f"{transformation_template_text}\n\n# INPUT"
|
||||
|
||||
system_prompt = Prompter(prompt_text=transformation_prompt_text).render(data=state)
|
||||
system_prompt = Prompter(template_text=transformation_template_text).render(
|
||||
data=state
|
||||
)
|
||||
payload = [SystemMessage(content=system_prompt)] + [HumanMessage(content=content)]
|
||||
chain = provision_langchain_model(
|
||||
str(payload),
|
||||
|
|
|
|||
|
|
@ -1,102 +0,0 @@
|
|||
"""
|
||||
A prompt management module using Jinja to generate complex prompts with simple templates.
|
||||
"""
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from jinja2 import Environment, FileSystemLoader, Template
|
||||
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
project_root = os.path.dirname(current_dir)
|
||||
|
||||
env = Environment(
|
||||
loader=FileSystemLoader(
|
||||
os.path.join(project_root, os.environ.get("PROMPT_PATH", "prompts"))
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Prompter:
|
||||
"""
|
||||
A class for managing and rendering prompt templates.
|
||||
|
||||
Attributes:
|
||||
prompt_template (str, optional): The name of the prompt template file.
|
||||
prompt_variation (str, optional): The variation of the prompt template.
|
||||
prompt_text (str, optional): The raw prompt text.
|
||||
template (Union[str, Template], optional): The Jinja2 template object.
|
||||
"""
|
||||
|
||||
prompt_template: Optional[str] = None
|
||||
prompt_variation: Optional[str] = "default"
|
||||
prompt_text: Optional[str] = None
|
||||
template: Optional[Union[str, Template]] = None
|
||||
parser: Optional[Any] = None
|
||||
|
||||
def __init__(self, prompt_template=None, prompt_text=None, parser=None):
|
||||
"""
|
||||
Initialize the Prompter with either a template file or raw text.
|
||||
|
||||
Args:
|
||||
prompt_template (str, optional): The name of the prompt template file.
|
||||
prompt_text (str, optional): The raw prompt text.
|
||||
"""
|
||||
self.prompt_template = prompt_template
|
||||
self.prompt_text = prompt_text
|
||||
self.parser = parser
|
||||
self.setup()
|
||||
|
||||
def setup(self):
|
||||
"""
|
||||
Set up the Jinja2 template based on the provided template file or text.
|
||||
Raises:
|
||||
ValueError: If neither prompt_template nor prompt_text is provided.
|
||||
"""
|
||||
if self.prompt_template:
|
||||
self.template = env.get_template(f"{self.prompt_template}.jinja")
|
||||
elif self.prompt_text:
|
||||
self.template = Template(self.prompt_text)
|
||||
else:
|
||||
raise ValueError("Prompter must have a prompt_template or prompt_text")
|
||||
|
||||
assert self.prompt_template or self.prompt_text, "Prompt is required"
|
||||
|
||||
@classmethod
|
||||
def from_text(cls, text: str):
|
||||
"""
|
||||
Create a Prompter instance from raw text, which can contain Jinja code.
|
||||
|
||||
Args:
|
||||
text (str): The raw prompt text.
|
||||
|
||||
Returns:
|
||||
Prompter: A new Prompter instance.
|
||||
"""
|
||||
return cls(prompt_text=text)
|
||||
|
||||
def render(self, data) -> str:
|
||||
"""
|
||||
Render the prompt template with the given data.
|
||||
|
||||
Args:
|
||||
data (dict): The data to be used in rendering the template.
|
||||
|
||||
Returns:
|
||||
str: The rendered prompt text.
|
||||
|
||||
Raises:
|
||||
AssertionError: If the template is not defined or not a Jinja2 Template.
|
||||
"""
|
||||
data["current_time"] = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
if self.parser:
|
||||
data["format_instructions"] = self.parser.get_format_instructions()
|
||||
assert self.template, "Prompter template is not defined"
|
||||
assert isinstance(
|
||||
self.template, Template
|
||||
), "Prompter template is not a Jinja2 Template"
|
||||
return self.template.render(data)
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
[project]
|
||||
name = "open-notebook"
|
||||
version = "0.2.0"
|
||||
version = "0.2.1"
|
||||
description = "An open source implementation of a research assistant, inspired by Google Notebook LM"
|
||||
authors = [
|
||||
{name = "Luis Novo", email = "lfnovo@gmail.com"}
|
||||
|
|
@ -41,6 +41,7 @@ dependencies = [
|
|||
"podcastfy",
|
||||
"nest-asyncio>=1.6.0",
|
||||
"content-core>=1.0.0",
|
||||
"ai-prompter>=0.3",
|
||||
]
|
||||
|
||||
[tool.setuptools]
|
||||
|
|
|
|||
10
uv.lock
10
uv.lock
|
|
@ -20,16 +20,16 @@ resolution-markers = [
|
|||
|
||||
[[package]]
|
||||
name = "ai-prompter"
|
||||
version = "0.2.3"
|
||||
version = "0.3.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "jinja2" },
|
||||
{ name = "pip" },
|
||||
{ name = "pydantic" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/39/ff/cf13c31b88c06e11a1ffeed505601c167293b23d3e2e4e02adac93cc9300/ai_prompter-0.2.3.tar.gz", hash = "sha256:40f55c18f87df250a13f84d0cf7a4e8b31815a01f27666039386d6592849694b", size = 72955 }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/88/1a/263b2fb49a485d1b394ead887361cb8855ab28daa20a184cef0d2a0f8f2c/ai_prompter-0.3.0.tar.gz", hash = "sha256:3369555345386c6b9eebb7edbbb96df268977ab2657acb2890c217290bf92569", size = 74091 }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/5f/11/9e3712b8393dbef152258c68617baec343040c3d08b372d77b57e51d8e5d/ai_prompter-0.2.3-py3-none-any.whl", hash = "sha256:e8c0becbb3c8bdff399e372830e2c0a3cc3292e02d67921e2b255871329ee477", size = 7345 },
|
||||
{ url = "https://files.pythonhosted.org/packages/90/ae/cc493d9d37cd1501e442154aa7265fa05814d0e8519ddf549ebd2f5fcb1b/ai_prompter-0.3.0-py3-none-any.whl", hash = "sha256:b70569bf6a64258ab3453e1ff99a7a4cd1c7709296093dc2a35127230d408e7b", size = 8419 },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -2701,9 +2701,10 @@ wheels = [
|
|||
|
||||
[[package]]
|
||||
name = "open-notebook"
|
||||
version = "0.2.0"
|
||||
version = "0.2.1"
|
||||
source = { editable = "." }
|
||||
dependencies = [
|
||||
{ name = "ai-prompter" },
|
||||
{ name = "content-core" },
|
||||
{ name = "google-generativeai" },
|
||||
{ name = "groq" },
|
||||
|
|
@ -2752,6 +2753,7 @@ dev = [
|
|||
|
||||
[package.metadata]
|
||||
requires-dist = [
|
||||
{ name = "ai-prompter", specifier = ">=0.3" },
|
||||
{ name = "content-core", specifier = ">=1.0.0" },
|
||||
{ name = "google-generativeai", specifier = ">=0.8.3" },
|
||||
{ name = "groq", specifier = ">=0.12.0" },
|
||||
|
|
|
|||
Loading…
Reference in a new issue