Spaces:
Sleeping
Sleeping
import pytorch_lightning as pl | |
import torch | |
from torch import nn | |
from torch.optim import SGD | |
from torchvision.utils import save_image | |
import os | |
from utils import TripletLossBatch, pairwise_distance_squared, GetTransformedCoords, DistanceMapLogger | |
from model import Autoencoder | |
class AutoencoderModule(pl.LightningModule): | |
def __init__(self, feature_dim=64, learning_rate=0.1, lambda_c=0.97, initial_margin=1.0, initial_threshold=2.0, save_interval=100, output_dir="output_images"): | |
super(AutoencoderModule, self).__init__() | |
self.feature_dim = feature_dim | |
self.learning_rate = learning_rate | |
self.lambda_c = lambda_c | |
self.margin_img = initial_margin | |
self.margin_img_init = initial_margin | |
self.threshold = initial_threshold | |
self.model = Autoencoder(self.feature_dim) | |
self.criterion = nn.MSELoss() | |
self.triplet_loss = TripletLossBatch() | |
self.losses = [] | |
self.save_interval = save_interval # バッチごとの出力間隔 | |
self.output_dir = output_dir | |
os.makedirs(self.output_dir, exist_ok=True) | |
def forward(self, x): | |
return self.model(x) | |
def training_step(self, batch, batch_idx): | |
img, mat, _ = batch | |
batch_size, _, _, size, size = img.shape | |
img = img.view(batch_size*2, 3, size, size) | |
mat = mat.view(batch_size*2, 3, 3) | |
dec5_output, output = self.model(img) | |
mse_loss = self.criterion(output, img) | |
# 画像内方向の処理 | |
num_anchor_sets = 2**12 | |
trip_loss = 0 | |
std_list = [2.5*1.025**self.current_epoch, 5*1.025**self.current_epoch] | |
for c in std_list: | |
std = size / c | |
anchors = torch.randint(0, size, (batch_size*2, num_anchor_sets, 1, 2)) | |
coords = anchors + torch.normal(0, std, (batch_size*2, num_anchor_sets, 2, 2)).long() | |
valid_coords_idx = (((coords >= 0) & (coords < size)).sum(3) == 2).sum(2) != 2 | |
coords[valid_coords_idx] = 0 | |
anchors[valid_coords_idx] = 0 | |
# 最も近い座標の選択 | |
d = pairwise_distance_squared(anchors.float(), coords.float()) | |
idx = torch.argmin(d, dim=2) | |
anchors, positives, negatives = self._get_triplet_coordinates(anchors, coords, idx) | |
# dec5_outputから特徴ベクトルを抽出 | |
anchor_vectors, positive_vectors, negative_vectors = self._extract_feature_vectors(dec5_output, batch_size, anchors, positives, negatives) | |
trip_loss += self.triplet_loss(anchor_vectors, positive_vectors, negative_vectors, self.margin_img) | |
trip_loss /= len(std_list) | |
self.margin_img = self.margin_img_init + self.margin_img - trip_loss.detach() | |
# 変形の学習 | |
num_samples = 2**20 | |
tf_loss = self._compute_transformation_loss(dec5_output, mat, batch_size, size, num_samples) | |
# バッチ方向の処理 | |
bat_dist_loss = self._compute_batch_direction_loss(dec5_output, batch_size, size) | |
# 合計損失 | |
loss = mse_loss + trip_loss + 0.001 * bat_dist_loss + (0.001 * 1.**self.current_epoch) * tf_loss | |
self.log("train_loss", loss) | |
# VRAM管理 | |
del img, output | |
torch.cuda.empty_cache() | |
return loss | |
def _get_triplet_coordinates(self, anchors, coords, idx): | |
anchors = anchors.squeeze(2) | |
positives = coords[torch.arange(coords.size(0))[:, None, None], torch.arange(coords.size(1))[None, :, None], idx[:, :, None]].squeeze(2) | |
negatives = coords[torch.arange(coords.size(0))[:, None, None], torch.arange(coords.size(1))[None, :, None], (1 - idx)[:, :, None]].squeeze(2) | |
return anchors, positives, negatives | |
def _extract_feature_vectors(self, dec5_output, batch_size, anchors, positives, negatives): | |
y_anchors = anchors[:, :, 0].unsqueeze(2).expand(-1, -1, self.feature_dim) | |
x_anchors = anchors[:, :, 1].unsqueeze(2).expand(-1, -1, self.feature_dim) | |
y_positives = positives[:, :, 0].unsqueeze(2).expand(-1, -1, self.feature_dim) | |
x_positives = positives[:, :, 1].unsqueeze(2).expand(-1, -1, self.feature_dim) | |
y_negatives = negatives[:, :, 0].unsqueeze(2).expand(-1, -1, self.feature_dim) | |
x_negatives = negatives[:, :, 1].unsqueeze(2).expand(-1, -1, self.feature_dim) | |
anchor_vectors = dec5_output[torch.arange(batch_size*2)[:, None, None], torch.arange(self.feature_dim), y_anchors, x_anchors] | |
positive_vectors = dec5_output[torch.arange(batch_size*2)[:, None, None], torch.arange(self.feature_dim), y_positives, x_positives] | |
negative_vectors = dec5_output[torch.arange(batch_size*2)[:, None, None], torch.arange(self.feature_dim), y_negatives, x_negatives] | |
return anchor_vectors, positive_vectors, negative_vectors | |
def _compute_transformation_loss(self, dec5_output, mat, batch_size, size, num_samples=2**12): | |
anchor_indices = torch.randint(batch_size, (num_samples, 1), device=self.device).repeat(1, 2).reshape(num_samples*2) | |
coords_x = torch.randint(0, size, (num_samples, 1), dtype=torch.float32, device=self.device).repeat(1, 2).reshape(num_samples*2, 1) | |
coords_y = torch.randint(0, size, (num_samples, 1), dtype=torch.float32, device=self.device).repeat(1, 2).reshape(num_samples*2, 1) | |
anchor_coords = torch.cat((coords_x, coords_y), 1) | |
anchor_mat = mat[anchor_indices] | |
tf_anchor_coords = GetTransformedCoords(anchor_mat, [size/2, size/2])(anchor_coords) | |
anchor_vectors = torch.zeros([num_samples*2, self.feature_dim], device=self.device) | |
inner_idx_flat = ((0 <= tf_anchor_coords[:,0]) & (tf_anchor_coords[:,0] < size)) & ((0 <= tf_anchor_coords[:,1]) & (tf_anchor_coords[:,1] < size)) | |
anchor_vectors[inner_idx_flat] = dec5_output[anchor_indices[inner_idx_flat], :, tf_anchor_coords[inner_idx_flat, 0], tf_anchor_coords[inner_idx_flat, 1]] | |
inner_idx_and = inner_idx_flat.view(num_samples, 2).t()[0] & inner_idx_flat.view(num_samples, 2).t()[1] | |
anchor_vectors = anchor_vectors.view(num_samples, 2, self.feature_dim)[inner_idx_and] | |
return pairwise_distance_squared(anchor_vectors[:,0], anchor_vectors[:,1]).mean() | |
def _compute_batch_direction_loss(self, dec5_output, batch_size, size): | |
N = 2**12 | |
anchor_indices = torch.randint(0, batch_size, (N,)) * 2 + torch.randint(0, 2, (N,)) | |
anchor_coords = torch.randint(0, size, (N, 2)) | |
other_indices = torch.randint(0, batch_size-1, (N, 2)) * 2 + torch.randint(0, 2, (N, 2)) | |
other_indices += (other_indices >= anchor_indices.unsqueeze(1)).long() * 2 | |
other_coords = torch.randint(0, size, (N, 2, 2)) | |
anchor_vectors = dec5_output[anchor_indices, :, anchor_coords[:, 0], anchor_coords[:, 1]] | |
other_vectors = dec5_output[other_indices, :, other_coords[:, :, 0], other_coords[:, :, 1]] | |
distances = pairwise_distance_squared(anchor_vectors.unsqueeze(1), other_vectors) | |
return distances[distances < self.threshold].sum() / ((distances < self.threshold).sum() + 1e-10) | |
def configure_optimizers(self): | |
optimizer = SGD(self.parameters(), lr=self.learning_rate) | |
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: self.lambda_c**epoch) | |
return [optimizer], [scheduler] | |
# def save_intermediate_image(self, output, epoch): | |
# save_image(output[:4], os.path.join(self.output_dir, f"epoch_{epoch}_output.png"), nrow=1) | |
# print(f"Saved intermediate image at epoch {epoch}") | |
# def distance_map(self, _input, feature_map, epoch, x_coords=None, y_coords=None): | |
# save_path = os.path.join(self.output_dir, f"epoch_{epoch}_distance_map.png") | |
# DistanceMapLogger()(_input, feature_map, save_path, x_coords, y_coords) | |
def configure_optimizers(self): | |
optimizer = SGD(self.parameters(), lr=self.learning_rate) | |
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: self.lambda_c**epoch) | |
return [optimizer], [scheduler] |