# -*- coding: utf-8 -*- import torch from polos.tokenizers_ import TextEncoderBase def average_pooling( tokens: torch.Tensor, embeddings: torch.Tensor, mask: torch.Tensor, padding_index: int, ) -> torch.Tensor: """Average pooling function. :param tokens: Word ids [batch_size x seq_length] :param embeddings: Word embeddings [batch_size x seq_length x hidden_size] :param mask: Padding mask [batch_size x seq_length] :param padding_index: Padding value. """ wordemb = mask_fill(0.0, tokens, embeddings, padding_index) sentemb = torch.sum(wordemb, 1) sum_mask = mask.unsqueeze(-1).expand(embeddings.size()).float().sum(1) return sentemb / sum_mask def max_pooling( tokens: torch.Tensor, embeddings: torch.Tensor, padding_index: int ) -> torch.Tensor: """Max pooling function. :param tokens: Word ids [batch_size x seq_length] :param embeddings: Word embeddings [batch_size x seq_length x hidden_size] :param padding_index: Padding value. """ return mask_fill(float("-inf"), tokens, embeddings, padding_index).max(dim=1)[0] def mask_fill( fill_value: float, tokens: torch.Tensor, embeddings: torch.Tensor, padding_index: int, ) -> torch.Tensor: """ Function that masks embeddings representing padded elements. :param fill_value: the value to fill the embeddings belonging to padded tokens. :param tokens: The input sequences [bsz x seq_len]. :param embeddings: word embeddings [bsz x seq_len x hiddens]. :param padding_index: Index of the padding token. """ padding_mask = tokens.eq(padding_index).unsqueeze(-1) return embeddings.float().masked_fill_(padding_mask, fill_value).type_as(embeddings) def sort_sequences(inputs: torch.Tensor, input_lengths: torch.Tensor): """ Sort sequences according to lengths of the input sequence (descendingly). :param inputs (Tensor): input sequences, size [B, T, D] :param input_lengths (Tensor): length of each sequence, size [B] """ lengths_sorted, sorted_idx = input_lengths.sort(descending=True) _, unsorted_idx = sorted_idx.sort() return inputs[sorted_idx], lengths_sorted, unsorted_idx def apply_to_sample(f, sample): if hasattr(sample, "__len__") and len(sample) == 0: return {} def _apply(x): if torch.is_tensor(x): return f(x) elif isinstance(x, dict): return {key: _apply(value) for key, value in x.items()} elif isinstance(x, list): return [_apply(x) for x in x] else: return x return _apply(sample) def move_to_cuda(sample): """ Moves a sample to cuda. Works with dictionaries, tensors and lists. """ def _move_to_cuda(tensor): return tensor.cuda() return apply_to_sample(_move_to_cuda, sample) def move_to_cpu(sample): """ Moves a sample to cuda. Works with dictionaries, tensors and lists. """ def _move_to_cpu(tensor): return tensor.cpu() return apply_to_sample(_move_to_cpu, sample) # --------------- LASER auxiliar functions from facebook research ------------------------------ def buffered_arange(max): if not hasattr(buffered_arange, "buf"): buffered_arange.buf = torch.LongTensor() if max > buffered_arange.buf.numel(): torch.arange(max, out=buffered_arange.buf) return buffered_arange.buf[:max] def convert_padding_direction( src_tokens, padding_idx, right_to_left=False, left_to_right=False ): assert right_to_left ^ left_to_right pad_mask = src_tokens.eq(padding_idx) if not pad_mask.any(): # no padding, return early return src_tokens if left_to_right and not pad_mask[:, 0].any(): # already right padded return src_tokens if right_to_left and not pad_mask[:, -1].any(): # already left padded return src_tokens max_len = src_tokens.size(1) range = buffered_arange(max_len).type_as(src_tokens).expand_as(src_tokens) num_pads = pad_mask.long().sum(dim=1, keepdim=True) if right_to_left: index = torch.remainder(range - num_pads, max_len) else: index = torch.remainder(range + num_pads, max_len) return src_tokens.gather(1, index)