File size: 5,817 Bytes
1cb4998 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
# ------------------------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License (MIT). See LICENSE in the repo root for license information.
# ------------------------------------------------------------------------------------------
from typing import Any, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import nn
from torch import Tensor as T
from transformers import BertForMaskedLM
from transformers.modeling_outputs import ModelOutput
from .configuration_cxrbert import CXRBertConfig
BERTTupleOutput = Tuple[T, T, T, T, T]
class CXRBertOutput(ModelOutput):
last_hidden_state: torch.FloatTensor
logits: torch.FloatTensor
cls_projected_embedding: Optional[torch.FloatTensor] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
class BertProjectionHead(nn.Module):
'''
Projection head to be used with BERT CLS token, it's similar to `BertPredictionHeadTransform` in HuggingFace library.
:param config: CXRBertConfig
:return: (batch_size, output_size)
'''
def __init__(self, config: CXRBertConfig) -> None:
super().__init__()
self.dense_to_hidden = nn.Linear(config.hidden_size, config.projection_size)
self.transform_act_fn = nn.functional.gelu
self.LayerNorm = nn.LayerNorm(config.projection_size, eps=1e-12)
self.dense_to_output = nn.Linear(config.projection_size, config.projection_size)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense_to_hidden(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
hidden_states = self.dense_to_output(hidden_states)
return hidden_states
class CXRBertModel(BertForMaskedLM):
"""
Implements the CXR-BERT model outlined in the manuscript:
Boecking et al. "Making the Most of Text Semantics to Improve Biomedical Vision-Language Processing", 2022
https://arxiv.org/abs/2204.09817
Extends the HuggingFace BertForMaskedLM model by adding a separate projection head. The projection "[CLS]" token is used to align
the latent vectors of image and text modalities.
"""
config_class = CXRBertConfig
def __init__(self, config: CXRBertConfig):
super().__init__(config)
self.cls_projection_head = BertProjectionHead(config)
self.init_weights()
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
output_cls_projected_embedding: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs: Any
) -> Union[BERTTupleOutput, CXRBertOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
bert_for_masked_lm_output = super().forward(input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
output_attentions=output_attentions,
output_hidden_states=True,
return_dict=True)
last_hidden_state = bert_for_masked_lm_output.hidden_states[-1]
cls_projected_embedding = self.cls_projection_head(last_hidden_state[:, 0, :]) if output_cls_projected_embedding else None
if return_dict:
return CXRBertOutput(
last_hidden_state=last_hidden_state,
logits=bert_for_masked_lm_output.logits,
cls_projected_embedding=cls_projected_embedding,
hidden_states=bert_for_masked_lm_output.hidden_states if output_hidden_states else None,
attentions=bert_for_masked_lm_output.attentions,
)
else:
return (
last_hidden_state,
bert_for_masked_lm_output.logits,
cls_projected_embedding,
bert_for_masked_lm_output.hidden_states,
bert_for_masked_lm_output.attentions,)
def get_projected_text_embeddings(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
"""
Returns l2-normalised projected cls token embeddings for the given input token ids and attention mask.
The joint latent space is trained using a contrastive objective between image and text data modalities.
:param input_ids: (batch_size, sequence_length)
:param attention_mask: (batch_size, sequence_length)
:return: (batch_size, projection_size)
"""
outputs = self.forward(input_ids=input_ids, attention_mask=attention_mask,
output_cls_projected_embedding=True, return_dict=True)
assert isinstance(outputs, CXRBertOutput)
assert outputs.cls_projected_embedding is not None
normalized_cls_embedding = F.normalize(outputs.cls_projected_embedding, dim=1)
return normalized_cls_embedding
|