yuwd's picture
init
03f6091
raw
history blame
4.26 kB
# -*- coding: utf-8 -*-
import torch
from polos.tokenizers_ import TextEncoderBase
def average_pooling(
tokens: torch.Tensor,
embeddings: torch.Tensor,
mask: torch.Tensor,
padding_index: int,
) -> torch.Tensor:
"""Average pooling function.
:param tokens: Word ids [batch_size x seq_length]
:param embeddings: Word embeddings [batch_size x seq_length x hidden_size]
:param mask: Padding mask [batch_size x seq_length]
:param padding_index: Padding value.
"""
wordemb = mask_fill(0.0, tokens, embeddings, padding_index)
sentemb = torch.sum(wordemb, 1)
sum_mask = mask.unsqueeze(-1).expand(embeddings.size()).float().sum(1)
return sentemb / sum_mask
def max_pooling(
tokens: torch.Tensor, embeddings: torch.Tensor, padding_index: int
) -> torch.Tensor:
"""Max pooling function.
:param tokens: Word ids [batch_size x seq_length]
:param embeddings: Word embeddings [batch_size x seq_length x hidden_size]
:param padding_index: Padding value.
"""
return mask_fill(float("-inf"), tokens, embeddings, padding_index).max(dim=1)[0]
def mask_fill(
fill_value: float,
tokens: torch.Tensor,
embeddings: torch.Tensor,
padding_index: int,
) -> torch.Tensor:
"""
Function that masks embeddings representing padded elements.
:param fill_value: the value to fill the embeddings belonging to padded tokens.
:param tokens: The input sequences [bsz x seq_len].
:param embeddings: word embeddings [bsz x seq_len x hiddens].
:param padding_index: Index of the padding token.
"""
padding_mask = tokens.eq(padding_index).unsqueeze(-1)
return embeddings.float().masked_fill_(padding_mask, fill_value).type_as(embeddings)
def sort_sequences(inputs: torch.Tensor, input_lengths: torch.Tensor):
"""
Sort sequences according to lengths of the input sequence (descendingly).
:param inputs (Tensor): input sequences, size [B, T, D]
:param input_lengths (Tensor): length of each sequence, size [B]
"""
lengths_sorted, sorted_idx = input_lengths.sort(descending=True)
_, unsorted_idx = sorted_idx.sort()
return inputs[sorted_idx], lengths_sorted, unsorted_idx
def apply_to_sample(f, sample):
if hasattr(sample, "__len__") and len(sample) == 0:
return {}
def _apply(x):
if torch.is_tensor(x):
return f(x)
elif isinstance(x, dict):
return {key: _apply(value) for key, value in x.items()}
elif isinstance(x, list):
return [_apply(x) for x in x]
else:
return x
return _apply(sample)
def move_to_cuda(sample):
""" Moves a sample to cuda. Works with dictionaries, tensors and lists. """
def _move_to_cuda(tensor):
return tensor.cuda()
return apply_to_sample(_move_to_cuda, sample)
def move_to_cpu(sample):
""" Moves a sample to cuda. Works with dictionaries, tensors and lists. """
def _move_to_cpu(tensor):
return tensor.cpu()
return apply_to_sample(_move_to_cpu, sample)
# --------------- LASER auxiliar functions from facebook research ------------------------------
def buffered_arange(max):
if not hasattr(buffered_arange, "buf"):
buffered_arange.buf = torch.LongTensor()
if max > buffered_arange.buf.numel():
torch.arange(max, out=buffered_arange.buf)
return buffered_arange.buf[:max]
def convert_padding_direction(
src_tokens, padding_idx, right_to_left=False, left_to_right=False
):
assert right_to_left ^ left_to_right
pad_mask = src_tokens.eq(padding_idx)
if not pad_mask.any():
# no padding, return early
return src_tokens
if left_to_right and not pad_mask[:, 0].any():
# already right padded
return src_tokens
if right_to_left and not pad_mask[:, -1].any():
# already left padded
return src_tokens
max_len = src_tokens.size(1)
range = buffered_arange(max_len).type_as(src_tokens).expand_as(src_tokens)
num_pads = pad_mask.long().sum(dim=1, keepdim=True)
if right_to_left:
index = torch.remainder(range - num_pads, max_len)
else:
index = torch.remainder(range + num_pads, max_len)
return src_tokens.gather(1, index)