AnyTool/toolbench/inference/utils.py
2024-02-23 15:13:06 +08:00

267 lines
8.8 KiB
Python

import gc
import abc
import numpy as np
import math
from typing import Iterable
import torch
from transformers.generation.logits_process import (
LogitsProcessorList,
RepetitionPenaltyLogitsProcessor,
TemperatureLogitsWarper,
TopKLogitsWarper,
TopPLogitsWarper,
)
# For DFS
def softmax_bias(answers,temperature=1):
sums = 0.0
answers = [ 10**((cont/temperature)/400) for cont in answers]
for cont in answers:
assert type(cont) == float or type(cont) == int
sums += cont
answers = [ cont/sums for cont in answers]
return np.array(answers)
def compute_epsilon_new_node(p_new_node):
'''
根据公式换算delta
'''
delta = 400 * math.log10(p_new_node /(1-p_new_node))
return 1000 + delta
# For prediction parsing, into ReACT format
def react_parser(string):
thought = [string[string.find("Thought: ") + len("Thought: "): string.find("\nAction: ")]]
action = [string[string.find("Action: ") + len("Action: "): string.find("\nAction Input: ")]]
action_input = [string[string.find("Action Input: ") + len("Action Input: "):]]
return thought[0], action[0], action_input[0]
# For toolllama's predictions
def prepare_logits_processor(
temperature: float, repetition_penalty: float, top_p: float, top_k: int
) -> LogitsProcessorList:
processor_list = LogitsProcessorList()
# TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases.
if temperature >= 1e-5 and temperature != 1.0:
processor_list.append(TemperatureLogitsWarper(temperature))
if repetition_penalty > 1.0:
processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty))
if 1e-8 <= top_p < 1.0:
processor_list.append(TopPLogitsWarper(top_p))
if top_k > 0:
processor_list.append(TopKLogitsWarper(top_k))
return processor_list
@torch.inference_mode()
def generate_stream(
model, tokenizer, params, device, context_len=8192, stream_interval=2, force_generate=False
):
prompt = params["prompt"]
len_prompt = len(prompt)
temperature = float(params.get("temperature", 1.0))
repetition_penalty = float(params.get("repetition_penalty", 1.0))
top_p = float(params.get("top_p", 1.0))
top_k = int(params.get("top_k", -1)) # -1 means disable
max_new_tokens = int(params.get("max_new_tokens", 256))
stop_str = params.get("stop", None)
echo = bool(params.get("echo", True))
stop_token_ids = params.get("stop_token_ids", None) or []
stop_token_ids.append(tokenizer.eos_token_id)
logits_processor = prepare_logits_processor(
temperature, repetition_penalty, top_p, top_k
)
input_ids = tokenizer(prompt).input_ids
input_echo_len = len(input_ids)
output_ids = list(input_ids)
if model.config.is_encoder_decoder:
max_src_len = context_len
else:
max_src_len = context_len - max_new_tokens - 8
input_ids = input_ids[-max_src_len:]
if model.config.is_encoder_decoder:
encoder_output = model.encoder(
input_ids=torch.as_tensor([input_ids], device=device)
)[0]
start_ids = torch.as_tensor(
[[model.generation_config.decoder_start_token_id]],
dtype=torch.int64,
device=device,
)
past_key_values = out = None
for i in range(max_new_tokens):
if i == 0:
if model.config.is_encoder_decoder:
out = model.decoder(
input_ids=start_ids,
encoder_hidden_states=encoder_output,
use_cache=True,
)
logits = model.lm_head(out[0])
else:
out = model(torch.as_tensor([input_ids], device=device), use_cache=True)
logits = out.logits
past_key_values = out.past_key_values
else:
if model.config.is_encoder_decoder:
out = model.decoder(
input_ids=torch.as_tensor([[token]], device=device),
encoder_hidden_states=encoder_output,
use_cache=True,
past_key_values=past_key_values,
)
logits = model.lm_head(out[0])
else:
out = model(
input_ids=torch.as_tensor([[token]], device=device),
use_cache=True,
past_key_values=past_key_values,
)
logits = out.logits
past_key_values = out.past_key_values
if logits_processor:
if repetition_penalty > 1.0:
tmp_output_ids = torch.as_tensor([output_ids], device=logits.device)
else:
tmp_output_ids = None
last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0]
else:
last_token_logits = logits[0, -1, :]
if device == "mps":
# Switch to CPU by avoiding some bugs in mps backend.
last_token_logits = last_token_logits.float().to("cpu")
if temperature < 1e-5 or top_p < 1e-8: # greedy
token = int(torch.argmax(last_token_logits))
else:
probs = torch.softmax(last_token_logits, dim=-1)
token = int(torch.multinomial(probs, num_samples=1))
output_ids.append(token)
if token in stop_token_ids:
stopped = True
else:
stopped = False
if i == 0 and force_generate:
stopped = False
if i == max_new_tokens - 1 or stopped:
if echo:
tmp_output_ids = output_ids
rfind_start = len_prompt
else:
tmp_output_ids = output_ids[input_echo_len:]
rfind_start = 0
output = tokenizer.decode(
tmp_output_ids,
skip_special_tokens=True,
spaces_between_special_tokens=False,
)
if stop_str:
if isinstance(stop_str, str):
pos = output.rfind(stop_str, rfind_start)
if pos != -1:
output = output[:pos]
stopped = True
elif isinstance(stop_str, Iterable):
for each_stop in stop_str:
pos = output.rfind(each_stop, rfind_start)
if pos != -1:
output = output[:pos]
stopped = True
break
else:
raise ValueError("Invalid stop field type.")
yield {
"text": output,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": i,
"total_tokens": input_echo_len + i,
},
"finish_reason": None,
}
if stopped:
break
# finish stream event, which contains finish reason
if i == max_new_tokens - 1:
finish_reason = "length"
elif stopped:
finish_reason = "stop"
else:
finish_reason = None
yield {
"text": output,
"usage": {
"prompt_tokens": input_echo_len,
"completion_tokens": i,
"total_tokens": input_echo_len + i,
},
"finish_reason": finish_reason,
}
# clean
del past_key_values, out
gc.collect()
torch.cuda.empty_cache()
# For IO presentation
class ChatIO(abc.ABC):
@abc.abstractmethod
def prompt_for_input(self, role: str) -> str:
"""Prompt for input from a role."""
@abc.abstractmethod
def prompt_for_output(self, role: str):
"""Prompt for output from a role."""
@abc.abstractmethod
def stream_output(self, output_stream):
"""Stream output."""
@abc.abstractmethod
def return_output(self, output_stream):
"""Return output."""
class SimpleChatIO(ChatIO):
def prompt_for_input(self, role) -> str:
return input(f"{role}: ")
def prompt_for_output(self, role: str):
print(f"{role}: ", end="", flush=True)
def stream_output(self, output_stream):
pre = 0
for outputs in output_stream:
output_text = outputs["text"]
output_text = output_text.strip().split(" ")
now = len(output_text) - 1
if now > pre:
print(" ".join(output_text[pre:now]), end=" ", flush=True)
pre = now
print(" ".join(output_text[pre:]), flush=True)
return " ".join(output_text)
def return_output(self, output_stream):
pre = 0
for outputs in output_stream:
output_text = outputs["text"]
output_text = output_text.strip().split(" ")
now = len(output_text) - 1
if now > pre:
pre = now
return " ".join(output_text)