|
from typing import * |
|
|
|
import numpy as np |
|
import torch |
|
from scipy.optimize import linear_sum_assignment |
|
from torch.nn.utils.rnn import pad_sequence |
|
|
|
|
|
def num2mask( |
|
nums: torch.Tensor, |
|
max_length: Optional[int] = None |
|
) -> torch.Tensor: |
|
""" |
|
E.g. input a tensor [2, 3, 4], return [[T T F F], [T T T F], [T T T T]] |
|
:param nums: Shape [batch] |
|
:param max_length: maximum length. if not provided, will choose the largest number from nums. |
|
:return: 2D binary mask. |
|
""" |
|
shape_backup = nums.shape |
|
nums = nums.flatten() |
|
max_length = max_length or int(nums.max()) |
|
batch_size = len(nums) |
|
range_nums = torch.arange(0, max_length, device=nums.device).unsqueeze(0).expand([batch_size, max_length]) |
|
ret = (range_nums.T < nums).T |
|
return ret.reshape(*shape_backup, max_length) |
|
|
|
|
|
def mask2idx( |
|
mask: torch.Tensor, |
|
max_length: Optional[int] = None, |
|
padding_value: int = 0, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
""" |
|
E.g. input a tensor [[T T F F], [T T T F], [F F F T]] with padding value -1, |
|
return [[0, 1, -1], [0, 1, 2], [3, -1, -1]] |
|
:param mask: Mask tensor. Boolean. Not necessarily to be 2D. |
|
:param max_length: If provided, will truncate. |
|
:param padding_value: Padding value. Default to 0. |
|
:return: Index tensor. |
|
""" |
|
shape_prefix, mask_length = mask.shape[:-1], mask.shape[-1] |
|
flat_mask = mask.flatten(0, -2) |
|
index_list = [torch.arange(mask_length, device=mask.device)[one_mask] for one_mask in flat_mask.unbind(0)] |
|
index_tensor = pad_sequence(index_list, batch_first=True, padding_value=padding_value) |
|
if max_length is not None: |
|
index_tensor = index_tensor[:, :max_length] |
|
index_tensor = index_tensor.reshape(*shape_prefix, -1) |
|
return index_tensor, mask.sum(-1) |
|
|
|
|
|
def one_hot(tags: torch.Tensor, num_tags: Optional[int] = None) -> torch.Tensor: |
|
num_tags = num_tags or int(tags.max()) |
|
ret = tags.new_zeros(size=[*tags.shape, num_tags], dtype=torch.bool) |
|
ret.scatter_(2, tags.unsqueeze(2), tags.new_ones([*tags.shape, 1], dtype=torch.bool)) |
|
return ret |
|
|
|
|
|
def numpy2torch( |
|
dict_obj: dict |
|
) -> dict: |
|
""" |
|
Convert list/np.ndarray data to torch.Tensor and add add a batch dim. |
|
""" |
|
ret = dict() |
|
for k, v in dict_obj.items(): |
|
if isinstance(v, list) or isinstance(v, np.ndarray): |
|
ret[k] = torch.tensor(v).unsqueeze(0) |
|
else: |
|
ret[k] = v |
|
return ret |
|
|
|
|
|
def max_match(mat: np.ndarray): |
|
row_idx, col_idx = linear_sum_assignment(mat, True) |
|
return mat[row_idx, col_idx].sum() |
|
|