DeLoFi / models /networks /network.py
tronskel's picture
init
71026d8 verified
import torch
import numpy as np
from abc import ABC, abstractmethod
from torch import nn
from hydra.utils import instantiate
import copy
from peft import LoraConfig, get_peft_model
from utils.model_utils import print_trainable_parameters
def freeze(model):
"""Freezes the parameters of a model."""
for p in model.parameters():
p.requires_grad = False
model.eval()
def unfreeze(model):
"""Unfreezes the parameters of a model.
for p in model.parameters():
p.requires_grad = True"""
model_parameters = model.named_parameters()
for name, param in model_parameters:
if name in [
"clip.vision_model.post_layernorm.weight",
"clip.vision_model.post_layernorm.bias",
]:
param.requires_grad = False
else:
param.requires_grad = True
model.train()
def unfreeze_last(model):
"""Unfreezes the parameters of a model.
for p in model.parameters():
p.requires_grad = True"""
model_parameters = model.named_parameters()
for name, param in model_parameters:
if len(name.split(".")) > 5:
if name.split(".")[4] == "11":
param.requires_grad = True
else:
param.requires_grad = False
else:
param.requires_grad = False
model.train()
class FrozenBackbone(nn.Module):
"""Freezes the backbone of a network."""
def __init__(self, backbone, mid, head):
super().__init__()
self.backbone = backbone.instance
self.mid = mid.instance
self.head = head.instance
self.target_key = head.target_key
freeze(self.backbone)
def forward(self, x):
"""Forward pass of the network.
x : Union[torch.Tensor, dict] with the output of the backbone.
"""
with torch.no_grad():
x = self.backbone(x)
x = self.mid(x)
x = self.head(x)
return x
class UnfrozenBackbone(nn.Module):
"""Unfreezes the backbone of a network."""
def __init__(self, backbone, mid, head):
super().__init__()
self.backbone = backbone.instance
self.mid = mid.instance
self.head = head.instance
self.target_key = head.target_key
unfreeze(self.backbone)
def forward(self, x):
"""Forward pass of the network.
x : Union[torch.Tensor, dict] with the output of the backbone.
"""
x = self.backbone(x)
x = self.mid(x)
x = self.head(x)
return x
class UnfrozenPartBackbone(nn.Module):
"""Unfreezes the backbone of a network."""
def __init__(self, backbone, mid, head):
super().__init__()
self.backbone = backbone.instance
self.mid = mid.instance
self.head = head.instance
self.target_key = head.target_key
unfreeze_last(self.backbone)
def forward(self, x):
"""Forward pass of the network.
x : Union[torch.Tensor, dict] with the output of the backbone.
"""
x = self.backbone(x)
x = self.mid(x)
x = self.head(x)
return x
class NoFeatureBackbone(nn.Module):
"""Randomizes the backbone of a network."""
def __init__(self, head):
super().__init__()
self.head = head.instance
self.target_key = head.target_key
def forward(self, x):
"""Forward pass of the network.
x : Union[torch.Tensor, dict] with the output of the backbone.
"""
return self.head(x)
class ContrastiveFrozenBackbone(FrozenBackbone):
"""Freezes the backbone of a network."""
def __init__(self, backbone, mid, head, mode):
super().__init__(backbone, mid, head)
self.mode = mode
def forward(self, x):
with torch.no_grad():
features = self.backbone(x)
if self.mode != "eval":
x_pos = {
k.strip("pos_"): v.clone()
if isinstance(v, torch.Tensor)
else copy.deepcopy(v)
for k, v in x.items()
if k.startswith("pos_")
}
pos_features = self.backbone(x_pos)
x = self.mid(features)
x = self.head(x)
if self.mode != "eval":
return {
"features": features[:, 0, :],
"pos_features": pos_features[:, 0, :],
**x,
}
return {
"features": features[:, 0, :],
**x,
}
class ContrastiveUnFrozenPartBackbone(UnfrozenPartBackbone):
"""Freezes the backbone of a network."""
def __init__(self, backbone, mid, head, mode):
super().__init__(backbone, mid, head)
self.mode = mode
def forward(self, x):
features = self.backbone(x)
if self.mode != "eval":
x_pos = {
k.strip("pos_"): v.clone()
if isinstance(v, torch.Tensor)
else copy.deepcopy(v)
for k, v in x.items()
if k.startswith("pos_")
}
pos_features = self.backbone(x_pos)
x = self.mid(features)
x = self.head(x)
if self.mode != "eval":
return {
"features": features[:, 0, :],
"pos_features": pos_features[:, 0, :],
**x,
}
return {
"features": features[:, 0, :],
**x,
}
class ContrastiveUnFrozenBackbone(UnfrozenBackbone):
"""Freezes the backbone of a network."""
def __init__(self, backbone, mid, head, mode):
super().__init__(backbone, mid, head)
self.mode = mode
def forward(self, x):
features = self.backbone(x)
if self.mode != "eval":
x_pos = {
k.strip("pos_"): v.clone()
if isinstance(v, torch.Tensor)
else copy.deepcopy(v)
for k, v in x.items()
if k.startswith("pos_")
}
pos_features = self.backbone(x_pos)
x = self.mid(features)
x = self.head(x)
if self.mode != "eval":
return {
"features": features[:, 0, :],
"pos_features": pos_features[:, 0, :],
**x,
}
return {
"features": features[:, 0, :],
**x,
}
class TextContrastiveUnFrozenBackbone(UnfrozenBackbone):
"""Freezes the backbone of a network."""
def __init__(self, backbone, mid, head):
super().__init__(backbone, mid, head)
def forward(self, x):
con, features = self.backbone(x)
x = self.mid(features)
x = self.head(x)
return {
"features": con,
**x,
}
class LoraBackbone(nn.Module):
"""Wraps the backbone in a PEFT model for LoRA tuning."""
def __init__(self, backbone, mid, head, r, alpha, dropout, bias):
super().__init__()
self.backbone = backbone.instance
self.mid = mid.instance
self.head = head.instance
self.target_key = head.target_key
freeze(self.backbone)
config = LoraConfig(
r=r,
lora_alpha=alpha,
lora_dropout=dropout,
bias=bias,
target_modules=["q_proj", "k_proj", "v_proj"],
)
self.backbone = get_peft_model(self.backbone, config)
print_trainable_parameters(self)
def forward(self, x):
"""Forward pass of the network.
x : Union[torch.Tensor, dict] with the output of the backbone.
"""
x = self.backbone(x)
x = self.mid(x)
return self.head(x)
class HybridFrozenBackbone(FrozenBackbone):
"""Freezes the backbone of a network."""
def forward(self, x):
"""Forward pass of the network.
x : Union[torch.Tensor, dict] with the output of the backbone.
"""
gt_label = x["label"] if self.training else None
with torch.no_grad():
x = self.backbone(x)
x = self.mid(x)
x = self.head(x, gt_label)
return x
class HybridUnfrozenBackbone(UnfrozenBackbone):
"""Unfreezes the backbone of a network."""
def forward(self, x):
"""Forward pass of the network.
x : Union[torch.Tensor, dict] with the output of the backbone.
"""
gt_label = x["label"] if self.training else None
x = self.backbone(x)
x = self.mid(x)
x = self.head(x, gt_label)
return x
class ContrastiveHybridUnFrozenBackbone(UnfrozenBackbone):
"""Freezes the backbone of a network."""
def __init__(self, backbone, mid, head, mode):
super().__init__(backbone, mid, head)
self.mode = mode
def forward(self, x):
gt_label = x["label"] if self.training else None
features = self.backbone(x)
if self.mode != "eval":
x_pos = {
k.strip("pos_"): v.clone()
if isinstance(v, torch.Tensor)
else copy.deepcopy(v)
for k, v in x.items()
if k.startswith("pos_")
}
pos_features = self.backbone(x_pos)
x = self.mid(features)
x = self.head(x, gt_label)
if self.mode != "eval":
return {
"features": features[:, 0, :],
"pos_features": pos_features[:, 0, :],
**x,
}
return {
"features": features[:, 0, :],
**x,
}