Spaces:
Runtime error
Runtime error
File size: 25,950 Bytes
d1b8c9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import copy
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn, Tensor
from torchmultimodal.models.albef.image_encoder import ALBEFVisionEncoder
from torchmultimodal.models.albef.model import ALBEFModel, ALBEFModelWithSimilarity
from torchmultimodal.models.albef.multimodal_encoder import ALBEFMultimodalEncoder
from torchmultimodal.modules.encoders.bert_text_encoder import bert_text_encoder
from torchmultimodal.modules.layers.text_embedding import BERTTextEmbeddings
from torchmultimodal.modules.losses.albef import (
CausalLanguageModelingLoss,
ImageTextContrastiveLoss,
)
from torchmultimodal.utils.attention import get_causal_attention_mask
from torchmultimodal.utils.common import momentum_update, remove_grad
_ALBEF_PRETRAINED_URLS = {
"vqa": "https://download.pytorch.org/models/multimodal/albef/pretrained_vqa_checkpoint.pt",
"retrieval": "https://download.pytorch.org/models/multimodal/albef/pretrained_retrieval_checkpoint.pt",
}
class PredictionHead(nn.Module):
"""
Predict the following token autoregressively.
Args:
vocab_size (int): The number of different tokens the prediction_head can predict.
hidden_size (int): The hidden size of the prediction_head.
layer_norm_eps (float): The epsilon used by the prediction_head normalization layer.
transform_act_fn (Callable[[Tensor], Tensor]): The activation function in the prediction_head.
Inputs:
hidden_states (Tensor): The hidden states of preceding tokens.
Returns:
Tensor: Prediction scores for the following token.
"""
def __init__(
self,
vocab_size: int = 30522,
hidden_size: int = 768,
layer_norm_eps: float = 1e-12,
transform_act_fn: Callable[[Tensor], Tensor] = nn.functional.gelu,
) -> None:
super().__init__()
self.dense = nn.Linear(hidden_size, hidden_size)
self.transform_act_fn = transform_act_fn
self.layer_norm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
self.decoder = nn.Linear(hidden_size, vocab_size)
def forward(self, hidden_states: Tensor) -> Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.layer_norm(hidden_states)
hidden_states = self.decoder(hidden_states)
return hidden_states
class ALBEFDecoder(nn.Module):
"""
Generate the prediction scores for answers from image and question hidden states.
Args:
text_embeddings (ALBEFTextEmbeddings): Instantiated ALBEFTextEmbeddings.
multimodal_encoder (ALBEFMultimodalEncoder): Instantiated ALBEFMultimodalEncoder.
prediction_head (PredictionHead): Instantiated PredictionHead.
Inputs:
input_ids (Tensor of shape (batch_size, seq_len)):
Input ids for input text tokens.
attention_mask (Tensor of shape (batch_size, seq_len)):
Input attention mask to avoid performing attention on padding token indices.
encoder_hidden_states (Tensor of shape (batch_size, encoder_seq_len, hidden_size)):
The encoder hidden states.
encoder_attention_mask (Tensor of shape (batch_size, encoder_seq_len)):
The attention mask for encoder hidden states.
Returns:
Tensor: Prediction scores for answers.
"""
def __init__(
self,
text_embeddings: BERTTextEmbeddings,
multimodal_encoder: ALBEFMultimodalEncoder,
prediction_head: PredictionHead,
) -> None:
super().__init__()
self.text_embeddings = text_embeddings
self.multimodal_encoder = multimodal_encoder
self.prediction_head = prediction_head
def get_extended_attention_mask_for_decoder(self, attention_mask: Tensor) -> Tensor:
"""
Apply a causal mask in addition to the padding mask and make the mask broadcastable,
such that future and masked tokens are ignored.
Args:
attention_mask (Tensor):
Padding mask with ones indicating tokens to attend to, zeros for tokens to ignore.
Returns:
extended_attention_mask (Tensor):
The broadcastable attention mask, with the same dtype as ``attention_mask.dtype``.
"""
device = attention_mask.device
batch_size, seq_length = attention_mask.shape
causal_mask = get_causal_attention_mask(seq_length).to(device)
causal_mask = causal_mask.repeat(batch_size, 1).view(
batch_size, seq_length, seq_length
)
extended_attention_mask = (
causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
)
extended_attention_mask = extended_attention_mask.to(dtype=attention_mask.dtype)
return extended_attention_mask
def forward(
self,
input_ids: Tensor,
attention_mask: Tensor,
encoder_hidden_states: Tensor,
encoder_attention_mask: Tensor,
) -> Tensor:
hidden_states = self.text_embeddings(input_ids)
attention_mask = self.get_extended_attention_mask_for_decoder(attention_mask)
decoder_output = self.multimodal_encoder(
hidden_states=hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
)
prediction_scores = self.prediction_head(decoder_output)
return prediction_scores
class ALBEFModelForVQA(nn.Module):
"""
ALBEF Model for VQA finetuning and inference.
Args:
model (ALBEFModel): Instantiated ALBEFModel.
answer_decoder (ALBEFDecoder): Instantiated ALBEFDecoder.
loss (CausalLanguageModelingLoss): Instantiated CausalLanguageModelingLoss.
Inputs:
image (Tensor of shape (B, C, H, W)): Image features.
question (Tensor of shape (B, L)): Question text features.
question_atts (Tensor of shape (B, L)): Question attention mask.
answers (Tensor of shape (N, M)): Answer text features.
answers_atts (Tensor of shape (N, M)): Answer attention mask.
ans_weights (Optional[Tensor] of shape (N)): Weights for each answer.
Required if is_train is True.
ans_lengths (Optional[List[int]] of length B): Number of answers for each question.
ans_lengths should sum to N.
Required if is_train is True.
alpha (Optional[float]): The interpolation value between clm_loss and loss_distill.
Required if is_train is True.
k (Optional[int]): The number of answers to return for inference.
Required if is_train is False.
is_train (Optional[bool]): Whether the model is in training.
Returns:
is_train is True:
Tensor: The masked language modeling loss for input.
is_train is False:
Tuple[Tensor, Tensor]: The ids and probabilities for the top k predicted answers.
"""
def __init__(
self,
model: ALBEFModel,
answer_decoder: ALBEFDecoder,
loss: CausalLanguageModelingLoss,
) -> None:
super().__init__()
self.model = model
self.answer_decoder = answer_decoder
self.loss = loss
self.answer_decoder_m = copy.deepcopy(self.answer_decoder)
remove_grad(
self.answer_decoder_m
) # remove gradient for the momentum decoder model
def _train_forward(
self,
image: Tensor,
question: Tensor,
question_atts: Tensor,
answers: Tensor,
answers_atts: Tensor,
ans_weights: Tensor,
ans_lengths: List[int],
alpha: float,
) -> Tensor:
"""
Forward step for training. Encode the inputs with the ALBEFModel.
Generate pseudo-targets using answer_decoder_m (momentum decoder model).
Generate answer predictions using answer_decoder.
Compute masked language modeling loss of the predictions using answers as labels,
pseudo-targets as soft-labels, and alpha as their interpolation value.
Inputs:
image (Tensor of shape (B, C, H, W)): Image features.
question (Tensor of shape (B, L)): Question text features.
question_atts (Tensor of shape (B, L)): Question attention mask.
answers (Tensor of shape (N, M)): Answer text features.
answers_atts (Tensor of shape (N, M)): Answer attention mask.
ans_weights (Tensor of shape (N)): Weights for each answer.
ans_lengths (List[int] of length B): Number of answers for each question.
ans_lengths should sum to N.
alpha (float): The interpolation value between clm_loss and loss_distill.
Returns:
Tensor: The masked language modeling loss for input.
"""
# get image-question embeddings from the ALBEFModel and format it to match the ans_lengths
encoder_outputs = self.model(image, question, question_atts)
(
encoder_hidden_states,
encoder_hidden_states_m,
encoder_attention_mask,
) = self._encoder_hidden_states(
encoder_outputs.multimodal_embeddings,
encoder_outputs.multimodal_embeddings_m,
question_atts,
ans_lengths,
)
# use the momentum model to generate pseudo-targets
with torch.no_grad():
momentum_update(
self.answer_decoder, self.answer_decoder_m, self.model.momentum
)
prediction_scores_m = self.answer_decoder_m(
input_ids=answers,
attention_mask=answers_atts,
encoder_hidden_states=encoder_hidden_states_m,
encoder_attention_mask=encoder_attention_mask,
)
# generate answer predictions
prediction_scores = self.answer_decoder(
input_ids=answers,
attention_mask=answers_atts,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
)
# compute masked language modeling loss from the prediction scores
labels = answers.masked_fill(answers == 0, self.loss.mask_token_id)
loss = self.loss(labels, prediction_scores, prediction_scores_m, alpha)
loss = ans_weights * loss
loss = loss.sum() / image.size(0)
return loss
def _eval_forward(
self,
image: Tensor,
question: Tensor,
question_atts: Tensor,
answers: Tensor,
answer_atts: Tensor,
k: int = 128,
) -> Tuple[Tensor, Tensor]:
"""
Forward step for evaluation. Encode the inputs with the ALBEFModel.
Generate answer autoregressively using the decoder, starting with the [CLS] token.
Compute the answer ids and their perspective probabilities of the top k predictions.
Inputs:
image (Tensor of shape (B, C, H, W)): Image features.
question (Tensor of shape (B, L)): Question text features.
question_atts (Tensor of shape (B, L)): Question attention mask.
answers (Tensor of shape (N, M)): Answer text features.
answer_atts (Tensor of shape (N, M)): Answer attention mask.
k (int): The number of answers to return for inference.
Returns:
Tuple[Tensor, Tensor]: The ids and probabilities for the top k predicted answers.
"""
# get multimodal embeddings from the ALBEFModel and
# feed it to the decoder as cross attention
encoder_outputs = self.model(image, question, question_atts)
# use cls token as the decoder's initial input token
num_ques = question.size(0)
start_ids = answers[0, 0].repeat(num_ques, 1)
atts = torch.ones(start_ids.shape).to(image.device)
# auto-regressively generates the answer
prediction_scores = self.answer_decoder(
input_ids=start_ids,
attention_mask=atts,
encoder_hidden_states=encoder_outputs.multimodal_embeddings,
encoder_attention_mask=question_atts,
)
logits = prediction_scores[:, 0, :]
answer_first_token = answers[:, 1]
prob_first_token = F.softmax(logits, dim=1).index_select(
dim=1, index=answer_first_token
)
topk_probs, topk_ids = prob_first_token.topk(k, dim=1)
input_ids = []
input_atts = []
for topk_id in topk_ids:
input_ids.append(answers.index_select(dim=0, index=topk_id))
input_atts.append(answer_atts.index_select(dim=0, index=topk_id))
input_ids = torch.cat(input_ids)
input_atts = torch.cat(input_atts)
targets_ids = input_ids.masked_fill(input_ids == 0, self.loss.mask_token_id)
question_states = encoder_outputs.multimodal_embeddings.repeat_interleave(
k, dim=0
)
question_atts = question_atts.repeat_interleave(k, dim=0)
prediction_scores = self.answer_decoder(
input_ids=input_ids,
attention_mask=input_atts,
encoder_hidden_states=question_states,
encoder_attention_mask=question_atts,
)
answer_loss = self.loss(targets_ids, prediction_scores)
answer_loss = answer_loss.view(input_ids.size(0), -1)
# topk_prob: first token probability
topk_probs = topk_probs.view(-1, 1)
log_probs = torch.cat([topk_probs.log(), -answer_loss], dim=1)
# re-calculate log probabilities for the answer sequences using chain rule
log_probs_sum = log_probs.sum(1)
log_probs_sum = log_probs_sum.view(num_ques, k)
topk_probs = F.softmax(log_probs_sum, dim=-1)
# get top-k after re-ranking
topk_probs, rerank_id = topk_probs.topk(k, dim=1)
topk_ids = torch.gather(topk_ids, 1, rerank_id)
return topk_ids, topk_probs
def _encoder_hidden_states(
self,
multimodal_embeds: Tensor,
multimodal_embeds_m: Tensor,
question_atts: Tensor,
ans_lengths: List[int],
) -> Tuple[Tensor, Tensor, Tensor]:
"""
Repeat each image-question input, repeat its embedding and mask to match the number of answers it has.
Args:
multimodal_embeds (Tensor): Image-question embeddings.
multimodal_embeds_m (Tensor): Image-question embeddings from the momentum model.
question_atts (Tensor): Question attention mask.
ans_lengths (List[int]): The number of answers each image-question input has.
Returns:
encoder_hidden_states (Tensor): Image-question embeddings after the repetition.
encoder_hidden_states_m (Tensor): Image-question embeddings from the momentum model after the repetition.
encoder_attention_mask (Tensor): Question attention mask after the repetition.
"""
encoder_hidden_states = []
encoder_attention_mask = []
for b, n in enumerate(ans_lengths):
encoder_hidden_states += [multimodal_embeds[b]] * n
encoder_attention_mask += [question_atts[b]] * n
encoder_hidden_states = torch.stack(encoder_hidden_states)
encoder_attention_mask = torch.stack(encoder_attention_mask)
with torch.no_grad():
encoder_hidden_states_m = []
for b, n in enumerate(ans_lengths):
encoder_hidden_states_m += [multimodal_embeds_m[b]] * n
encoder_hidden_states_m = torch.stack(encoder_hidden_states_m)
return encoder_hidden_states, encoder_hidden_states_m, encoder_attention_mask
def forward(
self,
image: Tensor,
question: Tensor,
question_atts: Tensor,
answers: Tensor,
answers_atts: Tensor,
ans_weights: Optional[Tensor] = None,
ans_lengths: Optional[List[int]] = None,
alpha: Optional[float] = 0.0,
k: Optional[int] = 128,
is_train: Optional[bool] = True,
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
if is_train:
return self._train_forward(
image,
question,
question_atts,
answers,
answers_atts,
ans_weights,
ans_lengths,
alpha,
)
else:
return self._eval_forward(
image,
question,
question_atts,
answers,
answers_atts,
k,
)
class ALBEFModelForRetrieval(nn.Module):
"""
ALBEF Model for Retrieval finetuning and inference.
In training mode, the forward step computes image-text contrastive loss and
image-text matching loss.
In evaluation mode, the forward step takes 3 types of input:
image: encode image input, project and normalize the embeddings.
text: encode text input, project and normalize the embeddings.
multimodal: create multimodal embeddings from image and text
embeddings, and compute image-text matching scores.
Args:
model_with_similarity (ALBEFModelWithSimilarity): Instantiated ALBEFModelWithSimilarity.
itc_loss (ImageTextContrastiveLoss): Instantiated ImageTextContrastiveLoss.
hidden_size (int): Dimensionality of encoder outputs.
Inputs:
image (Optional[Tensor] of shape (B, C, H, W)): Image features.
Required if is_train is True.
Required if input_type is "image" or "multimodal".
text (Optional[Tensor] of shape (B, L)): Text features.
Required if is_train is True.
Required if input_type is "text" or "multimodal".
text_atts (Tensor of shape (B, L)): Text attention mask.
Required if is_train is True.
Required if input_type is "text" or "multimodal".
idx (Tensor of shape (B)): Identifier for each image sample.
Required if is_train is True.
alpha (Optional[float]): The interpolation value between clm_loss and loss_distill.
Default is 0.
input_type (Optional[str]): "image", "text", or "multimodal" indicating the encoding type.
Required if is_train is False.
is_train (Optional[bool]): Whether the model is in training.
Default is True.
Returns:
is_train is True:
Tensor: The sum of itc loss and itm loss.
is_train is False:
input_type is "image":
Tuple[Tensor, Tensor]: Image embeddings and projected image features.
input_type is "text":
Tuple[Tensor, Tensor]: Text embeddings and projected text features.
input_type is "multimodal"
Tensor: Scores for the retrieval task.
"""
def __init__(
self,
model_with_similarity: ALBEFModelWithSimilarity,
itc_loss: ImageTextContrastiveLoss,
hidden_size: int,
) -> None:
super().__init__()
self.model_with_similarity = model_with_similarity
self.itc_loss = itc_loss
self.itm_head = nn.Linear(hidden_size, 2)
def _train_forward(
self,
image: Tensor,
text: Tensor,
text_atts: Tensor,
idx: Tensor,
alpha: float,
) -> Tensor:
encoder_output = self.model_with_similarity(image, text, text_atts, idx)
# compute image-text contrastive loss
similarity_outputs = encoder_output.similarity
similarity_targets = encoder_output.sim_targets
itc_loss = self.itc_loss(
similarity_outputs.sim_i2t,
similarity_outputs.sim_t2i,
similarity_outputs.sim_i2t_m,
similarity_outputs.sim_t2i_m,
similarity_targets,
alpha,
)
# compute image-text matching loss
pos_embeddings = encoder_output.multimodal_embeddings[:, 0, :]
neg_embeddings = encoder_output.multimodal_embeddings_neg[:, 0, :]
vl_embeddings = torch.cat([pos_embeddings, neg_embeddings], dim=0)
vl_output = self.itm_head(vl_embeddings)
itm_labels = torch.cat(
[
torch.ones(pos_embeddings.size(0), dtype=torch.long),
torch.zeros(neg_embeddings.size(0), dtype=torch.long),
],
dim=0,
).to(vl_embeddings.device)
itm_loss = F.cross_entropy(vl_output, itm_labels)
loss = itc_loss + itm_loss
return loss
def _encode_image(
self,
image: Tensor,
) -> Tuple[Tensor, Tensor]:
image_embed = self.model_with_similarity.albef_model.vision_encoder(image)
image_feat = F.normalize(
self.model_with_similarity.vision_proj(image_embed[:, 0, :]), dim=-1
)
return image_embed, image_feat
def _encode_text(
self,
text: Tensor,
text_atts: Tensor,
) -> Tuple[Tensor, Tensor]:
text_embed = self.model_with_similarity.albef_model.text_encoder(
text, text_atts
).last_hidden_state
text_feat = F.normalize(
self.model_with_similarity.text_proj(text_embed[:, 0, :]), dim=-1
)
return text_embed, text_feat
def _image_text_matching_score(
self,
image: Tensor,
text: Tensor,
text_atts: Tensor,
) -> Tensor:
multimodal_embeds = self.model_with_similarity.albef_model.multimodal_encoder(
text,
text_atts,
image,
)
score = self.itm_head(multimodal_embeds[:, 0, :])[:, 1]
return score
def _eval_forward(
self,
input_type: str,
image: Optional[Tensor],
text: Optional[Tensor],
text_atts: Optional[Tensor],
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
if input_type == "image":
assert image is not None, "image input tensor cannot be None"
return self._encode_image(image)
elif input_type == "text":
assert (
text is not None and text_atts is not None
), "text and text attention mask cannot be None"
return self._encode_text(text, text_atts)
elif input_type == "multimodal":
assert (
image is not None and text is not None and text_atts is not None
), "image embeddings, text embeddings, and text attention mask cannot be None"
return self._image_text_matching_score(image, text, text_atts)
else:
raise ValueError("input_type must be image, text, or multimodal")
def forward(
self,
image: Optional[Tensor] = None,
text: Optional[Tensor] = None,
text_atts: Optional[Tensor] = None,
idx: Optional[Tensor] = None,
alpha: Optional[Tensor] = 0.0,
input_type: Optional[str] = None,
is_train: Optional[bool] = True,
) -> Union[Tensor, Tuple[Tensor, Tensor]]:
if is_train:
return self._train_forward(
image,
text,
text_atts,
idx,
alpha,
)
else:
return self._eval_forward(
input_type,
image,
text,
text_atts,
)
def albef_model_for_vqa(
config: Dict[str, Any], pretrained: bool = False
) -> ALBEFModelForVQA:
vision_encoder = ALBEFVisionEncoder(**config["vision_encoder_args"])
text_encoder = bert_text_encoder(**config["text_encoder_args"])
question_multimodal_encoder = ALBEFMultimodalEncoder(
**config["multimodal_encoder_args"]
)
text_embeddings = BERTTextEmbeddings(**config["text_embeddings_args"])
answer_multimodal_encoder = ALBEFMultimodalEncoder(
**config["multimodal_encoder_args"]
)
prediction_head = PredictionHead(**config["prediction_head_args"])
albef_model = ALBEFModel(vision_encoder, text_encoder, question_multimodal_encoder)
decoder = ALBEFDecoder(text_embeddings, answer_multimodal_encoder, prediction_head)
loss = CausalLanguageModelingLoss()
model = ALBEFModelForVQA(albef_model, decoder, loss)
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
_ALBEF_PRETRAINED_URLS["vqa"], map_location="cpu"
)
model.load_state_dict(checkpoint)
return model
def albef_model_for_retrieval(
config: Dict[str, Any], pretrained: bool = False
) -> ALBEFModelForRetrieval:
vision_encoder = ALBEFVisionEncoder(**config["vision_encoder_args"])
text_encoder = bert_text_encoder(**config["text_encoder_args"])
multimodal_encoder = ALBEFMultimodalEncoder(**config["multimodal_encoder_args"])
vision_proj = nn.Linear(**config["projection_args"])
text_proj = nn.Linear(**config["projection_args"])
albef_model = ALBEFModel(vision_encoder, text_encoder, multimodal_encoder)
albef_model_with_sim = ALBEFModelWithSimilarity(
albef_model, vision_proj, text_proj, **config["similarity_args"]
)
itc_loss = ImageTextContrastiveLoss()
model = ALBEFModelForRetrieval(
albef_model_with_sim, itc_loss, config["hidden_size"]
)
if pretrained:
checkpoint = torch.hub.load_state_dict_from_url(
_ALBEF_PRETRAINED_URLS["retrieval"], map_location="cpu"
)
model.load_state_dict(checkpoint)
return model
|