|
import re |
|
import torch |
|
import torch.nn.functional as F |
|
from torch.nn.utils.rnn import pad_sequence |
|
import numpy as np |
|
import json |
|
import jsonlines |
|
from tqdm import tqdm, trange |
|
|
|
def set_seed(args): |
|
np.random.seed(args.seed) |
|
torch.manual_seed(args.seed) |
|
if args.n_gpu > 0: |
|
torch.cuda.manual_seed_all(args.seed) |
|
|
|
def filter_noise(text): |
|
space_pattern = '([\u4e00-\u9fa5|0-9|οΌ|γ|οΌ|οΌ|@|Β₯|β¦β¦|ββ|γ|γ|β|β|γ|οΌ|οΌ|β|β|οΌ|οΌ|γ|γ|γ|γ|Β·|ο½|-|+])\s+([\u4e00-\u9fa5|0-9|οΌ|γ|οΌ|οΌ|@|Β₯|β¦β¦|ββ|γ|γ|β|β|γ|οΌ|οΌ|β|β|οΌ|οΌ|γ|γ|γ|γ|Β·|ο½|-|+])' |
|
text = re.sub(space_pattern, r'\1\2', text) |
|
text = re.sub(space_pattern, r'\1\2', text) |
|
patterns = ['εΌη¨ζ₯ζ.*$', 'εθθ΅ζ.*$', '\[.*\]', 'γ.*γ', 'εζε°εοΌ', 'εζ转载οΌ', 'ζ¬ζ转θͺοΌ', 'ζ¬ζζθ¦οΌ', '<unk>'] |
|
for pattern in patterns: |
|
text = re.sub(pattern, "", text) |
|
return text.strip() |
|
|
|
def get_raw_data(raw_data): |
|
train_data = {} |
|
with open(raw_data, 'r', encoding='utf8') as fh: |
|
for line in fh: |
|
line = json.loads(line) |
|
for key in line.keys(): |
|
if key not in train_data.keys(): |
|
train_data[key] = [line[key]] |
|
else: |
|
train_data[key].append(line[key]) |
|
return train_data |
|
|
|
def save_output(input_text, output, output_file): |
|
with jsonlines.open(output_file, mode='a') as writer: |
|
for text_in,text_out in zip(input_text, output): |
|
otc = {} |
|
otc['text_a'] = str(text_in) |
|
otc['text_b'] = str(text_out) |
|
writer.write(otc) |
|
|
|
def enforce_repetition_penalty(lprobs, prev_output_tokens, repetition_penalty = 1.5): |
|
"""repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858). """ |
|
for i in range(len(prev_output_tokens)): |
|
for previous_token in set(prev_output_tokens[i]): |
|
|
|
if lprobs[i, previous_token] < 0: |
|
lprobs[i, previous_token] *= repetition_penalty |
|
else: |
|
lprobs[i, previous_token] /= repetition_penalty |
|
|
|
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): |
|
""" Filter a distribution of logits using top-k and/or nucleus (top-p) filtering |
|
Args: |
|
logits: logits distribution shape (vocabulary size) |
|
top_k > 0: keep only top k tokens with highest probability (top-k filtering). |
|
top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). |
|
Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) |
|
From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 |
|
""" |
|
|
|
top_k = min(top_k, logits.size(-1)) |
|
if top_k > 0: |
|
|
|
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] |
|
logits[indices_to_remove] = filter_value |
|
|
|
if top_p > 0.0: |
|
sorted_logits, sorted_indices = torch.sort(logits, dim=-1, descending=True) |
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) |
|
|
|
|
|
sorted_indices_to_remove = cumulative_probs > top_p |
|
|
|
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() |
|
sorted_indices_to_remove[..., 0] = 0 |
|
|
|
for i in range(sorted_indices.size()[0]): |
|
indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]] |
|
logits[i][indices_to_remove] = filter_value |
|
|
|
|
|
return logits |
|
|
|
def sample_sequence_conditional(model, length, context, latent_z=None, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, device='cpu'): |
|
|
|
context = torch.tensor(context, dtype=torch.long, device=device) |
|
context = context.unsqueeze(0) |
|
generated = context |
|
with torch.no_grad(): |
|
for i in trange(length): |
|
if i == 2: |
|
generated[generated[:, 1] == 127, 1] = 0 |
|
attention_mask = model.get_attn_mask(generated.shape[1]).to(device) |
|
inputs = {'input_ids': generated, 'latent_state': latent_z, 'attention_mask':attention_mask, 'mems':None} |
|
outputs = model(**inputs) |
|
next_token_logits = outputs[0][:, -1, :] / temperature |
|
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) |
|
|
|
log_probs = F.softmax(filtered_logits, dim=-1) |
|
if repetition_penalty != 1.0: |
|
enforce_repetition_penalty(log_probs, generated, repetition_penalty) |
|
next_token = torch.multinomial(log_probs, num_samples=1) |
|
generated = torch.cat((generated, next_token), dim=1) |
|
|
|
|
|
if next_token[0, 0] == 50000: |
|
break |
|
|
|
return generated |
|
|
|
def latent_code_from_text(text, tokenizer_encoder, model_vae, args, scale=1.0): |
|
tokenized1 = tokenizer_encoder.encode(text) |
|
coded = torch.Tensor([tokenized1]).long() |
|
with torch.no_grad(): |
|
coded = coded.to(device) |
|
outputs = model_vae.encoder(coded, attention_mask=(coded > 0).float()) |
|
pooled_hidden_fea = outputs[1] |
|
|
|
mean, logvar = model_vae.encoder.linear(pooled_hidden_fea).chunk(2, -1) |
|
std = logvar.mul(0.5).exp() |
|
eps = torch.zeros_like(std).normal_() |
|
|
|
return mean + torch.mul(eps, std)*scale |
|
|
|
def text_from_latent_code(latent_z, model_vae, args, tokenizer_decoder, prompt=None): |
|
bos_token = tokenizer_decoder.convert_tokens_to_ids(tokenizer_decoder.bos_token) |
|
context_tokens = [bos_token] |
|
|
|
if prompt is not None: |
|
context_tokens.append(tokenizer_decoder.encode(prompt)[:-1]) |
|
|
|
out = sample_sequence_conditional( |
|
model=model_vae.decoder, |
|
context=context_tokens, |
|
latent_z=latent_z, |
|
length= args.max_out_length, |
|
temperature=args.temperature, |
|
top_k=args.top_k, |
|
top_p=args.top_p, |
|
repetition_penalty=args.repetition_penalty, |
|
device=device |
|
) |
|
|
|
out_tokens = out[0, :].tolist() |
|
out_tokens = out_tokens[1:out_tokens.index(50000)] if 50000 in out_tokens else out_tokens |
|
text_x1 = tokenizer_decoder.decode(out_tokens, clean_up_tokenization_spaces=True) |
|
|
|
return text_x1 |
|
|
|
|
|
def simulate(model_vae, tokenizer_encoder, tokenizer_decoder, args, sent_input, prompt=None): |
|
latent_z, _ = latent_code_from_text(sent_input, tokenizer_encoder, model_vae, args) |
|
text_analogy = text_from_latent_code(latent_z, model_vae, args, tokenizer_decoder, prompt=prompt) |
|
|
|
return text_analogy |
|
|
|
def switch(next_value, init, is_update): |
|
is_update = is_update.type_as(next_value) |
|
return (1-is_update)*init + is_update*next_value |
|
|
|
def sample_sequence_conditional_batch(model, max_out_length, context_tokens_tensor, context_length_tensor, latent_z=None, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0, device='cpu', end_token=50000): |
|
org_context_length = torch.min(context_length_tensor).item() |
|
batch_size = context_tokens_tensor.shape[0] |
|
|
|
generated = context_tokens_tensor[:,:org_context_length] |
|
counter = org_context_length |
|
|
|
output_tokens_lists = [] |
|
output_order = [] |
|
orig_order = torch.LongTensor(list(range(batch_size))) |
|
|
|
with torch.no_grad(): |
|
while counter < max_out_length: |
|
if counter == org_context_length+2: |
|
generated[generated[:,org_context_length] == 127, org_context_length] = 0 |
|
attention_mask = model.get_attn_mask(generated.shape[1]).to(device) |
|
inputs = {'input_ids': generated, 'latent_state': latent_z, 'attention_mask': attention_mask} |
|
outputs = model(**inputs) |
|
next_token_logits = outputs[0][:, -1, :] / temperature |
|
filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) |
|
|
|
|
|
|
|
log_probs = F.softmax(filtered_logits, dim=-1) |
|
|
|
if repetition_penalty != 1.0: |
|
enforce_repetition_penalty(log_probs, generated, repetition_penalty) |
|
|
|
if any(log_probs.sum(dim=-1) <= 0.0) : |
|
break |
|
next_token = torch.multinomial(log_probs, num_samples=1).view(-1) |
|
next_token = switch(next_token, context_tokens_tensor[:, counter], context_length_tensor<=counter) |
|
|
|
if torch.all(next_token == end_token).item(): |
|
break |
|
|
|
stop_idx = next_token == end_token |
|
output_order.extend(orig_order[stop_idx].tolist()) |
|
|
|
finished = generated[stop_idx] |
|
output_tokens_lists.extend(finished.detach().cpu().tolist()) |
|
|
|
conti_idx = next_token != end_token |
|
orig_order = orig_order[conti_idx] |
|
generated = generated[conti_idx] |
|
latent_z = latent_z[conti_idx] |
|
|
|
next_token = next_token[conti_idx] |
|
context_tokens_tensor = context_tokens_tensor[conti_idx] |
|
context_length_tensor = context_length_tensor[conti_idx] |
|
batch_size = generated.shape[0] |
|
|
|
generated = torch.cat((generated, next_token.view(batch_size, 1)), dim=-1) |
|
counter += 1 |
|
|
|
output_order.extend(orig_order.tolist()) |
|
generated = generated.detach().cpu().tolist() |
|
output_tokens_lists.extend(generated) |
|
output_tokens_lists = [tokens[:tokens.index(end_token)] if end_token in tokens else tokens for tokens in output_tokens_lists] |
|
|
|
output_tokens_lists = [tokens for _,tokens in sorted(zip(output_order, output_tokens_lists))] |
|
|
|
return output_tokens_lists |
|
|
|
def latent_code_from_text_batch(texts, tokenizer_encoder, model_vae, args): |
|
tokens_tensor_list = [] |
|
for text in texts: |
|
tokens = tokenizer_encoder.encode(text)[:510] |
|
tokens_tensor_list.append(torch.tensor([101]+tokens+[102])) |
|
|
|
coded = pad_sequence(tokens_tensor_list, batch_first=True, padding_value=0).long() |
|
with torch.no_grad(): |
|
coded = coded.to(device) |
|
pooled_hidden_fea = model_vae.encoder(coded, attention_mask=(coded > 0).float())[1] |
|
mean, logvar = model_vae.encoder.linear(pooled_hidden_fea).chunk(2, -1) |
|
|
|
std = logvar.mul(0.5).exp() |
|
eps = torch.zeros_like(std).normal_() |
|
|
|
latent_z = mean + torch.mul(eps, std)*args.std_scale |
|
|
|
return latent_z |
|
|
|
def text_from_latent_code_batch(latent_z, model_vae, args, tokenizer_decoder, prompt=None): |
|
past = latent_z |
|
batch_size = latent_z.shape[0] |
|
bos_token = tokenizer_decoder.convert_tokens_to_ids(tokenizer_decoder.bos_token) |
|
end_token = tokenizer_decoder.convert_tokens_to_ids(tokenizer_decoder.eos_token) |
|
|
|
if prompt is not None: |
|
prompt = [[bos_token] + tokenizer_decoder.encode(text)[:-1] for text in prompt] |
|
else: |
|
prompt = [[bos_token]]*batch_size |
|
|
|
context_tokens_tensor = torch.tensor([[end_token]*args.max_out_length]*batch_size).to(device) |
|
context_length_tensor = torch.tensor([1]*batch_size).to(device) |
|
for i in range(batch_size): |
|
context_tokens_tensor[i,:len(prompt[i])] = torch.tensor(prompt[i]).long().to(device) |
|
context_length_tensor[i] = len(prompt[i]) |
|
|
|
|
|
out = sample_sequence_conditional_batch( |
|
model=model_vae.decoder, |
|
max_out_length= args.max_out_length, |
|
context_tokens_tensor=context_tokens_tensor, |
|
context_length_tensor=context_length_tensor, |
|
latent_z=latent_z, |
|
temperature=args.temperature, |
|
top_k=args.top_k, |
|
top_p=args.top_p, |
|
repetition_penalty=args.repetition_penalty, |
|
device=device |
|
) |
|
|
|
out_text = [] |
|
for i, tokens in enumerate(out): |
|
tokens = tokens[len(prompt[i]):] |
|
tokens = tokens[:tokens.index(end_token)] if end_token in tokens else tokens |
|
text = tokenizer_decoder.decode(tokens, clean_up_tokenization_spaces=True) |
|
out_text.append(filter_noise(text)) |
|
return out_text |
|
|
|
|
|
def simulate_batch(model_vae, tokenizer_encoder, tokenizer_decoder, args, sent_inputs, prompt=None): |
|
latent_z = latent_code_from_text_batch(sent_inputs, tokenizer_encoder, model_vae, args) |
|
text_analogy = text_from_latent_code_batch(latent_z, model_vae, args, tokenizer_decoder, prompt=prompt) |
|
return text_analogy |
|
|
|
def simulate_bz(model_vae, tokenizer_encoder, tokenizer_decoder, args, sent_inputs, prompt=None): |
|
latent_z = latent_code_from_text_batch(sent_inputs, tokenizer_encoder, model_vae, args) |
|
return latent_z |
|
|
|
def my_shuffle(x, index): |
|
result = [] |
|
for field in index: |
|
result.append(x[field]) |
|
return result |
|
|
|
|