TripletGeoEncoder-demo / model_module.py
yeq6x's picture
init
02ba63a
raw
history blame
8.19 kB
import pytorch_lightning as pl
import torch
from torch import nn
from torch.optim import SGD
from torchvision.utils import save_image
import os
from utils import TripletLossBatch, pairwise_distance_squared, GetTransformedCoords, DistanceMapLogger
from model import Autoencoder
class AutoencoderModule(pl.LightningModule):
def __init__(self, feature_dim=64, learning_rate=0.1, lambda_c=0.97, initial_margin=1.0, initial_threshold=2.0, save_interval=100, output_dir="output_images"):
super(AutoencoderModule, self).__init__()
self.feature_dim = feature_dim
self.learning_rate = learning_rate
self.lambda_c = lambda_c
self.margin_img = initial_margin
self.margin_img_init = initial_margin
self.threshold = initial_threshold
self.model = Autoencoder(self.feature_dim)
self.criterion = nn.MSELoss()
self.triplet_loss = TripletLossBatch()
self.losses = []
self.save_interval = save_interval # バッチごとの出力間隔
self.output_dir = output_dir
os.makedirs(self.output_dir, exist_ok=True)
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
img, mat, _ = batch
batch_size, _, _, size, size = img.shape
img = img.view(batch_size*2, 3, size, size)
mat = mat.view(batch_size*2, 3, 3)
dec5_output, output = self.model(img)
mse_loss = self.criterion(output, img)
# 画像内方向の処理
num_anchor_sets = 2**12
trip_loss = 0
std_list = [2.5*1.025**self.current_epoch, 5*1.025**self.current_epoch]
for c in std_list:
std = size / c
anchors = torch.randint(0, size, (batch_size*2, num_anchor_sets, 1, 2))
coords = anchors + torch.normal(0, std, (batch_size*2, num_anchor_sets, 2, 2)).long()
valid_coords_idx = (((coords >= 0) & (coords < size)).sum(3) == 2).sum(2) != 2
coords[valid_coords_idx] = 0
anchors[valid_coords_idx] = 0
# 最も近い座標の選択
d = pairwise_distance_squared(anchors.float(), coords.float())
idx = torch.argmin(d, dim=2)
anchors, positives, negatives = self._get_triplet_coordinates(anchors, coords, idx)
# dec5_outputから特徴ベクトルを抽出
anchor_vectors, positive_vectors, negative_vectors = self._extract_feature_vectors(dec5_output, batch_size, anchors, positives, negatives)
trip_loss += self.triplet_loss(anchor_vectors, positive_vectors, negative_vectors, self.margin_img)
trip_loss /= len(std_list)
self.margin_img = self.margin_img_init + self.margin_img - trip_loss.detach()
# 変形の学習
num_samples = 2**20
tf_loss = self._compute_transformation_loss(dec5_output, mat, batch_size, size, num_samples)
# バッチ方向の処理
bat_dist_loss = self._compute_batch_direction_loss(dec5_output, batch_size, size)
# 合計損失
loss = mse_loss + trip_loss + 0.001 * bat_dist_loss + (0.001 * 1.**self.current_epoch) * tf_loss
self.log("train_loss", loss)
# VRAM管理
del img, output
torch.cuda.empty_cache()
return loss
def _get_triplet_coordinates(self, anchors, coords, idx):
anchors = anchors.squeeze(2)
positives = coords[torch.arange(coords.size(0))[:, None, None], torch.arange(coords.size(1))[None, :, None], idx[:, :, None]].squeeze(2)
negatives = coords[torch.arange(coords.size(0))[:, None, None], torch.arange(coords.size(1))[None, :, None], (1 - idx)[:, :, None]].squeeze(2)
return anchors, positives, negatives
def _extract_feature_vectors(self, dec5_output, batch_size, anchors, positives, negatives):
y_anchors = anchors[:, :, 0].unsqueeze(2).expand(-1, -1, self.feature_dim)
x_anchors = anchors[:, :, 1].unsqueeze(2).expand(-1, -1, self.feature_dim)
y_positives = positives[:, :, 0].unsqueeze(2).expand(-1, -1, self.feature_dim)
x_positives = positives[:, :, 1].unsqueeze(2).expand(-1, -1, self.feature_dim)
y_negatives = negatives[:, :, 0].unsqueeze(2).expand(-1, -1, self.feature_dim)
x_negatives = negatives[:, :, 1].unsqueeze(2).expand(-1, -1, self.feature_dim)
anchor_vectors = dec5_output[torch.arange(batch_size*2)[:, None, None], torch.arange(self.feature_dim), y_anchors, x_anchors]
positive_vectors = dec5_output[torch.arange(batch_size*2)[:, None, None], torch.arange(self.feature_dim), y_positives, x_positives]
negative_vectors = dec5_output[torch.arange(batch_size*2)[:, None, None], torch.arange(self.feature_dim), y_negatives, x_negatives]
return anchor_vectors, positive_vectors, negative_vectors
def _compute_transformation_loss(self, dec5_output, mat, batch_size, size, num_samples=2**12):
anchor_indices = torch.randint(batch_size, (num_samples, 1), device=self.device).repeat(1, 2).reshape(num_samples*2)
coords_x = torch.randint(0, size, (num_samples, 1), dtype=torch.float32, device=self.device).repeat(1, 2).reshape(num_samples*2, 1)
coords_y = torch.randint(0, size, (num_samples, 1), dtype=torch.float32, device=self.device).repeat(1, 2).reshape(num_samples*2, 1)
anchor_coords = torch.cat((coords_x, coords_y), 1)
anchor_mat = mat[anchor_indices]
tf_anchor_coords = GetTransformedCoords(anchor_mat, [size/2, size/2])(anchor_coords)
anchor_vectors = torch.zeros([num_samples*2, self.feature_dim], device=self.device)
inner_idx_flat = ((0 <= tf_anchor_coords[:,0]) & (tf_anchor_coords[:,0] < size)) & ((0 <= tf_anchor_coords[:,1]) & (tf_anchor_coords[:,1] < size))
anchor_vectors[inner_idx_flat] = dec5_output[anchor_indices[inner_idx_flat], :, tf_anchor_coords[inner_idx_flat, 0], tf_anchor_coords[inner_idx_flat, 1]]
inner_idx_and = inner_idx_flat.view(num_samples, 2).t()[0] & inner_idx_flat.view(num_samples, 2).t()[1]
anchor_vectors = anchor_vectors.view(num_samples, 2, self.feature_dim)[inner_idx_and]
return pairwise_distance_squared(anchor_vectors[:,0], anchor_vectors[:,1]).mean()
def _compute_batch_direction_loss(self, dec5_output, batch_size, size):
N = 2**12
anchor_indices = torch.randint(0, batch_size, (N,)) * 2 + torch.randint(0, 2, (N,))
anchor_coords = torch.randint(0, size, (N, 2))
other_indices = torch.randint(0, batch_size-1, (N, 2)) * 2 + torch.randint(0, 2, (N, 2))
other_indices += (other_indices >= anchor_indices.unsqueeze(1)).long() * 2
other_coords = torch.randint(0, size, (N, 2, 2))
anchor_vectors = dec5_output[anchor_indices, :, anchor_coords[:, 0], anchor_coords[:, 1]]
other_vectors = dec5_output[other_indices, :, other_coords[:, :, 0], other_coords[:, :, 1]]
distances = pairwise_distance_squared(anchor_vectors.unsqueeze(1), other_vectors)
return distances[distances < self.threshold].sum() / ((distances < self.threshold).sum() + 1e-10)
def configure_optimizers(self):
optimizer = SGD(self.parameters(), lr=self.learning_rate)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: self.lambda_c**epoch)
return [optimizer], [scheduler]
# def save_intermediate_image(self, output, epoch):
# save_image(output[:4], os.path.join(self.output_dir, f"epoch_{epoch}_output.png"), nrow=1)
# print(f"Saved intermediate image at epoch {epoch}")
# def distance_map(self, _input, feature_map, epoch, x_coords=None, y_coords=None):
# save_path = os.path.join(self.output_dir, f"epoch_{epoch}_distance_map.png")
# DistanceMapLogger()(_input, feature_map, save_path, x_coords, y_coords)
def configure_optimizers(self):
optimizer = SGD(self.parameters(), lr=self.learning_rate)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch: self.lambda_c**epoch)
return [optimizer], [scheduler]