|
import os
|
|
from typing import Any
|
|
import pytorch_lightning as L
|
|
import torch
|
|
import torch.nn as nn
|
|
from hydra.utils import instantiate
|
|
import copy
|
|
import pandas as pd
|
|
import numpy as np
|
|
|
|
|
|
class Geolocalizer(L.LightningModule):
|
|
def __init__(self, cfg):
|
|
super().__init__()
|
|
self.cfg = cfg
|
|
self.model = instantiate(cfg.network.instance)
|
|
if cfg.text_tuning:
|
|
self.text_model = instantiate(cfg.text_network.instance)
|
|
self.loss = instantiate(cfg.loss)
|
|
self.val_metrics = instantiate(cfg.val_metrics)
|
|
self.test_metrics = instantiate(cfg.test_metrics)
|
|
self.text_tuning = cfg.text_tuning
|
|
|
|
def training_step(self, batch, batch_idx):
|
|
pred = self.model(batch)
|
|
if self.text_tuning:
|
|
pred["text_features"] = self.text_model(batch)
|
|
loss = self.loss(pred, batch, average=True)
|
|
for metric_name, metric_value in loss.items():
|
|
self.log(
|
|
f"train/{metric_name}",
|
|
metric_value,
|
|
sync_dist=True,
|
|
on_step=True,
|
|
on_epoch=True,
|
|
)
|
|
return loss
|
|
|
|
@torch.no_grad()
|
|
def validation_step(self, batch, batch_idx):
|
|
pred = self.model(batch)
|
|
if self.text_tuning:
|
|
pred["text_features"] = self.text_model(batch)
|
|
loss = self.loss(pred, batch, average=True)["loss"]
|
|
self.val_metrics.update(pred, batch)
|
|
self.log("val/loss", loss, sync_dist=True, on_step=False, on_epoch=True)
|
|
|
|
def on_validation_epoch_end(self):
|
|
metrics = self.val_metrics.compute()
|
|
for metric_name, metric_value in metrics.items():
|
|
self.log(
|
|
f"val/{metric_name}",
|
|
metric_value,
|
|
sync_dist=True,
|
|
on_step=False,
|
|
on_epoch=True,
|
|
)
|
|
|
|
@torch.no_grad()
|
|
def test_step(self, batch, batch_idx):
|
|
pred = self.model(batch)
|
|
self.test_metrics.update(pred, batch)
|
|
|
|
def on_test_epoch_end(self):
|
|
metrics = self.test_metrics.compute()
|
|
for metric_name, metric_value in metrics.items():
|
|
self.log(
|
|
f"test/{metric_name}",
|
|
metric_value,
|
|
sync_dist=True,
|
|
on_step=False,
|
|
on_epoch=True,
|
|
)
|
|
|
|
def configure_optimizers(self):
|
|
lora_params = []
|
|
backbone_params = []
|
|
other_params = []
|
|
last_block_params = []
|
|
for name, param in self.model.named_parameters():
|
|
if "lora" in name:
|
|
lora_params.append(param)
|
|
elif "backbone" in name:
|
|
if self.cfg.optimizer.diff_backbone_last and ".11." in name:
|
|
last_block_params.append(param)
|
|
else:
|
|
backbone_params.append(param)
|
|
else:
|
|
other_params.append(param)
|
|
|
|
params_to_optimize = [{"params": other_params}]
|
|
if self.cfg.optimizer.unfreeze_lr:
|
|
params_to_optimize += [
|
|
{"params": backbone_params, "lr": self.cfg.optimizer.backbone_lr}
|
|
]
|
|
if self.cfg.optimizer.diff_backbone_last:
|
|
params_to_optimize += [
|
|
{
|
|
"params": last_block_params,
|
|
"lr": self.cfg.optimizer.last_block_lr,
|
|
}
|
|
]
|
|
if len(lora_params) > 0:
|
|
|
|
params_to_optimize += [
|
|
{"params": lora_params, "lr": self.cfg.optimizer.lora_lr}
|
|
]
|
|
if self.cfg.optimizer.exclude_ln_and_biases_from_weight_decay:
|
|
parameters_names_wd = get_parameter_names(self.model, [nn.LayerNorm])
|
|
parameters_names_wd = [
|
|
name for name in parameters_names_wd if "bias" not in name
|
|
]
|
|
optimizer_grouped_parameters = [
|
|
{
|
|
"params": [
|
|
p
|
|
for n, p in self.model.named_parameters()
|
|
if n in parameters_names_wd
|
|
],
|
|
"weight_decay": self.cfg.optimizer.optim.weight_decay,
|
|
},
|
|
{
|
|
"params": [
|
|
p
|
|
for n, p in self.model.named_parameters()
|
|
if n not in parameters_names_wd
|
|
],
|
|
"weight_decay": 0.0,
|
|
},
|
|
]
|
|
optimizer = instantiate(
|
|
self.cfg.optimizer.optim, optimizer_grouped_parameters
|
|
)
|
|
else:
|
|
optimizer = instantiate(self.cfg.optimizer.optim, params_to_optimize)
|
|
scheduler = instantiate(self.cfg.lr_scheduler)(optimizer)
|
|
return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
|
|
|
|
def lr_scheduler_step(self, scheduler, metric):
|
|
scheduler.step(self.global_step)
|
|
|
|
|
|
def get_parameter_names(model, forbidden_layer_types):
|
|
"""
|
|
Returns the names of the model parameters that are not inside a forbidden layer.
|
|
Taken from HuggingFace transformers.
|
|
"""
|
|
result = []
|
|
for name, child in model.named_children():
|
|
result += [
|
|
f"{name}.{n}"
|
|
for n in get_parameter_names(child, forbidden_layer_types)
|
|
if not isinstance(child, tuple(forbidden_layer_types))
|
|
]
|
|
|
|
result += list(model._parameters.keys())
|
|
return result
|
|
|