|
import time |
|
from transformers import AutoTokenizer |
|
|
|
|
|
class BaseModel: |
|
def __init__(self, args): |
|
|
|
self.EOS = None |
|
self.SEQLEN = None |
|
self.input_str = "" |
|
self.system_prompt = "" |
|
self.history = [] |
|
|
|
|
|
self.devices = [int(d) for d in args.devid.split(",")] |
|
|
|
|
|
print("Load " + args.tokenizer_path + " ...") |
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
args.tokenizer_path, trust_remote_code=True |
|
) |
|
|
|
|
|
self.tokenizer.decode([0]) |
|
print("Done!") |
|
|
|
def chat(self): |
|
""" |
|
Start a chat session. |
|
""" |
|
|
|
if not self.EOS: |
|
raise NotImplementedError("Forget to set End of Sentence Token Id(EOS)") |
|
if not self.SEQLEN: |
|
raise NotImplementedError("Forget to set End of Sentence Token Id") |
|
|
|
|
|
print( |
|
"""\n=========================================================== |
|
1. If you want to quit, please enter one of [q, quit, exit] |
|
2. To create a new chat session, please enter one of [clear, new] |
|
===========================================================""" |
|
) |
|
|
|
while True: |
|
self.input_str = input("\nQuestion: ") |
|
|
|
if self.input_str in ["exit", "q", "quit"]: |
|
break |
|
|
|
elif self.input_str in ["clear", "new"]: |
|
self.clear() |
|
|
|
else: |
|
tokens = self.encode_tokens() |
|
|
|
|
|
if not tokens: |
|
print("Sorry: your question is empty!!") |
|
return |
|
if len(tokens) > self.SEQLEN: |
|
print( |
|
"The maximum question length should be shorter than {} but we get {} instead.".format( |
|
self.SEQLEN, len(tokens) |
|
) |
|
) |
|
return |
|
|
|
print("\nAnswer: ", end="") |
|
self.stream_answer(tokens) |
|
|
|
def stream_answer(self, tokens): |
|
""" |
|
Stream the answer for the given tokens. |
|
""" |
|
tok_num = 0 |
|
self.answer_cur = "" |
|
self.answer_token = [] |
|
|
|
|
|
first_start = time.time() |
|
token = self.forward_first(tokens) |
|
first_end = time.time() |
|
|
|
while token != self.EOS and self.model.token_length < self.SEQLEN: |
|
pre_word = self.decode_tokens([token]) |
|
word = self.decode_tokens([token, token])[len(pre_word):] |
|
self.answer_token += [token] |
|
print(word, flush=True, end="") |
|
tok_num += 1 |
|
token = self.forward_next() |
|
self.answer_cur = self.tokenizer.decode(self.answer_token) |
|
|
|
|
|
next_end = time.time() |
|
first_duration = first_end - first_start |
|
next_duration = next_end - first_end |
|
tps = tok_num / next_duration |
|
|
|
self.update_history() |
|
|
|
print() |
|
print(f"FTL: {first_duration:.3f} s") |
|
print(f"TPS: {tps:.3f} token/s") |
|
|
|
def stream_predict(self, query): |
|
""" |
|
Stream the prediction for the given query. |
|
""" |
|
self.answer_cur = "" |
|
self.input_str = query |
|
tokens = self.encode_tokens() |
|
|
|
for answer_cur, history in self._generate_predictions(tokens): |
|
yield answer_cur, history |
|
|
|
def _generate_predictions(self, tokens): |
|
""" |
|
Generate predictions for the given tokens. |
|
""" |
|
|
|
next_token = self.forward_first(tokens) |
|
output_tokens = [next_token] |
|
|
|
|
|
while True: |
|
next_token = self.forward_next() |
|
if next_token == self.EOS: |
|
break |
|
output_tokens += [next_token] |
|
self.answer_cur = self.tokenizer.decode(output_tokens) |
|
if self.model.token_length >= self.SEQLEN: |
|
self.update_history() |
|
yield self.answer_cur + "\n\n\nReached the maximum length; The history context has been cleared.", self.history |
|
break |
|
else: |
|
yield self.answer_cur, self.history |
|
|
|
self.update_history() |
|
|
|
def forward_first(self, tokens): |
|
""" |
|
Forward the first token. |
|
""" |
|
token = self.model.forward_first(tokens) |
|
return token |
|
|
|
def forward_next(self): |
|
""" |
|
Forward the next token. |
|
""" |
|
token = self.model.forward_next() |
|
return token |
|
|
|
def decode_tokens(self, token): |
|
""" |
|
Decode the given token. |
|
""" |
|
word = self.tokenizer.decode(token, skip_special_tokens=True) |
|
return word |
|
|
|
def encode_tokens(self): |
|
""" |
|
Encode the input string to tokens. |
|
""" |
|
raise NotImplementedError |
|
|
|
def load_model(self): |
|
""" |
|
Load the model. |
|
""" |
|
raise NotImplementedError |
|
|
|
def clear(self): |
|
""" |
|
Clear the chat session. |
|
""" |
|
raise NotImplementedError |
|
|
|
def update_history(self): |
|
""" |
|
Update chat history. |
|
""" |
|
raise NotImplementedError |
|
|