import pytorch_lightning as pl import torch from torch import nn from torch.optim import SGD from torchvision.utils import save_image import os from utils import TripletLossBatch, pairwise_distance_squared, GetTransformedCoords, DistanceMapLogger from model import Autoencoder class AutoencoderModule(pl.LightningModule): def __init__(self, feature_dim=64, learning_rate=0.1, lambda_c=0.97, initial_margin=1.0, initial_threshold=2.0, save_interval=100, output_dir="output_images"): super(AutoencoderModule, self).__init__() self.feature_dim = feature_dim self.learning_rate = learning_rate self.lambda_c = lambda_c self.margin_img = initial_margin self.margin_img_init = initial_margin self.threshold = initial_threshold self.model = Autoencoder(self.feature_dim) self.criterion = nn.MSELoss() self.triplet_loss = TripletLossBatch() self.losses = [] self.save_interval = save_interval # バッチごとの出力間隔 self.output_dir = output_dir os.makedirs(self.output_dir, exist_ok=True) def forward(self, x): return self.model(x) def training_step(self, batch, batch_idx): img, mat, _ = batch batch_size, _, _, size, size = img.shape img = img.view(batch_size*2, 3, size, size) mat = mat.view(batch_size*2, 3, 3) dec5_output, output = self.model(img) mse_loss = self.criterion(output, img) # 画像内方向の処理 num_anchor_sets = 2**12 trip_loss = 0 std_list = [2.5*1.025**self.current_epoch, 5*1.025**self.current_epoch] for c in std_list: std = size / c anchors = torch.randint(0, size, (batch_size*2, num_anchor_sets, 1, 2)) coords = anchors + torch.normal(0, std, (batch_size*2, num_anchor_sets, 2, 2)).long() valid_coords_idx = (((coords >= 0) & (coords < size)).sum(3) == 2).sum(2) != 2 coords[valid_coords_idx] = 0 anchors[valid_coords_idx] = 0 # 最も近い座標の選択 d = pairwise_distance_squared(anchors.float(), coords.float()) idx = torch.argmin(d, dim=2) anchors, positives, negatives = self._get_triplet_coordinates(anchors, coords, idx) # dec5_outputから特徴ベクトルを抽出 anchor_vectors, positive_vectors, negative_vectors = self._extract_feature_vectors(dec5_output, batch_size, anchors, positives, negatives) trip_loss += self.triplet_loss(anchor_vectors, positive_vectors, negative_vectors, self.margin_img) trip_loss /= len(std_list) self.margin_img = self.margin_img_init + self.margin_img - trip_loss.detach() # 変形の学習 num_samples = 2**20 tf_loss = self._compute_transformation_loss(dec5_output, mat, batch_size, size, num_samples) # バッチ方向の処理 bat_dist_loss = self._compute_batch_direction_loss(dec5_output, batch_size, size) # 合計損失 loss = mse_loss + trip_loss + 0.001 * bat_dist_loss + (0.001 * 1.**self.current_epoch) * tf_loss self.log("train_loss", loss) # VRAM管理 del img, output torch.cuda.empty_cache() return loss def _get_triplet_coordinates(self, anchors, coords, idx): anchors = anchors.squeeze(2) positives = coords[torch.arange(coords.size(0))[:, None, None], torch.arange(coords.size(1))[None, :, None], idx[:, :, None]].squeeze(2) negatives = coords[torch.arange(coords.size(0))[:, None, None], torch.arange(coords.size(1))[None, :, None], (1 - idx)[:, :, None]].squeeze(2) return anchors, positives, negatives def _extract_feature_vectors(self, dec5_output, batch_size, anchors, positives, negatives): y_anchors = anchors[:, :, 0].unsqueeze(2).expand(-1, -1, self.feature_dim) x_anchors = anchors[:, :, 1].unsqueeze(2).expand(-1, -1, self.feature_dim) y_positives = positives[:, :, 0].unsqueeze(2).expand(-1, -1, self.feature_dim) x_positives = positives[:, :, 1].unsqueeze(2).expand(-1, -1, self.feature_dim) y_negatives = negatives[:, :, 0].unsqueeze(2).expand(-1, -1, self.feature_dim) x_negatives = negatives[:, :, 1].unsqueeze(2).expand(-1, -1, self.feature_dim) anchor_vectors = dec5_output[torch.arange(batch_size*2)[:, None, None], torch.arange(self.feature_dim), y_anchors, x_anchors] positive_vectors = dec5_output[torch.arange(batch_size*2)[:, None, None], torch.arange(self.feature_dim), y_positives, x_positives] negative_vectors = dec5_output[torch.arange(batch_size*2)[:, None, None], torch.arange(self.feature_dim), y_negatives, x_negatives] return anchor_vectors, positive_vectors, negative_vectors def _compute_transformation_loss(self, dec5_output, mat, batch_size, size, num_samples=2**12): anchor_indices = torch.randint(batch_size, (num_samples, 1), device=self.device).repeat(1, 2).reshape(num_samples*2) coords_x = torch.randint(0, size, (num_samples, 1), dtype=torch.float32, device=self.device).repeat(1, 2).reshape(num_samples*2, 1) coords_y = torch.randint(0, size, (num_samples, 1), dtype=torch.float32, device=self.device).repeat(1, 2).reshape(num_samples*2, 1) anchor_coords = torch.cat((coords_x, coords_y), 1) anchor_mat = mat[anchor_indices] tf_anchor_coords = GetTransformedCoords(anchor_mat, [size/2, size/2])(anchor_coords) anchor_vectors = torch.zeros([num_samples*2, self.feature_dim], device=self.device) inner_idx_flat = ((0 <= tf_anchor_coords[:,0]) & (tf_anchor_coords[:,0] < size)) & ((0 <= tf_anchor_coords[:,1]) & (tf_anchor_coords[:,1] < size)) anchor_vectors[inner_idx_flat] = dec5_output[anchor_indices[inner_idx_flat], :, tf_anchor_coords[inner_idx_flat, 0], tf_anchor_coords[inner_idx_flat, 1]] inner_idx_and = inner_idx_flat.view(num_samples, 2).t()[0] & inner_idx_flat.view(num_samples, 2).t()[1] anchor_vectors = anchor_vectors.view(num_samples, 2, self.feature_dim)[inner_idx_and] return pairwise_distance_squared(anchor_vectors[:,0], anchor_vectors[:,1]).mean() def _compute_batch_direction_loss(self, dec5_output, batch_size, size): N = 2**12 anchor_indices = torch.randint(0, batch_size, (N,)) * 2 + torch.randint(0, 2, (N,)) anchor_coords = torch.randint(0, size, (N, 2)) other_indices = torch.randint(0, batch_size-1, (N, 2)) * 2 + torch.randint(0, 2, (N, 2)) other_indices += (other_indices >= anchor_indices.unsqueeze(1)).long() * 2 other_coords = torch.randint(0, size, (N, 2, 2)) anchor_vectors = dec5_output[anchor_indices, :, anchor_coords[:, 0], anchor_coords[:, 1]] other_vectors = dec5_output[other_indices, :, other_coords[:, :, 0], other_coords[:, :, 1]] distances = pairwise_distance_squared(anchor_vectors.unsqueeze(1), other_vectors) return distances[distances < self.threshold].sum() / ((distances < self.threshold).sum() + 1e-10) def configure_optimizers(self): optimizer = SGD(self.parameters(), lr=self.learning_rate) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: self.lambda_c**epoch) return [optimizer], [scheduler] # def save_intermediate_image(self, output, epoch): # save_image(output[:4], os.path.join(self.output_dir, f"epoch_{epoch}_output.png"), nrow=1) # print(f"Saved intermediate image at epoch {epoch}") # def distance_map(self, _input, feature_map, epoch, x_coords=None, y_coords=None): # save_path = os.path.join(self.output_dir, f"epoch_{epoch}_distance_map.png") # DistanceMapLogger()(_input, feature_map, save_path, x_coords, y_coords) def configure_optimizers(self): optimizer = SGD(self.parameters(), lr=self.learning_rate) scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: self.lambda_c**epoch) return [optimizer], [scheduler]