rename model_name to model_id
This commit is contained in:
parent
af315a0bab
commit
c65bf8ba12
7 changed files with 40 additions and 35 deletions
|
|
@ -22,12 +22,12 @@ class ThreadState(TypedDict):
|
|||
|
||||
|
||||
def call_model_with_messages(state: ThreadState, config: RunnableConfig) -> dict:
|
||||
model_name = config.get("configurable", {}).get(
|
||||
"model_name", DEFAULT_MODELS.default_chat_model
|
||||
model_id = config.get("configurable", {}).get(
|
||||
"model_id", DEFAULT_MODELS.default_chat_model
|
||||
)
|
||||
ai_message = run_pattern(
|
||||
"chat",
|
||||
model_name,
|
||||
model_id,
|
||||
messages=state["messages"],
|
||||
state=state,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -19,10 +19,10 @@ class DocQueryState(TypedDict):
|
|||
|
||||
|
||||
def call_model(state: dict, config: RunnableConfig) -> dict:
|
||||
model_name = config.get("configurable", {}).get(
|
||||
"model_name", os.environ.get("RETRIEVAL_MODEL")
|
||||
model_id = config.get("configurable", {}).get(
|
||||
"model_id", os.environ.get("RETRIEVAL_MODEL")
|
||||
)
|
||||
return {"answer": run_pattern("doc_query", model_name, state)}
|
||||
return {"answer": run_pattern("doc_query", model_id, state)}
|
||||
|
||||
|
||||
# todo: there is probably a better way to do this and avoid repetition
|
||||
|
|
|
|||
|
|
@ -18,8 +18,8 @@ class PatternChainState(TypedDict):
|
|||
|
||||
|
||||
def call_model(state: dict, config: RunnableConfig) -> dict:
|
||||
model_name = config.get("configurable", {}).get(
|
||||
"model_name", DEFAULT_MODELS.default_transformation_model
|
||||
model_id = config.get("configurable", {}).get(
|
||||
"model_id", DEFAULT_MODELS.default_transformation_model
|
||||
)
|
||||
transformations = state["transformations"]
|
||||
current_transformation = transformations.pop(0)
|
||||
|
|
@ -34,7 +34,7 @@ def call_model(state: dict, config: RunnableConfig) -> dict:
|
|||
|
||||
transformation_result = run_pattern(
|
||||
pattern_name=current_transformation,
|
||||
model_name=model_name,
|
||||
model_id=model_id,
|
||||
state=input_args,
|
||||
)
|
||||
return {
|
||||
|
|
|
|||
|
|
@ -15,13 +15,13 @@ class PatternState(TypedDict):
|
|||
|
||||
|
||||
def call_model(state: dict, config: RunnableConfig) -> dict:
|
||||
model_name = config.get("configurable", {}).get(
|
||||
"model_name", DEFAULT_MODELS.default_transformation_model
|
||||
model_id = config.get("configurable", {}).get(
|
||||
"model_id", DEFAULT_MODELS.default_transformation_model
|
||||
)
|
||||
return {
|
||||
"output": run_pattern(
|
||||
pattern_name=state["pattern"],
|
||||
model_name=model_name,
|
||||
model_id=model_id,
|
||||
state=state,
|
||||
)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -49,13 +49,13 @@ def chunk_condition(state: TocState) -> Literal["get_chunk", END]: # type: igno
|
|||
|
||||
|
||||
def call_model(state: TocState, config: RunnableConfig) -> dict:
|
||||
model_name = config.get("configurable", {}).get(
|
||||
"model_name", DEFAULT_MODELS.default_transformation_model
|
||||
model_id = config.get("configurable", {}).get(
|
||||
"model_id", DEFAULT_MODELS.default_transformation_model
|
||||
)
|
||||
return {
|
||||
"toc": run_pattern(
|
||||
pattern_name="recursive_toc",
|
||||
model_name=model_name,
|
||||
model_id=model_id,
|
||||
state=state,
|
||||
).content
|
||||
}
|
||||
|
|
|
|||
|
|
@ -59,14 +59,14 @@ def chunk_condition(state: SummaryState) -> Literal["get_chunk", END]: # type:
|
|||
|
||||
|
||||
def call_model(state: dict, config: RunnableConfig) -> dict:
|
||||
model_name = config.get("configurable", {}).get(
|
||||
"model_name", DEFAULT_MODELS.default_transformation_model
|
||||
model_id = config.get("configurable", {}).get(
|
||||
"model_id", DEFAULT_MODELS.default_transformation_model
|
||||
)
|
||||
parser = PydanticOutputParser(pydantic_object=SummaryResponse)
|
||||
return {
|
||||
"output": run_pattern(
|
||||
pattern_name="summarize",
|
||||
model_name=model_name,
|
||||
model_id=model_id,
|
||||
state=state,
|
||||
parser=parser,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -9,36 +9,41 @@ from open_notebook.utils import token_count
|
|||
|
||||
def run_pattern(
|
||||
pattern_name: str,
|
||||
model_name=None,
|
||||
model_id=None,
|
||||
messages=[],
|
||||
state: dict = {},
|
||||
parser=None,
|
||||
output_fixing_model_name=None,
|
||||
output_fixing_model_id=None,
|
||||
) -> dict:
|
||||
system_prompt = Prompter(prompt_template=pattern_name, parser=parser).render(
|
||||
data=state
|
||||
)
|
||||
|
||||
tokens = token_count(str(system_prompt) + str(messages))
|
||||
if tokens > 105_000 and DEFAULT_MODELS.large_context_model:
|
||||
model_name = DEFAULT_MODELS.large_context_model
|
||||
logger.debug(
|
||||
f"Using large context model ({model_name}) because the content has {tokens} tokens"
|
||||
)
|
||||
logger.warning(system_prompt)
|
||||
elif tokens > 105_000 and not DEFAULT_MODELS.large_context_model:
|
||||
logger.critical(
|
||||
f"Content has {tokens} tokens, but no large context model is configured"
|
||||
)
|
||||
elif not model_name:
|
||||
model_name = DEFAULT_MODELS.default_transformation_model
|
||||
|
||||
chain = get_model(model_name, model_type="language")
|
||||
model_id = (
|
||||
DEFAULT_MODELS.large_context_model
|
||||
or DEFAULT_MODELS.default_transformation_model
|
||||
or DEFAULT_MODELS.default_chat_model
|
||||
)
|
||||
if tokens > 105_000:
|
||||
model_id = DEFAULT_MODELS.large_context_model
|
||||
logger.debug(
|
||||
f"Using large context model ({model_id}) because the content has {tokens} tokens"
|
||||
)
|
||||
|
||||
model_id = (
|
||||
model_id
|
||||
or DEFAULT_MODELS.default_transformation_model
|
||||
or DEFAULT_MODELS.default_chat_model
|
||||
)
|
||||
|
||||
chain = get_model(model_id, model_type="language")
|
||||
if parser:
|
||||
chain = chain | parser
|
||||
|
||||
if output_fixing_model_name and parser:
|
||||
output_fix_model = get_model(output_fixing_model_name, model_type="language")
|
||||
if output_fixing_model_id and parser:
|
||||
output_fix_model = get_model(output_fixing_model_id, model_type="language")
|
||||
chain = chain | OutputFixingParser.from_llm(
|
||||
parser=parser,
|
||||
llm=output_fix_model,
|
||||
|
|
|
|||
Loading…
Reference in a new issue