summary / fengshen /utils /transfo_xl_utils.py
skf15963's picture
Duplicate from fclong/summary
fb238e8
# encoding=utf-8
import torch, math
import torch.nn.functional as F
def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
# This function has been mostly taken from huggingface conversational ai code at
# https://medium.com/huggingface/how-to-build-a-state-of-the-art-conversational-ai-with-transfer-learning-2d818ac26313
if top_k > 0:
# Remove all tokens with a probability less than the last token of the top-k
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
if top_p > 0.0:
# convert to 1D
sorted_logits, sorted_indices = torch.sort(logits, dim=-1, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
# Remove tokens with cumulative probability above the threshold
sorted_indices_to_remove = cumulative_probs > top_p
# Shift the indices to the right to keep also the first token above the threshold
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
sorted_indices_to_remove[..., 0] = 0
for i in range(sorted_indices.size()[0]):
indices_to_remove = sorted_indices[i][sorted_indices_to_remove[i]]
logits[i][indices_to_remove] = filter_value
return logits
def enforce_repetition_penalty(lprobs, prev_output_tokens, repetition_penalty=1.5):
"""repetition penalty (from CTRL paper https://arxiv.org/abs/1909.05858). """
for previous_token in set(prev_output_tokens):
# if score < 0 then repetition penalty has to multiplied to reduce the previous token probability
if lprobs[previous_token] < 0:
lprobs[previous_token] *= repetition_penalty
else:
lprobs[previous_token] /= repetition_penalty
def switch(next_value, init, is_update): # 换成真实token
is_update = is_update.type_as(next_value)
return (1-is_update)*init + is_update*next_value
def get_atten_mask(batch_size, seq_length, memory_length=0):
memory_attention_mask = torch.ones(
(batch_size, 1, seq_length, seq_length + memory_length), dtype=torch.int16)
memory_attention_mask = torch.tril(
torch.triu(memory_attention_mask, 1 - seq_length + memory_length), memory_length)
return memory_attention_mask # [bs, 1, seq_len, seq_len+M]
def get_masks_and_position_ids(data, mem_length=None):
# Extract batch size and sequence length.
batch_size, seq_length = data.size()
# Attention mask (lower triangular).
attention_mask = torch.ones((1, seq_length, seq_length + mem_length), device=data.device)
attention_mask = torch.tril(torch.triu(attention_mask, 1 - seq_length + mem_length), mem_length)
attention_mask = attention_mask.unsqueeze(1)
# Position ids.
position_ids = torch.arange(seq_length, dtype=torch.long,
device=data.device)
position_ids = position_ids.unsqueeze(0).expand_as(data)
return attention_mask, position_ids
def sample_sequence_batch(model, context_tokens_tensor, context_length_tensor, max_out_seq=None, mems=None,
end_token_id=None, repetition_penalty=1.0, temperature=1.0, top_k=0, top_p=0.0):
"""_summary_
Args:
model (_type_): _description_
context_tokens_tensor (Tensor): [bs, seq_len]
context_length_tensor (Tensor): [bs, ]
max_out_seq (_type_, optional): _description_. Defaults to None.
mems (_type_, optional): _description_. Defaults to None.
end_token_id (_type_, optional): _description_. Defaults to None.
repetition_penalty (float, optional): _description_. Defaults to 1.0.
temperature (float, optional): _description_. Defaults to 1.0.
top_k (int, optional): _description_. Defaults to 0.
top_p (float, optional): _description_. Defaults to 0.0.
Returns:
_type_: _description_
"""
model_dtype = next(model.parameters()).dtype
org_context_length = torch.min(context_length_tensor).item()
batch_size = context_tokens_tensor.shape[0]
tokens = context_tokens_tensor[:, :org_context_length]
attention_mask = get_atten_mask(batch_size, org_context_length).cuda(context_tokens_tensor.device).to(model_dtype)
position_ids = torch.arange(org_context_length, dtype=torch.long,
device=tokens.device)
position_ids = position_ids.unsqueeze(0).expand_as(tokens)
counter, mem_length = 0, 0
if mems is None:
mems = []
if end_token_id is None:
end_token_id = 50000
if max_out_seq is None:
max_out_seq = 512
output_tokens_lists = []
# record order
origin_order = torch.tensor(range(batch_size), device=tokens.device)
output_order = []
# record log_probs
log_probs_tensor = torch.tensor([0.0] * batch_size, device=tokens.device)
log_probs_list = []
with torch.no_grad():
# while counter < (max_out_seq - org_context_length):
while counter < max_out_seq:
index = org_context_length + counter
if counter == 0:
output = model.forward(input_ids=tokens, position_ids=position_ids,
attention_mask=attention_mask, hidden_states=mems)
logits, mems = output.logits, output.hidden_states
else:
output = model.forward(input_ids=tokens[:, index - 1: index], position_ids=tokens.new_ones((1, 1)) * (index - 1),
attention_mask=tokens.new_ones(batch_size, 1, 1, mem_length + 1).to(model_dtype), hidden_states=mems)
logits, mems = output.logits, output.hidden_states
logits = logits[:, -1]
logits /= temperature
logits = top_k_logits(logits, top_k=top_k, top_p=top_p)
# if repetition_penalty != 1.0:
# for bz in range(batch_size):
# enforce_repetition_penalty(logits[bz, :], tokens[bz, :], repetition_penalty)
log_probs = F.softmax(logits, dim=-1) # [bs, vocab_size]
# if repetition_penalty != 1.0:
# for bz in range(batch_size):
# enforce_repetition_penalty(
# log_probs[bz, :], tokens[bz, :], repetition_penalty)
prev = torch.multinomial(log_probs, num_samples=1).view(-1)
if index < torch.max(context_length_tensor).item():
prev = switch(
prev, context_tokens_tensor[:, index], context_length_tensor <= index)
for i in range(batch_size):
if index > context_length_tensor[i] and prev[i] != end_token_id:
log_probs_tensor[i] += math.log(log_probs[i][prev[i]] + 1e-6) ###
if prev[i] == end_token_id:
log_probs_tensor[i] /= (context_length_tensor[i].cpu() - index)
# with torch.autocast('cpu'):
stop_idx = prev == end_token_id
if torch.all(stop_idx).item():
output_order.extend(origin_order[stop_idx].tolist())
break
finished = tokens[stop_idx]
output_tokens_lists.extend(finished.detach().cpu().tolist())
log_probs_list.extend(log_probs_tensor[stop_idx].tolist())
output_order.extend(origin_order[stop_idx].tolist())
# continue with non-ending tokens
conti_idx = (prev != end_token_id)
origin_order = origin_order[conti_idx]
tokens, prev = tokens[conti_idx], prev[conti_idx]
context_tokens_tensor = context_tokens_tensor[conti_idx]
context_length_tensor = context_length_tensor[conti_idx]
log_probs_tensor = log_probs_tensor[conti_idx]
batch_size = tokens.shape[0]
for im in range(len(mems)):
mems[im] = mems[im][conti_idx, :, :]
tokens = torch.cat((tokens, prev.view(batch_size, 1)), dim=-1)
counter += 1
output_tokens_lists.extend(tokens.detach().cpu().tolist())
log_probs_list.extend(log_probs_tensor.tolist())
output_order.extend(origin_order.tolist()) ###
output_tokens_lists = [tokens[:tokens.index(
end_token_id)] if end_token_id in tokens else tokens for tokens in output_tokens_lists]
output_tokens_lists = [tokens for _, tokens in sorted(zip(output_order, output_tokens_lists))]
output_log_porbs = [prob for _, prob in sorted(zip(output_order, log_probs_list))]
return output_tokens_lists, output_log_porbs
def sample_sequence(model, tokens, attention_mask, do_sampling=True,
repetition_penalty=1.0, max_out_seq=None, mems=None, end_token_id=None,
mem_length=0, temperature=1.0, top_k=0, top_p=0.0):
"""_summary_
Args:
model (_type_): _description_
tokens (Tensor): [1, seq_len]
attention_mask (Tensor): [1, 1, seq_len, seq_len]
do_sampling (bool, optional): _description_. Defaults to True.
repetition_penalty (float, optional): _description_. Defaults to 1.0.
max_out_seq (_type_, optional): _description_. Defaults to None.
mems (_type_, optional): _description_. Defaults to None.
end_token (_type_, optional): _description_. Defaults to None.
mem_length (int, optional): _description_. Defaults to 0.
temperature (float, optional): _description_. Defaults to 1.0.
top_k (int, optional): _description_. Defaults to 0.
top_p (float, optional): _description_. Defaults to 0.0.
Returns:
_type_: _description_
"""
counter = 0
if mems is None:
mems = []
if end_token_id is None:
end_token_id = 50000
if max_out_seq is None:
max_out_seq = 512
org_context_length = tokens.size(1)
with torch.no_grad():
# while counter < (max_out_seq - org_context_length):
while counter < max_out_seq:
if counter == 0:
logits, *mems = model(input_ids=tokens, position_ids=None,
attention_mask=attention_mask, mems=mems)
else:
index = org_context_length + counter
logits, *mems = model(input_ids=tokens[:, index - 1: index], position_ids=None,
attention_mask=tokens.new_ones(1, 1, 1, mem_length + 1), mems=mems)
logits = logits[:, -1]
logits /= temperature
if do_sampling:
logits = top_k_logits(logits, top_k=top_k, top_p=top_p)
log_probs = F.softmax(logits, dim=-1)
if repetition_penalty != 1.0:
enforce_repetition_penalty(
log_probs[0, :], tokens[0, :], repetition_penalty)
prev = torch.multinomial(log_probs, num_samples=1)[0]
is_end = (prev == end_token_id)
if is_end:
break
tokens = torch.cat((tokens, prev.view(1, 1)), dim=1)
counter += 1
output_tokens_list = tokens.detach().cpu().tolist()
if end_token_id in output_tokens_list:
output_tokens_list = output_tokens_list[:output_tokens_list.index(
end_token_id)]
return output_tokens_list[0], mems