|
|
|
''' |
|
@File : inference_cogview.py |
|
@Time : 2021/10/09 19:41:58 |
|
@Author : Ming Ding |
|
@Contact : dm18@mails.tsinghua.edu.cn |
|
''' |
|
|
|
|
|
import os |
|
import sys |
|
import math |
|
import random |
|
import torch |
|
import argparse |
|
import stat |
|
|
|
from SwissArmyTransformer import mpu, get_args, get_tokenizer |
|
from SwissArmyTransformer.model import CachedAutoregressiveModel |
|
from SwissArmyTransformer.generation.sampling_strategies import BaseStrategy |
|
from SwissArmyTransformer.generation.autoregressive_sampling import filling_sequence |
|
from SwissArmyTransformer.generation.utils import timed_name, generate_continually |
|
from SwissArmyTransformer.training import set_random_seed |
|
|
|
import json |
|
|
|
def main(args): |
|
|
|
''' |
|
2022/06/17 |
|
Modify load_checkpoint to from_pretraind |
|
''' |
|
|
|
|
|
|
|
model_path = '/path/to/checkpoints/' |
|
|
|
model, args = CachedAutoregressiveModel.from_pretrained(args, model_path) |
|
|
|
if args.fp16: |
|
model = model.half() |
|
model = model.to(args.device) |
|
set_random_seed(args.seed) |
|
model.eval() |
|
|
|
tokenizer = get_tokenizer(args) |
|
|
|
|
|
end_tokens = [tokenizer.get_command('eos').Id] |
|
strategy = BaseStrategy(temperature=args.temperature, top_k=args.top_k, end_tokens=end_tokens) |
|
|
|
def process(raw_text): |
|
if args.with_id: |
|
query_id, raw_text = raw_text.split('\t') |
|
raw_text = json.loads(raw_text) |
|
question=raw_text["question"] + "答:" |
|
raw_text = question |
|
seq = tokenizer._encode(raw_text) |
|
if len(seq) != 0 and seq[0] == 20005: |
|
seq = seq[1:] |
|
seq = [tokenizer.get_command('ENC').Id] + seq |
|
seq += [-1] * (args.max_sequence_length - len(seq)) |
|
if len(seq) > args.max_sequence_length: |
|
raise ValueError('text too long.') |
|
|
|
seq = torch.cuda.LongTensor(seq, device=args.device) |
|
mbz = args.max_inference_batch_size |
|
assert args.batch_size < mbz or args.batch_size % mbz == 0 |
|
output_list = [] |
|
for tim in range(max(args.batch_size // mbz, 1)): |
|
output = filling_sequence(model, seq.clone(), |
|
batch_size=min(args.batch_size, mbz), |
|
strategy=strategy, |
|
log_attention_weights=None |
|
)[0] |
|
if isinstance(output, torch.Tensor): |
|
output = list(output) |
|
|
|
output_list.extend(output) |
|
|
|
for i in range(len(output_list)): |
|
output = output_list[i].tolist() |
|
try: |
|
unfinished = output.index(-1) |
|
except ValueError: |
|
unfinished = len(output) |
|
if output[unfinished - 1] in end_tokens: |
|
unfinished -= 1 |
|
output_list[i] = output[1:unfinished] |
|
bog = output.index(tokenizer.get_command('eos').Id) |
|
output_list[i] = output[1:bog] + output[bog+1:unfinished] |
|
|
|
|
|
txts = [] |
|
for seq in output_list: |
|
decode_tokens = tokenizer.DecodeIds(seq) |
|
txts.append(decode_tokens) |
|
|
|
|
|
if args.with_id: |
|
full_path = os.path.join(args.output_path, query_id + '.txt') |
|
else: |
|
prefix = raw_text.replace('/', '')[:20] |
|
full_path = timed_name(prefix, '.txt', args.output_path) |
|
print(txts[0]) |
|
test_eval_path = os.path.join(args.output_path, 'test_eval.txt') |
|
with open(test_eval_path, 'a', encoding='utf-8') as fout: |
|
fout.write(txts[0] + '\n') |
|
os.chmod(test_eval_path, stat.S_IRWXO + stat.S_IRWXG + stat.S_IRWXU) |
|
|
|
os.makedirs(args.output_path, exist_ok=True) |
|
generate_continually(process, args.input_source) |
|
|
|
|
|
if __name__ == "__main__": |
|
py_parser = argparse.ArgumentParser(add_help=False) |
|
|
|
known, args_list = py_parser.parse_known_args() |
|
args = get_args(args_list) |
|
args = argparse.Namespace(**vars(args), **vars(known)) |
|
args.do_train = False |
|
|
|
with torch.no_grad(): |
|
main(args) |
|
|
|
|