Gengzigang commited on
Commit
3831a1e
·
1 Parent(s): 9d797ea
Files changed (1) hide show
  1. modeling_clip.py +108 -60
modeling_clip.py CHANGED
@@ -39,7 +39,6 @@ from transformers.utils import (
39
  )
40
  from .configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
41
 
42
-
43
  if is_flash_attn_2_available():
44
  from transformers.modeling_flash_attention_utils import _flash_attention_forward
45
 
@@ -603,16 +602,15 @@ class CLIPPreTrainedModel(PreTrainedModel):
603
  fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
604
  nn.init.normal_(module.fc1.weight, std=fc_std)
605
  nn.init.normal_(module.fc2.weight, std=in_proj_std)
606
- elif isinstance(module, CLIPModel):
607
- pass
608
  # nn.init.normal_(
609
  # module.text_projection.weight,
610
  # std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
611
  # )
612
- # nn.init.normal_(
613
- # module.visual_projection.weight,
614
- # std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
615
- # )
616
  elif isinstance(module, CLIPVisionModelWithProjection):
617
  nn.init.normal_(
618
  module.visual_projection.weight,
@@ -1112,80 +1110,97 @@ class CLIPVisionModel(CLIPPreTrainedModel):
1112
 
1113
 
1114
  @add_start_docstrings(CLIP_START_DOCSTRING)
1115
- class CLIPModel(CLIPPreTrainedModel):
1116
  config_class = CLIPConfig
1117
  _no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer", "CLIPVisionEmbeddings"]
1118
 
1119
  def __init__(self, config: CLIPConfig):
1120
  super().__init__(config)
 
 
 
 
 
 
1121
  if not isinstance(config.vision_config, CLIPVisionConfig):
1122
  raise TypeError(
1123
  "config.vision_config is expected to be of type CLIPVisionConfig but is of type"
1124
  f" {type(config.vision_config)}."
1125
  )
1126
 
 
1127
  vision_config = config.vision_config
1128
 
1129
  self.projection_dim = config.projection_dim
 
1130
  self.vision_embed_dim = vision_config.hidden_size
 
 
 
 
 
 
1131
 
1132
  vision_model = CLIPVisionModel._from_config(vision_config, attn_implementation=config._attn_implementation)
1133
  self.vision_model = vision_model.vision_model
1134
 
1135
- # self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
1136
- scale = self.vision_embed_dim ** -0.5
1137
- self.visual_projection = nn.Parameter(scale * torch.randn(self.vision_embed_dim, self.projection_dim))
1138
  self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
1139
 
1140
  # Initialize weights and apply final processing
1141
  self.post_init()
1142
-
1143
- @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
1144
- def get_text_features(
1145
- self,
1146
- input_ids: Optional[torch.Tensor] = None,
1147
- attention_mask: Optional[torch.Tensor] = None,
1148
- position_ids: Optional[torch.Tensor] = None,
1149
- output_attentions: Optional[bool] = None,
1150
- output_hidden_states: Optional[bool] = None,
1151
- return_dict: Optional[bool] = None,
1152
- ) -> torch.FloatTensor:
1153
- r"""
1154
- Returns:
1155
- text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
1156
- applying the projection layer to the pooled output of [`CLIPTextModel`].
1157
-
1158
- Examples:
1159
-
1160
- ```python
1161
- >>> from transformers import AutoTokenizer, CLIPModel
1162
-
1163
- >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
1164
- >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
1165
-
1166
- >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
1167
- >>> text_features = model.get_text_features(**inputs)
1168
- ```"""
1169
- # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
1170
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1171
- output_hidden_states = (
1172
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1173
- )
1174
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1175
-
1176
- text_outputs = self.text_model(
1177
- input_ids=input_ids,
1178
- attention_mask=attention_mask,
1179
- position_ids=position_ids,
1180
- output_attentions=output_attentions,
1181
- output_hidden_states=output_hidden_states,
1182
- return_dict=return_dict,
1183
- )
1184
-
1185
- pooled_output = text_outputs[1]
1186
- text_features = self.text_projection(pooled_output)
1187
-
1188
- return text_features
 
 
 
 
1189
 
1190
  @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
1191
  def get_image_features(
@@ -1232,7 +1247,7 @@ class CLIPModel(CLIPPreTrainedModel):
1232
  )
1233
 
1234
  pooled_output = vision_outputs[1] # pooled_output
1235
- image_features = pooled_output @ self.visual_projection
1236
 
1237
  return image_features
1238
 
@@ -1413,7 +1428,40 @@ class CLIPTextModelWithProjection(CLIPPreTrainedModel):
1413
  attentions=text_outputs.attentions,
1414
  )
