File size: 9,586 Bytes
1a597d0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 |
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PreTrainedModel
from collections.abc import Sequence
from monai.networks.blocks.patchembedding import PatchEmbeddingBlock
from monai.networks.blocks.transformerblock import TransformerBlock
try:
import torch.distributed.nn
from torch import distributed as dist
has_distributed = True
except ImportError:
has_distributed = False
from .configuration_m3d_clip import M3DCLIPConfig
from transformers import BertModel, BertConfig
def gather_features(
image_features,
text_features,
local_loss=False,
gather_with_grad=True,
rank=0,
world_size=1,
):
assert has_distributed, 'torch.distributed did not import correctly, please use a PyTorch version with support.'
# We gather tensors from all gpus
if gather_with_grad:
all_image_features = torch.cat(torch.distributed.nn.all_gather(image_features), dim=0)
all_text_features = torch.cat(torch.distributed.nn.all_gather(text_features), dim=0)
else:
gathered_image_features = [torch.zeros_like(image_features) for _ in range(world_size)]
gathered_text_features = [torch.zeros_like(text_features) for _ in range(world_size)]
dist.all_gather(gathered_image_features, image_features)
dist.all_gather(gathered_text_features, text_features)
if not local_loss:
# ensure grads for local rank when all_* features don't have a gradient
gathered_image_features[rank] = image_features
gathered_text_features[rank] = text_features
all_image_features = torch.cat(gathered_image_features, dim=0)
all_text_features = torch.cat(gathered_text_features, dim=0)
return all_image_features, all_text_features
class ViT(nn.Module):
"""
Vision Transformer (ViT), based on: "Dosovitskiy et al.,
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>"
ViT supports Torchscript but only works for Pytorch after 1.8.
"""
def __init__(
self,
in_channels: int,
img_size: Sequence[int] | int,
patch_size: Sequence[int] | int,
hidden_size: int = 768,
mlp_dim: int = 3072,
num_layers: int = 12,
num_heads: int = 12,
pos_embed: str = "conv",
classification: bool = False,
num_classes: int = 2,
dropout_rate: float = 0.0,
spatial_dims: int = 3,
post_activation="Tanh",
qkv_bias: bool = False,
save_attn: bool = False,
) -> None:
"""
Args:
in_channels (int): dimension of input channels.
img_size (Union[Sequence[int], int]): dimension of input image.
patch_size (Union[Sequence[int], int]): dimension of patch size.
hidden_size (int, optional): dimension of hidden layer. Defaults to 768.
mlp_dim (int, optional): dimension of feedforward layer. Defaults to 3072.
num_layers (int, optional): number of transformer blocks. Defaults to 12.
num_heads (int, optional): number of attention heads. Defaults to 12.
pos_embed (str, optional): position embedding layer type. Defaults to "conv".
classification (bool, optional): bool argument to determine if classification is used. Defaults to False.
num_classes (int, optional): number of classes if classification is used. Defaults to 2.
dropout_rate (float, optional): faction of the input units to drop. Defaults to 0.0.
spatial_dims (int, optional): number of spatial dimensions. Defaults to 3.
post_activation (str, optional): add a final acivation function to the classification head
when `classification` is True. Default to "Tanh" for `nn.Tanh()`.
Set to other values to remove this function.
qkv_bias (bool, optional): apply bias to the qkv linear layer in self attention block. Defaults to False.
save_attn (bool, optional): to make accessible the attention in self attention block. Defaults to False.
Examples::
# for single channel input with image size of (96,96,96), conv position embedding and segmentation backbone
>>> net = ViT(in_channels=1, img_size=(96,96,96), pos_embed='conv')
# for 3-channel with image size of (128,128,128), 24 layers and classification backbone
>>> net = ViT(in_channels=3, img_size=(128,128,128), pos_embed='conv', classification=True)
# for 3-channel with image size of (224,224), 12 layers and classification backbone
>>> net = ViT(in_channels=3, img_size=(224,224), pos_embed='conv', classification=True, spatial_dims=2)
"""
super().__init__()
if not (0 <= dropout_rate <= 1):
raise ValueError("dropout_rate should be between 0 and 1.")
if hidden_size % num_heads != 0:
raise ValueError("hidden_size should be divisible by num_heads.")
self.hidden_size = hidden_size
self.classification = classification
self.patch_embedding = PatchEmbeddingBlock(
in_channels=in_channels,
img_size=img_size,
patch_size=patch_size,
hidden_size=hidden_size,
num_heads=num_heads,
pos_embed=pos_embed,
dropout_rate=dropout_rate,
spatial_dims=spatial_dims,
)
self.blocks = nn.ModuleList(
[
TransformerBlock(hidden_size, mlp_dim, num_heads, dropout_rate, qkv_bias, save_attn)
for i in range(num_layers)
]
)
self.norm = nn.LayerNorm(hidden_size)
if self.classification:
self.cls_token = nn.Parameter(torch.zeros(1, 1, hidden_size))
# if post_activation == "Tanh":
# self.classification_head = nn.Sequential(nn.Linear(hidden_size, num_classes), nn.Tanh())
# else:
# self.classification_head = nn.Linear(hidden_size, num_classes) # type: ignore
def forward(self, x):
x = self.patch_embedding(x)
if hasattr(self, "cls_token"):
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
hidden_states_out = []
for blk in self.blocks:
x = blk(x)
hidden_states_out.append(x)
x = self.norm(x)
# if hasattr(self, "classification_head"):
# x = self.classification_head(x[:, 0])
return x, hidden_states_out
class M3DCLIP(PreTrainedModel):
config_class = M3DCLIPConfig
def __init__(self, config):
super().__init__(config)
self.vision_encoder = ViT(
in_channels=config.in_channels,
img_size=config.img_size,
patch_size=config.patch_size,
hidden_size=config.hidden_size,
mlp_dim=config.mlp_dim,
num_layers=config.num_layers,
num_heads=config.num_heads,
pos_embed=config.pos_embed,
dropout_rate=config.dropout_rate,
spatial_dims=config.spatial_dims,
classification=True,
)
# configuration = BertConfig()
# self.language_encoder = BertModel(configuration)
self.language_encoder = BertModel.from_pretrained(config.language_model_name_or_path)
self.mm_vision_proj = nn.Linear(config.hidden_size, config.hidden_size)
self.mm_language_proj = nn.Linear(config.hidden_size, config.hidden_size)
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.local_loss = config.local_loss
self.gather_loss = config.gather_loss
def encode_image(self, image):
image_feats, _ = self.vision_encoder(image)
image_feats = self.mm_vision_proj(image_feats)
image_feats = F.normalize(image_feats, dim=-1)
return image_feats
def encode_text(self, input_id, attention_mask):
text_feats = self.language_encoder(input_id, attention_mask=attention_mask)["last_hidden_state"]
text_feats = self.mm_language_proj(text_feats)
text_feats = F.normalize(text_feats, dim=-1)
return text_feats
def forward(self, images, input_ids, attention_mask, labels, **kwargs):
image_features = self.encode_image(images)[:, 0]
text_features = self.encode_text(input_ids, attention_mask)[:, 0]
if self.gather_loss:
all_image_features, all_text_features = gather_features(image_features, text_features)
if self.local_loss:
logits_per_image = self.logit_scale * image_features @ all_text_features.T
logits_per_text = self.logit_scale * text_features @ all_image_features.T
else:
logits_per_image = self.logit_scale * all_image_features @ all_text_features.T
logits_per_text = logits_per_image.T
else:
logits_per_image = self.logit_scale * image_features @ text_features.T
logits_per_text = self.logit_scale * text_features @ image_features.T
loss = (
F.cross_entropy(logits_per_image, labels) +
F.cross_entropy(logits_per_text, labels)
) / 2
ret = {
"loss": loss,
"logits": (logits_per_image + logits_per_text) / 2.0,
}
return ret
|