# encoding=utf-8 from typing import List, Union import torch from torch.nn.utils.rnn import pad_sequence from transformers import T5Tokenizer from fengshen.models.transfo_xl_reasoning import TransfoXLModel from fengshen.utils import sample_sequence_batch def en_to_zh(sentence:str): en_pun = u",.!?[]()<>\"\"''" zh_pun = u",。!?【】()《》“”‘’" table = { ord(f): ord(t) for f,t in zip(en_pun, zh_pun) } return sentence.translate(table) def deduction_generate( model:TransfoXLModel, tokenizer:T5Tokenizer, input_text:Union[str, List[str]], device:int=0, batch_size:int=2, temperature:float=1.0, repetition_penalty:float=2.0, max_out_seq:int=512, top_p:float=0.6) -> List[str]: """ Generate with fixed prompt of deduction """ model = model.eval().cuda(device) if isinstance(input_text, str): input_text = [input_text] input_text = [f"{text},因而" for text in input_text] input_ids = [torch.tensor(ids[:-1]) for ids in tokenizer(input_text).input_ids] input_length = [len(ids) for ids in input_ids] output = [] for index in range(0, len(input_ids), batch_size): input_ids_batch = pad_sequence( input_ids[index: index + batch_size], batch_first=True, padding_value=50000, ) input_ids_length = torch.tensor(input_length[index: index + batch_size]) res_ids_batch, _ = sample_sequence_batch( model=model, context_tokens_tensor=input_ids_batch.cuda(device=device), context_length_tensor=input_ids_length.cuda(device=device), end_token_id=50000, top_k=0, top_p=top_p, max_out_seq=max_out_seq, repetition_penalty=repetition_penalty, temperature=temperature ) res_sentence = [ en_to_zh(tokenizer.decode(ids[length:])).replace(" ", "") for ids, length in zip(res_ids_batch, input_length[index: index + batch_size]) ] output.extend(res_sentence) return output def abduction_generate( model:TransfoXLModel, tokenizer:T5Tokenizer, input_text:Union[str, List[str]], device:int=0, batch_size:int=2, temperature:float=1.0, repetition_penalty:float=2.0, top_p:float=0.6) -> List[str]: """ Generate with fixed prompt of abduction """ model = model.eval().cuda(device) if isinstance(input_text, str): input_text = [input_text] input_text = [f"之所以{text},是因为" for text in input_text] input_ids = [torch.tensor(ids[:-1]) for ids in tokenizer(input_text).input_ids] input_length = [len(ids) for ids in input_ids] output = [] for index in range(0, len(input_ids), batch_size): input_ids_batch = pad_sequence( input_ids[index: index + batch_size], batch_first=True, padding_value=50000, ) input_ids_length = torch.tensor(input_length[index: index + batch_size]) res_ids_batch, _ = sample_sequence_batch( model=model, context_tokens_tensor=input_ids_batch.cuda(device=device), context_length_tensor=input_ids_length.cuda(device=device), end_token_id=50000, top_k=0, top_p=top_p, max_out_seq=512, repetition_penalty=repetition_penalty, temperature=temperature ) res_sentence = [ en_to_zh(tokenizer.decode(ids[length:])).replace(" ", "") for ids, length in zip(res_ids_batch, input_length[index: index + batch_size]) ] output.extend(res_sentence) return output