Spaces:
Running
on
Zero
Running
on
Zero
import numpy as np | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from MT import FeatureTransformer | |
from torch.cuda.amp import autocast as autocast | |
from flow_tools import viz_img_seq, save_img_seq, plt_show_img_flow | |
from copy import deepcopy | |
from V1 import V1 | |
import matplotlib.pyplot as plt | |
from io import BytesIO | |
from PIL import Image | |
def conv(in_planes, out_planes, kernel_size=3, stride=1, dilation=1, isReLU=True): | |
if isReLU: | |
return nn.Sequential( | |
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, | |
dilation=dilation, | |
padding=((kernel_size - 1) * dilation) // 2, bias=True), | |
nn.GELU() | |
) | |
else: | |
return nn.Sequential( | |
nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, | |
dilation=dilation, | |
padding=((kernel_size - 1) * dilation) // 2, bias=True) | |
) | |
def plt_attention(attention, h, w): | |
col = len(attention) // 2 | |
fig = plt.figure(figsize=(10, 8)) | |
for i in range(len(attention)): | |
viz = attention[i][0, :, :, h, w].detach().cpu().numpy() | |
# viz = viz[7:-7, 7:-7] | |
if i == 0: | |
viz_all = viz | |
else: | |
viz_all = viz_all + viz | |
ax1 = fig.add_subplot(2, col, i + 1) | |
img = ax1.imshow(viz, cmap="rainbow", interpolation="bilinear") | |
ax1.scatter(w, h, color='grey', s=300, alpha=0.5) | |
ax1.scatter(w, h, color='red', s=150, alpha=0.5) | |
plt.title(" Iteration %d" % (i + 1)) | |
if i == len(attention) - 1: | |
plt.title(" Final Iteration") | |
plt.xticks([]) | |
plt.yticks([]) | |
# tight layout | |
plt.tight_layout() | |
# save the figure | |
buf = BytesIO() | |
plt.savefig(buf, format='png') | |
buf.seek(0) | |
plt.close() | |
# convert the figure to an array | |
img = Image.open(buf) | |
img = np.array(img) | |
return img | |
class FlowDecoder(nn.Module): | |
# can reduce 25% of training time. | |
def __init__(self, ch_in): | |
super(FlowDecoder, self).__init__() | |
self.conv1 = conv(ch_in, 256, kernel_size=1) | |
self.conv2 = conv(256, 128, kernel_size=1) | |
self.conv3 = conv(256 + 128, 96, kernel_size=1) | |
self.conv4 = conv(96 + 128, 64, kernel_size=1) | |
self.conv5 = conv(96 + 64, 32, kernel_size=1) | |
self.feat_dim = 32 | |
self.predict_flow = conv(64 + 32, 2, isReLU=False) | |
def forward(self, x): | |
x1 = self.conv1(x) | |
x2 = self.conv2(x1) | |
x3 = self.conv3(torch.cat([x1, x2], dim=1)) | |
x4 = self.conv4(torch.cat([x2, x3], dim=1)) | |
x5 = self.conv5(torch.cat([x3, x4], dim=1)) | |
flow = self.predict_flow(torch.cat([x4, x5], dim=1)) | |
return flow | |
class FFV1DNN(nn.Module): | |
def __init__(self, | |
num_scales=8, | |
num_cells=256, | |
upsample_factor=8, | |
feature_channels=256, | |
scale_factor=16, | |
num_layers=6, | |
): | |
super(FFV1DNN, self).__init__() | |
self.ffv1 = V1(spatial_num=num_cells // num_scales, scale_num=num_scales, scale_factor=scale_factor, | |
kernel_radius=7, num_ft=num_cells // num_scales, | |
kernel_size=6, average_time=True) | |
self.v1_kz = 7 | |
self.scale_factor = scale_factor | |
scale_each_level = np.exp(1 / (num_scales - 1) * np.log(1 / scale_factor)) | |
self.scale_num = num_scales | |
self.scale_each_level = scale_each_level | |
v1_channel = self.ffv1.num_after_st | |
self.num_scales = num_scales | |
self.MT_channel = feature_channels | |
assert self.MT_channel == v1_channel | |
self.feature_channels = feature_channels | |
self.upsample_factor = upsample_factor | |
self.num_layers = num_layers | |
# convex upsampling: concat feature0 and flow as input | |
self.upsampler_1 = nn.Sequential(nn.Conv2d(2 + feature_channels, 256, 3, 1, 1), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(256, 256, 3, 1, 1), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(256, upsample_factor ** 2 * 9, 3, 1, 1)) | |
self.decoder = FlowDecoder(feature_channels) | |
self.conv_feat = nn.ModuleList([conv(v1_channel, feature_channels, 1) for i in range(num_scales)]) | |
self.MT = FeatureTransformer(d_model=feature_channels, num_layers=self.num_layers) | |
# 2*2*8*scale` | |
def upsample_flow(self, flow, feature, upsampler=None, bilinear=False, upsample_factor=4): | |
if bilinear: | |
up_flow = F.interpolate(flow, scale_factor=upsample_factor, | |
mode='bilinear', align_corners=True) * upsample_factor | |
else: | |
# convex upsampling | |
concat = torch.cat((flow, feature), dim=1) | |
mask = upsampler(concat) | |
b, flow_channel, h, w = flow.shape | |
mask = mask.view(b, 1, 9, upsample_factor, upsample_factor, h, w) # [B, 1, 9, K, K, H, W] | |
mask = torch.softmax(mask, dim=2) | |
up_flow = F.unfold(upsample_factor * flow, [3, 3], padding=1) | |
up_flow = up_flow.view(b, flow_channel, 9, 1, 1, h, w) # [B, 2, 9, 1, 1, H, W] | |
up_flow = torch.sum(mask * up_flow, dim=2) # [B, 2, K, K, H, W] | |
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) # [B, 2, K, H, K, W] | |
up_flow = up_flow.reshape(b, flow_channel, upsample_factor * h, | |
upsample_factor * w) # [B, 2, K*H, K*W] | |
return up_flow | |
def forward(self, image_list, mix_enable=True, layer=6): | |
if layer is not None: | |
self.MT.num_layers = layer | |
self.num_layers = layer | |
results_dict = {} | |
padding = self.v1_kz * self.scale_factor | |
with torch.no_grad(): | |
if image_list[0].max() > 10: | |
image_list = [img / 255.0 for img in image_list] # [B, 1, H, W] 0-1 | |
if image_list[0].shape[1] == 3: | |
# convert to gray using transform Gray = R*0.299 + G*0.587 + B*0.114 | |
image_list = [img[:, 0, :, :] * 0.299 + img[:, 1, :, :] * 0.587 + img[:, 2, :, :] * 0.114 for img in | |
image_list] | |
image_list = [img.unsqueeze(1) for img in image_list] | |
B, _, H, W = image_list[0].shape | |
MT_size = (H // 8, W // 8) | |
with autocast(enabled=mix_enable): | |
# with torch.no_grad(): # TODO: only for test wheather a trainable V1 is needed. | |
st_component = self.ffv1(image_list) | |
# viz_img_seq(image_scale, if_debug=True) | |
if self.num_layers == 0: | |
motion_feature = [st_component] | |
flows = [self.decoder(feature) for feature in motion_feature] | |
flows_up = [self.upsample_flow(flow, feature=None, bilinear=True, upsample_factor=8) for flow in flows] | |
results_dict["flow_seq"] = flows_up | |
return results_dict | |
motion_feature, attn = self.MT.forward_save_mem(st_component) | |
flow_v1 = self.decoder(st_component) | |
flows = [flow_v1] + [self.decoder(feature) for feature in motion_feature] | |
flows_bi = [self.upsample_flow(flow, feature=None, bilinear=True, upsample_factor=8) for flow in flows] | |
flows_up = [flows_bi[0]] + \ | |
[self.upsample_flow(flows, upsampler=self.upsampler_1, feature=attn, upsample_factor=8) for | |
flows, attn in zip(flows[1:], attn)] | |
assert len(flows_bi) == len(flows_up) | |
results_dict["flow_seq"] = flows_up | |
results_dict["flow_seq_bi"] = flows_bi | |
return results_dict | |
def forward_test(self, image_list, mix_enable=True, layer=6): | |
if layer is not None: | |
self.MT.num_layers = layer | |
self.num_layers = layer | |
results_dict = {} | |
padding = self.v1_kz * self.scale_factor | |
with torch.no_grad(): | |
if image_list[0].max() > 10: | |
image_list = [img / 255.0 for img in image_list] # [B, 1, H, W] 0-1 | |
B, _, H, W = image_list[0].shape | |
MT_size = (H // 8, W // 8) | |
with autocast(enabled=mix_enable): | |
st_component = self.ffv1(image_list) | |
# viz_img_seq(image_scale, if_debug=True) | |
if self.num_layers == 0: | |
motion_feature = [st_component] | |
flows = [self.decoder(feature) for feature in motion_feature] | |
flows_up = [self.upsample_flow(flow, feature=None, bilinear=True, upsample_factor=8) for flow in flows] | |
results_dict["flow_seq"] = flows_up | |
return results_dict | |
motion_feature, attn, _ = self.MT.forward_save_mem(st_component) | |
flow_v1 = self.decoder(st_component) | |
flows = [flow_v1] + [self.decoder(feature) for feature in motion_feature] | |
flows_bi = [self.upsample_flow(flow, feature=None, bilinear=True, upsample_factor=8) for flow in flows] | |
flows_up = [flows_bi[0]] + \ | |
[self.upsample_flow(flows, upsampler=self.upsampler_1, feature=attn, upsample_factor=8) for | |
flows, attn in zip(flows[1:], attn)] | |
assert len(flows_bi) == len(flows_up) | |
results_dict["flow_seq"] = flows_up | |
results_dict["flow_seq_bi"] = flows_bi | |
return results_dict | |
def forward_viz(self, image_list, layer=None, x=50, y=50): | |
x = x / 100 | |
y = y / 100 | |
if layer is not None: | |
self.MT.num_layers = layer | |
results_dict = {} | |
padding = self.v1_kz * self.scale_factor | |
with torch.no_grad(): | |
if image_list[0].max() > 10: | |
image_list = [img / 255.0 for img in image_list] # [B, 1, H, W] 0-1 | |
if image_list[0].shape[1] == 3: | |
# convert to gray using transform Gray = R*0.299 + G*0.587 + B*0.114 | |
image_list = [img[:, 0, :, :] * 0.299 + img[:, 1, :, :] * 0.587 + img[:, 2, :, :] * 0.114 for img in | |
image_list] | |
image_list = [img.unsqueeze(1) for img in image_list] | |
image_list_ori = deepcopy(image_list) | |
B, _, H, W = image_list[0].shape | |
MT_size = (H // 8, W // 8) | |
with autocast(enabled=True): | |
st_component = self.ffv1(image_list) | |
activation = self.ffv1.visualize_activation(st_component) | |
# viz_img_seq(image_scale, if_debug=True) | |
motion_feature, attn, attn_viz = self.MT(st_component) | |
flow_v1 = self.decoder(st_component) | |
flows = [flow_v1] + [self.decoder(feature) for feature in motion_feature] | |
flows_bi = [self.upsample_flow(flow, feature=None, bilinear=True, upsample_factor=8) for flow in flows] | |
flows_up = [flows_bi[0]] + \ | |
[self.upsample_flow(flows, upsampler=self.upsampler_1, feature=attn, upsample_factor=8) for | |
flows, attn in zip(flows[1:], attn)] | |
assert len(flows_bi) == len(flows_up) | |
results_dict["flow_seq"] = flows_up | |
# select 1,3,5,7 | |
flows_up = [flows_up[i] for i in [0, 2, 4]] + [flows_up[-1]] | |
attn_viz = [attn_viz[i] for i in [0, 2, 4]] + [attn_viz[-1]] | |
flow = plt_show_img_flow(image_list_ori, flows_up) | |
h = int(MT_size[0] * y) | |
w = int(MT_size[1] * x) | |
attention = plt_attention(attn_viz, h=h, w=w) | |
print("done") | |
results_dict["activation"] = activation | |
results_dict["attention"] = attention | |
results_dict["flow"] = flow | |
plt.clf() | |
plt.cla() | |
plt.close() | |
return results_dict | |
def num_parameters(self): | |
return sum( | |
[p.data.nelement() if p.requires_grad else 0 for p in self.parameters()]) | |
def init_weights(self): | |
for layer in self.named_modules(): | |
if isinstance(layer, nn.Conv2d): | |
nn.init.kaiming_normal_(layer.weight) | |
if layer.bias is not None: | |
nn.init.constant_(layer.bias, 0) | |
if isinstance(layer, nn.Conv1d): | |
nn.init.kaiming_normal_(layer.weight) | |
if layer.bias is not None: | |
nn.init.constant_(layer.bias, 0) | |
elif isinstance(layer, nn.ConvTranspose2d): | |
nn.init.kaiming_normal_(layer.weight) | |
if layer.bias is not None: | |
nn.init.constant_(layer.bias, 0) | |
def demo(file=None): | |
import time | |
from utils import torch_utils as utils | |
frame_list = [torch.randn([4, 1, 512, 512], device="cuda")] * 11 | |
model = FFV1DNN(num_scales=8, scale_factor=16, num_cells=256, upsample_factor=8, num_layers=6, | |
feature_channels=256).cuda() | |
if file is not None: | |
model = utils.restore_model(model, file) | |
print(model.num_parameters()) | |
for i in range(100): | |
start = time.time() | |
output = model.forward_viz(frame_list, layer=7) | |
# print(output["flow_seq"][-1]) | |
torch.mean(output["flow_seq"][-1]).backward() | |
print(torch.any(torch.isnan(output["flow_seq"][-1]))) | |
end = time.time() | |
print(end - start) | |
print("#================================++#") | |
if __name__ == '__main__': | |
FFV1DNN.demo(None) | |