Spaces:
Sleeping
Sleeping
from typing import Dict | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from einops import repeat | |
from models.base import CaptionMetaMixin | |
from utils.model_util import init | |
class WmlEncoderKdWrapper(nn.Module, CaptionMetaMixin): | |
def __init__(self, | |
model: nn.Module, | |
shared_dim: int, | |
tchr_layer_to_dims: Dict[str, int], | |
loss_type: str = "mse",): | |
super().__init__() | |
self.model = model | |
self.tchr_layers = list(tchr_layer_to_dims.keys()) | |
self.stdnt_qv_proj = nn.Linear(model.encoder.fc_emb_size, | |
2 * shared_dim) | |
self.stdnt_qv_proj.apply(init) | |
for layer, dim in tchr_layer_to_dims.items(): | |
self.add_module(f'tchr_kv_proj_{layer}', nn.Linear(dim, 2 * shared_dim)) | |
getattr(self, f'tchr_kv_proj_{layer}').apply(init) | |
if loss_type == "mse": | |
self.loss_fn = nn.MSELoss(reduction="none") | |
def forward(self, input_dict: Dict): | |
output_dict = self.model(input_dict) | |
if "tchr_output" in input_dict: | |
stdnt_emb = output_dict["fc_emb"] | |
stdnt_qv = self.stdnt_qv_proj(stdnt_emb) | |
stdnt_q, stdnt_v = torch.chunk(stdnt_qv, 2, dim=-1) | |
tchr_output = input_dict["tchr_output"] | |
layer_ks, layer_vs = [], [] | |
for layer in self.tchr_layers: | |
layer_kv = getattr(self, f'tchr_kv_proj_{layer}')(tchr_output[layer]) | |
layer_k, layer_v = torch.chunk(layer_kv, 2, dim=-1) | |
layer_ks.append(layer_k) | |
layer_vs.append(layer_v) | |
layer_ks = torch.stack(layer_ks, dim=1) | |
layer_vs = torch.stack(layer_vs, dim=1) | |
weights = torch.softmax(stdnt_q.unsqueeze(1) @ layer_ks.transpose(1, 2), dim=-1) | |
stdnt_v = repeat(stdnt_v, 'b d -> b n d', n=len(self.tchr_layers)) | |
loss = self.loss_fn(stdnt_v, layer_vs).mean(dim=-1, keepdim=True) | |
loss = (weights @ loss).mean() | |
output_dict["enc_kd_loss"] = loss | |
return output_dict | |
class MseEncoderKdWrapper(nn.Module, CaptionMetaMixin): | |
def __init__(self, | |
model: nn.Module, | |
shared_dim: int, | |
tchr_dim: int, | |
use_tchr_proj: bool = True, | |
l2_norm: bool = False, | |
): | |
super().__init__() | |
self.model = model | |
self.use_tchr_proj = use_tchr_proj | |
if not use_tchr_proj: | |
assert shared_dim == tchr_dim | |
self.tchr_dim = tchr_dim | |
self.l2_norm = l2_norm | |
if hasattr(model, "encoder"): | |
self.stdnt_proj = nn.Linear(model.encoder.fc_emb_size, | |
shared_dim) | |
else: | |
self.stdnt_proj = nn.Linear(model.fc_emb_size, | |
shared_dim) | |
self.stdnt_proj.apply(init) | |
if use_tchr_proj: | |
self.tchr_proj = nn.Linear(tchr_dim, shared_dim) | |
self.tchr_proj.apply(init) | |
else: | |
self.tchr_proj = nn.Identity() | |
def forward(self, input_dict: Dict): | |
unsup = input_dict.get("unsup", False) | |
if unsup is False: | |
if self.use_tchr_proj: | |
output_dict = self.model(input_dict) | |
stdnt_emb = output_dict["fc_emb"] | |
else: | |
encoder_output = self.model.encoder(input_dict) | |
stdnt_emb = encoder_output["fc_emb"] | |
encoder_output["fc_emb"] = self.stdnt_proj(encoder_output["fc_emb"]) | |
encoder_output["attn_emb"] = self.stdnt_proj(encoder_output["attn_emb"]) | |
output_dict = self.model.forward_decoder(input_dict, encoder_output) | |
else: | |
output_dict = self.model.encoder(input_dict) | |
stdnt_emb = output_dict["fc_emb"] | |
if "tchr_output" in input_dict: | |
stdnt_emb = self.stdnt_proj(stdnt_emb) | |
tchr_emb = input_dict["tchr_output"]["embedding"] | |
thcr_emb = self.tchr_proj(tchr_emb) | |
if self.l2_norm: | |
stdnt_emb = F.normalize(stdnt_emb, dim=-1) | |
thcr_emb = F.normalize(thcr_emb, dim=-1) | |
loss = F.mse_loss(stdnt_emb, thcr_emb) | |
output_dict["enc_kd_loss"] = loss | |
return output_dict | |
class ContraEncoderKdWrapper(nn.Module, CaptionMetaMixin): | |
def __init__(self, | |
model: nn.Module, | |
shared_dim: int, | |
tchr_dim: int, | |
): | |
super().__init__() | |
self.model = model | |
self.tchr_dim = tchr_dim | |
if hasattr(model, "encoder"): | |
self.stdnt_proj = nn.Linear(model.encoder.fc_emb_size, | |
shared_dim) | |
else: | |
self.stdnt_proj = nn.Linear(model.fc_emb_size, | |
shared_dim) | |
self.stdnt_proj.apply(init) | |
self.tchr_proj = nn.Linear(tchr_dim, shared_dim) | |
self.tchr_proj.apply(init) | |
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) | |
def forward(self, input_dict: Dict): | |
unsup = input_dict.get("unsup", False) | |
if unsup is False: | |
output_dict = self.model(input_dict) | |
else: | |
output_dict = self.model.encoder(input_dict) | |
if "tchr_output" in input_dict: | |
stdnt_emb = output_dict["fc_emb"] | |
stdnt_emb = self.stdnt_proj(stdnt_emb) | |
tchr_emb = input_dict["tchr_output"]["embedding"] | |
thcr_emb = self.tchr_proj(tchr_emb) | |
stdnt_emb = F.normalize(stdnt_emb, dim=-1) | |
thcr_emb = F.normalize(thcr_emb, dim=-1) | |
unscaled_logit = stdnt_emb @ thcr_emb.transpose(0, 1) | |
logit = self.logit_scale * unscaled_logit | |
label = torch.arange(logit.shape[0]).to(logit.device) | |
loss1 = F.cross_entropy(logit, label) | |
loss2 = F.cross_entropy(logit.transpose(0, 1), label) | |
loss = (loss1 + loss2) / 2 | |
output_dict["enc_kd_loss"] = loss | |
return output_dict | |
class ContraMseEncoderKdWrapper(nn.Module, CaptionMetaMixin): | |
def __init__(self, | |
model: nn.Module, | |
shared_dim: int, | |
tchr_dim: int, | |
use_tchr_proj: bool = True, | |
l2_norm: bool = False, | |
): | |
super().__init__() | |
self.model = model | |
self.use_tchr_proj = use_tchr_proj | |
if not use_tchr_proj: | |
assert shared_dim == tchr_dim | |
self.tchr_dim = tchr_dim | |
self.l2_norm = l2_norm | |
if hasattr(model, "encoder"): | |
self.stdnt_proj = nn.Linear(model.encoder.fc_emb_size, | |
shared_dim) | |
else: | |
self.stdnt_proj = nn.Linear(model.fc_emb_size, | |
shared_dim) | |
self.stdnt_proj.apply(init) | |
if use_tchr_proj: | |
self.tchr_proj = nn.Linear(tchr_dim, shared_dim) | |
self.tchr_proj.apply(init) | |
else: | |
self.tchr_proj = nn.Identity() | |
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) | |
def forward(self, input_dict: Dict): | |
unsup = input_dict.get("unsup", False) | |
if unsup is False: | |
if self.use_tchr_proj: | |
output_dict = self.model(input_dict) | |
stdnt_emb = output_dict["fc_emb"] | |
else: | |
encoder_output = self.model.encoder(input_dict) | |
stdnt_emb = encoder_output["fc_emb"] | |
encoder_output["fc_emb"] = self.stdnt_proj(encoder_output["fc_emb"]) | |
encoder_output["attn_emb"] = self.stdnt_proj(encoder_output["attn_emb"]) | |
output_dict = self.model.forward_decoder(input_dict, encoder_output) | |
else: | |
output_dict = self.model.encoder(input_dict) | |
stdnt_emb = output_dict["fc_emb"] | |
if "tchr_output" in input_dict: | |
stdnt_emb = self.stdnt_proj(stdnt_emb) | |
tchr_emb = input_dict["tchr_output"]["embedding"] | |
thcr_emb = self.tchr_proj(tchr_emb) | |
if self.l2_norm: | |
stdnt_emb = F.normalize(stdnt_emb, dim=-1) | |
thcr_emb = F.normalize(thcr_emb, dim=-1) | |
mse_loss = F.mse_loss(stdnt_emb, thcr_emb) | |
stdnt_emb = F.normalize(stdnt_emb, dim=-1) | |
thcr_emb = F.normalize(thcr_emb, dim=-1) | |
unscaled_logit = stdnt_emb @ thcr_emb.transpose(0, 1) | |
logit = self.logit_scale * unscaled_logit | |
label = torch.arange(logit.shape[0]).to(logit.device) | |
loss1 = F.cross_entropy(logit, label) | |
loss2 = F.cross_entropy(logit.transpose(0, 1), label) | |
cntr_loss = (loss1 + loss2) / 2 | |
output_dict["enc_kd_loss"] = mse_loss + cntr_loss | |
return output_dict | |