glip-zeroshot-demo / maskrcnn_benchmark /utils /shallow_contrastive_loss_helper.py
haotiz's picture
initial commit
708dec4
raw
history blame
2.1 kB
import torch
import maskrcnn_benchmark.utils.dist as dist
def normalized_positive_map(positive_map):
positive_map = positive_map.float()
positive_map_num_pos = positive_map.sum(2)
positive_map_num_pos[positive_map_num_pos == 0] = 1e-6
positive_map = positive_map / positive_map_num_pos.unsqueeze(-1)
return positive_map
def pad_tensor_given_dim_length(tensor, dim, length, padding_value=0, batch_first=True):
new_size = list(tensor.size()[:dim]) + [length] + list(tensor.size()[dim + 1:])
out_tensor = tensor.data.new(*new_size).fill_(padding_value)
if batch_first:
out_tensor[:, :tensor.size(1), ...] = tensor
else:
out_tensor[:tensor.size(0), ...] = tensor
return out_tensor
def pad_random_negative_tensor_given_length(positive_tensor, negative_padding_tensor, length=None):
assert positive_tensor.shape[0] + negative_padding_tensor.shape[0] == length
return torch.cat((positive_tensor, negative_padding_tensor), dim=0)
def gather_tensors(tensor):
"""
Performs all_gather operation on the provided tensors.
*** Warning ***: torch.distributed.all_gather has no gradient.
"""
if not dist.is_dist_avail_and_initialized():
return torch.stack([tensor], dim=0)
total = dist.get_world_size()
rank = torch.distributed.get_rank()
# gathered_normalized_img_emb = [torch.zeros_like(normalized_img_emb) for _ in range(total)]
# torch.distributed.all_gather(gathered_normalized_img_emb, normalized_img_emb)
tensors_gather = [
torch.zeros_like(tensor)
for _ in range(total)
]
torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
# need to do this to restore propagation of the gradients
tensors_gather[rank] = tensor
output = torch.stack(tensors_gather, dim=0)
return output
def convert_to_roi_format(boxes):
concat_boxes = boxes.bbox
device, dtype = concat_boxes.device, concat_boxes.dtype
ids = torch.full((len(boxes), 1), 0, dtype=dtype, device=device)
rois = torch.cat([ids, concat_boxes], dim=1)
return rois