remove duplicate file
This commit is contained in:
parent
5491964e81
commit
b678a55f63
2 changed files with 1 additions and 74 deletions
|
|
@ -15,7 +15,7 @@ from .beam import BeamPyTorchInference
|
||||||
from .eow_detection import fire_at_boundary, load_cif
|
from .eow_detection import fire_at_boundary, load_cif
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from token_buffer import TokenBuffer
|
from .token_buffer import TokenBuffer
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from .generation_progress import *
|
from .generation_progress import *
|
||||||
|
|
|
||||||
|
|
@ -1,73 +0,0 @@
|
||||||
import torch
|
|
||||||
import sys
|
|
||||||
class TokenBuffer:
|
|
||||||
|
|
||||||
def __init__(self, text="", tokenizer=None, device=None, prefix_token_ids=[]):
|
|
||||||
self.text = text
|
|
||||||
self.prefix_token_ids = prefix_token_ids
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.device = device
|
|
||||||
|
|
||||||
def as_token_ids(self, tokenizer=None):
|
|
||||||
|
|
||||||
if tokenizer is None:
|
|
||||||
tokenizer = self.tokenizer
|
|
||||||
if tokenizer is None:
|
|
||||||
raise ValueError("Tokenizer is not set.")
|
|
||||||
return self.prefix_token_ids + tokenizer.encode(self.text)
|
|
||||||
|
|
||||||
def as_tensor(self, device=None):
|
|
||||||
if device is None:
|
|
||||||
device = self.device
|
|
||||||
if device is None:
|
|
||||||
raise ValueError("Device is not set.")
|
|
||||||
tok_ids = self.as_token_ids()
|
|
||||||
return torch.tensor(tok_ids,
|
|
||||||
dtype=torch.long, device=device).unsqueeze(0)
|
|
||||||
|
|
||||||
def as_tensor_beam(self, beam, device=None):
|
|
||||||
t = self.as_tensor(device=device)
|
|
||||||
return t.repeat_interleave(beam, dim=0)
|
|
||||||
|
|
||||||
|
|
||||||
def as_text(self):
|
|
||||||
return self.text
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def empty(*a, **kw):
|
|
||||||
return TokenBuffer(*a,**kw)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_text(text, *a, **kw):
|
|
||||||
return TokenBuffer(*a, text=text, **kw)
|
|
||||||
|
|
||||||
def is_empty(self):
|
|
||||||
return self.text is None or self.text == ""
|
|
||||||
|
|
||||||
def trim_words(self, num=1, after=0):
|
|
||||||
'''
|
|
||||||
num: how many words to trim from the beginning
|
|
||||||
after: how many characters to skip (length of the static prompt)
|
|
||||||
'''
|
|
||||||
tokenizer = self.tokenizer
|
|
||||||
assert tokenizer is not None, "Tokenizer is not set."
|
|
||||||
|
|
||||||
ids = tokenizer.encode(self.text[after:])
|
|
||||||
words, wids = self.tokenizer.split_to_word_tokens(ids)
|
|
||||||
print(words, file=sys.stderr)
|
|
||||||
print(wids, file=sys.stderr)
|
|
||||||
if not words:
|
|
||||||
return 0
|
|
||||||
self.text = self.text[:after] + "".join(words[num:])
|
|
||||||
return sum(len(wi) for wi in wids[:num])
|
|
||||||
|
|
||||||
def append_token_ids(self, token_ids):
|
|
||||||
tokenizer = self.tokenizer
|
|
||||||
assert tokenizer is not None, "Tokenizer is not set."
|
|
||||||
self.text += self.tokenizer.decode(token_ids)
|
|
||||||
|
|
||||||
def as_split_word_tokens(self):
|
|
||||||
tokenizer = self.tokenizer
|
|
||||||
assert tokenizer is not None, "Tokenizer is not set."
|
|
||||||
ids = tokenizer.encode(self.text)
|
|
||||||
return tokenizer.split_to_word_tokens(ids)
|
|
||||||
Loading…
Reference in a new issue