improved typing
This commit is contained in:
parent
7dc37a3ac7
commit
d9c0c93deb
3 changed files with 19 additions and 28 deletions
|
|
@ -11,7 +11,7 @@ from typing_extensions import TypedDict
|
|||
|
||||
from open_notebook.config import LANGGRAPH_CHECKPOINT_FILE
|
||||
from open_notebook.domain.notebook import Notebook
|
||||
from open_notebook.graphs.utils import provision_model
|
||||
from open_notebook.graphs.utils import provision_langchain_model
|
||||
from open_notebook.prompter import Prompter
|
||||
|
||||
|
||||
|
|
@ -25,7 +25,7 @@ class ThreadState(TypedDict):
|
|||
def call_model_with_messages(state: ThreadState, config: RunnableConfig) -> dict:
|
||||
system_prompt = Prompter(prompt_template="chat").render(data=state)
|
||||
payload = [system_prompt] + state.get("messages", [])
|
||||
model = provision_model(str(payload), config, "chat")
|
||||
model = provision_langchain_model(str(payload), config, "chat")
|
||||
ai_message = model.invoke(payload)
|
||||
return {"messages": ai_message}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,13 +1,14 @@
|
|||
from langchain.output_parsers import OutputFixingParser
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import BaseMessage
|
||||
from loguru import logger
|
||||
|
||||
from open_notebook.domain.models import model_manager
|
||||
from open_notebook.models.llms import LanguageModel
|
||||
from open_notebook.prompter import Prompter
|
||||
from open_notebook.utils import token_count
|
||||
|
||||
|
||||
def provision_model(content, config, default_type):
|
||||
def provision_langchain_model(content, config, default_type) -> BaseChatModel:
|
||||
"""
|
||||
Returns the best model to use based on the context size and on whether there is a specific model being requested in Config.
|
||||
If context > 105_000, returns the large_context_model
|
||||
|
|
@ -20,13 +21,14 @@ def provision_model(content, config, default_type):
|
|||
logger.debug(
|
||||
f"Using large context model because the content has {tokens} tokens"
|
||||
)
|
||||
return model_manager.get_default_model("large_context").to_langchain()
|
||||
model = model_manager.get_default_model("large_context")
|
||||
elif config.get("configurable", {}).get("model_id"):
|
||||
return model_manager.get_model(
|
||||
config.get("configurable", {}).get("model_id")
|
||||
).to_langchain()
|
||||
model = model_manager.get_model(config.get("configurable", {}).get("model_id"))
|
||||
else:
|
||||
return model_manager.get_default_model(default_type).to_langchain()
|
||||
model = model_manager.get_default_model(default_type)
|
||||
|
||||
assert isinstance(model, LanguageModel), f"Model is not a LanguageModel: {model}"
|
||||
return model.to_langchain()
|
||||
|
||||
|
||||
# todo: turn into a graph
|
||||
|
|
@ -36,23 +38,12 @@ def run_pattern(
|
|||
messages=[],
|
||||
state: dict = {},
|
||||
parser=None,
|
||||
output_fixing_model_id=None,
|
||||
) -> AIMessage:
|
||||
) -> BaseMessage:
|
||||
system_prompt = Prompter(prompt_template=pattern_name, parser=parser).render(
|
||||
data=state
|
||||
)
|
||||
payload = [system_prompt] + messages
|
||||
chain = provision_model(str(payload), config, "transformation")
|
||||
|
||||
if parser:
|
||||
chain = chain | parser
|
||||
|
||||
if output_fixing_model_id and parser:
|
||||
output_fix_model = model_manager.get_model(output_fixing_model_id)
|
||||
chain = chain | OutputFixingParser.from_llm(
|
||||
parser=parser,
|
||||
llm=output_fix_model,
|
||||
)
|
||||
chain = provision_langchain_model(str(payload), config, "transformation")
|
||||
|
||||
response = chain.invoke(payload)
|
||||
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ def split_text(txt: str, chunk=1000, overlap=0, separator=" "):
|
|||
return text_splitter.split_text(txt)
|
||||
|
||||
|
||||
def token_count(input_string):
|
||||
def token_count(input_string) -> int:
|
||||
"""
|
||||
Count the number of tokens in the input string using the 'o200k_base' encoding.
|
||||
|
||||
|
|
@ -46,7 +46,7 @@ def token_count(input_string):
|
|||
return token_count
|
||||
|
||||
|
||||
def token_cost(token_count, cost_per_million=0.150):
|
||||
def token_cost(token_count, cost_per_million=0.150) -> float:
|
||||
"""
|
||||
Calculate the cost of tokens based on the token count and cost per million tokens.
|
||||
|
||||
|
|
@ -60,11 +60,11 @@ def token_cost(token_count, cost_per_million=0.150):
|
|||
return cost_per_million * (token_count / 1_000_000)
|
||||
|
||||
|
||||
def remove_non_ascii(text):
|
||||
def remove_non_ascii(text) -> str:
|
||||
return re.sub(r"[^\x00-\x7F]+", "", text)
|
||||
|
||||
|
||||
def remove_non_printable(text):
|
||||
def remove_non_printable(text) -> str:
|
||||
# Remove control characters, except newlines and tabs
|
||||
text = "".join(
|
||||
char for char in text if unicodedata.category(char)[0] != "C" or char in "\n\t"
|
||||
|
|
@ -74,7 +74,7 @@ def remove_non_printable(text):
|
|||
return re.sub(r"[^\w\s.,!?\-\n\t]", "", text, flags=re.UNICODE)
|
||||
|
||||
|
||||
def surreal_clean(text):
|
||||
def surreal_clean(text) -> str:
|
||||
"""
|
||||
Clean the input text by removing non-ASCII and non-printable characters,
|
||||
and adjusting colon placement for SurrealDB compatibility.
|
||||
|
|
|
|||
Loading…
Reference in a new issue