import gc import abc import numpy as np import math from typing import Iterable import torch from transformers.generation.logits_process import ( LogitsProcessorList, RepetitionPenaltyLogitsProcessor, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper, ) # For DFS def softmax_bias(answers,temperature=1): sums = 0.0 answers = [ 10**((cont/temperature)/400) for cont in answers] for cont in answers: assert type(cont) == float or type(cont) == int sums += cont answers = [ cont/sums for cont in answers] return np.array(answers) def compute_epsilon_new_node(p_new_node): ''' 根据公式换算delta ''' delta = 400 * math.log10(p_new_node /(1-p_new_node)) return 1000 + delta # For prediction parsing, into ReACT format def react_parser(string): thought = [string[string.find("Thought: ") + len("Thought: "): string.find("\nAction: ")]] action = [string[string.find("Action: ") + len("Action: "): string.find("\nAction Input: ")]] action_input = [string[string.find("Action Input: ") + len("Action Input: "):]] return thought[0], action[0], action_input[0] # For toolllama's predictions def prepare_logits_processor( temperature: float, repetition_penalty: float, top_p: float, top_k: int ) -> LogitsProcessorList: processor_list = LogitsProcessorList() # TemperatureLogitsWarper doesn't accept 0.0, 1.0 makes it a no-op so we skip two cases. if temperature >= 1e-5 and temperature != 1.0: processor_list.append(TemperatureLogitsWarper(temperature)) if repetition_penalty > 1.0: processor_list.append(RepetitionPenaltyLogitsProcessor(repetition_penalty)) if 1e-8 <= top_p < 1.0: processor_list.append(TopPLogitsWarper(top_p)) if top_k > 0: processor_list.append(TopKLogitsWarper(top_k)) return processor_list @torch.inference_mode() def generate_stream( model, tokenizer, params, device, context_len=8192, stream_interval=2, force_generate=False ): prompt = params["prompt"] len_prompt = len(prompt) temperature = float(params.get("temperature", 1.0)) repetition_penalty = float(params.get("repetition_penalty", 1.0)) top_p = float(params.get("top_p", 1.0)) top_k = int(params.get("top_k", -1)) # -1 means disable max_new_tokens = int(params.get("max_new_tokens", 256)) stop_str = params.get("stop", None) echo = bool(params.get("echo", True)) stop_token_ids = params.get("stop_token_ids", None) or [] stop_token_ids.append(tokenizer.eos_token_id) logits_processor = prepare_logits_processor( temperature, repetition_penalty, top_p, top_k ) input_ids = tokenizer(prompt).input_ids input_echo_len = len(input_ids) output_ids = list(input_ids) if model.config.is_encoder_decoder: max_src_len = context_len else: max_src_len = context_len - max_new_tokens - 8 input_ids = input_ids[-max_src_len:] if model.config.is_encoder_decoder: encoder_output = model.encoder( input_ids=torch.as_tensor([input_ids], device=device) )[0] start_ids = torch.as_tensor( [[model.generation_config.decoder_start_token_id]], dtype=torch.int64, device=device, ) past_key_values = out = None for i in range(max_new_tokens): if i == 0: if model.config.is_encoder_decoder: out = model.decoder( input_ids=start_ids, encoder_hidden_states=encoder_output, use_cache=True, ) logits = model.lm_head(out[0]) else: out = model(torch.as_tensor([input_ids], device=device), use_cache=True) logits = out.logits past_key_values = out.past_key_values else: if model.config.is_encoder_decoder: out = model.decoder( input_ids=torch.as_tensor([[token]], device=device), encoder_hidden_states=encoder_output, use_cache=True, past_key_values=past_key_values, ) logits = model.lm_head(out[0]) else: out = model( input_ids=torch.as_tensor([[token]], device=device), use_cache=True, past_key_values=past_key_values, ) logits = out.logits past_key_values = out.past_key_values if logits_processor: if repetition_penalty > 1.0: tmp_output_ids = torch.as_tensor([output_ids], device=logits.device) else: tmp_output_ids = None last_token_logits = logits_processor(tmp_output_ids, logits[:, -1, :])[0] else: last_token_logits = logits[0, -1, :] if device == "mps": # Switch to CPU by avoiding some bugs in mps backend. last_token_logits = last_token_logits.float().to("cpu") if temperature < 1e-5 or top_p < 1e-8: # greedy token = int(torch.argmax(last_token_logits)) else: probs = torch.softmax(last_token_logits, dim=-1) token = int(torch.multinomial(probs, num_samples=1)) output_ids.append(token) if token in stop_token_ids: stopped = True else: stopped = False if i == 0 and force_generate: stopped = False if i == max_new_tokens - 1 or stopped: if echo: tmp_output_ids = output_ids rfind_start = len_prompt else: tmp_output_ids = output_ids[input_echo_len:] rfind_start = 0 output = tokenizer.decode( tmp_output_ids, skip_special_tokens=True, spaces_between_special_tokens=False, ) if stop_str: if isinstance(stop_str, str): pos = output.rfind(stop_str, rfind_start) if pos != -1: output = output[:pos] stopped = True elif isinstance(stop_str, Iterable): for each_stop in stop_str: pos = output.rfind(each_stop, rfind_start) if pos != -1: output = output[:pos] stopped = True break else: raise ValueError("Invalid stop field type.") yield { "text": output, "usage": { "prompt_tokens": input_echo_len, "completion_tokens": i, "total_tokens": input_echo_len + i, }, "finish_reason": None, } if stopped: break # finish stream event, which contains finish reason if i == max_new_tokens - 1: finish_reason = "length" elif stopped: finish_reason = "stop" else: finish_reason = None yield { "text": output, "usage": { "prompt_tokens": input_echo_len, "completion_tokens": i, "total_tokens": input_echo_len + i, }, "finish_reason": finish_reason, } # clean del past_key_values, out gc.collect() torch.cuda.empty_cache() # For IO presentation class ChatIO(abc.ABC): @abc.abstractmethod def prompt_for_input(self, role: str) -> str: """Prompt for input from a role.""" @abc.abstractmethod def prompt_for_output(self, role: str): """Prompt for output from a role.""" @abc.abstractmethod def stream_output(self, output_stream): """Stream output.""" @abc.abstractmethod def return_output(self, output_stream): """Return output.""" class SimpleChatIO(ChatIO): def prompt_for_input(self, role) -> str: return input(f"{role}: ") def prompt_for_output(self, role: str): print(f"{role}: ", end="", flush=True) def stream_output(self, output_stream): pre = 0 for outputs in output_stream: output_text = outputs["text"] output_text = output_text.strip().split(" ") now = len(output_text) - 1 if now > pre: print(" ".join(output_text[pre:now]), end=" ", flush=True) pre = now print(" ".join(output_text[pre:]), flush=True) return " ".join(output_text) def return_output(self, output_stream): pre = 0 for outputs in output_stream: output_text = outputs["text"] output_text = output_text.strip().split(" ") now = len(output_text) - 1 if now > pre: pre = now return " ".join(output_text)