1415
 
1416
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1417
  @add_start_docstrings(
1418
  """
1419
  CLIP Vision Model with a projection layer on top (a linear layer on top of the pooled output).
 
39
  )
40
  from .configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
41
 
 
42
  if is_flash_attn_2_available():
43
  from transformers.modeling_flash_attention_utils import _flash_attention_forward
44
 
 
602
  fc_std = (2 * module.config.hidden_size) ** -0.5 * factor
603
  nn.init.normal_(module.fc1.weight, std=fc_std)
604
  nn.init.normal_(module.fc2.weight, std=in_proj_std)
605
+ elif isinstance(module, LLM2CLIPModel):
 
606
  # nn.init.normal_(
607
  # module.text_projection.weight,
608
  # std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
609
  # )
610
+ nn.init.normal_(
611
+ module.visual_projection.weight,
612
+ std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
613
+ )
614
  elif isinstance(module, CLIPVisionModelWithProjection):
615
  nn.init.normal_(
616
  module.visual_projection.weight,
 
1110
 
1111
 
1112
  @add_start_docstrings(CLIP_START_DOCSTRING)
1113
+ class LLM2CLIPModel(CLIPPreTrainedModel):
1114
  config_class = CLIPConfig
1115
  _no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer", "CLIPVisionEmbeddings"]
1116
 
1117
  def __init__(self, config: CLIPConfig):
1118
  super().__init__(config)
1119
+ # if not isinstance(config.text_config, CLIPTextConfig):
1120
+ # raise TypeError(
1121
+ # "config.text_config is expected to be of type CLIPTextConfig but is of type"
1122
+ # f" {type(config.text_config)}."
1123
+ # )
1124
+
1125
  if not isinstance(config.vision_config, CLIPVisionConfig):
1126
  raise TypeError(
1127
  "config.vision_config is expected to be of type CLIPVisionConfig but is of type"
1128
  f" {type(config.vision_config)}."
1129
  )
1130
 
1131
+ # text_config = config.text_config
1132
  vision_config = config.vision_config
1133
 
1134
  self.projection_dim = config.projection_dim
1135
+ # self.text_embed_dim = text_config.hidden_size
1136
  self.vision_embed_dim = vision_config.hidden_size
1137
+
1138
+ adapter = LLM2CLIP_Adapter()
1139
+ self.text_adapter = adapter
1140
+
1141
+ # text_model = CLIPTextModel._from_config(text_config, attn_implementation=config._attn_implementation)
1142
+ # self.text_model = text_model.text_model
1143
 
1144
  vision_model = CLIPVisionModel._from_config(vision_config, attn_implementation=config._attn_implementation)
1145
  self.vision_model = vision_model.vision_model
1146
 
1147
+ self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False)
1148
+ # self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False)
 
1149
  self.logit_scale = nn.Parameter(torch.tensor(self.config.logit_scale_init_value))
1150
 
1151
  # Initialize weights and apply final processing
1152
  self.post_init()
1153
+
1154
+ def get_text_features(self, inputs):
1155
+ #TODO: make this more flexible and configurable
1156
+ return self.text_adapter(inputs)
1157
+
1158
+ # @add_start_docstrings_to_model_forward(CLIP_TEXT_INPUTS_DOCSTRING)
1159
+ # def get_text_features(
1160
+ # self,
1161
+ # input_ids: Optional[torch.Tensor] = None,
1162
+ # attention_mask: Optional[torch.Tensor] = None,
1163
+ # position_ids: Optional[torch.Tensor] = None,
1164
+ # output_attentions: Optional[bool] = None,
1165
+ # output_hidden_states: Optional[bool] = None,
1166
+ # return_dict: Optional[bool] = None,
1167
+ # ) -> torch.FloatTensor:
1168
+ # r"""
1169
+ # Returns:
1170
+ # text_features (`torch.FloatTensor` of shape `(batch_size, output_dim`): The text embeddings obtained by
1171
+ # applying the projection layer to the pooled output of [`CLIPTextModel`].
1172
+
1173
+ # Examples:
1174
+
1175
+ # ```python
1176
+ # >>> from transformers import AutoTokenizer, CLIPModel
1177
+
1178
+ # >>> model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
1179
+ # >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
1180
+
1181
+ # >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
1182
+ # >>> text_features = model.get_text_features(**inputs)
1183
+ # ```"""
1184
+ # # Use CLIP model's config for some fields (if specified) instead of those of vision & text components.
1185
+ # output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1186
+ # output_hidden_states = (
1187
+ # output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1188
+ # )
1189
+ # return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1190
+
1191
+ # text_outputs = self.text_model(
1192
+ # input_ids=input_ids,
1193
+ # attention_mask=attention_mask,
1194
+ # position_ids=position_ids,
1195
+ # output_attentions=output_attentions,
1196
+ # output_hidden_states=output_hidden_states,
1197
+ # return_dict=return_dict,
1198
+ # )
1199
+
1200
+ # pooled_output = text_outputs[1]
1201
+ # text_features = self.text_projection(pooled_output)
1202
+
1203
+ # return text_features
1204
 
1205
  @add_start_docstrings_to_model_forward(CLIP_VISION_INPUTS_DOCSTRING)
1206
  def get_image_features(
 
1247
  )
1248
 
1249
  pooled_output = vision_outputs[1] # pooled_output
1250
+ image_features = self.visual_projection(pooled_output)
1251
 
1252
  return image_features
1253
 
 
1428
  attentions=text_outputs.attentions,
1429
  )
