rename model_name to model_id

This commit is contained in:
LUIS NOVO 2024-11-01 12:07:00 -03:00
parent af315a0bab
commit c65bf8ba12
7 changed files with 40 additions and 35 deletions

View file

@ -22,12 +22,12 @@ class ThreadState(TypedDict):
def call_model_with_messages(state: ThreadState, config: RunnableConfig) -> dict:
model_name = config.get("configurable", {}).get(
"model_name", DEFAULT_MODELS.default_chat_model
model_id = config.get("configurable", {}).get(
"model_id", DEFAULT_MODELS.default_chat_model
)
ai_message = run_pattern(
"chat",
model_name,
model_id,
messages=state["messages"],
state=state,
)

View file

@ -19,10 +19,10 @@ class DocQueryState(TypedDict):
def call_model(state: dict, config: RunnableConfig) -> dict:
model_name = config.get("configurable", {}).get(
"model_name", os.environ.get("RETRIEVAL_MODEL")
model_id = config.get("configurable", {}).get(
"model_id", os.environ.get("RETRIEVAL_MODEL")
)
return {"answer": run_pattern("doc_query", model_name, state)}
return {"answer": run_pattern("doc_query", model_id, state)}
# todo: there is probably a better way to do this and avoid repetition

View file

@ -18,8 +18,8 @@ class PatternChainState(TypedDict):
def call_model(state: dict, config: RunnableConfig) -> dict:
model_name = config.get("configurable", {}).get(
"model_name", DEFAULT_MODELS.default_transformation_model
model_id = config.get("configurable", {}).get(
"model_id", DEFAULT_MODELS.default_transformation_model
)
transformations = state["transformations"]
current_transformation = transformations.pop(0)
@ -34,7 +34,7 @@ def call_model(state: dict, config: RunnableConfig) -> dict:
transformation_result = run_pattern(
pattern_name=current_transformation,
model_name=model_name,
model_id=model_id,
state=input_args,
)
return {

View file

@ -15,13 +15,13 @@ class PatternState(TypedDict):
def call_model(state: dict, config: RunnableConfig) -> dict:
model_name = config.get("configurable", {}).get(
"model_name", DEFAULT_MODELS.default_transformation_model
model_id = config.get("configurable", {}).get(
"model_id", DEFAULT_MODELS.default_transformation_model
)
return {
"output": run_pattern(
pattern_name=state["pattern"],
model_name=model_name,
model_id=model_id,
state=state,
)
}

View file

@ -49,13 +49,13 @@ def chunk_condition(state: TocState) -> Literal["get_chunk", END]: # type: igno
def call_model(state: TocState, config: RunnableConfig) -> dict:
model_name = config.get("configurable", {}).get(
"model_name", DEFAULT_MODELS.default_transformation_model
model_id = config.get("configurable", {}).get(
"model_id", DEFAULT_MODELS.default_transformation_model
)
return {
"toc": run_pattern(
pattern_name="recursive_toc",
model_name=model_name,
model_id=model_id,
state=state,
).content
}

View file

@ -59,14 +59,14 @@ def chunk_condition(state: SummaryState) -> Literal["get_chunk", END]: # type:
def call_model(state: dict, config: RunnableConfig) -> dict:
model_name = config.get("configurable", {}).get(
"model_name", DEFAULT_MODELS.default_transformation_model
model_id = config.get("configurable", {}).get(
"model_id", DEFAULT_MODELS.default_transformation_model
)
parser = PydanticOutputParser(pydantic_object=SummaryResponse)
return {
"output": run_pattern(
pattern_name="summarize",
model_name=model_name,
model_id=model_id,
state=state,
parser=parser,
)

View file

@ -9,36 +9,41 @@ from open_notebook.utils import token_count
def run_pattern(
pattern_name: str,
model_name=None,
model_id=None,
messages=[],
state: dict = {},
parser=None,
output_fixing_model_name=None,
output_fixing_model_id=None,
) -> dict:
system_prompt = Prompter(prompt_template=pattern_name, parser=parser).render(
data=state
)
tokens = token_count(str(system_prompt) + str(messages))
if tokens > 105_000 and DEFAULT_MODELS.large_context_model:
model_name = DEFAULT_MODELS.large_context_model
logger.debug(
f"Using large context model ({model_name}) because the content has {tokens} tokens"
)
logger.warning(system_prompt)
elif tokens > 105_000 and not DEFAULT_MODELS.large_context_model:
logger.critical(
f"Content has {tokens} tokens, but no large context model is configured"
)
elif not model_name:
model_name = DEFAULT_MODELS.default_transformation_model
chain = get_model(model_name, model_type="language")
model_id = (
DEFAULT_MODELS.large_context_model
or DEFAULT_MODELS.default_transformation_model
or DEFAULT_MODELS.default_chat_model
)
if tokens > 105_000:
model_id = DEFAULT_MODELS.large_context_model
logger.debug(
f"Using large context model ({model_id}) because the content has {tokens} tokens"
)
model_id = (
model_id
or DEFAULT_MODELS.default_transformation_model
or DEFAULT_MODELS.default_chat_model
)
chain = get_model(model_id, model_type="language")
if parser:
chain = chain | parser
if output_fixing_model_name and parser:
output_fix_model = get_model(output_fixing_model_name, model_type="language")
if output_fixing_model_id and parser:
output_fix_model = get_model(output_fixing_model_id, model_type="language")
chain = chain | OutputFixingParser.from_llm(
parser=parser,
llm=output_fix_model,