Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,574 Bytes
1de8821 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 |
import torch
from .position import PositionEmbeddingSine
def split_feature(feature,
num_splits=2,
channel_last=False,
):
if channel_last: # [B, H, W, C]
b, h, w, c = feature.size()
# if h % num_splits:
# feature = feature[:, :, :-1, :]
# if w % num_splits:
# feature = feature[:, :, :, :-1]
# b, h, w, c = feature.size()
assert h % num_splits == 0 and w % num_splits == 0
b_new = b * num_splits * num_splits
h_new = h // num_splits
w_new = w // num_splits
feature = feature.view(b, num_splits, h // num_splits, num_splits, w // num_splits, c
).permute(0, 1, 3, 2, 4, 5).reshape(b_new, h_new, w_new, c) # [B*K*K, H/K, W/K, C]
else: # [B, C, H, W]
b, c, h, w = feature.size()
# if h % num_splits:
# feature = feature[:, :, :-1, :]
# if w % num_splits:
# feature = feature[:, :, :, :-1]
# b, c, h, w = feature.size()
assert h % num_splits == 0 and w % num_splits == 0
b_new = b * num_splits * num_splits
h_new = h // num_splits
w_new = w // num_splits
feature = feature.view(b, c, num_splits, h // num_splits, num_splits, w // num_splits
).permute(0, 2, 4, 1, 3, 5).reshape(b_new, c, h_new, w_new) # [B*K*K, C, H/K, W/K]
return feature
def merge_splits(splits,
num_splits=2,
channel_last=False,
):
if channel_last: # [B*K*K, H/K, W/K, C]
b, h, w, c = splits.size()
new_b = b // num_splits // num_splits
splits = splits.view(new_b, num_splits, num_splits, h, w, c)
merge = splits.permute(0, 1, 3, 2, 4, 5).contiguous().view(
new_b, num_splits * h, num_splits * w, c) # [B, H, W, C]
else: # [B*K*K, C, H/K, W/K]
b, c, h, w = splits.size()
new_b = b // num_splits // num_splits
splits = splits.view(new_b, num_splits, num_splits, c, h, w)
merge = splits.permute(0, 3, 1, 4, 2, 5).contiguous().view(
new_b, c, num_splits * h, num_splits * w) # [B, C, H, W]
return merge
def normalize_img(img0, img1):
# loaded images are in [0, 255]
# normalize by ImageNet mean and std
mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(img1.device)
std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(img1.device)
img0 = (img0 / 255. - mean) / std
img1 = (img1 / 255. - mean) / std
return img0, img1
def feature_add_position(feature0, feature1, attn_splits, feature_channels):
pos_enc = PositionEmbeddingSine(num_pos_feats=feature_channels // 2)
if attn_splits > 1: # add position in splited window
# import ipdb; ipdb.set_trace()
feature0_splits = split_feature(feature0, num_splits=attn_splits)
feature1_splits = split_feature(feature1, num_splits=attn_splits)
position = pos_enc(feature0_splits)
feature0_splits = feature0_splits + position
feature1_splits = feature1_splits + position
feature0 = merge_splits(feature0_splits, num_splits=attn_splits)
feature1 = merge_splits(feature1_splits, num_splits=attn_splits)
else:
position = pos_enc(feature0)
feature0 = feature0 + position
feature1 = feature1 + position
return feature0, feature1
|