import numpy as np import math import torch from io import BytesIO import numpy from torch import nn from torch.nn import functional as F import matplotlib.pyplot as plt import os import pandas as pd import imageio from torch.cuda.amp import autocast as autocast def cart2pol(x, y): rho = np.sqrt(x ** 2 + y ** 2) phi = np.arctan2(y, x) return (rho, phi) def pol2cart(rho, phi): x = rho * np.cos(phi) y = rho * np.sin(phi) return (x, y) def inverse_sigmoid(p): return np.log(p / (1 - p)) def artanh(y): return 0.5 * np.log((1 + y) / (1 - y)) class V1(nn.Module): """each input includes 10 frame with 25 frame/sec sampling rate temporal window size = 5 frame(200ms) spatial window size = 5*2 + 1 = 11 spatial filter is lambda is frequency of cos wave """ def __init__(self, spatial_num=32, scale_num=8, scale_factor=16, kernel_radius=7, num_ft=32, kernel_size=6, average_time=True): super(V1, self).__init__() def make_param(in_channels, values, requires_grad=True, dtype=None): if dtype is None: dtype = 'float32' values = numpy.require(values, dtype=dtype) n = in_channels * len(values) data = torch.from_numpy(values).view(1, -1) data = data.repeat(in_channels, 1) return torch.nn.Parameter(data=data, requires_grad=requires_grad) assert spatial_num == num_ft scale_each_level = np.exp(1 / (scale_num - 1) * np.log(1 / scale_factor)) self.scale_each_level = scale_each_level self.scale_num = scale_num self.cell_index = 0 self.spatial_filter = nn.ModuleList([GaborFilters(kernel_radius=kernel_radius, num_units=spatial_num,random=False) for i in range(scale_num)]) self.temporal_decay = 0.2 self.spatial_decay = 0.2 self.spatial_radius = kernel_radius self.spatial_kernel_size = kernel_radius * 2 + 1 self.spatial_num = spatial_num self.temporal_filter = nn.ModuleList([TemporalFilter(num_ft=num_ft, kernel_size=kernel_size, random=False) for i in range(scale_num)]) # 16 filter self.n_frames = 11 self._num_after_st = spatial_num * scale_num if not average_time: self._num_after_st = self._num_after_st * (self.n_frames - kernel_size + 1) if average_time: self.temporal_pooling = make_param(self._num_after_st, np.ones((self.n_frames - kernel_size + 1)), requires_grad=True) # TODO: concentrate on middle frame self.temporal_pooling = make_param(self._num_after_st, [0.05, 0.1, 0.4, 0.4, 0.1, 0.05], requires_grad=True) self.norm_sigma = make_param(1, np.array([0.2]), requires_grad=True) self.spontaneous_firing = make_param(1, np.array([0.3]), requires_grad=True) self.norm_k = make_param(1, np.array([4.0]), requires_grad=True) self._average_time = average_time self.t_sin = None self.t_cos = None self.s_sin = None self.s_cos = None def infer_scale(self, x, scale): # x should be list of B,1,H,W energy_list = [] n = len(x) B, C, H, W = x[0].shape x = [img.unsqueeze(0) for img in x] x = torch.cat(x, dim=0).reshape(n * B, C, H, W) sy = x.size(2) sx = x.size(3) s_sin = self.s_sin s_cos = self.s_cos gb_sin = s_sin.view(self.spatial_num, 1, self.spatial_kernel_size, self.spatial_kernel_size) gb_cos = s_cos.view(self.spatial_num, 1, self.spatial_kernel_size, self.spatial_kernel_size) # flip kernel gb_sin = torch.flip(gb_sin, dims=[-1, -2]) gb_cos = torch.flip(gb_cos, dims=[-1, -2]) res_sin = F.conv2d(input=x, weight=gb_sin, padding=self.spatial_radius, groups=1) res_cos = F.conv2d(input=x, weight=gb_cos, padding=self.spatial_radius, groups=1) res_sin = res_sin.view(B, -1, sy, sx) res_cos = res_cos.view(B, -1, sy, sx) g_asin_list = res_sin.reshape(n, B, -1, H, W) g_acos_list = res_cos.reshape(n, B, -1, H, W) for channel in range(self.spatial_filter[0].n_channels_post_conv): k_sin = self.t_sin[channel, ...][None] k_cos = self.t_cos[channel, ...][None] # spatial filter g_asin, g_acos = g_asin_list[:, :, channel, :, :], g_acos_list[:, :, channel, :, :] # n,b,h,w g_asin = g_asin.reshape(n, B * H * W, 1).permute(1, 2, 0) # bhw,1,n g_acos = g_acos.reshape(n, B * H * W, 1).permute(1, 2, 0) # reverse the impulse response k_sin = torch.flip(k_sin, dims=(-1,)) k_cos = torch.flip(k_cos, dims=(-1,)) # a = F.conv1d(g_acos, k_sin, padding="valid", bias=None) b = F.conv1d(g_asin, k_cos, padding="valid", bias=None) g_o = a + b a = F.conv1d(g_acos, k_cos, padding="valid", bias=None) b = F.conv1d(g_asin, k_sin, padding="valid", bias=None) g_e = a - b energy_component = g_o ** 2 + g_e ** 2 + self.spontaneous_firing.square() energy_component = energy_component.reshape(B, H, W, a.size(-1)).permute(0, 3, 1, 2) if self._average_time: # average motion energy across time total_channel = scale * self.spatial_num + channel pooling = self.temporal_pooling[total_channel][None, ..., None, None] energy_component = abs(torch.mean(energy_component * pooling, dim=1, keepdim=True)) energy_list.append(energy_component) energy_list = torch.cat(energy_list, dim=1) return energy_list def forward(self, image_list): _, _, H, W = image_list[0].shape MT_size = (H // 8, W // 8) self.cell_index = 0 with torch.no_grad(): if image_list[0].max() > 10: image_list = [img / 255.0 for img in image_list] # [B, 1, H, W] 0-1 # I_mean = torch.cat(image_list, dim=0).mean() # image_list = [(image - I_mean) for image in image_list] ms_com = [] for scale in range(self.scale_num): self.t_sin, self.t_cos = self.temporal_filter[scale].make_temporal_filter() self.s_sin, self.s_cos = self.spatial_filter[scale].make_gabor_filters(quadrature=True) st_component = self.infer_scale(image_list, scale) st_component = F.interpolate(st_component, size=MT_size, mode="bilinear", align_corners=True) ms_com.append(st_component) image_list = [F.interpolate(img, scale_factor=self.scale_each_level, mode="bilinear") for img in image_list] motion_energy = self.normalize(torch.cat(ms_com, dim=1)) # self.visualize_activation(motion_energy) return motion_energy def normalize(self, x): # TODO sum_activation = torch.mean(x, dim=[1], keepdim=True) + torch.square(self.norm_sigma) x = self.norm_k.abs() * x / sum_activation return x def _get_v1_order(self): thetas = [gabor_scale.thetas for gabor_scale in self.spatial_filter] fss = [gabor_scale.fs for gabor_scale in self.spatial_filter] fts = [temporal_scale.ft for temporal_scale in self.temporal_filter] scale_each_level = self.scale_each_level scale_num = self.scale_num neural_representation = [] index = 0 for scale_idx in range(len(thetas)): theta_scale = thetas[scale_idx] theta_scale = torch.sigmoid(theta_scale) * 2 * torch.pi # spatial orientation constrain to 0-pi fs_scale = fss[scale_idx] fs_scale = torch.sigmoid(fs_scale) * 0.25 fs_scale = fs_scale * (scale_each_level ** scale_idx) ft_scale = fts[scale_idx] ft_scale = torch.sigmoid(ft_scale) * 0.25 theta_scale = theta_scale.squeeze().cpu().detach().numpy() fs_scale = fs_scale.squeeze().cpu().detach().numpy() ft_scale = ft_scale.squeeze().cpu().detach().numpy() for gabor_idx in range(len(theta_scale)): speed = ft_scale[gabor_idx] / fs_scale[gabor_idx] assert speed >= 0 angle = theta_scale[gabor_idx] a = {"theta": -angle + np.pi, "fs": fs_scale[gabor_idx], "ft": ft_scale[gabor_idx], "speed": speed, "index": index} index = index + 1 neural_representation.append(a) return neural_representation def visualize_activation(self, activation, if_log=True): neural_representation = self._get_v1_order() activation = activation[:, :, 14:-14, 14:-14] # eliminate boundary activation = torch.mean(activation, dim=[2, 3], keepdim=False)[0] ax1 = plt.subplot(111, projection='polar') theta_list = [] v_list = [] energy_list = [] for index in range(len(neural_representation)): v = neural_representation[index]["speed"] theta = neural_representation[index]["theta"] location = neural_representation[index]["index"] energy = activation.squeeze()[location].cpu().detach().numpy() theta_list.append(theta) v_list.append(v) energy_list.append(energy) v_list, theta_list, energy_list = np.array(v_list), np.array(theta_list), np.array(energy_list) x, y = pol2cart(v_list, theta_list) plt.scatter(theta_list, v_list, c=energy_list, cmap="rainbow", s=(energy_list + 20), alpha=0.5) plt.axis('on') if if_log: ax1.set_rscale('symlog') plt.colorbar() energy_list = np.expand_dims(energy_list, 0).repeat(len(theta_list), 0) buf = BytesIO() plt.savefig(buf, format='png') buf.seek(0) # read the buffer and convert to an image image = imageio.imread(buf) buf.close() plt.close() plt.clf() return image @staticmethod def demo(): input = [torch.ones(2, 1, 256, 256).cuda() for k in range(11)] model = V1(spatial_num=16, scale_num=16, scale_factor=16, kernel_radius=7, num_ft=16, kernel_size=6, average_time=True).cuda() for i in range(100): import time start = time.time() with autocast(enabled=True): x = model(input) print(x.shape) torch.mean(x).backward() end = time.time() print(end - start) print("#================================++#") @property def num_after_st(self): return self._num_after_st class TemporalFilter(nn.Module): def __init__(self, in_channels=1, num_ft=8, kernel_size=6, random=True): # 40ms per time unit, 200ms -> 5+1 frames # use exponential decay plus sin wave super().__init__() self.kernel_size = kernel_size def make_param(in_channels, values, requires_grad=True, dtype=None): if dtype is None: dtype = 'float32' values = numpy.require(values, dtype=dtype) n = in_channels * len(values) data = torch.from_numpy(values).view(1, -1) data = data.repeat(in_channels, 1) return torch.nn.Parameter(data=data, requires_grad=requires_grad) indices = torch.arange(kernel_size, dtype=torch.float32) self.register_buffer('indices', indices) if random: self.ft = make_param(in_channels, values=inverse_sigmoid(numpy.random.uniform(0.01, 0.99, num_ft)), requires_grad=True) self.tao = make_param(in_channels, values=numpy.arange(num_ft) / 2 + 1, requires_grad=True) else: # evenly distributed self.ft = make_param(in_channels, values=inverse_sigmoid(numpy.linspace(0.01, 0.99, num_ft)), requires_grad=True) self.tao = make_param(in_channels, values=numpy.arange(num_ft) / 2 + 1, requires_grad=True) self.feat_dim = num_ft self.temporal_decay = 0.2 def make_temporal_filter(self): fts = torch.sigmoid(self.ft) * 0.25 tao = torch.sigmoid(self.tao) * (-self.kernel_size / np.log(self.temporal_decay)) t = self.indices fts = fts.view(1, fts.shape[1], 1) tao = tao.view(1, tao.shape[1], 1) t = t.view(1, 1, t.shape[0]) temporal_sin = torch.exp(-t / tao) * torch.sin(2 * torch.pi * fts * t) temporal_cos = torch.exp(-t / tao) * torch.cos(2 * torch.pi * fts * t) temporal_sin = temporal_sin.view(-1, self.kernel_size) temporal_cos = temporal_cos.view(-1, self.kernel_size) temporal_sin = temporal_sin.view(self.feat_dim, 1, self.kernel_size) temporal_cos = temporal_cos.view(self.feat_dim, 1, self.kernel_size) # temporal_sin = torch.chunk(temporal_sin, dim=0, chunks=self._feat_dim) # temporal_cos = torch.chunk(temporal_cos, dim=0, chunks=self._feat_dim) return temporal_sin, temporal_cos # 1,kz def demo_temporal_filter(self, points=100): fts = torch.sigmoid(self.ft) * 0.25 tao = torch.sigmoid(self.tao) * (-(self.kernel_size - 1) / np.log(self.temporal_decay)) t = torch.linspace(self.indices[0], self.indices[-1], steps=points) fts = fts.view(1, fts.shape[1], 1) tao = tao.view(1, tao.shape[1], 1) t = t.view(1, 1, t.shape[0]) print("ft:" + str(fts)) print("tao:" + str(tao)) temporal_sin = torch.exp(-t / tao) * torch.sin(2 * torch.pi * fts * t) temporal_cos = torch.exp(-t / tao) * torch.cos(2 * torch.pi * fts * t) temporal_sin = temporal_sin.view(-1, points) temporal_cos = temporal_cos.view(-1, points) temporal_sin = temporal_sin.view(self.feat_dim, 1, points) temporal_cos = temporal_cos.view(self.feat_dim, 1, points) # temporal_sin = torch.chunk(temporal_sin, dim=0, chunks=self._feat_dim) # temporal_cos = torch.chunk(temporal_cos, dim=0, chunks=self._feat_dim) return temporal_sin, temporal_cos # 1,kz def forward(self, x_sin, x_cos): in_channels = x_sin.size(1) n = x_sin.size(2) # batch, c, sequence me = [] t_sin, t_cos = self.make_temporal_filter() for n_t in range(self.feat_dim): k_sin = t_sin[n_t, ...].expand(in_channels, -1, -1) k_cos = t_cos[n_t, ...].expand(in_channels, -1, -1) a = F.conv1d(x_sin, weight=k_cos, padding="same", groups=in_channels, bias=None) b = F.conv1d(x_cos, weight=k_sin, padding="same", groups=in_channels, bias=None) g_o = a + b a = F.conv1d(x_sin, weight=k_sin, padding="same", groups=in_channels, bias=None) b = F.conv1d(x_cos, weight=k_cos, padding="same", groups=in_channels, bias=None) g_e = a - b energy_component = g_o ** 2 + g_e ** 2 me.append(energy_component) return me class GaborFilters(nn.Module): def __init__(self, in_channels=1, kernel_radius=7, num_units=512, random=True ): # the total number of or units for each scale super().__init__() self.in_channels = in_channels kernel_size = kernel_radius * 2 + 1 self.kernel_size = kernel_size self.kernel_radius = kernel_radius def make_param(in_channels, values, requires_grad=True, dtype=None): if dtype is None: dtype = 'float32' values = numpy.require(values, dtype=dtype) n = in_channels * len(values) data = torch.from_numpy(values).view(1, -1) data = data.repeat(in_channels, 1) return torch.nn.Parameter(data=data, requires_grad=requires_grad) # build all learnable parameters # random distribution if random: self.sigmas = make_param(in_channels, inverse_sigmoid(np.random.uniform(0.8, 0.99, num_units))) self.fs = make_param(in_channels, values=inverse_sigmoid(numpy.random.uniform(0.2, 0.8, num_units))) # maximun is 0.25 cycle/frame self.gammas = make_param(in_channels, numpy.ones(num_units)) # TODO: fix gamma or not self.psis = make_param(in_channels, np.zeros(num_units), requires_grad=False) # fix phase self.thetas = make_param(in_channels, values=inverse_sigmoid(numpy.random.uniform(0.01, 0.99, num_units)), requires_grad=True) else: # evenly distribution self.sigmas = make_param(in_channels, inverse_sigmoid(np.linspace(0.8, 0.99, num_units))) self.fs = make_param(in_channels, values=inverse_sigmoid(numpy.linspace(0.01, 0.99, num_units))) # maximun is 0.25 cycle/frame self.gammas = make_param(in_channels, numpy.ones(num_units)) # TODO: fix gamma or not self.psis = make_param(in_channels, np.zeros(num_units), requires_grad=False) # fix phase self.thetas = make_param(in_channels, values=inverse_sigmoid(numpy.linspace(0, 1, num_units)), requires_grad=True) indices = torch.arange(kernel_size, dtype=torch.float32) - (kernel_size - 1) / 2 self.register_buffer('indices', indices) self.spatial_decay = 0.5 # number of channels after the conv self.n_channels_post_conv = num_units def make_gabor_filters(self, quadrature=True): sigmas = torch.sigmoid(self.sigmas) * np.sqrt( (self.kernel_radius - 1) ** 2 * 0.5 / np.log( 1 / self.spatial_decay)) # std of gauss win decay to 0.2 by log(0.2) fs = torch.sigmoid(self.fs) * 0.25 # frequency of cos and sine wave keep positive, must > 2 to avoid aliasing gammas = torch.abs(self.gammas) # shape of gauss win, set as 1 by default psis = self.psis # phase of cos wave thetas = torch.sigmoid(self.thetas) * 2 * torch.pi # spatial orientation constrain to 0-2pi y = self.indices x = self.indices in_channels = sigmas.shape[0] assert in_channels == fs.shape[0] assert in_channels == gammas.shape[0] kernel_size = y.shape[0], x.shape[0] sigmas = sigmas.view(in_channels, sigmas.shape[1], 1, 1) fs = fs.view(in_channels, fs.shape[1], 1, 1) gammas = gammas.view(in_channels, gammas.shape[1], 1, 1) psis = psis.view(in_channels, psis.shape[1], 1, 1) thetas = thetas.view(in_channels, thetas.shape[1], 1, 1) y = y.view(1, 1, y.shape[0], 1) x = x.view(1, 1, 1, x.shape[0]) sigma_x = sigmas sigma_y = sigmas / gammas sin_t = torch.sin(thetas) cos_t = torch.cos(thetas) y_theta = -x * sin_t + y * cos_t x_theta = x * cos_t + y * sin_t if quadrature: gb_cos = torch.exp(-.5 * (x_theta ** 2 / sigma_x ** 2 + y_theta ** 2 / sigma_y ** 2)) \ * torch.cos(2.0 * math.pi * x_theta * fs + psis) gb_sin = torch.exp(-.5 * (x_theta ** 2 / sigma_x ** 2 + y_theta ** 2 / sigma_y ** 2)) \ * torch.sin(2.0 * math.pi * x_theta * fs + psis) gb_cos = gb_cos.reshape(-1, 1, kernel_size[0], kernel_size[1]) gb_sin = gb_sin.reshape(-1, 1, kernel_size[0], kernel_size[1]) # remove DC gb_cos = gb_cos - torch.sum(gb_cos, dim=[-1, -2], keepdim=True) / (kernel_size[0] * kernel_size[1]) gb_sin = gb_sin - torch.sum(gb_sin, dim=[-1, -2], keepdim=True) / (kernel_size[0] * kernel_size[1]) return gb_sin, gb_cos else: gb = torch.exp(-.5 * (x_theta ** 2 / sigma_x ** 2 + y_theta ** 2 / sigma_y ** 2)) \ * torch.cos(2.0 * math.pi * x_theta * fs + psis) gb = gb.view(-1, kernel_size[0], kernel_size[1]) return gb def forward(self, x): batch_size = x.size(0) sy = x.size(2) sx = x.size(3) gb_sin, gb_cos = self.make_gabor_filters(quadrature=True) assert gb_sin.shape[0] == self.n_channels_post_conv assert gb_sin.shape[2] == self.kernel_size assert gb_sin.shape[3] == self.kernel_size gb_sin = gb_sin.view(self.n_channels_post_conv, 1, self.kernel_size, self.kernel_size) gb_cos = gb_cos.view(self.n_channels_post_conv, 1, self.kernel_size, self.kernel_size) # flip ke gb_sin = torch.flip(gb_sin, dims=[-1, -2]) gb_cos = torch.flip(gb_cos, dims=[-1, -2]) res_sin = F.conv2d(input=x, weight=gb_sin, padding=self.kernel_radius, groups=self.in_channels) res_cos = F.conv2d(input=x, weight=gb_cos, padding=self.kernel_radius, groups=self.in_channels) if self.rotation_invariant: res_sin = res_sin.view(batch_size, self.in_channels, -1, self.n_thetas, sy, sx) res_sin, _ = res_sin.max(dim=3) res_cos = res_cos.view(batch_size, self.in_channels, -1, self.n_thetas, sy, sx) res_cos, _ = res_cos.max(dim=3) res_sin = res_sin.view(batch_size, -1, sy, sx) res_cos = res_cos.view(batch_size, -1, sy, sx) return res_sin, res_cos def demo_gabor_filters(self, quadrature=True, points=100): sigmas = torch.sigmoid(self.sigmas) * np.sqrt( (self.kernel_radius - 1) ** 2 * 0.5 / np.log( 1 / self.spatial_decay)) # std of gauss win decay to 0.2 by log(0.2) fs = torch.sigmoid(self.fs) * 0.25 # frequency of cos and sine wave keep positive, must > 2 to avoid aliasing gammas = torch.abs(self.gammas) # shape of gauss win, set as 1 by default thetas = torch.sigmoid(self.thetas) * 2 * torch.pi # spatial orientation constrain to 0-2pi psis = self.psis # phase of cos wave print("theta:" + str(thetas)) print("fs:" + str(fs)) x = torch.linspace(self.indices[0], self.indices[-1], points) y = torch.linspace(self.indices[0], self.indices[-1], points) in_channels = sigmas.shape[0] assert in_channels == fs.shape[0] assert in_channels == gammas.shape[0] kernel_size = y.shape[0], x.shape[0] sigmas = sigmas.view(in_channels, sigmas.shape[1], 1, 1) fs = fs.view(in_channels, fs.shape[1], 1, 1) gammas = gammas.view(in_channels, gammas.shape[1], 1, 1) psis = psis.view(in_channels, psis.shape[1], 1, 1) thetas = thetas.view(in_channels, thetas.shape[1], 1, 1) y = y.view(1, 1, y.shape[0], 1) x = x.view(1, 1, 1, x.shape[0]) sigma_x = sigmas sigma_y = sigmas / gammas sin_t = torch.sin(thetas) cos_t = torch.cos(thetas) y_theta = -x * sin_t + y * cos_t x_theta = x * cos_t + y * sin_t if quadrature: gb_cos = torch.exp(-.5 * (x_theta ** 2 / sigma_x ** 2 + y_theta ** 2 / sigma_y ** 2)) \ * torch.cos(2.0 * math.pi * x_theta * fs + psis) gb_sin = torch.exp(-.5 * (x_theta ** 2 / sigma_x ** 2 + y_theta ** 2 / sigma_y ** 2)) \ * torch.sin(2.0 * math.pi * x_theta * fs + psis) gb_cos = gb_cos.reshape(-1, 1, points, points) gb_sin = gb_sin.reshape(-1, 1, points, points) # remove DC gb_cos = gb_cos - torch.sum(gb_cos, dim=[-1, -2], keepdim=True) / (points * points) gb_sin = gb_sin - torch.sum(gb_sin, dim=[-1, -2], keepdim=True) / (points * points) return gb_sin, gb_cos else: gb = torch.exp(-.5 * (x_theta ** 2 / sigma_x ** 2 + y_theta ** 2 / sigma_y ** 2)) \ * torch.cos(2.0 * math.pi * x_theta * fs + psis) gb = gb.view(-1, kernel_size[0], kernel_size[1]) return gb def te_gabor_(num_units=48): s_point = 100 s_kz = 7 gb_sin, gb_cos = GaborFilters(num_units=num_units, kernel_radius=s_kz).demo_gabor_filters(points=s_point) gb = gb_sin ** 2 + gb_cos ** 2 print(gb_sin.shape) for c in range(gb_sin.size(0)): plt.subplot(1, 3, 1) curve = gb_cos[c].detach().cpu().squeeze().numpy() plt.imshow(curve) plt.subplot(1, 3, 2) curve = gb_sin[c].detach().cpu().squeeze().numpy() plt.imshow(curve) plt.subplot(1, 3, 3) curve = gb[c].detach().cpu().squeeze().numpy() plt.imshow(curve) plt.show() def te_spatial_temporal(): t_point = 6 * 100 s_point = 14 * 100 s_kz = 7 t_kz = 6 filenames = [] gb_sin_b, gb_cos_b = GaborFilters(num_units=48, kernel_radius=s_kz).demo_gabor_filters(points=s_point) temporal = TemporalFilter(num_ft=2, kernel_size=t_kz) t_sin, t_cos = temporal.demo_temporal_filter(points=t_point) x = np.linspace(0, t_kz, t_point) index = 0 for i in range(gb_sin_b.size(0)): for j in range(t_sin.size(0)): plt.figure(figsize=(14, 9), dpi=80) plt.subplot(2, 3, 1) curve = gb_sin_b[i].squeeze().detach().numpy() plt.imshow(curve) plt.title("Gabor Sin") plt.subplot(2, 3, 2) curve = gb_cos_b[i].squeeze().detach().numpy() plt.imshow(curve) plt.title("Gabor Cos") plt.subplot(2, 3, 3) curve = t_sin[j].squeeze().detach().numpy() plt.plot(x, curve, label='sin') plt.title("Temporal Sin") curve = t_cos[j].squeeze().detach().numpy() plt.plot(x, curve, label='cos') plt.xlabel('Time (s)') plt.ylabel('Response to pulse at t=0') plt.legend() plt.title("Temporal filter") gb_sin = gb_sin_b[i].squeeze().detach()[5, :] gb_cos = gb_cos_b[i].squeeze().detach()[5, :] a = np.outer(t_cos[j].detach(), gb_sin) b = np.outer(t_sin[j].detach(), gb_cos) g_o = a + b a = np.outer(t_sin[j].detach(), gb_sin) b = np.outer(t_cos[j].detach(), gb_cos) g_e = a - b energy_component = g_o ** 2 + g_e ** 2 plt.subplot(2, 3, 4) curve = g_o plt.imshow(curve, cmap="gray") plt.title("Spatial Temporal even") plt.subplot(2, 3, 5) curve = g_e plt.imshow(curve, cmap="gray") plt.title("Spatial Temporal odd") plt.subplot(2, 3, 6) curve = energy_component plt.imshow(curve, cmap="gray") plt.title("energy") plt.savefig('filter_%d.png' % (index)) filenames.append('filter_%d.png' % (index)) index += 1 plt.show() # build gif with imageio.get_writer('filters_orientation.gif', mode='I') as writer: for filename in filenames: image = imageio.imread(filename) writer.append_data(image) # Remove files for filename in set(filenames): os.remove(filename) def te_temporal_(): k_size = 6 temporal = TemporalFilter(n_tao=2, num_ft=8, kernel_size=k_size) sin, cos = temporal.demo_temporal_filter() print(sin.shape) x = np.linspace(0, k_size, k_size * 100) # plot temporal filters to illustrate what they look like. for c in range(sin.size(0)): curve = cos[c].detach().cpu().squeeze().numpy() plt.plot(x, curve, label='cos') curve = sin[c].detach().cpu().squeeze().numpy() plt.plot(x, curve, label='sin') plt.xlabel('Time (s)') plt.ylabel('Response to pulse at t=0') plt.legend() plt.show() def circular_hist(ax, x, bins=16, density=True, offset=0, gaps=True): """ Produce a circular histogram of angles on ax. Parameters ---------- ax : matplotlib.axes._subplots.PolarAxesSubplot axis instance created with subplot_kw=dict(projection='polar'). x : array Angles to plot, expected in units of radians. bins : int, optional Defines the number of equal-width bins in the range. The default is 16. density : bool, optional If True plot frequency proportional to area. If False plot frequency proportional to radius. The default is True. offset : float, optional Sets the offset for the location of the 0 direction in units of radians. The default is 0. gaps : bool, optional Whether to allow gaps between bins. When gaps = False the bins are forced to partition the entire [-pi, pi] range. The default is True. Returns ------- n : array or list of arrays The number of values in each bin. bins : array The edges of the bins. patches : `.BarContainer` or list of a single `.Polygon` Container of individual artists used to create the histogram or list of such containers if there are multiple input datasets. """ # Wrap angles to [-pi, pi) x = (x + np.pi) % (2 * np.pi) - np.pi # Force bins to partition entire circle if not gaps: bins = np.linspace(-np.pi, np.pi, num=bins + 1) # Bin data and record counts n, bins = np.histogram(x, bins=bins) # Compute width of each bin widths = np.diff(bins) # By default plot frequency proportional to area if density: # Area to assign each bin area = n / x.size # Calculate corresponding bin radius radius = (area / np.pi) ** .5 # Otherwise plot frequency proportional to radius else: radius = n # Plot data on ax patches = ax.bar(bins[:-1], radius, zorder=1, align='edge', width=widths, edgecolor='C0', fill=False, linewidth=1) # Set the direction of the zero angle ax.set_theta_offset(offset) # Remove ylabels for area plots (they are mostly obstructive) if density: ax.set_yticks([]) return n, bins, patches def show_trained_model(file_name="/home/2TSSD/experiment/FFMEDNN/Sintel_fixv1_10.62_ckpt.pth.tar"): import utils.torch_utils as utils from model.fle_version_2_3.FFV1MT_MS import FFV1DNN model = FFV1DNN(num_scales=8, num_cells=256, upsample_factor=8, feature_channels=256, scale_factor=16, num_layers=6) # model = utils.restore_model(model, file_name) model = model.ffv1 t_point = 100 s_point = 100 t_kz = 6 filenames = [] x = np.arange(0, 6) * 40 x = np.repeat(x[None], axis=0, repeats=256) temporal = model.temporal_pooling.data.cpu().squeeze().numpy() mean = np.mean(temporal, axis=0) plt.figure(figsize=(10, 10)) plt.subplot(2, 1, 1) for idx in range(0, 256): plt.plot(x[idx], temporal[idx]) plt.subplot(2, 1, 2) plt.plot(x[0], mean, label="mean") plt.xlabel("times (ms)") plt.ylabel("temporal pooling weight") plt.legend() plt.grid(True) plt.show() neural_representation = model._get_v1_order() fs = np.array([ne["fs"] for ne in neural_representation]) ft = np.array([ne["ft"] for ne in neural_representation]) ax1 = plt.subplot(131, projection='polar') theta_list = [] v_list = [] energy_list = [] for index in range(len(neural_representation)): v = neural_representation[index]["speed"] theta = neural_representation[index]["theta"] theta_list.append(theta) v_list.append(v) v_list, theta_list = np.array(v_list), np.array(theta_list) x, y = pol2cart(v_list, theta_list) plt.scatter(theta_list, v_list, c=v_list, cmap="rainbow", s=(v_list + 20), alpha=0.8) plt.axis('on') # plt.colorbar() plt.grid(True) # plt.subplot(132, projection="polar") # plt.scatter(theta_list, np.ones_like(theta_list)) plt.subplot(132, projection='polar') plt.scatter(theta_list, np.ones_like(v_list)) lst = [] for scale in range(8): lst += ["scale %d" % scale] * 32 data = {"Spatial Frequency": fs, 'Temporal Frequency': ft, "Class": lst} df = pd.DataFrame(data=data) ax = plt.subplot(133, projection='polar') # theta_list = theta_list[v_list > (ft * v_list.mean())] print(len(theta_list)) bins_number = 8 # the [0, 360) interval will be subdivided into this # number of equal bins zone = np.pi / 8 theta_list[theta_list < (-np.pi + zone)] = theta_list[theta_list < (-np.pi + zone)] + np.pi * 2 bins = np.linspace(-np.pi + zone, np.pi + zone, bins_number + 1) n, _, _ = plt.hist(theta_list, bins, edgecolor="black") # ax.set_theta_offset(-np.pi / 8 - np.pi) ax.set_yticklabels([]) plt.grid(True) import seaborn as sns sns.jointplot(data=df, x="Spatial Frequency", y="Temporal Frequency", hue="Class", xlim=[0, 0.3], ylim=[0, 0.3]) plt.grid(True) g = sns.jointplot(data=df, x="Spatial Frequency", y="Temporal Frequency", xlim=[0, 0.25], ylim=[0, 0.25]) # g.plot_joint(sns.kdeplot, color="r", zorder=0, levels=6) plt.grid(True) plt.show() # show spatial frequency preference and temporal frequency preference. x = np.linspace(0, t_kz, t_point) index = 0 for scale in range(len(model.spatial_filter)): t_sin, t_cos = model.temporal_filter[scale].demo_temporal_filter(points=t_point) gb_sin_b, gb_cos_b = model.spatial_filter[scale].demo_gabor_filters(points=s_point) for i in range(gb_sin_b.size(0)): plt.figure(figsize=(14, 9), dpi=80) plt.subplot(2, 3, 1) curve = gb_sin_b[i].squeeze().detach().numpy() plt.imshow(curve) plt.title("Gabor Sin") plt.subplot(2, 3, 2) curve = gb_cos_b[i].squeeze().detach().numpy() plt.imshow(curve) plt.title("Gabor Cos") plt.subplot(2, 3, 3) curve = t_sin[i].squeeze().detach().numpy() plt.plot(x, curve, label='sin') plt.title("Temporal Sin") curve = t_cos[i].squeeze().detach().numpy() plt.plot(x, curve, label='cos') plt.xlabel('Time (s)') plt.ylabel('Response to pulse at t=0') plt.legend() plt.title("Temporal filter") gb_sin = gb_sin_b[i].squeeze().detach()[5, :] gb_cos = gb_cos_b[i].squeeze().detach()[5, :] a = np.outer(t_cos[i].detach(), gb_sin) b = np.outer(t_sin[i].detach(), gb_cos) g_o = a + b a = np.outer(t_sin[i].detach(), gb_sin) b = np.outer(t_cos[i].detach(), gb_cos) g_e = a - b energy_component = g_o ** 2 + g_e ** 2 plt.subplot(2, 3, 4) curve = g_o plt.imshow(curve, cmap="gray") plt.title("Spatial Temporal even") plt.subplot(2, 3, 5) curve = g_e plt.imshow(curve, cmap="gray") plt.title("Spatial Temporal odd") plt.subplot(2, 3, 6) curve = energy_component plt.imshow(curve, cmap="gray") plt.title("energy") plt.savefig('filter_%d.png' % (index)) filenames.append('filter_%d.png' % (index)) index += 1 # plt.show() # build gif with imageio.get_writer('filters_orientation.gif', mode='I') as writer: for filename in filenames: image = imageio.imread(filename) writer.append_data(image) # Remove files for filename in set(filenames): os.remove(filename) if __name__ == "__main__": show_trained_model() # V1.demo() # draw_polar() # # V1.demo() # # draw_polar() show_trained_model() # te_spatial_temporal()