|
import torch
|
|
import numpy as np
|
|
from abc import ABC, abstractmethod
|
|
from torch import nn
|
|
from hydra.utils import instantiate
|
|
import copy
|
|
from peft import LoraConfig, get_peft_model
|
|
from utils.model_utils import print_trainable_parameters
|
|
|
|
|
|
def freeze(model):
|
|
"""Freezes the parameters of a model."""
|
|
for p in model.parameters():
|
|
p.requires_grad = False
|
|
model.eval()
|
|
|
|
|
|
def unfreeze(model):
|
|
"""Unfreezes the parameters of a model.
|
|
for p in model.parameters():
|
|
p.requires_grad = True"""
|
|
model_parameters = model.named_parameters()
|
|
for name, param in model_parameters:
|
|
if name in [
|
|
"clip.vision_model.post_layernorm.weight",
|
|
"clip.vision_model.post_layernorm.bias",
|
|
]:
|
|
param.requires_grad = False
|
|
else:
|
|
param.requires_grad = True
|
|
model.train()
|
|
|
|
|
|
def unfreeze_last(model):
|
|
"""Unfreezes the parameters of a model.
|
|
for p in model.parameters():
|
|
p.requires_grad = True"""
|
|
model_parameters = model.named_parameters()
|
|
for name, param in model_parameters:
|
|
if len(name.split(".")) > 5:
|
|
if name.split(".")[4] == "11":
|
|
param.requires_grad = True
|
|
else:
|
|
param.requires_grad = False
|
|
else:
|
|
param.requires_grad = False
|
|
model.train()
|
|
|
|
|
|
class FrozenBackbone(nn.Module):
|
|
"""Freezes the backbone of a network."""
|
|
|
|
def __init__(self, backbone, mid, head):
|
|
super().__init__()
|
|
self.backbone = backbone.instance
|
|
self.mid = mid.instance
|
|
self.head = head.instance
|
|
self.target_key = head.target_key
|
|
freeze(self.backbone)
|
|
|
|
def forward(self, x):
|
|
"""Forward pass of the network.
|
|
x : Union[torch.Tensor, dict] with the output of the backbone.
|
|
"""
|
|
with torch.no_grad():
|
|
x = self.backbone(x)
|
|
x = self.mid(x)
|
|
x = self.head(x)
|
|
return x
|
|
|
|
|
|
class UnfrozenBackbone(nn.Module):
|
|
"""Unfreezes the backbone of a network."""
|
|
|
|
def __init__(self, backbone, mid, head):
|
|
super().__init__()
|
|
self.backbone = backbone.instance
|
|
self.mid = mid.instance
|
|
self.head = head.instance
|
|
self.target_key = head.target_key
|
|
unfreeze(self.backbone)
|
|
|
|
def forward(self, x):
|
|
"""Forward pass of the network.
|
|
x : Union[torch.Tensor, dict] with the output of the backbone.
|
|
"""
|
|
x = self.backbone(x)
|
|
x = self.mid(x)
|
|
x = self.head(x)
|
|
return x
|
|
|
|
|
|
class UnfrozenPartBackbone(nn.Module):
|
|
"""Unfreezes the backbone of a network."""
|
|
|
|
def __init__(self, backbone, mid, head):
|
|
super().__init__()
|
|
self.backbone = backbone.instance
|
|
self.mid = mid.instance
|
|
self.head = head.instance
|
|
self.target_key = head.target_key
|
|
unfreeze_last(self.backbone)
|
|
|
|
def forward(self, x):
|
|
"""Forward pass of the network.
|
|
x : Union[torch.Tensor, dict] with the output of the backbone.
|
|
"""
|
|
x = self.backbone(x)
|
|
x = self.mid(x)
|
|
x = self.head(x)
|
|
return x
|
|
|
|
|
|
class NoFeatureBackbone(nn.Module):
|
|
"""Randomizes the backbone of a network."""
|
|
|
|
def __init__(self, head):
|
|
super().__init__()
|
|
self.head = head.instance
|
|
self.target_key = head.target_key
|
|
|
|
def forward(self, x):
|
|
"""Forward pass of the network.
|
|
x : Union[torch.Tensor, dict] with the output of the backbone.
|
|
"""
|
|
return self.head(x)
|
|
|
|
|
|
class ContrastiveFrozenBackbone(FrozenBackbone):
|
|
"""Freezes the backbone of a network."""
|
|
|
|
def __init__(self, backbone, mid, head, mode):
|
|
super().__init__(backbone, mid, head)
|
|
self.mode = mode
|
|
|
|
def forward(self, x):
|
|
with torch.no_grad():
|
|
features = self.backbone(x)
|
|
if self.mode != "eval":
|
|
x_pos = {
|
|
k.strip("pos_"): v.clone()
|
|
if isinstance(v, torch.Tensor)
|
|
else copy.deepcopy(v)
|
|
for k, v in x.items()
|
|
if k.startswith("pos_")
|
|
}
|
|
pos_features = self.backbone(x_pos)
|
|
x = self.mid(features)
|
|
x = self.head(x)
|
|
if self.mode != "eval":
|
|
return {
|
|
"features": features[:, 0, :],
|
|
"pos_features": pos_features[:, 0, :],
|
|
**x,
|
|
}
|
|
return {
|
|
"features": features[:, 0, :],
|
|
**x,
|
|
}
|
|
|
|
|
|
class ContrastiveUnFrozenPartBackbone(UnfrozenPartBackbone):
|
|
"""Freezes the backbone of a network."""
|
|
|
|
def __init__(self, backbone, mid, head, mode):
|
|
super().__init__(backbone, mid, head)
|
|
self.mode = mode
|
|
|
|
def forward(self, x):
|
|
features = self.backbone(x)
|
|
if self.mode != "eval":
|
|
x_pos = {
|
|
k.strip("pos_"): v.clone()
|
|
if isinstance(v, torch.Tensor)
|
|
else copy.deepcopy(v)
|
|
for k, v in x.items()
|
|
if k.startswith("pos_")
|
|
}
|
|
pos_features = self.backbone(x_pos)
|
|
x = self.mid(features)
|
|
x = self.head(x)
|
|
if self.mode != "eval":
|
|
return {
|
|
"features": features[:, 0, :],
|
|
"pos_features": pos_features[:, 0, :],
|
|
**x,
|
|
}
|
|
return {
|
|
"features": features[:, 0, :],
|
|
**x,
|
|
}
|
|
|
|
|
|
class ContrastiveUnFrozenBackbone(UnfrozenBackbone):
|
|
"""Freezes the backbone of a network."""
|
|
|
|
def __init__(self, backbone, mid, head, mode):
|
|
super().__init__(backbone, mid, head)
|
|
self.mode = mode
|
|
|
|
def forward(self, x):
|
|
features = self.backbone(x)
|
|
if self.mode != "eval":
|
|
x_pos = {
|
|
k.strip("pos_"): v.clone()
|
|
if isinstance(v, torch.Tensor)
|
|
else copy.deepcopy(v)
|
|
for k, v in x.items()
|
|
if k.startswith("pos_")
|
|
}
|
|
pos_features = self.backbone(x_pos)
|
|
x = self.mid(features)
|
|
x = self.head(x)
|
|
if self.mode != "eval":
|
|
return {
|
|
"features": features[:, 0, :],
|
|
"pos_features": pos_features[:, 0, :],
|
|
**x,
|
|
}
|
|
return {
|
|
"features": features[:, 0, :],
|
|
**x,
|
|
}
|
|
|
|
|
|
class TextContrastiveUnFrozenBackbone(UnfrozenBackbone):
|
|
"""Freezes the backbone of a network."""
|
|
|
|
def __init__(self, backbone, mid, head):
|
|
super().__init__(backbone, mid, head)
|
|
|
|
def forward(self, x):
|
|
con, features = self.backbone(x)
|
|
x = self.mid(features)
|
|
x = self.head(x)
|
|
return {
|
|
"features": con,
|
|
**x,
|
|
}
|
|
|
|
|
|
class LoraBackbone(nn.Module):
|
|
"""Wraps the backbone in a PEFT model for LoRA tuning."""
|
|
|
|
def __init__(self, backbone, mid, head, r, alpha, dropout, bias):
|
|
super().__init__()
|
|
self.backbone = backbone.instance
|
|
self.mid = mid.instance
|
|
self.head = head.instance
|
|
self.target_key = head.target_key
|
|
freeze(self.backbone)
|
|
|
|
config = LoraConfig(
|
|
r=r,
|
|
lora_alpha=alpha,
|
|
lora_dropout=dropout,
|
|
bias=bias,
|
|
target_modules=["q_proj", "k_proj", "v_proj"],
|
|
)
|
|
self.backbone = get_peft_model(self.backbone, config)
|
|
print_trainable_parameters(self)
|
|
|
|
def forward(self, x):
|
|
"""Forward pass of the network.
|
|
x : Union[torch.Tensor, dict] with the output of the backbone.
|
|
"""
|
|
x = self.backbone(x)
|
|
x = self.mid(x)
|
|
return self.head(x)
|
|
|
|
|
|
class HybridFrozenBackbone(FrozenBackbone):
|
|
"""Freezes the backbone of a network."""
|
|
|
|
def forward(self, x):
|
|
"""Forward pass of the network.
|
|
x : Union[torch.Tensor, dict] with the output of the backbone.
|
|
"""
|
|
|
|
gt_label = x["label"] if self.training else None
|
|
|
|
with torch.no_grad():
|
|
x = self.backbone(x)
|
|
x = self.mid(x)
|
|
x = self.head(x, gt_label)
|
|
return x
|
|
|
|
|
|
class HybridUnfrozenBackbone(UnfrozenBackbone):
|
|
"""Unfreezes the backbone of a network."""
|
|
|
|
def forward(self, x):
|
|
"""Forward pass of the network.
|
|
x : Union[torch.Tensor, dict] with the output of the backbone.
|
|
"""
|
|
|
|
gt_label = x["label"] if self.training else None
|
|
|
|
x = self.backbone(x)
|
|
x = self.mid(x)
|
|
x = self.head(x, gt_label)
|
|
return x
|
|
|
|
|
|
class ContrastiveHybridUnFrozenBackbone(UnfrozenBackbone):
|
|
"""Freezes the backbone of a network."""
|
|
|
|
def __init__(self, backbone, mid, head, mode):
|
|
super().__init__(backbone, mid, head)
|
|
self.mode = mode
|
|
|
|
def forward(self, x):
|
|
gt_label = x["label"] if self.training else None
|
|
features = self.backbone(x)
|
|
if self.mode != "eval":
|
|
x_pos = {
|
|
k.strip("pos_"): v.clone()
|
|
if isinstance(v, torch.Tensor)
|
|
else copy.deepcopy(v)
|
|
for k, v in x.items()
|
|
if k.startswith("pos_")
|
|
}
|
|
pos_features = self.backbone(x_pos)
|
|
x = self.mid(features)
|
|
x = self.head(x, gt_label)
|
|
if self.mode != "eval":
|
|
return {
|
|
"features": features[:, 0, :],
|
|
"pos_features": pos_features[:, 0, :],
|
|
**x,
|
|
}
|
|
return {
|
|
"features": features[:, 0, :],
|
|
**x,
|
|
}
|
|
|