|
from __future__ import annotations |
|
|
|
import configparser |
|
import pathlib |
|
import typing |
|
|
|
import torch |
|
import transformers |
|
from torch.nn.utils.rnn import pad_sequence |
|
|
|
from .config import BELLE_PARAM, LIB_SO_PATH |
|
from .model import BelleModel |
|
import os |
|
|
|
|
|
class LyraBelle: |
|
def __init__(self, model_path, model_name, dtype='fp16', int8_mode=0) -> None: |
|
self.model_path = model_path |
|
self.model_name = model_name |
|
self.dtype = dtype |
|
if dtype != 'int8': |
|
int8_mode = 0 |
|
self.int8_mode = int8_mode |
|
|
|
print(f'Loading model and tokenizer from {self.model_path}') |
|
self.model, self.tokenizer = self.load_model_and_tokenizer() |
|
print("Got model and tokenizer") |
|
|
|
def load_model_and_tokenizer(self): |
|
tokenizer = transformers.AutoTokenizer.from_pretrained(self.model_path) |
|
|
|
checkpoint_path = pathlib.Path(self.model_path) |
|
config_path = checkpoint_path / 'config.ini' |
|
|
|
if config_path.exists(): |
|
|
|
cfg = configparser.ConfigParser() |
|
cfg.read(config_path) |
|
model_name = 'belle' |
|
inference_data_type = self.dtype |
|
if inference_data_type == None: |
|
inference_data_type = cfg.get(model_name, "weight_data_type") |
|
model_args = dict( |
|
head_num=cfg.getint(model_name, 'head_num'), |
|
size_per_head=cfg.getint(model_name, "size_per_head"), |
|
layer_num=cfg.getint(model_name, "num_layer"), |
|
tensor_para_size=cfg.getint(model_name, "tensor_para_size"), |
|
vocab_size=cfg.getint(model_name, "vocab_size"), |
|
start_id=cfg.getint(model_name, "start_id"), |
|
end_id=cfg.getint(model_name, "end_id"), |
|
weights_data_type=cfg.get(model_name, "weight_data_type"), |
|
layernorm_eps=cfg.getfloat(model_name, 'layernorm_eps'), |
|
inference_data_type=inference_data_type) |
|
else: |
|
inference_data_type = self.dtype |
|
if inference_data_type == None: |
|
inference_data_type = BELLE_PARAM.weights_data_type |
|
model_args = dict(head_num=BELLE_PARAM.num_heads, |
|
size_per_head=BELLE_PARAM.size_per_head, |
|
vocab_size=BELLE_PARAM.vocab_size, |
|
start_id=BELLE_PARAM.start_id or tokenizer.bos_token_id, |
|
end_id=BELLE_PARAM.end_id or tokenizer.eos_token_id, |
|
layer_num=BELLE_PARAM.num_layers, |
|
tensor_para_size=BELLE_PARAM.tensor_para_size, |
|
weights_data_type=BELLE_PARAM.weights_data_type, |
|
inference_data_type=inference_data_type) |
|
|
|
|
|
model_args.update(dict( |
|
lib_path=LIB_SO_PATH, |
|
pipeline_para_size=BELLE_PARAM.pipeline_para_size, |
|
shared_contexts_ratio=BELLE_PARAM.shared_contexts_ratio, |
|
int8_mode=self.int8_mode |
|
)) |
|
|
|
print('[FT][INFO] Load Our FT Highly Optimized BELLE model') |
|
for k, v in model_args.items(): |
|
print(f' - {k.ljust(25, ".")}: {v}') |
|
|
|
|
|
checklist = ['head_num', 'size_per_head', 'vocab_size', 'layer_num', |
|
'tensor_para_size', 'tensor_para_size', 'weights_data_type'] |
|
if None in [model_args[k] for k in checklist]: |
|
none_params = [p for p in checklist if model_args[p] is None] |
|
print(f'[FT][WARNING] Found None parameters {none_params}. They must ' |
|
f'be provided either by config file or CLI arguments.') |
|
if model_args['start_id'] != tokenizer.bos_token_id: |
|
print('[FT][WARNING] Given start_id is not matched with the bos token ' |
|
'id of the pretrained tokenizer.') |
|
if model_args['end_id'] not in (tokenizer.pad_token_id, tokenizer.eos_token_id): |
|
print('[FT][WARNING] Given end_id is not matched with neither pad ' |
|
'token id nor eos token id of the pretrained tokenizer.') |
|
|
|
model = BelleModel(**model_args) |
|
if not model.load(ckpt_path=os.path.join(self.model_path, self.model_name)): |
|
print('[FT][WARNING] Skip model loading since no checkpoints are found') |
|
|
|
return model, tokenizer |
|
|
|
def generate(self, prompts: typing.List[str] | str, |
|
output_length: int = 512, |
|
beam_width: int = 1, |
|
top_k: typing.Optional[torch.IntTensor] = 1, |
|
top_p: typing.Optional[torch.FloatTensor] = 1.0, |
|
beam_search_diversity_rate: typing.Optional[torch.FloatTensor] = 0.0, |
|
temperature: typing.Optional[torch.FloatTensor] = 1.0, |
|
len_penalty: typing.Optional[torch.FloatTensor] = 0.0, |
|
repetition_penalty: typing.Optional[torch.FloatTensor] = 1.0, |
|
presence_penalty: typing.Optional[torch.FloatTensor] = None, |
|
min_length: typing.Optional[torch.IntTensor] = None, |
|
bad_words_list: typing.Optional[torch.IntTensor] = None, |
|
do_sample: bool = False, |
|
return_output_length: bool = False, |
|
return_cum_log_probs: int = 0): |
|
|
|
if isinstance(prompts, str): |
|
prompts = [prompts, ] |
|
|
|
inputs = ['Human: ' + prompt.strip() + |
|
'\n\nAssistant:' for prompt in prompts] |
|
batch_size = len(inputs) |
|
ones_int = torch.ones(size=[batch_size], dtype=torch.int32) |
|
ones_float = torch.ones(size=[batch_size], dtype=torch.float32) |
|
|
|
|
|
input_token_ids = [self.tokenizer(text, return_tensors="pt").input_ids.int().squeeze() for text in inputs] |
|
input_lengths = torch.IntTensor([len(ids) for ids in input_token_ids]) |
|
|
|
input_token_ids = pad_sequence(input_token_ids, batch_first=True, padding_value=self.tokenizer.eos_token_id) |
|
|
|
random_seed = None |
|
if do_sample: |
|
random_seed = torch.randint(0, 262144, (batch_size,), dtype=torch.long) |
|
|
|
outputs = self.model(start_ids=input_token_ids, |
|
start_lengths=input_lengths, |
|
output_len=output_length, |
|
beam_width=beam_width, |
|
top_k=top_k*ones_int, |
|
top_p=top_p*ones_float, |
|
beam_search_diversity_rate=beam_search_diversity_rate*ones_float, |
|
temperature=temperature*ones_float, |
|
len_penalty=len_penalty*ones_float, |
|
repetition_penalty=repetition_penalty*ones_float, |
|
presence_penalty=presence_penalty, |
|
min_length=min_length, |
|
random_seed=random_seed, |
|
bad_words_list=bad_words_list, |
|
return_output_length=return_output_length, |
|
return_cum_log_probs=return_cum_log_probs) |
|
|
|
if return_cum_log_probs > 0: |
|
outputs = outputs[0] |
|
|
|
|
|
|
|
output_token_ids = [out[0, length:].cpu() |
|
for out, length in zip(outputs, input_lengths)] |
|
|
|
output_texts = self.tokenizer.batch_decode( |
|
output_token_ids, skip_special_tokens=True) |
|
|
|
return output_texts |
|
|