1430
 
1431
+ class LinearBlock(nn.Module):
1432
+ def __init__(self, dim, expansion_factor=4, dropout=0.,norm_layer=nn.LayerNorm):
1433
+ super().__init__()
1434
+ self.fn = nn.Sequential(
1435
+ nn.Linear(dim, int(expansion_factor * dim)),
1436
+ nn.GELU(),
1437
+ nn.Dropout(dropout),
1438
+ nn.Linear(int(expansion_factor * dim), dim),
1439
+ )
1440
+ self.ln = norm_layer(dim)
1441
+
1442
+ def forward(self, x):
1443
+ return x + self.fn(self.ln(x))
1444
+
1445
+ class LLM2CLIP_Adapter(nn.Module):
1446
+ def __init__(self):
1447
+ super().__init__()
1448
+ #TODO: make this more flexible and configurable
1449
+ # hard-coded values from the LLM2CLIP model
1450
+ text_embedding_dim = 4096
1451
+ expansion_factor = 2
1452
+ adaptor_num_layers = 4
1453
+ proj_bias = True
1454
+ output_dim = 1280
1455
+ self.adaptor = nn.Sequential(
1456
+ *[LinearBlock(text_embedding_dim, expansion_factor) for _ in range(adaptor_num_layers)],
1457
+ nn.LayerNorm(text_embedding_dim),
1458
+ nn.Linear(text_embedding_dim, output_dim, bias=proj_bias),
1459
+ )
1460
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1461
+ hidden_states = torch.nn.functional.normalize(hidden_states, p=2, dim=1)
1462
+ hidden_states = self.adaptor(hidden_states)
1463
+ return hidden_states
1464
+
1465
  @add_start_docstrings(
1466
  """
1467
  CLIP Vision Model with a projection layer on top (a linear layer on top of the pooled output).