marcusinthesky commited on
Commit
5e3f7ec
·
1 Parent(s): 5b6ff34

Upload model

Browse files
Files changed (3) hide show
  1. config.json +183 -0
  2. modelling.py +181 -0
  3. pytorch_model.bin +3 -0
config.json ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_commit_hash": "092b10bbf4bc3d008a454897fba1141fb67c0b9e",
3
+ "_name_or_path": "flavour/vtde-dinov2-small-jina-embedding-t-en-v1",
4
+ "architectures": [
5
+ "VTDEModel"
6
+ ],
7
+ "auto_map": {
8
+ "AutoConfig": "modelling.VTDEConfig",
9
+ "AutoModelForZeroShotImageClassification": "modelling.VTDEModel"
10
+ },
11
+ "logit_scale_init_value": 2.6592,
12
+ "model_type": "vtde",
13
+ "projection_dim": 384,
14
+ "text_config": {
15
+ "_name_or_path": "jinaai/jina-embedding-t-en-v1",
16
+ "add_cross_attention": false,
17
+ "architectures": null,
18
+ "attention_probs_dropout_prob": 0.1,
19
+ "bad_words_ids": null,
20
+ "begin_suppress_tokens": null,
21
+ "bos_token_id": null,
22
+ "cell": {},
23
+ "chunk_size_feed_forward": 0,
24
+ "classifier_dropout": null,
25
+ "cross_attention_hidden_size": null,
26
+ "decoder_start_token_id": null,
27
+ "diversity_penalty": 0.0,
28
+ "do_sample": false,
29
+ "early_stopping": false,
30
+ "emb_size": 312,
31
+ "encoder_no_repeat_ngram_size": 0,
32
+ "eos_token_id": null,
33
+ "exponential_decay_length_penalty": null,
34
+ "finetuning_task": null,
35
+ "forced_bos_token_id": null,
36
+ "forced_eos_token_id": null,
37
+ "hidden_act": "gelu",
38
+ "hidden_dropout_prob": 0.1,
39
+ "hidden_size": 312,
40
+ "id2label": {
41
+ "0": "LABEL_0",
42
+ "1": "LABEL_1"
43
+ },
44
+ "initializer_range": 0.02,
45
+ "intermediate_size": 1200,
46
+ "is_decoder": false,
47
+ "is_encoder_decoder": false,
48
+ "label2id": {
49
+ "LABEL_0": 0,
50
+ "LABEL_1": 1
51
+ },
52
+ "layer_norm_eps": 1e-12,
53
+ "length_penalty": 1.0,
54
+ "max_length": 20,
55
+ "max_position_embeddings": 512,
56
+ "min_length": 0,
57
+ "model_type": "bert",
58
+ "no_repeat_ngram_size": 0,
59
+ "num_attention_heads": 12,
60
+ "num_beam_groups": 1,
61
+ "num_beams": 1,
62
+ "num_hidden_layers": 4,
63
+ "num_return_sequences": 1,
64
+ "output_attentions": false,
65
+ "output_hidden_states": false,
66
+ "output_scores": false,
67
+ "pad_token_id": 0,
68
+ "position_embedding_type": "absolute",
69
+ "pre_trained": "",
70
+ "prefix": null,
71
+ "problem_type": null,
72
+ "pruned_heads": {},
73
+ "remove_invalid_values": false,
74
+ "repetition_penalty": 1.0,
75
+ "return_dict": true,
76
+ "return_dict_in_generate": false,
77
+ "sep_token_id": null,
78
+ "structure": [],
79
+ "suppress_tokens": null,
80
+ "task_specific_params": null,
81
+ "temperature": 1.0,
82
+ "tf_legacy_loss": false,
83
+ "tie_encoder_decoder": false,
84
+ "tie_word_embeddings": true,
85
+ "tokenizer_class": null,
86
+ "top_k": 50,
87
+ "top_p": 1.0,
88
+ "torch_dtype": null,
89
+ "torchscript": false,
90
+ "transformers_version": "4.32.0.dev0",
91
+ "type_vocab_size": 2,
92
+ "typical_p": 1.0,
93
+ "use_bfloat16": false,
94
+ "use_cache": true,
95
+ "vocab_size": 30522
96
+ },
97
+ "text_pooling_mode": "mean",
98
+ "torch_dtype": "float32",
99
+ "transformers_version": null,
100
+ "vision_config": {
101
+ "_name_or_path": "facebook/dinov2-small",
102
+ "add_cross_attention": false,
103
+ "architectures": [
104
+ "Dinov2Model"
105
+ ],
106
+ "attention_probs_dropout_prob": 0.0,
107
+ "bad_words_ids": null,
108
+ "begin_suppress_tokens": null,
109
+ "bos_token_id": null,
110
+ "chunk_size_feed_forward": 0,
111
+ "cross_attention_hidden_size": null,
112
+ "decoder_start_token_id": null,
113
+ "diversity_penalty": 0.0,
114
+ "do_sample": false,
115
+ "drop_path_rate": 0.0,
116
+ "early_stopping": false,
117
+ "encoder_no_repeat_ngram_size": 0,
118
+ "eos_token_id": null,
119
+ "exponential_decay_length_penalty": null,
120
+ "finetuning_task": null,
121
+ "forced_bos_token_id": null,
122
+ "forced_eos_token_id": null,
123
+ "hidden_act": "gelu",
124
+ "hidden_dropout_prob": 0.0,
125
+ "hidden_size": 384,
126
+ "id2label": {
127
+ "0": "LABEL_0",
128
+ "1": "LABEL_1"
129
+ },
130
+ "image_size": 518,
131
+ "initializer_range": 0.02,
132
+ "is_decoder": false,
133
+ "is_encoder_decoder": false,
134
+ "label2id": {
135
+ "LABEL_0": 0,
136
+ "LABEL_1": 1
137
+ },
138
+ "layer_norm_eps": 1e-06,
139
+ "layerscale_value": 1.0,
140
+ "length_penalty": 1.0,
141
+ "max_length": 20,
142
+ "min_length": 0,
143
+ "mlp_ratio": 4,
144
+ "model_type": "dinov2",
145
+ "no_repeat_ngram_size": 0,
146
+ "num_attention_heads": 6,
147
+ "num_beam_groups": 1,
148
+ "num_beams": 1,
149
+ "num_channels": 3,
150
+ "num_hidden_layers": 12,
151
+ "num_return_sequences": 1,
152
+ "output_attentions": false,
153
+ "output_hidden_states": false,
154
+ "output_scores": false,
155
+ "pad_token_id": null,
156
+ "patch_size": 14,
157
+ "prefix": null,
158
+ "problem_type": null,
159
+ "pruned_heads": {},
160
+ "qkv_bias": true,
161
+ "remove_invalid_values": false,
162
+ "repetition_penalty": 1.0,
163
+ "return_dict": true,
164
+ "return_dict_in_generate": false,
165
+ "sep_token_id": null,
166
+ "suppress_tokens": null,
167
+ "task_specific_params": null,
168
+ "temperature": 1.0,
169
+ "tf_legacy_loss": false,
170
+ "tie_encoder_decoder": false,
171
+ "tie_word_embeddings": true,
172
+ "tokenizer_class": null,
173
+ "top_k": 50,
174
+ "top_p": 1.0,
175
+ "torch_dtype": "float32",
176
+ "torchscript": false,
177
+ "transformers_version": "4.32.0.dev0",
178
+ "typical_p": 1.0,
179
+ "use_bfloat16": false,
180
+ "use_swiglu_ffn": false
181
+ },
182
+ "vision_pooling_mode": "max"
183
+ }
modelling.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AUTOGENERATED! DO NOT EDIT! File to edit: ../notebooks/12_modelling.ipynb.
2
+
3
+ # %% auto 0
4
+ __all__ = ['VTDEConfig', 'VTDEModel']
5
+
6
+ # %% ../notebooks/12_modelling.ipynb 1
7
+ from transformers.models.clip.modeling_clip import CLIPOutput, clip_loss
8
+ from typing import Optional, Tuple, Union
9
+ from transformers import PreTrainedModel, VisionTextDualEncoderModel
10
+ import torch
11
+ from transformers import VisionTextDualEncoderConfig
12
+
13
+ class VTDEConfig(VisionTextDualEncoderConfig):
14
+ model_type = "vtde"
15
+
16
+ def __init__(self, projection_dim=512, logit_scale_init_value=2.6592,
17
+ text_pooling_mode='mean',
18
+ vision_pooling_mode='max',
19
+ **kwargs):
20
+ """
21
+ pooling_mode in ['mean', 'max', 'cls']
22
+ https://arxiv.org/pdf/2210.09996.pdf
23
+ https://github.com/kahnchana/clippy/blob/3c102c29c32f7c66c6e52e09b795fe9c061bbb03/src/open_clip/hf_model.py#L56
24
+ also
25
+ https://arxiv.org/pdf/2301.07836.pdf
26
+ """
27
+ self.text_pooling_mode = text_pooling_mode
28
+ self.vision_pooling_mode = vision_pooling_mode
29
+ super().__init__(projection_dim, logit_scale_init_value, **kwargs)
30
+
31
+ VTDEConfig.register_for_auto_class()
32
+
33
+
34
+ class VTDEModel(VisionTextDualEncoderModel):
35
+ config_class = VTDEConfig
36
+ base_model_prefix = "vtde"
37
+
38
+ def __init__(
39
+ self,
40
+ config: Optional[VTDEConfig] = None,
41
+ vision_model: Optional[PreTrainedModel] = None,
42
+ text_model: Optional[PreTrainedModel] = None,
43
+ ):
44
+ # You can customize the constructor if needed
45
+ super().__init__(config, vision_model, text_model)
46
+ self.text_pooling_mode = config.text_pooling_mode
47
+ self.vision_pooling_mode = config.vision_pooling_mode
48
+
49
+ def get_text_features(
50
+ self,
51
+ input_ids=None,
52
+ attention_mask=None,
53
+ position_ids=None,
54
+ token_type_ids=None,
55
+ output_attentions=None,
56
+ output_hidden_states=None,
57
+ return_dict=None,
58
+ ):
59
+ text_outputs = self.text_model(
60
+ input_ids=input_ids,
61
+ attention_mask=attention_mask,
62
+ token_type_ids=token_type_ids,
63
+ position_ids=position_ids,
64
+ output_attentions=output_attentions,
65
+ output_hidden_states=output_hidden_states,
66
+ return_dict=return_dict,
67
+ )
68
+ if self.text_pooling_mode == 'cls':
69
+ pooled_output = text_outputs[1]
70
+ elif self.text_pooling_mode == 'mean':
71
+ pooled_output = torch.mean(text_outputs[0], dim=1)
72
+ elif self.text_pooling_mode == 'max':
73
+ pooled_output = torch.max(text_outputs[0], dim=1)[0]
74
+ elif self.text_pooling_mode == 'norm':
75
+ """we select the patch with the largest norm"""
76
+ last_hidden_states = text_outputs[0]
77
+ patch_norms = torch.norm(last_hidden_states[:, 1:, :], dim=-1)
78
+ max_norm_idx = torch.argmax(patch_norms, dim=1)
79
+ pooled_output = last_hidden_states[:, max_norm_idx, :][:, 0, :]
80
+ else:
81
+ "We want to raise the name of the pooling mode"
82
+ raise NotImplementedError
83
+
84
+ text_features = self.text_projection(pooled_output)
85
+
86
+ return text_features
87
+
88
+ def get_image_features(
89
+ self,
90
+ pixel_values=None,
91
+ output_attentions=None,
92
+ output_hidden_states=None,
93
+ return_dict=None,
94
+ ):
95
+ vision_outputs = self.vision_model(
96
+ pixel_values=pixel_values,
97
+ output_attentions=output_attentions,
98
+ output_hidden_states=output_hidden_states,
99
+ return_dict=return_dict,
100
+ )
101
+
102
+ if self.vision_pooling_mode == 'cls':
103
+ pooled_output = vision_outputs[1]
104
+ elif self.vision_pooling_mode == 'mean':
105
+ pooled_output = torch.mean(vision_outputs[0], dim=1)
106
+ elif self.vision_pooling_mode == 'max':
107
+ pooled_output = torch.max(vision_outputs[0], dim=1)[0]
108
+ elif self.vision_pooling_mode == 'norm':
109
+ """we select the patch with the largest norm"""
110
+ last_hidden_states = vision_outputs[0]
111
+ patch_norms = torch.norm(last_hidden_states[:, 1:, :], dim=-1)
112
+ max_norm_idx = torch.argmax(patch_norms, dim=1)
113
+ pooled_output = last_hidden_states[:, max_norm_idx, :][:, 0, :]
114
+ else:
115
+ raise NotImplementedError
116
+
117
+ image_features = self.visual_projection(pooled_output)
118
+
119
+ return image_features
120
+
121
+ def forward(
122
+ self,
123
+ input_ids: Optional[torch.LongTensor] = None,
124
+ pixel_values: Optional[torch.FloatTensor] = None,
125
+ attention_mask: Optional[torch.Tensor] = None,
126
+ position_ids: Optional[torch.LongTensor] = None,
127
+ return_loss: Optional[bool] = None,
128
+ token_type_ids: Optional[torch.LongTensor] = None,
129
+ output_attentions: Optional[bool] = None,
130
+ output_hidden_states: Optional[bool] = None,
131
+ return_dict: Optional[bool] = None,
132
+ ) -> Union[Tuple[torch.Tensor], CLIPOutput]:
133
+
134
+ return_dict = return_dict if return_dict is not None else self.config.return_dict
135
+
136
+ image_embeds = self.get_image_features(
137
+ pixel_values=pixel_values,
138
+ output_attentions=output_attentions,
139
+ output_hidden_states=output_hidden_states,
140
+ return_dict=return_dict,
141
+ )
142
+
143
+ text_embeds = self.get_text_features(
144
+ input_ids=input_ids,
145
+ attention_mask=attention_mask,
146
+ position_ids=position_ids,
147
+ output_attentions=output_attentions,
148
+ output_hidden_states=output_hidden_states,
149
+ return_dict=return_dict,
150
+ )
151
+
152
+ # normalized features
153
+ image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
154
+ text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
155
+
156
+ # cosine similarity as logits
157
+ logit_scale = self.logit_scale.exp()
158
+ logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
159
+ logits_per_image = logits_per_text.T
160
+
161
+ loss = None
162
+ if return_loss:
163
+ loss = clip_loss(logits_per_text)
164
+
165
+ if not return_dict:
166
+ output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_embeds, image_embeds)
167
+ return ((loss,) + output) if loss is not None else output
168
+
169
+ return CLIPOutput(
170
+ loss=loss,
171
+ logits_per_image=logits_per_image,
172
+ logits_per_text=logits_per_text,
173
+ text_embeds=text_embeds,
174
+ image_embeds=image_embeds,
175
+ text_model_output=text_embeds,
176
+ vision_model_output=image_embeds,
177
+ )
178
+
179
+
180
+ VTDEModel.register_for_auto_class("AutoModel")
181
+ VTDEModel.register_for_auto_class("AutoModelForZeroShotImageClassification")
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:122ed6b3647aba3eb34884797aa201d2a1480e7aad1d57b2b39e45186f9099a0
3
+ size 147385401