import torch from torch import nn as nn from rlkit.pythonplusplus import identity import numpy as np class CNN(nn.Module): def __init__( self, input_width, input_height, input_channels, output_size, kernel_sizes, n_channels, strides, paddings, hidden_sizes=None, added_fc_input_size=0, batch_norm_conv=False, batch_norm_fc=False, init_w=1e-4, hidden_init=nn.init.xavier_uniform_, hidden_activation=nn.ReLU(), output_activation=identity, ): if hidden_sizes is None: hidden_sizes = [] assert len(kernel_sizes) == \ len(n_channels) == \ len(strides) == \ len(paddings) super().__init__() self.hidden_sizes = hidden_sizes self.input_width = input_width self.input_height = input_height self.input_channels = input_channels self.output_size = output_size self.output_activation = output_activation self.hidden_activation = hidden_activation self.batch_norm_conv = batch_norm_conv self.batch_norm_fc = batch_norm_fc self.added_fc_input_size = added_fc_input_size self.conv_input_length = self.input_width * self.input_height * self.input_channels self.conv_layers = nn.ModuleList() self.conv_norm_layers = nn.ModuleList() self.fc_layers = nn.ModuleList() self.fc_norm_layers = nn.ModuleList() for out_channels, kernel_size, stride, padding in \ zip(n_channels, kernel_sizes, strides, paddings): conv = nn.Conv2d(input_channels, out_channels, kernel_size, stride=stride, padding=padding) hidden_init(conv.weight) conv.bias.data.fill_(0) conv_layer = conv self.conv_layers.append(conv_layer) input_channels = out_channels # find output dim of conv_layers by trial and add normalization conv layers test_mat = torch.zeros(1, self.input_channels, self.input_width, self.input_height) # initially the model is on CPU (caller should then move it to GPU if for conv_layer in self.conv_layers: test_mat = conv_layer(test_mat) self.conv_norm_layers.append(nn.BatchNorm2d(test_mat.shape[1])) fc_input_size = int(np.prod(test_mat.shape)) # used only for injecting input directly into fc layers fc_input_size += added_fc_input_size for idx, hidden_size in enumerate(hidden_sizes): fc_layer = nn.Linear(fc_input_size, hidden_size) norm_layer = nn.BatchNorm1d(hidden_size) fc_layer.weight.data.uniform_(-init_w, init_w) fc_layer.bias.data.uniform_(-init_w, init_w) self.fc_layers.append(fc_layer) self.fc_norm_layers.append(norm_layer) fc_input_size = hidden_size self.last_fc = nn.Linear(fc_input_size, output_size) self.last_fc.weight.data.uniform_(-init_w, init_w) self.last_fc.bias.data.uniform_(-init_w, init_w) def forward(self, input): fc_input = (self.added_fc_input_size != 0) conv_input = input.narrow(start=0, length=self.conv_input_length, dim=1).contiguous() if fc_input: extra_fc_input = input.narrow(start=self.conv_input_length, length=self.added_fc_input_size, dim=1) # need to reshape from batch of flattened images into (channsls, w, h) h = conv_input.view(conv_input.shape[0], self.input_channels, self.input_height, self.input_width) h = self.apply_forward(h, self.conv_layers, self.conv_norm_layers, use_batch_norm=self.batch_norm_conv) # flatten channels for fc layers h = h.view(h.size(0), -1) if fc_input: h = torch.cat((h, extra_fc_input), dim=1) h = self.apply_forward(h, self.fc_layers, self.fc_norm_layers, use_batch_norm=self.batch_norm_fc) output = self.output_activation(self.last_fc(h)) return output def apply_forward(self, input, hidden_layers, norm_layers, use_batch_norm=False): h = input for layer, norm_layer in zip(hidden_layers, norm_layers): h = layer(h) if use_batch_norm: h = norm_layer(h) h = self.hidden_activation(h) return h class TwoHeadDCNN(nn.Module): def __init__( self, fc_input_size, hidden_sizes, deconv_input_width, deconv_input_height, deconv_input_channels, deconv_output_kernel_size, deconv_output_strides, deconv_output_channels, kernel_sizes, n_channels, strides, paddings, batch_norm_deconv=False, batch_norm_fc=False, init_w=1e-3, hidden_init=nn.init.xavier_uniform_, hidden_activation=nn.ReLU(), output_activation=identity, ): assert len(kernel_sizes) == \ len(n_channels) == \ len(strides) == \ len(paddings) super().__init__() self.hidden_sizes = hidden_sizes self.output_activation = output_activation self.hidden_activation = hidden_activation self.deconv_input_width = deconv_input_width self.deconv_input_height = deconv_input_height self.deconv_input_channels = deconv_input_channels deconv_input_size = self.deconv_input_channels * self.deconv_input_height * self.deconv_input_width self.batch_norm_deconv = batch_norm_deconv self.batch_norm_fc = batch_norm_fc self.deconv_layers = nn.ModuleList() self.deconv_norm_layers = nn.ModuleList() self.fc_layers = nn.ModuleList() self.fc_norm_layers = nn.ModuleList() for idx, hidden_size in enumerate(hidden_sizes): fc_layer = nn.Linear(fc_input_size, hidden_size) norm_layer = nn.BatchNorm1d(hidden_size) fc_layer.weight.data.uniform_(-init_w, init_w) fc_layer.bias.data.uniform_(-init_w, init_w) self.fc_layers.append(fc_layer) self.fc_norm_layers.append(norm_layer) fc_input_size = hidden_size self.last_fc = nn.Linear(fc_input_size, deconv_input_size) self.last_fc.weight.data.uniform_(-init_w, init_w) self.last_fc.bias.data.uniform_(-init_w, init_w) for out_channels, kernel_size, stride, padding in \ zip(n_channels, kernel_sizes, strides, paddings): deconv = nn.ConvTranspose2d(deconv_input_channels, out_channels, kernel_size, stride=stride, padding=padding) hidden_init(deconv.weight) deconv.bias.data.fill_(0) deconv_layer = deconv self.deconv_layers.append(deconv_layer) deconv_input_channels = out_channels test_mat = torch.zeros(1, self.deconv_input_channels, self.deconv_input_width, self.deconv_input_height) # initially the model is on CPU (caller should then move it to GPU if for deconv_layer in self.deconv_layers: test_mat = deconv_layer(test_mat) self.deconv_norm_layers.append(nn.BatchNorm2d(test_mat.shape[1])) self.first_deconv_output = nn.ConvTranspose2d( deconv_input_channels, deconv_output_channels, deconv_output_kernel_size, stride=deconv_output_strides, ) hidden_init(self.first_deconv_output.weight) self.first_deconv_output.bias.data.fill_(0) self.second_deconv_output = nn.ConvTranspose2d( deconv_input_channels, deconv_output_channels, deconv_output_kernel_size, stride=deconv_output_strides, ) hidden_init(self.second_deconv_output.weight) self.second_deconv_output.bias.data.fill_(0) def forward(self, input): h = self.apply_forward(input, self.fc_layers, self.fc_norm_layers, use_batch_norm=self.batch_norm_fc) h = self.hidden_activation(self.last_fc(h)) h = h.view(-1, self.deconv_input_channels, self.deconv_input_width, self.deconv_input_height) h = self.apply_forward(h, self.deconv_layers, self.deconv_norm_layers, use_batch_norm=self.batch_norm_deconv) first_output = self.output_activation(self.first_deconv_output(h)) second_output = self.output_activation(self.second_deconv_output(h)) return first_output, second_output def apply_forward(self, input, hidden_layers, norm_layers, use_batch_norm=False): h = input for layer, norm_layer in zip(hidden_layers, norm_layers): h = layer(h) if use_batch_norm: h = norm_layer(h) h = self.hidden_activation(h) return h class DCNN(TwoHeadDCNN): def forward(self, x): return super().forward(x)[0]