HMMC_t2v_search / modules /modeling.py
cheetah003's picture
first commit
29c5a57
raw
history blame
26 kB
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import logging
import torch
from torch import nn
import numpy as np
import torch.nn.functional as F
from transformers import AutoConfig, AutoModel, BertTokenizer
from modules.tokenization_clip import SimpleTokenizer as ClipTokenizer
from modules.until_module import PreTrainedModel, AllGather, CrossEn, Dual_CrossEn
from modules.module_cross import TextEncoder, VisualEncoder, CrossConfig, BertLMPredictionHead
logger = logging.getLogger(__name__)
allgather = AllGather.apply
class CLIP4ClipPreTrainedModel(PreTrainedModel, nn.Module):
""" An abstract class to handle weights initialization and
a simple interface for dowloading and loading pretrained models.
"""
def __init__(self, cross_config, *inputs, **kwargs):
super(CLIP4ClipPreTrainedModel, self).__init__(cross_config)
self.cross_config = cross_config
@classmethod
def from_pretrained(cls, cross_model_name, state_dict=None, cache_dir=None, type_vocab_size=2, *inputs, **kwargs):
task_config = None
if "task_config" in kwargs.keys():
task_config = kwargs["task_config"]
if not hasattr(task_config, "local_rank"):
task_config.__dict__["local_rank"] = 0
elif task_config.local_rank == -1:
task_config.local_rank = 0
cross_config, _ = CrossConfig.get_config(cross_model_name, cache_dir, type_vocab_size, state_dict=None,
task_config=task_config)
model = cls(cross_config, *inputs, **kwargs)
if state_dict is not None:
model = cls.init_preweight(model, state_dict, task_config=task_config)
return model
def show_log(task_config, info):
if task_config is None or task_config.local_rank == 0:
logger.warning(info)
def update_attr(target_name, target_config, target_attr_name, source_config, source_attr_name, default_value=None):
if hasattr(source_config, source_attr_name):
if default_value is None or getattr(source_config, source_attr_name) != default_value:
setattr(target_config, target_attr_name, getattr(source_config, source_attr_name))
show_log(source_config, "Set {}.{}: {}.".format(target_name,
target_attr_name, getattr(target_config, target_attr_name)))
return target_config
def check_attr(target_name, task_config):
return hasattr(task_config, target_name) and task_config.__dict__[target_name]
class BirdPreTrainedModel(CLIP4ClipPreTrainedModel):
def __init__(self, cross_config, task_config):
super(BirdPreTrainedModel, self).__init__(cross_config)
self.task_config = task_config
self.rank = task_config.local_rank
self.mlm_probability = cross_config.mlm_probability
self.top_frames = task_config.top_frames
# self.weight_sum = torch.nn.Parameter(torch.tensor([0.5], dtype=torch.float32), requires_grad=True)
self.weight_FAM = cross_config.weight_FAM
self.weight_VTM = cross_config.weight_VTM
self.weight_FTM = cross_config.weight_FTM
self.weight_MLM = cross_config.weight_MLM
self.contrast_momentum = task_config.contrast_momentum
self.contrast_temperature = task_config.contrast_temperature
self.contrast_num_negative = task_config.contrast_num_negative
################## chinese text Encoder
if self.task_config.language == "chinese":
self.tokenizer = BertTokenizer.from_pretrained(self.task_config.pretrained_text)
else:
self.tokenizer = ClipTokenizer()
if self.rank == 0:
logger.info("voacb_size:{}".format(self.tokenizer.vocab_size))
t_config = AutoConfig.from_pretrained(self.task_config.pretrained_text)
self.text_encoder = TextEncoder(self.task_config, cross_config)
self.text_encoder_k = TextEncoder(self.task_config, cross_config)
self.t_projector = MLP(num_layers=cross_config.proj_num_layers)
self.t_projector_k = MLP(num_layers=cross_config.proj_num_layers)
nn.SyncBatchNorm.convert_sync_batchnorm(self.t_projector)
nn.SyncBatchNorm.convert_sync_batchnorm(self.t_projector_k)
# for MLM
t_config.hidden_size = cross_config.temporal_hidden_size
t_config.vocab_size = self.tokenizer.vocab_size
self.cls = BertLMPredictionHead(t_config)
################## visual_encoder
self.visual_encoder = VisualEncoder(self.task_config, cross_config)
self.visual_encoder_k = VisualEncoder(self.task_config, cross_config)
self.v_projector = MLP(num_layers=cross_config.proj_num_layers)
self.v_projector_k = MLP(num_layers=cross_config.proj_num_layers)
self.v_predictor = MLP(num_layers=cross_config.pred_num_layers)
nn.SyncBatchNorm.convert_sync_batchnorm(self.v_projector)
nn.SyncBatchNorm.convert_sync_batchnorm(self.v_projector_k)
nn.SyncBatchNorm.convert_sync_batchnorm(self.v_predictor)
################# momemtun mdoel pairs
self.model_pairs = [[self.visual_encoder, self.visual_encoder_k],
[self.text_encoder, self.text_encoder_k],
[self.v_projector, self.v_projector_k],
[self.t_projector, self.t_projector_k],
]
self.copy_params()
################## create queue
self.register_buffer("queue_v_cross_ng", torch.randn(cross_config.temporal_hidden_size, self.contrast_num_negative))
self.register_buffer("queue_frame_proj_ng", torch.randn(cross_config.temporal_hidden_size,
self.contrast_num_negative * self.task_config.max_frames))
self.register_buffer("queue_frame_cross_ng", torch.randn(cross_config.temporal_hidden_size,
self.contrast_num_negative * self.task_config.max_frames))
self.register_buffer("queue_title_cross_ng", torch.randn(cross_config.temporal_hidden_size, self.contrast_num_negative))
self.register_buffer("queue_tag_cross_ng", torch.randn(cross_config.temporal_hidden_size, self.contrast_num_negative))
self.queue_v_cross_ng = F.normalize(self.queue_v_cross_ng, dim=0)
self.queue_frame_proj_ng = F.normalize(self.queue_frame_proj_ng, dim=0)
self.queue_frame_cross_ng = F.normalize(self.queue_frame_cross_ng, dim=0)
self.queue_title_cross_ng = F.normalize(self.queue_title_cross_ng, dim=0)
self.queue_tag_cross_ng = F.normalize(self.queue_tag_cross_ng, dim=0)
self.register_buffer("queue_ptr", torch.zeros(1, dtype=torch.long))
################## loss function
self.loss_fct = CrossEn()
self.loss_fct_dual = Dual_CrossEn()
# self.apply(self.init_weights)
def get_mlm_loss(self, input_ids, input_mask):
to_mask_input_ids = input_ids.clone()
input_labels = to_mask_input_ids.clone()
input_probability_matrix = torch.full(input_labels.shape, self.mlm_probability)
masked_input_ids, input_labels = self.mask(to_mask_input_ids, self.tokenizer.vocab_size,
input_mask.device, targets=input_labels,
probability_matrix=input_probability_matrix)
masked_input_output = self.text_encoder(masked_input_ids, input_mask, return_hidden=True)
mlm_input_loss = self.calculate_mlm_loss(masked_input_output, input_labels)
return mlm_input_loss
def calculate_mlm_loss(self, sequence_output_mlm, labels):
mlm_scores = self.cls(sequence_output_mlm)
# logger.info("sequence_output_mlm.shape:{}".format(sequence_output_mlm.shape))
# logger.info("mlm_scores.shape:{}".format(mlm_scores.shape))
# logger.info("labels.shape:{}".format(labels.shape))
mlm_loss = F.cross_entropy(mlm_scores.view(-1, self.tokenizer.vocab_size),
labels.view(-1), ignore_index=-100)
return mlm_loss
def mask(self, input_ids, vocab_size, device, targets=None, masked_indices=None, probability_matrix=None):
if masked_indices is None:
masked_indices = torch.bernoulli(probability_matrix).bool()
masked_indices[input_ids == self.tokenizer.pad_token_id] = False
masked_indices[input_ids == self.tokenizer.cls_token_id] = False
# logger.info("masked_indices:{}".format(masked_indices))
# logger.info("masked_indices.shape:{}".format(masked_indices.shape))
if targets is not None:
targets[~masked_indices] = -100 # We only compute loss on masked tokens
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices
input_ids[indices_replaced] = self.tokenizer.mask_token_id
# 10% of the time, we replace masked input tokens with random word
indices_random = torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(vocab_size, input_ids.shape, dtype=torch.long).to(device)
input_ids[indices_random] = random_words[indices_random]
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
if targets is not None:
return input_ids, targets
else:
return input_ids
def loose_similarity(self, sequence_output, visual_output):
sequence_output, visual_output = sequence_output.contiguous(), visual_output.contiguous()
visual_output = visual_output.squeeze()
visual_output = visual_output / visual_output.norm(dim=-1, keepdim=True)
sequence_output = sequence_output.squeeze()
sequence_output = sequence_output / sequence_output.norm(dim=-1, keepdim=True)
logit_scale = self.text_encoder.logit_scale.exp()
logit_scale.data = torch.clamp(logit_scale.data, max=100)
# if self.rank == 0:
# logger.info("logit_scale:{},dtype:{}".format(logit_scale, logit_scale.dtype))
# logger.info("sequence_output.shape:{}".format(sequence_output.shape))
# logger.info("visual_output.shape:{}".format(visual_output.shape))
if len(visual_output.shape) == 2:
retrieve_logits = logit_scale * torch.matmul(sequence_output, visual_output.t())
else:
visual_temp = visual_output.permute(0, 2, 1)
retrieve_logits = logit_scale * torch.matmul(sequence_output, visual_temp)
if len(retrieve_logits.shape) == 3:
retrieve_logits = retrieve_logits.permute(1, 0, 2)
return retrieve_logits
@torch.no_grad()
def copy_params(self):
for model_pair in self.model_pairs:
for param, param_k in zip(model_pair[0].parameters(), model_pair[1].parameters()):
param_k.data.copy_(param.data) # initialize
param_k.requires_grad = False # not update by gradient
@torch.no_grad()
def _momentum_update(self):
for model_pair in self.model_pairs:
for param, param_k in zip(model_pair[0].parameters(), model_pair[1].parameters()):
param_k.data = param_k.data * self.contrast_momentum + param.data * (1. - self.contrast_momentum)
@torch.no_grad()
def _dequeue_and_enqueue(self, v_fea_k, tag_fea_k, title_fea_k, frame_fea_k, frame_proj_k):
# gather keys before updating queue
# [bs,hidden]
v_fea_k = F.normalize(v_fea_k, dim=1)
tag_fea_k = F.normalize(tag_fea_k, dim=1)
title_fea_k = F.normalize(title_fea_k, dim=1)
# [bs,frame,hidden]
frame_fea_k = F.normalize(frame_fea_k, dim=2)
frame_proj_k = F.normalize(frame_proj_k, dim=2)
batch_size = v_fea_k.size(0)
frame_num = frame_fea_k.size(1)
frame_fea_k = frame_fea_k.view(-1, frame_fea_k.size(-1))
frame_proj_k = frame_proj_k.view(-1, frame_proj_k.size(-1))
ptr = int(self.queue_ptr)
# if self.rank == 0:
# logger.info(
# "begin>>>>: ptr:{},batch_size:{},frame_num:{},queue_size:{}".format(ptr, batch_size, frame_num, self.contrast_num_negative))
# logger.info("v1_self_k.shape:{},tag_cross_k.shape:{},frame_proj_k.shape:{}".format(v_fea_k.shape, tag_fea_k.shape, frame_proj_k.shape))
# replace the keys at ptr (dequeue and enqueue)
self.queue_v_cross_ng[:, ptr:ptr + batch_size] = v_fea_k.T
self.queue_tag_cross_ng[:, ptr:ptr + batch_size] = tag_fea_k.T
self.queue_title_cross_ng[:, ptr:ptr + batch_size] = title_fea_k.T
self.queue_frame_proj_ng[:, ptr * frame_num:(ptr + batch_size) * frame_num] = frame_proj_k.T
self.queue_frame_cross_ng[:, ptr * frame_num:(ptr + batch_size) * frame_num] = frame_fea_k.T
# move pointer
ptr = (ptr + batch_size) % self.contrast_num_negative
# if self.rank == 0:
# logger.info("end>>>>: ptr:{}".format(ptr))
self.queue_ptr[0] = ptr
def contrastive_loss(self, q, k, queue):
q = q.squeeze()
q = F.normalize(q, dim=1)
k = k.squeeze()
k = F.normalize(k, dim=1)
bs = q.size(0)
# logger.info("q.dtype:{},k.dtype:{}".format(q.dtype, k.dtype))
# positive logits: Nx1
# >>>>>>got error in apex:amp level=01!!!!!!!!!
# l_pos = torch.einsum('nc,nc->n', [q, k]).unsqueeze(-1)
l_pos = torch.matmul(q, k.T)
l_pos = torch.diag(l_pos).reshape([bs, -1])
# negative logits: NxK
# l_neg = torch.einsum('nc,ck->nk', [q, queue.clone().detach()])
l_neg = torch.matmul(q, queue.clone().detach())
# logits: Nx(1+K)
logits = torch.cat([l_pos, l_neg], dim=1)
# if self.rank == 0:
# logger.info("logits.shape:{}".format(logits.shape))
# apply temperature
logits /= self.contrast_temperature
# labels: positive key indicators
labels = torch.zeros(logits.shape[0], dtype=torch.long).cuda()
return F.cross_entropy(logits, labels)
def frame_self_loss(self, frame_fea, frame_fea_k, queue_frame_ng):
loss = 0.
for i in range(frame_fea.size(1) - 1):
frame_loss = self.contrastive_loss(frame_fea[:, i, :], frame_fea_k[:, i+1, :], queue_frame_ng) \
+ self.contrastive_loss(frame_fea[:, i+1, :], frame_fea_k[:, i, :], queue_frame_ng)
loss += frame_loss
loss = loss / (frame_fea.size(1) - 1)
return loss
def frame_cross_loss(self, frame_fea, frame_fea_k, queue_frame_ng, text_fea, text_fea_k, queue_text_ng):
loss = 0.
for i in range(frame_fea.size(1)):
frame_loss = self.contrastive_loss(text_fea, frame_fea_k[:, i, :], queue_frame_ng) + \
self.contrastive_loss(frame_fea[:, i, :], text_fea_k, queue_text_ng)
loss += frame_loss
loss = loss / frame_fea.size(1)
return loss
def forward(self, video_data, video_frame, tag_ids, tag_mask, title_ids, title_mask, global_step):
tag_ids = tag_ids.view(-1, tag_ids.shape[-1])
tag_mask = tag_mask.view(-1, tag_mask.shape[-1])
title_ids = title_ids.view(-1, title_ids.shape[-1])
title_mask = title_mask.view(-1, title_mask.shape[-1])
# bs x frames x 3 x H x W
video = torch.as_tensor(video_data)
if self.rank == 0 and global_step % self.task_config.n_display == 0:
logger.info("video1.shape:{}, dtype:{}, device:{}".format(video.shape, video.dtype, video.device))
if self.training:
# loss = 0.0
v_fea, frame_fea = self.visual_encoder(video, video_frame)
if self.task_config.dataset == "bird":
tag_fea = self.text_encoder(tag_ids, tag_mask)
title_fea = self.text_encoder(title_ids, title_mask)
# for video self supervised learning
# [bs,hidden_size]
bs, frame, hidden = frame_fea.shape
frame_fea = frame_fea.view(-1, hidden)
frame_proj = self.v_projector(frame_fea)
frame_pred = self.v_predictor(frame_proj)
frame_fea = frame_fea.view(bs, frame, hidden)
frame_proj = frame_proj.view(bs, frame, hidden)
frame_pred = frame_pred.view(bs, frame, hidden)
if self.rank == 0 and global_step % self.task_config.n_display == 0:
logger.info("v_fea.shape:{},device:{}".format(v_fea.shape, v_fea.device))
logger.info("frame_fea.shape:{},device:{}".format(frame_fea.shape, frame_fea.device))
logger.info("frame_proj.shape:{},device:{}".format(frame_proj.shape, frame_proj.device))
logger.info("title_fea.shape:{}".format(title_fea.shape))
logger.info("queue_v_cross_ng.shape:{}".format(self.queue_v_cross_ng.shape))
# compute key features
with torch.no_grad(): # no gradient to keys
self._momentum_update() # update the key encoder
tag_fea_k = self.text_encoder_k(tag_ids, tag_mask)
title_fea_k = self.text_encoder_k(title_ids, title_mask)
#
v_fea_k, frame_fea_k = self.visual_encoder_k(video, video_frame)
frame_fea_k = frame_fea_k.view(-1, hidden)
frame_proj_k = self.v_projector_k(frame_fea_k)
frame_fea_k = frame_fea_k.view(bs, frame, hidden)
frame_proj_k = frame_proj_k.view(bs, frame, hidden)
# compute loss
if self.rank == 0 and global_step % self.task_config.n_display == 0:
logger.info(
"dtype: v_fea:{},v_fea_k:{},title_fea:{}".format(v_fea.dtype, v_fea_k.dtype, title_fea.dtype))
# single video modality: video queue loss
loss_FAM = self.frame_self_loss(frame_pred, frame_proj_k, self.queue_frame_proj_ng)
# cross modality: cross queue loss
v_title_queue_loss = self.contrastive_loss(v_fea, title_fea_k, self.queue_title_cross_ng) \
+ self.contrastive_loss(title_fea, v_fea_k, self.queue_v_cross_ng)
if self.task_config.dataset == "bird":
v_tag_queue_loss = self.contrastive_loss(v_fea, tag_fea_k, self.queue_tag_cross_ng) \
+ self.contrastive_loss(tag_fea, v_fea_k, self.queue_v_cross_ng)
loss_VTM = (v_tag_queue_loss + v_title_queue_loss) / 2
else:
loss_VTM = v_title_queue_loss
loss_FTM = 0.
if self.task_config.use_frame_fea:
frame_title_loss = self.frame_cross_loss(frame_fea, frame_fea_k, self.queue_frame_cross_ng, title_fea,
title_fea_k, self.queue_title_cross_ng)
if self.task_config.dataset == "bird":
frame_tag_loss = self.frame_cross_loss(frame_fea, frame_fea_k, self.queue_frame_cross_ng, tag_fea,
tag_fea_k, self.queue_tag_cross_ng)
loss_FTM += (frame_tag_loss + frame_title_loss) / 2
else:
loss_FTM = frame_title_loss
# single text modality: text queue loss
# t_queue_loss = self.contrastive_loss(title_fea, tag_fea_k, self.queue_tag_cross_ng) \
# + self.contrastive_loss(tag_fea, title_fea_k, self.queue_v_cross_ng)
# dequeue_and_enqueue
self._dequeue_and_enqueue(v_fea_k, tag_fea_k, title_fea_k, frame_fea_k, frame_proj_k)
# mlm loss
mlm_title_loss = self.get_mlm_loss(title_ids, title_mask)
if self.task_config.dataset == "bird":
mlm_tag_loss = self.get_mlm_loss(tag_ids, tag_mask)
loss_MLM = (mlm_tag_loss + mlm_title_loss) / 2
else:
loss_MLM = mlm_title_loss
# total loss
loss = self.weight_FAM * loss_FAM + self.weight_VTM * loss_VTM + self.weight_FTM * loss_FTM + self.weight_MLM * loss_MLM
if self.rank == 0:
if global_step % self.task_config.n_display == 0:
logger.info("loss:{},loss_FAM:{},loss_VTM:{},loss_FTM:{},loss_MLM:{}"
"".format(loss, loss_FAM, loss_VTM, loss_FTM, loss_MLM))
if self.task_config.logdir:
loss_item = {"loss": float(loss), "loss_FAM": float(loss_FAM), "loss_VTM": float(loss_VTM),
"loss_FTM": float(loss_FTM), "loss_MLM": float(loss_MLM)}
self.task_config.writer.add_scalars('loss', loss_item, global_step=global_step)
# self.task_config.writer.add_scalar('loss', video_cross_loss, global_step=global_step)
return loss
else:
return None
class BirdModel(BirdPreTrainedModel):
def __init__(self, cross_config, task_config):
super(BirdPreTrainedModel, self).__init__(cross_config)
self.task_config = task_config
self.rank = task_config.local_rank
# self.weight_sim = torch.nn.Parameter(torch.tensor([0.9], dtype=torch.float32), requires_grad=True)
self.weight_VTM_finetune = cross_config.weight_VTM_finetune
self.weight_FTM_finetune = cross_config.weight_FTM_finetune
self.top_frames = task_config.top_frames
################## text Encoder
self.text_encoder = TextEncoder(self.task_config, cross_config)
################## visual_encoder
self.visual_encoder = VisualEncoder(self.task_config, cross_config)
################## loss function
self.loss_fct = CrossEn()
self.loss_fct_dual = Dual_CrossEn()
def frame_loss(self, query_output, frame_output):
frame_num = frame_output.size(1)
loss = 0.
for i in range(frame_num):
frame_single = frame_output[:, i, :].squeeze()
sim_matrix = self.loose_similarity(query_output, frame_single)
sim_loss = self.loss_fct(sim_matrix) + self.loss_fct(sim_matrix.T)
loss += sim_loss / frame_num
# logger.info("frame_output.shape:{},dtype:{}".format(frame_output.shape, frame_output.dtype))
# logger.info("query_output.shape:{},dtype:{}".format(query_output.shape, frame_output.dtype))
# sim_matrix = self.loose_similarity(query_output, frame_output)
# sim_matrix = torch.topk(sim_matrix, k=self.top_frames, dim=2)[0]
# sim_matrix = torch.mean(sim_matrix, dim=2)
# sim_loss = self.loss_fct(sim_matrix) + self.loss_fct(sim_matrix.T)
# loss += sim_loss
return loss
def forward(self, query_ids, query_mask, video_data, video_frame, idx, global_step):
query_ids = query_ids.view(-1, query_ids.shape[-1])
query_mask = query_mask.view(-1, query_mask.shape[-1])
# T x 3 x H x W
video = torch.as_tensor(video_data)
# if self.rank == 0:
# logger.info("video.shape:{}, dtype:{}".format(video.shape, video.dtype))
if self.training:
loss = 0.0
query_output = self.text_encoder(query_ids, query_mask)
visual_output, frame_output = self.visual_encoder(video, video_frame)
# if self.rank == 0:
# logger.info("query_output.shape:{},dtype:{}".format(query_output.shape, query_output.dtype))
# logger.info("visual_output.shape:{},dtype:{}".format(visual_output.shape, visual_output.dtype))
# logger.info("frame_output.shape:{},dtype:{}".format(frame_output.shape, frame_output.dtype))
# frame loss
if self.task_config.use_frame_fea:
frame_loss = self.frame_loss(query_output, frame_output)
loss += self.weight_FTM_finetune * frame_loss
# video loss
sim_matrix = self.loose_similarity(query_output, visual_output)
sim_loss = self.loss_fct(sim_matrix) + self.loss_fct(sim_matrix.T)
loss += self.weight_VTM_finetune * sim_loss
# loss += sim_loss
if self.task_config.local_rank == 0:
if global_step % self.task_config.n_display == 0:
logger.info(
"loss:{},frame_loss:{},sim_loss:{},type:{},sim_matrix.shape:{}".format(loss, loss - sim_loss,
sim_loss, sim_loss.dtype, sim_matrix.shape))
if self.task_config.logdir:
self.task_config.writer.add_scalar('loss', float(loss), global_step=global_step)
return loss
else:
return None
class MLP(nn.Module):
def __init__(self, in_dim=512, inner_dim=4096, out_dim=512, num_layers=2):
super(MLP, self).__init__()
# hidden layers
linear_hidden = [nn.Identity()]
for i in range(num_layers - 1):
linear_hidden.append(nn.Linear(in_dim if i == 0 else inner_dim, inner_dim))
linear_hidden.append(nn.BatchNorm1d(inner_dim))
linear_hidden.append(nn.ReLU(inplace=True))
self.linear_hidden = nn.Sequential(*linear_hidden)
self.linear_out = nn.Linear(in_dim if num_layers == 1 else inner_dim,
out_dim) if num_layers >= 1 else nn.Identity()
def forward(self, x):
x = self.linear_hidden(x)
x = self.linear_out(x)
return x