tomaarsen HF staff commited on
Commit
a2095b8
1 Parent(s): 18ae8db

Add new SentenceTransformer model.

Browse files
README.md ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - feature-extraction
4
+ - sentence-similarity
5
+ - mteb
6
+ - clip
7
+ - vision
8
+ - transformers.js
9
+ language: en
10
+ inference: false
11
+ license: apache-2.0
12
+ library_name: transformers
13
+ ---
14
+
15
+ <br><br>
16
+
17
+ <p align="center">
18
+ <img src="https://aeiljuispo.cloudimg.io/v7/https://cdn-uploads.huggingface.co/production/uploads/603763514de52ff951d89793/AFoybzd5lpBQXEBrQHuTt.png?w=200&h=200&f=face" alt="Finetuner logo: Finetuner helps you to create experiments in order to improve embeddings on search tasks. It accompanies you to deliver the last mile of performance-tuning for neural search applications." width="150px">
19
+ </p>
20
+
21
+
22
+ <p align="center">
23
+ <b>The embedding set trained by <a href="https://jina.ai/"><b>Jina AI</b></a>.</b>
24
+ </p>
25
+
26
+ <p align="center">
27
+ <b>Jina CLIP: your CLIP model is also your text retriever!</b>
28
+ </p>
29
+
30
+
31
+ ## Intended Usage & Model Info
32
+
33
+ `jina-clip-v1` is a state-of-the-art English **multimodal (text-image) embedding model**.
34
+
35
+ Traditional text embedding models, such as [jina-embeddings-v2-base-en](https://huggingface.co/jinaai/jina-embeddings-v2-base-en), excel in text-to-text retrieval but incapable of cross-modal tasks. Models like [openai/clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) effectively align image and text embeddings but are not optimized for text-to-text retrieval due to their training methodologies and context limitations.
36
+
37
+ `jina-clip-v1` bridges this gap by offering robust performance in both domains.
38
+ Its text component matches the retrieval efficiency of `jina-embeddings-v2-base-en`, while its overall architecture sets a new benchmark for cross-modal retrieval.
39
+ This dual capability makes it an excellent tool for multimodal retrieval-augmented generation (MuRAG) applications, enabling seamless text-to-text and text-to-image searches within a single model.
40
+
41
+
42
+ ## Data & Parameters
43
+
44
+ [Check out our paper](https://arxiv.org/abs/2405.20204)
45
+
46
+ ## Usage
47
+
48
+ 1. The easiest way to starting using jina-clip-v1-en is to use Jina AI's [Embeddings API](https://jina.ai/embeddings/).
49
+ 2. Alternatively, you can use Jina CLIP directly via transformers package.
50
+
51
+ ```python
52
+ !pip install transformers einops timm pillow
53
+ from transformers import AutoModel
54
+
55
+ # Initialize the model
56
+ model = AutoModel.from_pretrained('jinaai/jina-clip-v1', trust_remote_code=True)
57
+
58
+ # New meaningful sentences
59
+ sentences = ['A blue cat', 'A red cat']
60
+
61
+ # Public image URLs
62
+ image_urls = [
63
+ 'https://i.pinimg.com/600x315/21/48/7e/21487e8e0970dd366dafaed6ab25d8d8.jpg',
64
+ 'https://i.pinimg.com/736x/c9/f2/3e/c9f23e212529f13f19bad5602d84b78b.jpg'
65
+ ]
66
+
67
+ # Encode text and images
68
+ text_embeddings = model.encode_text(sentences)
69
+ image_embeddings = model.encode_image(image_urls) # also accepts PIL.image, local filenames, dataURI
70
+
71
+ # Compute similarities
72
+ print(text_embeddings[0] @ text_embeddings[1].T) # text embedding similarity
73
+ print(text_embeddings[0] @ image_embeddings[0].T) # text-image cross-modal similarity
74
+ print(text_embeddings[0] @ image_embeddings[1].T) # text-image cross-modal similarity
75
+ print(text_embeddings[1] @ image_embeddings[0].T) # text-image cross-modal similarity
76
+ print(text_embeddings[1] @ image_embeddings[1].T)# text-image cross-modal similarity
77
+ ```
78
+
79
+ 3. JavaScript developers can use Jina CLIP via the [Transformers.js](https://huggingface.co/docs/transformers.js) library. Note that to use this model, you need to install Transformers.js [v3](https://github.com/xenova/transformers.js/tree/v3) from source using `npm install xenova/transformers.js#v3`.
80
+
81
+ ```js
82
+ import { AutoTokenizer, CLIPTextModelWithProjection, AutoProcessor, CLIPVisionModelWithProjection, RawImage, cos_sim } from '@xenova/transformers';
83
+
84
+ // Load tokenizer and text model
85
+ const tokenizer = await AutoTokenizer.from_pretrained('jinaai/jina-clip-v1');
86
+ const text_model = await CLIPTextModelWithProjection.from_pretrained('jinaai/jina-clip-v1');
87
+
88
+ // Load processor and vision model
89
+ const processor = await AutoProcessor.from_pretrained('Xenova/clip-vit-base-patch32');
90
+ const vision_model = await CLIPVisionModelWithProjection.from_pretrained('jinaai/jina-clip-v1');
91
+
92
+ // Run tokenization
93
+ const texts = ['A blue cat', 'A red cat'];
94
+ const text_inputs = tokenizer(texts, { padding: true, truncation: true });
95
+
96
+ // Compute text embeddings
97
+ const { text_embeds } = await text_model(text_inputs);
98
+
99
+ // Read images and run processor
100
+ const urls = [
101
+ 'https://i.pinimg.com/600x315/21/48/7e/21487e8e0970dd366dafaed6ab25d8d8.jpg',
102
+ 'https://i.pinimg.com/736x/c9/f2/3e/c9f23e212529f13f19bad5602d84b78b.jpg'
103
+ ];
104
+ const image = await Promise.all(urls.map(url => RawImage.read(url)));
105
+ const image_inputs = await processor(image);
106
+
107
+ // Compute vision embeddings
108
+ const { image_embeds } = await vision_model(image_inputs);
109
+
110
+ // Compute similarities
111
+ console.log(cos_sim(text_embeds[0].data, text_embeds[1].data)) // text embedding similarity
112
+ console.log(cos_sim(text_embeds[0].data, image_embeds[0].data)) // text-image cross-modal similarity
113
+ console.log(cos_sim(text_embeds[0].data, image_embeds[1].data)) // text-image cross-modal similarity
114
+ console.log(cos_sim(text_embeds[1].data, image_embeds[0].data)) // text-image cross-modal similarity
115
+ console.log(cos_sim(text_embeds[1].data, image_embeds[1].data)) // text-image cross-modal similarity
116
+ ```
117
+
118
+ ## Performance
119
+
120
+ ### Text-Image Retrieval
121
+
122
+ | Name | Flickr Image Retr. R@1 | Flickr Image Retr. R@5 | Flickr Text Retr. R@1 | Flickr Text Retr. R@5 |
123
+ |------------------|-------------------------|-------------------------|-----------------------|-----------------------|
124
+ | ViT-B-32 | 0.597 | 0.8398 | 0.781 | 0.938 |
125
+ | ViT-B-16 | 0.6216 | 0.8572 | 0.822 | 0.966 |
126
+ | jina-clip | 0.6748 | 0.8902 | 0.811 | 0.965 |
127
+
128
+
129
+ | Name | MSCOCO Image Retr. R@1 | MSCOCO Image Retr. R@5 | MSCOCO Text Retr. R@1 | MSCOCO Text Retr. R@5 |
130
+ |------------------|-------------------------|-------------------------|-----------------------|-----------------------|
131
+ | ViT-B-32 | 0.342 | 0.6001 | 0.5234 | 0.7634 |
132
+ | ViT-B-16 | 0.3309 | 0.5842 | 0.5242 | 0.767 |
133
+ | jina-clip | 0.4111 | 0.6644 | 0.5544 | 0.7904 |
134
+
135
+ ### Text-Text Retrieval
136
+
137
+ | Name | STS12 | STS15 | STS17 | STS13 | STS14 | STS16 | STS22 | STSBenchmark | SummEval |
138
+ |-----------------------|--------|--------|--------|--------|--------|--------|--------|--------------|----------|
139
+ | jina-embeddings-v2 | 0.7427 | 0.8755 | 0.8888 | 0.833 | 0.7917 | 0.836 | 0.6346 | 0.8404 | 0.3056 |
140
+ | jina-clip | 0.7352 | 0.8746 | 0.8976 | 0.8323 | 0.7868 | 0.8377 | 0.6583 | 0.8493 | 0.3048 |
141
+
142
+
143
+ | Name | ArguAna | FiQA2018 | NFCorpus | Quora | SCIDOCS | SciFact | TRECCOVID |
144
+ |--------------------|---------|----------|----------|-------|---------|---------|-----------|
145
+ | jina-embeddings-v2 | 0.4418 | 0.4158 | 0.3245 | 0.882 | 0.1986 | 0.6668 | 0.6591 |
146
+ | jina-clip | 0.4933 | 0.3827 | 0.3352 | 0.8789| 0.2024 | 0.6734 | 0.7161 |
147
+
148
+ ## Contact
149
+
150
+ Join our [Discord community](https://discord.jina.ai) and chat with other community members about ideas.
151
+
152
+ ## Citation
153
+
154
+ If you find `jina-clip-v1` useful in your research, please cite the following paper:
155
+
156
+ ```bibtex
157
+ @misc{2405.20204,
158
+ Author = {Andreas Koukounas and Georgios Mastrapas and Michael Günther and Bo Wang and Scott Martens and Isabelle Mohr and Saba Sturua and Mohammad Kalim Akram and Joan Fontanals Martínez and Saahil Ognawala and Susana Guzman and Maximilian Werk and Nan Wang and Han Xiao},
159
+ Title = {Jina CLIP: Your CLIP Model Is Also Your Text Retriever},
160
+ Year = {2024},
161
+ Eprint = {arXiv:2405.20204},
162
+ }
163
+ ```
164
+
165
+ ## FAQ
166
+
167
+ ### I encounter this problem, what should I do?
168
+
169
+ ```
170
+ ValueError: The model class you are passing has a `config_class` attribute that is not consistent with the config class you passed (model has <class 'transformers_modules.jinaai.jina-clip-implementation.7f069e2d54d609ef1ad2eb578c7bf07b5a51de41.configuration_clip.JinaCLIPConfig'> and you passed <class 'transformers_modules.jinaai.jina-clip-implementation.7f069e2d54d609ef1ad2eb578c7bf07b5a51de41.configuration_cli.JinaCLIPConfig'>. Fix one of those so they match!
171
+ ```
172
+
173
+ There was a bug in Transformers library between 4.40.x to 4.41.1. You can update transformers to >4.41.2 or <=4.40.0
174
+
175
+ ### Given one query, how can I merge its text-text and text-image cosine similarity?
176
+
177
+ Our emperical study shows that text-text cosine similarity is normally larger than text-image cosine similarity!
178
+ If you want to merge two scores, we recommended 2 ways:
179
+
180
+ 1. weighted average of text-text sim and text-image sim:
181
+
182
+ ```python
183
+ combined_scores = sim(text, text) + lambda * sim(text, image) # optimal lambda depends on your dataset, but in general lambda=2 can be a good choice.
184
+ ```
185
+
186
+ 2. apply z-score normalization before merging scores:
187
+
188
+ ```python
189
+ # pseudo code
190
+ query_document_mean = np.mean(cos_sim_text_texts)
191
+ query_document_std = np.std(cos_sim_text_texts)
192
+ text_image_mean = np.mean(cos_sim_text_images)
193
+ text_image_std = np.std(cos_sim_text_images)
194
+
195
+ query_document_sim_normalized = (cos_sim_query_documents - query_document_mean) / query_document_std
196
+ text_image_sim_normalized = (cos_sim_text_images - text_image_mean) / text_image_std
197
+ ```
config.json ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_commit_hash": null,
3
+ "_name_or_path": "jina-clip-v1-remote",
4
+ "add_projections": false,
5
+ "architectures": [
6
+ "JinaCLIPModel"
7
+ ],
8
+ "auto_map": {
9
+ "AutoConfig": "configuration_clip.JinaCLIPConfig",
10
+ "AutoModel": "modeling_clip.JinaCLIPModel"
11
+ },
12
+ "initializer_factor": 1.0,
13
+ "logit_scale_init_value": 2.6592,
14
+ "model_type": "jina_clip",
15
+ "projection_dim": 768,
16
+ "text_config": {
17
+ "_name_or_path": "",
18
+ "add_cross_attention": false,
19
+ "architectures": null,
20
+ "bad_words_ids": null,
21
+ "begin_suppress_tokens": null,
22
+ "bos_token_id": null,
23
+ "chunk_size_feed_forward": 0,
24
+ "cross_attention_hidden_size": null,
25
+ "decoder_start_token_id": null,
26
+ "diversity_penalty": 0.0,
27
+ "do_sample": false,
28
+ "early_stopping": false,
29
+ "embed_dim": 768,
30
+ "encoder_no_repeat_ngram_size": 0,
31
+ "eos_token_id": null,
32
+ "exponential_decay_length_penalty": null,
33
+ "finetuning_task": null,
34
+ "forced_bos_token_id": null,
35
+ "forced_eos_token_id": null,
36
+ "hf_model_config_kwargs": {
37
+ "use_flash_attn": false
38
+ },
39
+ "hf_model_name_or_path": "jinaai/jina-bert-flash-implementation",
40
+ "id2label": {
41
+ "0": "LABEL_0",
42
+ "1": "LABEL_1"
43
+ },
44
+ "is_decoder": false,
45
+ "is_encoder_decoder": false,
46
+ "label2id": {
47
+ "LABEL_0": 0,
48
+ "LABEL_1": 1
49
+ },
50
+ "length_penalty": 1.0,
51
+ "max_length": 20,
52
+ "min_length": 0,
53
+ "model_type": "jina_clip_text",
54
+ "no_repeat_ngram_size": 0,
55
+ "num_beam_groups": 1,
56
+ "num_beams": 1,
57
+ "num_return_sequences": 1,
58
+ "output_attentions": false,
59
+ "output_hidden_states": false,
60
+ "output_scores": false,
61
+ "pad_token_id": null,
62
+ "pooler_type": "mean_pooler",
63
+ "prefix": null,
64
+ "problem_type": null,
65
+ "proj_bias": false,
66
+ "proj_type": null,
67
+ "pruned_heads": {},
68
+ "remove_invalid_values": false,
69
+ "repetition_penalty": 1.0,
70
+ "return_dict": true,
71
+ "return_dict_in_generate": false,
72
+ "sep_token_id": null,
73
+ "suppress_tokens": null,
74
+ "task_specific_params": null,
75
+ "temperature": 1.0,
76
+ "tf_legacy_loss": false,
77
+ "tie_encoder_decoder": false,
78
+ "tie_word_embeddings": true,
79
+ "tokenizer_class": null,
80
+ "top_k": 50,
81
+ "top_p": 1.0,
82
+ "torch_dtype": null,
83
+ "torchscript": false,
84
+ "transformers_version": "4.41.2",
85
+ "typical_p": 1.0,
86
+ "use_bfloat16": false
87
+ },
88
+ "torch_dtype": "float32",
89
+ "transformers_version": null,
90
+ "use_text_flash_attn": null,
91
+ "use_vision_xformers": null,
92
+ "vision_config": {
93
+ "_name_or_path": "",
94
+ "add_cross_attention": false,
95
+ "architectures": null,
96
+ "bad_words_ids": null,
97
+ "begin_suppress_tokens": null,
98
+ "bos_token_id": null,
99
+ "chunk_size_feed_forward": 0,
100
+ "cross_attention_hidden_size": null,
101
+ "decoder_start_token_id": null,
102
+ "diversity_penalty": 0.0,
103
+ "do_sample": false,
104
+ "drop_path_rate": 0.0,
105
+ "early_stopping": false,
106
+ "embed_dim": 768,
107
+ "encoder_no_repeat_ngram_size": 0,
108
+ "eos_token_id": null,
109
+ "exponential_decay_length_penalty": null,
110
+ "finetuning_task": null,
111
+ "forced_bos_token_id": null,
112
+ "forced_eos_token_id": null,
113
+ "fused_layer_norm": false,
114
+ "head_width": 64,
115
+ "id2label": {
116
+ "0": "LABEL_0",
117
+ "1": "LABEL_1"
118
+ },
119
+ "image_size": 224,
120
+ "intp_freq": false,
121
+ "is_decoder": false,
122
+ "is_encoder_decoder": false,
123
+ "label2id": {
124
+ "LABEL_0": 0,
125
+ "LABEL_1": 1
126
+ },
127
+ "layers": 12,
128
+ "length_penalty": 1.0,
129
+ "ls_init_value": null,
130
+ "max_length": 20,
131
+ "min_length": 0,
132
+ "mlp_ratio": 2.6667,
133
+ "model_type": "jina_clip_vision",
134
+ "naive_swiglu": true,
135
+ "no_repeat_ngram_size": 0,
136
+ "num_beam_groups": 1,
137
+ "num_beams": 1,
138
+ "num_return_sequences": 1,
139
+ "output_attentions": false,
140
+ "output_hidden_states": false,
141
+ "output_scores": false,
142
+ "pad_token_id": null,
143
+ "patch_dropout": 0.1,
144
+ "patch_size": 16,
145
+ "post_norm": false,
146
+ "prefix": null,
147
+ "problem_type": null,
148
+ "proj_type": null,
149
+ "pruned_heads": {},
150
+ "pt_hw_seq_len": 14,
151
+ "qkv_bias": true,
152
+ "remove_invalid_values": false,
153
+ "repetition_penalty": 1.0,
154
+ "return_dict": true,
155
+ "return_dict_in_generate": false,
156
+ "rope_embeddings": true,
157
+ "sep_token_id": null,
158
+ "subln": true,
159
+ "suppress_tokens": null,
160
+ "task_specific_params": null,
161
+ "temperature": 1.0,
162
+ "tf_legacy_loss": false,
163
+ "tie_encoder_decoder": false,
164
+ "tie_word_embeddings": true,
165
+ "tokenizer_class": null,
166
+ "top_k": 50,
167
+ "top_p": 1.0,
168
+ "torch_dtype": null,
169
+ "torchscript": false,
170
+ "transformers_version": "4.41.2",
171
+ "typical_p": 1.0,
172
+ "use_bfloat16": false,
173
+ "width": 768,
174
+ "x_attention": false
175
+ }
176
+ }
config_sentence_transformers.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "__version__": {
3
+ "sentence_transformers": "3.1.0.dev0",
4
+ "transformers": "4.41.2",
5
+ "pytorch": "2.3.1+cu121"
6
+ },
7
+ "prompts": {},
8
+ "default_prompt_name": null,
9
+ "similarity_fn_name": "cosine"
10
+ }
configuration_clip.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ #
3
+ # Code mainly copied from:
4
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/configuration_clip.py
5
+ # and adjusted for Jina CLIP
6
+
7
+ import os
8
+ from copy import deepcopy
9
+ from typing import Any, Dict, Optional, Union
10
+
11
+ from transformers import PretrainedConfig, logging
12
+
13
+ logger = logging.get_logger(__name__)
14
+
15
+
16
+ """ Jina CLIP model configuration """
17
+
18
+
19
+ class JinaCLIPTextConfig(PretrainedConfig):
20
+ model_type = 'jina_clip_text'
21
+
22
+ def __init__(
23
+ self,
24
+ embed_dim: int = 768,
25
+ hf_model_name_or_path: str = 'jinaai/jina-bert-flash-implementation',
26
+ hf_model_config_kwargs: Optional[Dict[str, Any]] = None,
27
+ pooler_type: Optional[str] = None,
28
+ proj_type: Optional[str] = None,
29
+ proj_bias: bool = False,
30
+ **kwargs,
31
+ ):
32
+ super().__init__(**kwargs)
33
+
34
+ self.embed_dim = embed_dim
35
+ self.hf_model_name_or_path = hf_model_name_or_path
36
+ self.hf_model_config_kwargs = hf_model_config_kwargs or {}
37
+ self.pooler_type = pooler_type
38
+ self.proj_type = proj_type
39
+ self.proj_bias = proj_bias
40
+
41
+ @classmethod
42
+ def from_pretrained(
43
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
44
+ ) -> 'PretrainedConfig':
45
+ cls._set_token_in_kwargs(kwargs)
46
+
47
+ configdict, kwargs = cls.get_config_dict(
48
+ pretrained_model_name_or_path, **kwargs
49
+ )
50
+
51
+ # get the text config dict if we are loading from JinaCLIPConfig
52
+ if configdict.get('model_type') == 'jina_clip':
53
+ configdict = configdict['text_config']
54
+
55
+ if (
56
+ 'model_type' in configdict
57
+ and hasattr(cls, 'model_type')
58
+ and configdict['model_type'] != cls.model_type
59
+ ):
60
+ logger.warning(
61
+ f'You are using a model of type {configdict["model_type"]} to '
62
+ f'instantiate a model of type {cls.model_type}. This is not supported '
63
+ 'for all configurations of models and can yield errors.'
64
+ )
65
+
66
+ return cls.from_dict(configdict, **kwargs)
67
+
68
+
69
+ class JinaCLIPVisionConfig(PretrainedConfig):
70
+ model_type = 'jina_clip_vision'
71
+
72
+ def __init__(
73
+ self,
74
+ embed_dim: int = 768,
75
+ width: int = 768,
76
+ image_size: int = 224,
77
+ patch_size: int = 16,
78
+ layers: int = 12,
79
+ head_width: int = 64,
80
+ mlp_ratio: float = 4.0,
81
+ ls_init_value: Optional[float] = None,
82
+ patch_dropout: float = 0.0,
83
+ qkv_bias: bool = True,
84
+ fused_layer_norm: bool = False,
85
+ x_attention: bool = False,
86
+ post_norm: bool = False,
87
+ rope_embeddings: bool = False,
88
+ pt_hw_seq_len: int = 16,
89
+ intp_freq: bool = False,
90
+ naive_swiglu: bool = False,
91
+ subln: bool = False,
92
+ drop_path_rate: float = 0.0,
93
+ proj_type: Optional[str] = None,
94
+ **kwargs,
95
+ ):
96
+ super().__init__(**kwargs)
97
+
98
+ self.layers = layers
99
+ self.embed_dim = embed_dim
100
+ self.width = width
101
+ self.head_width = head_width
102
+ self.mlp_ratio = mlp_ratio
103
+ self.image_size = image_size
104
+ self.patch_size = patch_size
105
+ self.ls_init_value = ls_init_value
106
+ self.patch_dropout = patch_dropout
107
+ self.qkv_bias = qkv_bias
108
+ self.fused_layer_norm = fused_layer_norm
109
+ self.x_attention = x_attention
110
+ self.post_norm = post_norm
111
+ self.rope_embeddings = rope_embeddings
112
+ self.pt_hw_seq_len = pt_hw_seq_len
113
+ self.intp_freq = intp_freq
114
+ self.naive_swiglu = naive_swiglu
115
+ self.subln = subln
116
+ self.drop_path_rate = drop_path_rate
117
+ self.proj_type = proj_type
118
+
119
+ @classmethod
120
+ def from_pretrained(
121
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
122
+ ) -> 'PretrainedConfig':
123
+ cls._set_token_in_kwargs(kwargs)
124
+
125
+ configdict, kwargs = cls.get_config_dict(
126
+ pretrained_model_name_or_path, **kwargs
127
+ )
128
+
129
+ # get the vision config dict if we are loading from JinaCLIPConfig
130
+ if configdict.get('model_type') == 'jina_clip':
131
+ configdict = configdict['vision_config']
132
+
133
+ if (
134
+ 'model_type' in configdict
135
+ and hasattr(cls, 'model_type')
136
+ and configdict['model_type'] != cls.model_type
137
+ ):
138
+ logger.warning(
139
+ f'You are using a model of type {configdict["model_type"]} to '
140
+ f'instantiate a model of type {cls.model_type}. This is not supported '
141
+ 'for all configurations of models and can yield errors.'
142
+ )
143
+
144
+ return cls.from_dict(configdict, **kwargs)
145
+
146
+
147
+ class JinaCLIPConfig(PretrainedConfig):
148
+ model_type = 'jina_clip'
149
+ is_composition = True
150
+
151
+ def __init__(
152
+ self,
153
+ text_config: Optional[Dict] = None,
154
+ vision_config: Optional[Dict] = None,
155
+ add_projections: bool = False,
156
+ projection_dim: int = 768,
157
+ logit_scale_init_value: float = 2.6592,
158
+ use_text_flash_attn: Optional[bool] = None,
159
+ use_vision_xformers: Optional[bool] = None,
160
+ **kwargs,
161
+ ):
162
+ # If `_config_dict` exist, we use them for the backward compatibility.
163
+ # We pop out these 2 attributes before calling `super().__init__` to avoid
164
+ # them being saved (which causes a lot of confusion!).
165
+
166
+ text_config_dict: Optional[Dict] = kwargs.pop('text_config_dict', None)
167
+ vision_config_dict: Optional[Dict] = kwargs.pop('vision_config_dict', None)
168
+ self.use_text_flash_attn = use_text_flash_attn
169
+ self.use_vision_xformers = use_vision_xformers
170
+
171
+ super().__init__(**kwargs)
172
+
173
+ if text_config_dict is not None:
174
+ if text_config is None:
175
+ text_config = {}
176
+
177
+ # This is the complete result when using `text_config_dict`.
178
+ _text_config_dict = JinaCLIPTextConfig(**text_config_dict).to_dict()
179
+
180
+ # Give a warning if the values exist in both `_text_config_dict` and
181
+ # `text_config` but being different.
182
+ for key, value in _text_config_dict.items():
183
+ if (
184
+ key in text_config
185
+ and value != text_config[key]
186
+ and key not in ['transformers_version']
187
+ ):
188
+ # If specified in `text_config_dict`
189
+ if key in text_config_dict:
190
+ message = (
191
+ f'`{key}` is found in both `text_config_dict` and '
192
+ f'`text_config` but with different values. '
193
+ f'The value `text_config_dict["{key}"]` will be used '
194
+ f'instead.'
195
+ )
196
+ # If inferred from default argument values (
197
+ # just to be super careful)
198
+ else:
199
+ message = (
200
+ f'`text_config_dict` is provided which will be used to '
201
+ f'initialize `JinaCLIPTextConfig`. The '
202
+ f'value `text_config["{key}"]` will be overriden.'
203
+ )
204
+ logger.info(message)
205
+
206
+ # Update all values in `text_config` with the ones in `_text_config_dict`.
207
+ text_config.update(_text_config_dict)
208
+
209
+ if vision_config_dict is not None:
210
+ if vision_config is None:
211
+ vision_config = {}
212
+
213
+ # This is the complete result when using `vision_config_dict`.
214
+ _vision_config_dict = JinaCLIPVisionConfig(**vision_config_dict).to_dict()
215
+ # convert keys to string instead of integer
216
+ if 'id2label' in _vision_config_dict:
217
+ _vision_config_dict['id2label'] = {
218
+ str(key): value
219
+ for key, value in _vision_config_dict['id2label'].items()
220
+ }
221
+
222
+ # Give a warning if the values exist in both `_vision_config_dict`
223
+ # and `vision_config` but being different.
224
+ for key, value in _vision_config_dict.items():
225
+ if (
226
+ key in vision_config
227
+ and value != vision_config[key]
228
+ and key not in ['transformers_version']
229
+ ):
230
+ # If specified in `vision_config_dict`
231
+ if key in vision_config_dict:
232
+ message = (
233
+ f'`{key}` is found in both `vision_config_dict` and '
234
+ f'`vision_config` but with different '
235
+ f'values. The value `vision_config_dict["{key}"]` will '
236
+ f'be used instead.'
237
+ )
238
+ # If inferred from default argument values
239
+ # (just to be super careful)
240
+ else:
241
+ message = (
242
+ f'`vision_config_dict` is provided which will be used to '
243
+ f'initialize `JinaCLIPVisionConfig`. '
244
+ f'The value `vision_config["{key}"]` will be overriden.'
245
+ )
246
+ logger.info(message)
247
+
248
+ # Update all values in `vision_config` with the ones in
249
+ # `_vision_config_dict`.
250
+ vision_config.update(_vision_config_dict)
251
+
252
+ if text_config is None:
253
+ text_config = {}
254
+ logger.info(
255
+ '`text_config` is `None`. Initializing the `JinaCLIPTextConfig` with '
256
+ 'default values.'
257
+ )
258
+
259
+ if vision_config is None:
260
+ vision_config = {}
261
+ logger.info(
262
+ '`vision_config` is `None`. initializing the `JinaCLIPVisionConfig` '
263
+ 'with default values.'
264
+ )
265
+
266
+ self.text_config = JinaCLIPTextConfig(**text_config)
267
+ self.vision_config = JinaCLIPVisionConfig(**vision_config)
268
+
269
+ self.add_projections = add_projections
270
+ self.projection_dim = projection_dim
271
+ self.logit_scale_init_value = logit_scale_init_value
272
+ self.initializer_factor = 1.0
273
+
274
+ if not self.add_projections:
275
+ if self.text_config.embed_dim != self.vision_config.embed_dim:
276
+ raise ValueError(
277
+ 'When projections are disabled (`add_projections=False`), text '
278
+ 'and vision towers need to have the same embedding dimensionality. '
279
+ f'Currently text embedding dim is {self.text_config.embed_dim} != '
280
+ f'{self.vision_config.embed_dim} of the vision tower. '
281
+ 'Either set the same output dim for both towers, or enable '
282
+ 'projections with `add_projections=True`.'
283
+ )
284
+
285
+ @classmethod
286
+ def from_text_vision_configs(
287
+ cls,
288
+ text_config: JinaCLIPTextConfig,
289
+ vision_config: JinaCLIPVisionConfig,
290
+ **kwargs,
291
+ ):
292
+ return cls(
293
+ text_config=text_config.to_dict(),
294
+ vision_config=vision_config.to_dict(),
295
+ projection_dim=text_config.projection_dim,
296
+ **kwargs,
297
+ )
298
+
299
+ def to_dict(self):
300
+ output = deepcopy(self.__dict__)
301
+ output['text_config'] = self.text_config.to_dict()
302
+ output['vision_config'] = self.vision_config.to_dict()
303
+ output['model_type'] = self.__class__.model_type
304
+ return output
custom_st.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ from io import BytesIO
3
+ import json
4
+ import os
5
+ from typing import Any, Dict, List, Optional, Tuple, Union
6
+
7
+ from .custom_st_2 import OtherClass
8
+ import requests
9
+ import torch
10
+ from torch import nn
11
+ from transformers import AutoConfig, AutoModel, AutoTokenizer, AutoImageProcessor
12
+ from PIL import Image
13
+
14
+ OtherClass()
15
+
16
+ class Transformer(nn.Module):
17
+ """Huggingface AutoModel to generate token embeddings.
18
+ Loads the correct class, e.g. BERT / RoBERTa etc.
19
+
20
+ Args:
21
+ model_name_or_path: Huggingface models name
22
+ (https://huggingface.co/models)
23
+ max_seq_length: Truncate any inputs longer than max_seq_length
24
+ model_args: Keyword arguments passed to the Huggingface
25
+ Transformers model
26
+ tokenizer_args: Keyword arguments passed to the Huggingface
27
+ Transformers tokenizer
28
+ config_args: Keyword arguments passed to the Huggingface
29
+ Transformers config
30
+ cache_dir: Cache dir for Huggingface Transformers to store/load
31
+ models
32
+ do_lower_case: If true, lowercases the input (independent if the
33
+ model is cased or not)
34
+ tokenizer_name_or_path: Name or path of the tokenizer. When
35
+ None, then model_name_or_path is used
36
+ """
37
+
38
+ def __init__(
39
+ self,
40
+ model_name_or_path: str,
41
+ max_seq_length: Optional[int] = None,
42
+ model_args: Optional[Dict[str, Any]] = None,
43
+ tokenizer_args: Optional[Dict[str, Any]] = None,
44
+ config_args: Optional[Dict[str, Any]] = None,
45
+ cache_dir: Optional[str] = None,
46
+ do_lower_case: bool = False,
47
+ tokenizer_name_or_path: str = None,
48
+ ) -> None:
49
+ super(Transformer, self).__init__()
50
+ self.config_keys = ["max_seq_length", "do_lower_case"]
51
+ self.do_lower_case = do_lower_case
52
+ if model_args is None:
53
+ model_args = {}
54
+ if tokenizer_args is None:
55
+ tokenizer_args = {}
56
+ if config_args is None:
57
+ config_args = {}
58
+
59
+ config = AutoConfig.from_pretrained(model_name_or_path, **config_args, cache_dir=cache_dir)
60
+ self.jina_clip = AutoModel.from_pretrained(
61
+ model_name_or_path, config=config, cache_dir=cache_dir, **model_args
62
+ )
63
+
64
+ if max_seq_length is not None and "model_max_length" not in tokenizer_args:
65
+ tokenizer_args["model_max_length"] = max_seq_length
66
+ self.tokenizer = AutoTokenizer.from_pretrained(
67
+ tokenizer_name_or_path if tokenizer_name_or_path is not None else model_name_or_path,
68
+ cache_dir=cache_dir,
69
+ **tokenizer_args,
70
+ )
71
+ self.preprocessor = AutoImageProcessor.from_pretrained(
72
+ tokenizer_name_or_path if tokenizer_name_or_path is not None else model_name_or_path,
73
+ cache_dir=cache_dir,
74
+ **tokenizer_args,
75
+ )
76
+
77
+ # No max_seq_length set. Try to infer from model
78
+ if max_seq_length is None:
79
+ if (
80
+ hasattr(self.jina_clip, "config")
81
+ and hasattr(self.jina_clip.config, "max_position_embeddings")
82
+ and hasattr(self.tokenizer, "model_max_length")
83
+ ):
84
+ max_seq_length = min(self.jina_clip.config.max_position_embeddings, self.tokenizer.model_max_length)
85
+
86
+ self.max_seq_length = max_seq_length
87
+
88
+ if tokenizer_name_or_path is not None:
89
+ self.jina_clip.config.tokenizer_class = self.tokenizer.__class__.__name__
90
+
91
+ def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
92
+ """Returns token_embeddings, cls_token"""
93
+ if "input_ids" in features:
94
+ embedding = self.jina_clip.get_text_features(input_ids=features["input_ids"])
95
+ else:
96
+ embedding = self.jina_clip.get_image_features(pixel_values=features["pixel_values"])
97
+ return {"sentence_embedding": embedding}
98
+
99
+ def get_word_embedding_dimension(self) -> int:
100
+ return self.config.text_config.embed_dim
101
+
102
+ def decode_data_image(data_image_str):
103
+ header, data = data_image_str.split(',', 1)
104
+ image_data = base64.b64decode(data)
105
+ return Image.open(BytesIO(image_data))
106
+
107
+ def tokenize(
108
+ self, batch: Union[List[str]], padding: Union[str, bool] = True
109
+ ) -> Dict[str, torch.Tensor]:
110
+ """Tokenizes a text and maps tokens to token-ids"""
111
+ images = []
112
+ texts = []
113
+ for sample in batch:
114
+ if isinstance(sample, str):
115
+ if sample.startswith('http'):
116
+ response = requests.get(sample)
117
+ images.append(Image.open(BytesIO(response.content)).convert('RGB'))
118
+ elif sample.startswith('data:image/'):
119
+ images.append(self.decode_data_image(sample).convert('RGB'))
120
+ else:
121
+ # TODO: Make sure that Image.open fails for non-image files
122
+ try:
123
+ images.append(Image.open(sample).convert('RGB'))
124
+ except:
125
+ texts.append(sample)
126
+ elif isinstance(sample, Image.Image):
127
+ images.append(sample.convert('RGB'))
128
+
129
+ if images and texts:
130
+ raise ValueError('Batch must contain either images or texts, not both')
131
+
132
+ if texts:
133
+ return self.tokenizer(
134
+ texts,
135
+ padding=padding,
136
+ truncation="longest_first",
137
+ return_tensors="pt",
138
+ max_length=self.max_seq_length,
139
+ )
140
+ elif images:
141
+ return self.preprocessor(images)
142
+ return {}
143
+
144
+ def save(self, output_path: str, safe_serialization: bool = True) -> None:
145
+ self.jina_clip.save_pretrained(output_path, safe_serialization=safe_serialization)
146
+ self.tokenizer.save_pretrained(output_path)
147
+ self.preprocessor.save_pretrained(output_path)
148
+
149
+ @staticmethod
150
+ def load(input_path: str) -> "Transformer":
151
+ # Old classes used other config names than 'sentence_bert_config.json'
152
+ for config_name in [
153
+ "sentence_bert_config.json",
154
+ "sentence_roberta_config.json",
155
+ "sentence_distilbert_config.json",
156
+ "sentence_camembert_config.json",
157
+ "sentence_albert_config.json",
158
+ "sentence_xlm-roberta_config.json",
159
+ "sentence_xlnet_config.json",
160
+ ]:
161
+ sbert_config_path = os.path.join(input_path, config_name)
162
+ if os.path.exists(sbert_config_path):
163
+ break
164
+
165
+ with open(sbert_config_path) as fIn:
166
+ config = json.load(fIn)
167
+ # Don't allow configs to set trust_remote_code
168
+ if "model_args" in config and "trust_remote_code" in config["model_args"]:
169
+ config["model_args"].pop("trust_remote_code")
170
+ if "tokenizer_args" in config and "trust_remote_code" in config["tokenizer_args"]:
171
+ config["tokenizer_args"].pop("trust_remote_code")
172
+ if "config_args" in config and "trust_remote_code" in config["config_args"]:
173
+ config["config_args"].pop("trust_remote_code")
174
+ return Transformer(model_name_or_path=input_path, **config)
custom_st_2.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+
2
+ class OtherClass:
3
+ pass
eva_model.py ADDED
@@ -0,0 +1,764 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Adapted from EVA CLIP
3
+ # https://github.com/baaivision/EVA/tree/master/EVA-CLIP/rei/eva_clip
4
+ # --------------------------------------------------------
5
+
6
+ import math
7
+ import os
8
+ from functools import partial
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ try:
15
+ from timm.models.layers import drop_path, to_2tuple, trunc_normal_
16
+ except ImportError or ModuleNotFoundError:
17
+ from timm.layers import drop_path, to_2tuple, trunc_normal_
18
+
19
+ from .rope_embeddings import VisionRotaryEmbeddingFast
20
+
21
+ if os.getenv('ENV_TYPE') == 'deepspeed':
22
+ try:
23
+ from deepspeed.runtime.activation_checkpointing.checkpointing import checkpoint
24
+ except ImportError or ModuleNotFoundError:
25
+ from torch.utils.checkpoint import checkpoint
26
+ else:
27
+ from torch.utils.checkpoint import checkpoint
28
+
29
+ try:
30
+ import xformers.ops as xops
31
+ except ImportError:
32
+ xops = None
33
+
34
+
35
+ class PatchDropout(nn.Module):
36
+ """
37
+ https://arxiv.org/abs/2212.00794
38
+ """
39
+
40
+ def __init__(self, prob, exclude_first_token=True):
41
+ super().__init__()
42
+ assert 0 <= prob < 1.0
43
+ self.prob = prob
44
+ self.exclude_first_token = exclude_first_token # exclude CLS token
45
+
46
+ def forward(self, x):
47
+ if not self.training or self.prob == 0.0:
48
+ return x
49
+
50
+ if self.exclude_first_token:
51
+ cls_tokens, x = x[:, :1], x[:, 1:]
52
+ else:
53
+ cls_tokens = torch.jit.annotate(torch.Tensor, x[:, :1])
54
+
55
+ batch = x.size()[0]
56
+ num_tokens = x.size()[1]
57
+
58
+ batch_indices = torch.arange(batch)
59
+ batch_indices = batch_indices[..., None]
60
+
61
+ keep_prob = 1 - self.prob
62
+ num_patches_keep = max(1, int(num_tokens * keep_prob))
63
+
64
+ rand = torch.randn(batch, num_tokens)
65
+ patch_indices_keep = rand.topk(num_patches_keep, dim=-1).indices
66
+
67
+ x = x[batch_indices, patch_indices_keep]
68
+
69
+ if self.exclude_first_token:
70
+ x = torch.cat((cls_tokens, x), dim=1)
71
+
72
+ return x, patch_indices_keep
73
+
74
+
75
+ class DropPath(nn.Module):
76
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of
77
+ residual blocks)."""
78
+
79
+ def __init__(self, drop_prob=None):
80
+ super(DropPath, self).__init__()
81
+ self.drop_prob = drop_prob
82
+
83
+ def forward(self, x):
84
+ return drop_path(x, self.drop_prob, self.training)
85
+
86
+ def extra_repr(self) -> str:
87
+ return 'p={}'.format(self.drop_prob)
88
+
89
+
90
+ class Mlp(nn.Module):
91
+ def __init__(
92
+ self,
93
+ in_features,
94
+ hidden_features=None,
95
+ out_features=None,
96
+ act_layer=nn.GELU,
97
+ norm_layer=nn.LayerNorm,
98
+ drop=0.0,
99
+ subln=False,
100
+ ):
101
+ super().__init__()
102
+ out_features = out_features or in_features
103
+ hidden_features = hidden_features or in_features
104
+ self.fc1 = nn.Linear(in_features, hidden_features)
105
+ self.act = act_layer()
106
+
107
+ self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
108
+
109
+ self.fc2 = nn.Linear(hidden_features, out_features)
110
+ self.drop = nn.Dropout(drop)
111
+
112
+ def forward(self, x):
113
+ x = self.fc1(x)
114
+ x = self.act(x)
115
+ # x = self.drop(x)
116
+ # commit this for the orignal BERT implement
117
+ x = self.ffn_ln(x)
118
+
119
+ x = self.fc2(x)
120
+ x = self.drop(x)
121
+ return x
122
+
123
+
124
+ class SwiGLU(nn.Module):
125
+ def __init__(
126
+ self,
127
+ in_features,
128
+ hidden_features=None,
129
+ out_features=None,
130
+ act_layer=nn.SiLU,
131
+ drop=0.0,
132
+ norm_layer=nn.LayerNorm,
133
+ subln=False,
134
+ ):
135
+ super().__init__()
136
+ out_features = out_features or in_features
137
+ hidden_features = hidden_features or in_features
138
+
139
+ self.w1 = nn.Linear(in_features, hidden_features)
140
+ self.w2 = nn.Linear(in_features, hidden_features)
141
+
142
+ self.act = act_layer()
143
+ self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
144
+ self.w3 = nn.Linear(hidden_features, out_features)
145
+
146
+ self.drop = nn.Dropout(drop)
147
+
148
+ def forward(self, x):
149
+ x1 = self.w1(x)
150
+ x2 = self.w2(x)
151
+ hidden = self.act(x1) * x2
152
+ x = self.ffn_ln(hidden)
153
+ x = self.w3(x)
154
+ x = self.drop(x)
155
+ return x
156
+
157
+
158
+ class Attention(nn.Module):
159
+ def __init__(
160
+ self,
161
+ dim,
162
+ num_heads=8,
163
+ qkv_bias=False,
164
+ qk_scale=None,
165
+ attn_drop=0.0,
166
+ proj_drop=0.0,
167
+ window_size=None,
168
+ attn_head_dim=None,
169
+ xattn=False,
170
+ rope=None,
171
+ subln=False,
172
+ norm_layer=nn.LayerNorm,
173
+ ):
174
+ super().__init__()
175
+ self.num_heads = num_heads
176
+ head_dim = dim // num_heads
177
+ if attn_head_dim is not None:
178
+ head_dim = attn_head_dim
179
+ all_head_dim = head_dim * self.num_heads
180
+ self.scale = qk_scale or head_dim**-0.5
181
+
182
+ self.subln = subln
183
+ if self.subln:
184
+ self.q_proj = nn.Linear(dim, all_head_dim, bias=False)
185
+ self.k_proj = nn.Linear(dim, all_head_dim, bias=False)
186
+ self.v_proj = nn.Linear(dim, all_head_dim, bias=False)
187
+ else:
188
+ self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
189
+
190
+ if qkv_bias:
191
+ self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
192
+ self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
193
+ else:
194
+ self.q_bias = None
195
+ self.v_bias = None
196
+
197
+ if window_size:
198
+ self.window_size = window_size
199
+ self.num_relative_distance = (2 * window_size[0] - 1) * (
200
+ 2 * window_size[1] - 1
201
+ ) + 3
202
+ self.relative_position_bias_table = nn.Parameter(
203
+ torch.zeros(self.num_relative_distance, num_heads)
204
+ ) # 2*Wh-1 * 2*Ww-1, nH
205
+ # cls to token & token 2 cls & cls to cls
206
+
207
+ # get pair-wise relative position index for each token inside the window
208
+ coords_h = torch.arange(window_size[0])
209
+ coords_w = torch.arange(window_size[1])
210
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
211
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
212
+ relative_coords = (
213
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
214
+ ) # 2, Wh*Ww, Wh*Ww
215
+ relative_coords = relative_coords.permute(
216
+ 1, 2, 0
217
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
218
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
219
+ relative_coords[:, :, 1] += window_size[1] - 1
220
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
221
+ relative_position_index = torch.zeros(
222
+ size=(window_size[0] * window_size[1] + 1,) * 2,
223
+ dtype=relative_coords.dtype,
224
+ )
225
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
226
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
227
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
228
+ relative_position_index[0, 0] = self.num_relative_distance - 1
229
+
230
+ self.register_buffer('relative_position_index', relative_position_index)
231
+ else:
232
+ self.window_size = None
233
+ self.relative_position_bias_table = None
234
+ self.relative_position_index = None
235
+
236
+ self.attn_drop = nn.Dropout(attn_drop)
237
+ self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity()
238
+ # self.proj = nn.Linear(all_head_dim, all_head_dim)
239
+ self.proj = nn.Linear(all_head_dim, dim)
240
+ self.proj_drop = nn.Dropout(proj_drop)
241
+ self.xattn = xattn
242
+ self.xattn_drop = attn_drop
243
+
244
+ self.rope = rope
245
+
246
+ def forward(self, x, rel_pos_bias=None, attn_mask=None):
247
+ B, N, C = x.shape
248
+ if self.subln:
249
+ q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
250
+ k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
251
+ v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
252
+
253
+ q = q.reshape(B, N, self.num_heads, -1).permute(
254
+ 0, 2, 1, 3
255
+ ) # B, num_heads, N, C
256
+ k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
257
+ v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
258
+ else:
259
+ qkv_bias = None
260
+ if self.q_bias is not None:
261
+ qkv_bias = torch.cat(
262
+ (
263
+ self.q_bias,
264
+ torch.zeros_like(self.v_bias, requires_grad=False),
265
+ self.v_bias,
266
+ )
267
+ )
268
+
269
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
270
+ qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(
271
+ 2, 0, 3, 1, 4
272
+ ) # 3, B, num_heads, N, C
273
+ q, k, v = qkv[0], qkv[1], qkv[2]
274
+
275
+ if self.rope:
276
+ # slightly fast impl
277
+ q_t = q[:, :, 1:, :]
278
+ ro_q_t = self.rope(q_t)
279
+ q = torch.cat((q[:, :, :1, :], ro_q_t), -2).type_as(v)
280
+
281
+ k_t = k[:, :, 1:, :]
282
+ ro_k_t = self.rope(k_t)
283
+ k = torch.cat((k[:, :, :1, :], ro_k_t), -2).type_as(v)
284
+
285
+ if self.xattn:
286
+ if xops is None:
287
+ raise ValueError(
288
+ "Can't use xattn without xformers. Please 'pip install xformers'"
289
+ )
290
+ q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C
291
+ k = k.permute(0, 2, 1, 3)
292
+ v = v.permute(0, 2, 1, 3)
293
+
294
+ x = xops.memory_efficient_attention(
295
+ q,
296
+ k,
297
+ v,
298
+ p=self.xattn_drop,
299
+ scale=self.scale,
300
+ )
301
+ x = x.reshape(B, N, -1)
302
+ x = self.inner_attn_ln(x)
303
+ x = self.proj(x)
304
+ x = self.proj_drop(x)
305
+ else:
306
+ q = q * self.scale
307
+ attn = q @ k.transpose(-2, -1)
308
+
309
+ if self.relative_position_bias_table is not None:
310
+ relative_position_bias = self.relative_position_bias_table[
311
+ self.relative_position_index.view(-1)
312
+ ].view(
313
+ self.window_size[0] * self.window_size[1] + 1,
314
+ self.window_size[0] * self.window_size[1] + 1,
315
+ -1,
316
+ ) # Wh*Ww,Wh*Ww,nH
317
+ relative_position_bias = relative_position_bias.permute(
318
+ 2, 0, 1
319
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
320
+ attn = attn + relative_position_bias.unsqueeze(0).type_as(attn)
321
+
322
+ if rel_pos_bias is not None:
323
+ attn = attn + rel_pos_bias.type_as(attn)
324
+
325
+ if attn_mask is not None:
326
+ attn_mask = attn_mask.bool()
327
+ attn = attn.masked_fill(~attn_mask[:, None, None, :], float('-inf'))
328
+
329
+ attn = attn.softmax(dim=-1)
330
+ attn = self.attn_drop(attn)
331
+
332
+ x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
333
+ x = self.inner_attn_ln(x)
334
+ x = self.proj(x)
335
+ x = self.proj_drop(x)
336
+ return x
337
+
338
+
339
+ class Block(nn.Module):
340
+ def __init__(
341
+ self,
342
+ dim,
343
+ num_heads,
344
+ mlp_ratio=4.0,
345
+ qkv_bias=False,
346
+ qk_scale=None,
347
+ drop=0.0,
348
+ attn_drop=0.0,
349
+ drop_path=0.0,
350
+ init_values=None,
351
+ act_layer=nn.GELU,
352
+ norm_layer=nn.LayerNorm,
353
+ window_size=None,
354
+ attn_head_dim=None,
355
+ xattn=False,
356
+ rope=None,
357
+ postnorm=False,
358
+ subln=False,
359
+ naiveswiglu=False,
360
+ ):
361
+ super().__init__()
362
+ self.norm1 = norm_layer(dim)
363
+ self.attn = Attention(
364
+ dim,
365
+ num_heads=num_heads,
366
+ qkv_bias=qkv_bias,
367
+ qk_scale=qk_scale,
368
+ attn_drop=attn_drop,
369
+ proj_drop=drop,
370
+ window_size=window_size,
371
+ attn_head_dim=attn_head_dim,
372
+ xattn=xattn,
373
+ rope=rope,
374
+ subln=subln,
375
+ norm_layer=norm_layer,
376
+ )
377
+ # NOTE: drop path for stochastic depth, we shall see if this is better
378
+ # than dropout here
379
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
380
+ self.norm2 = norm_layer(dim)
381
+ mlp_hidden_dim = int(dim * mlp_ratio)
382
+
383
+ if naiveswiglu:
384
+ self.mlp = SwiGLU(
385
+ in_features=dim,
386
+ hidden_features=mlp_hidden_dim,
387
+ subln=subln,
388
+ norm_layer=norm_layer,
389
+ )
390
+ else:
391
+ self.mlp = Mlp(
392
+ in_features=dim,
393
+ hidden_features=mlp_hidden_dim,
394
+ act_layer=act_layer,
395
+ subln=subln,
396
+ drop=drop,
397
+ )
398
+
399
+ if init_values is not None and init_values > 0:
400
+ self.gamma_1 = nn.Parameter(
401
+ init_values * torch.ones((dim,)), requires_grad=True
402
+ )
403
+ self.gamma_2 = nn.Parameter(
404
+ init_values * torch.ones((dim,)), requires_grad=True
405
+ )
406
+ else:
407
+ self.gamma_1, self.gamma_2 = None, None
408
+
409
+ self.postnorm = postnorm
410
+
411
+ def forward(self, x, rel_pos_bias=None, attn_mask=None):
412
+ if self.gamma_1 is None:
413
+ if self.postnorm:
414
+ x = x + self.drop_path(
415
+ self.norm1(
416
+ self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)
417
+ )
418
+ )
419
+ x = x + self.drop_path(self.norm2(self.mlp(x)))
420
+ else:
421
+ x = x + self.drop_path(
422
+ self.attn(
423
+ self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask
424
+ )
425
+ )
426
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
427
+ else:
428
+ if self.postnorm:
429
+ x = x + self.drop_path(
430
+ self.gamma_1
431
+ * self.norm1(
432
+ self.attn(x, rel_pos_bias=rel_pos_bias, attn_mask=attn_mask)
433
+ )
434
+ )
435
+ x = x + self.drop_path(self.gamma_2 * self.norm2(self.mlp(x)))
436
+ else:
437
+ x = x + self.drop_path(
438
+ self.gamma_1
439
+ * self.attn(
440
+ self.norm1(x), rel_pos_bias=rel_pos_bias, attn_mask=attn_mask
441
+ )
442
+ )
443
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
444
+ return x
445
+
446
+
447
+ class PatchEmbed(nn.Module):
448
+ """Image to Patch Embedding"""
449
+
450
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
451
+ super().__init__()
452
+ img_size = to_2tuple(img_size)
453
+ patch_size = to_2tuple(patch_size)
454
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
455
+ self.patch_shape = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
456
+ self.img_size = img_size
457
+ self.patch_size = patch_size
458
+ self.num_patches = num_patches
459
+
460
+ self.proj = nn.Conv2d(
461
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
462
+ )
463
+
464
+ def forward(self, x, **kwargs):
465
+ target_dtype = self.proj.weight.dtype
466
+ B, C, H, W = x.shape
467
+ # FIXME look at relaxing size constraints
468
+ assert H == self.img_size[0] and W == self.img_size[1], (
469
+ f"Input image size ({H}*{W}) doesn't match model "
470
+ f'({self.img_size[0]}*{self.img_size[1]}).'
471
+ )
472
+ x = self.proj(x.to(dtype=target_dtype)).flatten(2).transpose(1, 2)
473
+ return x
474
+
475
+
476
+ class RelativePositionBias(nn.Module):
477
+ def __init__(self, window_size, num_heads):
478
+ super().__init__()
479
+ self.window_size = window_size
480
+ self.num_relative_distance = (2 * window_size[0] - 1) * (
481
+ 2 * window_size[1] - 1
482
+ ) + 3
483
+ self.relative_position_bias_table = nn.Parameter(
484
+ torch.zeros(self.num_relative_distance, num_heads)
485
+ ) # 2*Wh-1 * 2*Ww-1, nH
486
+ # cls to token & token 2 cls & cls to cls
487
+
488
+ # get pair-wise relative position index for each token inside the window
489
+ coords_h = torch.arange(window_size[0])
490
+ coords_w = torch.arange(window_size[1])
491
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
492
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
493
+ relative_coords = (
494
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
495
+ ) # 2, Wh*Ww, Wh*Ww
496
+ relative_coords = relative_coords.permute(
497
+ 1, 2, 0
498
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
499
+ relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
500
+ relative_coords[:, :, 1] += window_size[1] - 1
501
+ relative_coords[:, :, 0] *= 2 * window_size[1] - 1
502
+ relative_position_index = torch.zeros(
503
+ size=(window_size[0] * window_size[1] + 1,) * 2, dtype=relative_coords.dtype
504
+ )
505
+ relative_position_index[1:, 1:] = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
506
+ relative_position_index[0, 0:] = self.num_relative_distance - 3
507
+ relative_position_index[0:, 0] = self.num_relative_distance - 2
508
+ relative_position_index[0, 0] = self.num_relative_distance - 1
509
+
510
+ self.register_buffer('relative_position_index', relative_position_index)
511
+
512
+ def forward(self):
513
+ relative_position_bias = self.relative_position_bias_table[
514
+ self.relative_position_index.view(-1)
515
+ ].view(
516
+ self.window_size[0] * self.window_size[1] + 1,
517
+ self.window_size[0] * self.window_size[1] + 1,
518
+ -1,
519
+ ) # Wh*Ww,Wh*Ww,nH
520
+ return relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
521
+
522
+
523
+ class EVAVisionTransformer(nn.Module):
524
+ """Vision Transformer with support for patch or hybrid CNN input stage"""
525
+
526
+ def __init__(
527
+ self,
528
+ img_size=224,
529
+ patch_size=16,
530
+ in_chans=3,
531
+ num_classes=0,
532
+ embed_dim=768,
533
+ depth=12,
534
+ num_heads=12,
535
+ mlp_ratio=4.0,
536
+ qkv_bias=False,
537
+ qk_scale=None,
538
+ drop_rate=0.0,
539
+ attn_drop_rate=0.0,
540
+ drop_path_rate=0.0,
541
+ norm_layer=nn.LayerNorm,
542
+ init_values=None,
543
+ patch_dropout=0.0,
544
+ use_abs_pos_emb=True,
545
+ use_rel_pos_bias=False,
546
+ use_shared_rel_pos_bias=False,
547
+ rope=False,
548
+ use_mean_pooling=True,
549
+ init_scale=0.001,
550
+ grad_checkpointing=False,
551
+ xattn=False,
552
+ postnorm=False,
553
+ pt_hw_seq_len=16,
554
+ intp_freq=False,
555
+ naiveswiglu=False,
556
+ subln=False,
557
+ proj_type=None,
558
+ ):
559
+ super().__init__()
560
+ self.image_size = img_size
561
+ self.num_classes = num_classes
562
+ self.num_features = (
563
+ self.embed_dim
564
+ ) = embed_dim # num_features for consistency with other models
565
+
566
+ self.patch_embed = PatchEmbed(
567
+ img_size=img_size,
568
+ patch_size=patch_size,
569
+ in_chans=in_chans,
570
+ embed_dim=embed_dim,
571
+ )
572
+ num_patches = self.patch_embed.num_patches
573
+
574
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
575
+ # self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
576
+ if use_abs_pos_emb:
577
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
578
+ else:
579
+ self.pos_embed = None
580
+ self.pos_drop = nn.Dropout(p=drop_rate)
581
+
582
+ if use_shared_rel_pos_bias:
583
+ self.rel_pos_bias = RelativePositionBias(
584
+ window_size=self.patch_embed.patch_shape, num_heads=num_heads
585
+ )
586
+ else:
587
+ self.rel_pos_bias = None
588
+
589
+ if rope:
590
+ half_head_dim = embed_dim // num_heads // 2
591
+ hw_seq_len = img_size // patch_size
592
+ self.rope = VisionRotaryEmbeddingFast(
593
+ dim=half_head_dim,
594
+ pt_seq_len=pt_hw_seq_len,
595
+ ft_seq_len=hw_seq_len if intp_freq else None,
596
+ patch_dropout=patch_dropout,
597
+ )
598
+ else:
599
+ self.rope = None
600
+
601
+ self.naiveswiglu = naiveswiglu
602
+
603
+ dpr = [
604
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
605
+ ] # stochastic depth decay rule
606
+ self.use_rel_pos_bias = use_rel_pos_bias
607
+ self.blocks = nn.ModuleList(
608
+ [
609
+ Block(
610
+ dim=embed_dim,
611
+ num_heads=num_heads,
612
+ mlp_ratio=mlp_ratio,
613
+ qkv_bias=qkv_bias,
614
+ qk_scale=qk_scale,
615
+ drop=drop_rate,
616
+ attn_drop=attn_drop_rate,
617
+ drop_path=dpr[i],
618
+ norm_layer=norm_layer,
619
+ init_values=init_values,
620
+ window_size=self.patch_embed.patch_shape
621
+ if use_rel_pos_bias
622
+ else None,
623
+ xattn=xattn,
624
+ rope=self.rope,
625
+ postnorm=postnorm,
626
+ subln=subln,
627
+ naiveswiglu=naiveswiglu,
628
+ )
629
+ for i in range(depth)
630
+ ]
631
+ )
632
+ self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
633
+ self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
634
+ if (num_classes == embed_dim) and (proj_type is None):
635
+ self.head = nn.Identity()
636
+ elif proj_type == 'linear':
637
+ self.head = nn.Linear(embed_dim, num_classes, bias=qkv_bias)
638
+ elif proj_type == 'mlp':
639
+ hidden_size = (embed_dim + num_classes) // 2
640
+ self.proj = nn.Sequential(
641
+ nn.Linear(embed_dim, hidden_size, bias=qkv_bias),
642
+ nn.GELU(),
643
+ nn.Linear(hidden_size, num_classes, bias=qkv_bias),
644
+ )
645
+
646
+ if self.pos_embed is not None:
647
+ trunc_normal_(self.pos_embed, std=0.02)
648
+
649
+ trunc_normal_(self.cls_token, std=0.02)
650
+
651
+ self.apply(self._init_weights)
652
+ self.fix_init_weight()
653
+
654
+ if isinstance(self.head, nn.Linear):
655
+ trunc_normal_(self.head.weight, std=0.02)
656
+ self.head.weight.data.mul_(init_scale)
657
+ if qkv_bias:
658
+ self.head.bias.data.mul_(init_scale)
659
+
660
+ # setting a patch_dropout of 0. would mean it is disabled and this function
661
+ # would be the identity fn
662
+ self.patch_dropout = (
663
+ PatchDropout(patch_dropout) if patch_dropout > 0.0 else nn.Identity()
664
+ )
665
+
666
+ self.grad_checkpointing = grad_checkpointing
667
+
668
+ def fix_init_weight(self):
669
+ def rescale(param, layer_id):
670
+ param.div_(math.sqrt(2.0 * layer_id))
671
+
672
+ for layer_id, layer in enumerate(self.blocks):
673
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
674
+ if self.naiveswiglu:
675
+ rescale(layer.mlp.w3.weight.data, layer_id + 1)
676
+ else:
677
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
678
+
679
+ def get_cast_dtype(self) -> torch.dtype:
680
+ return self.blocks[0].mlp.fc2.weight.dtype
681
+
682
+ def _init_weights(self, m):
683
+ if isinstance(m, nn.Linear):
684
+ trunc_normal_(m.weight, std=0.02)
685
+ if m.bias is not None:
686
+ nn.init.constant_(m.bias, 0)
687
+ elif isinstance(m, nn.LayerNorm):
688
+ nn.init.constant_(m.bias, 0)
689
+ nn.init.constant_(m.weight, 1.0)
690
+
691
+ def get_num_layers(self):
692
+ return len(self.blocks)
693
+
694
+ def lock(self, unlocked_groups=0, freeze_bn_stats=False):
695
+ assert (
696
+ unlocked_groups == 0
697
+ ), 'partial locking not currently supported for this model'
698
+ for param in self.parameters():
699
+ param.requires_grad = False
700
+
701
+ @torch.jit.ignore
702
+ def set_grad_checkpointing(self, enable=True):
703
+ self.grad_checkpointing = enable
704
+
705
+ @torch.jit.ignore
706
+ def no_weight_decay(self):
707
+ return {'pos_embed', 'cls_token'}
708
+
709
+ def get_classifier(self):
710
+ return self.head
711
+
712
+ def reset_classifier(self, num_classes, global_pool=''):
713
+ self.num_classes = num_classes
714
+ self.head = (
715
+ nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
716
+ )
717
+
718
+ def forward_features(self, x, return_all_features=False):
719
+ x = self.patch_embed(x)
720
+ batch_size, seq_len, _ = x.size()
721
+
722
+ cls_tokens = self.cls_token.expand(
723
+ batch_size, -1, -1
724
+ ) # stole cls_tokens impl from Phil Wang, thanks
725
+ x = torch.cat((cls_tokens, x), dim=1)
726
+ if self.pos_embed is not None:
727
+ x = x + self.pos_embed
728
+ x = self.pos_drop(x)
729
+
730
+ # a patch_dropout of 0. would mean it is disabled and this function would do
731
+ # nothing but return what was passed in
732
+ if self.rope is not None:
733
+ if self.training and not isinstance(self.patch_dropout, nn.Identity):
734
+ x, patch_indices_keep = self.patch_dropout(x)
735
+ self.rope.forward = partial(
736
+ self.rope.forward, patch_indices_keep=patch_indices_keep
737
+ )
738
+ else:
739
+ self.rope.forward = partial(self.rope.forward, patch_indices_keep=None)
740
+ x = self.patch_dropout(x)
741
+ else:
742
+ x = self.patch_dropout(x)
743
+
744
+ rel_pos_bias = self.rel_pos_bias() if self.rel_pos_bias is not None else None
745
+ for blk in self.blocks:
746
+ if self.grad_checkpointing:
747
+ x = checkpoint(blk, x, (rel_pos_bias,))
748
+ else:
749
+ x = blk(x, rel_pos_bias=rel_pos_bias)
750
+
751
+ if not return_all_features:
752
+ x = self.norm(x)
753
+ if self.fc_norm is not None:
754
+ return self.fc_norm(x.mean(1))
755
+ else:
756
+ return x[:, 0]
757
+ return x
758
+
759
+ def forward(self, x, return_all_features=False):
760
+ if return_all_features:
761
+ return self.forward_features(x, return_all_features)
762
+ x = self.forward_features(x)
763
+ x = self.head(x)
764
+ return x
hf_model.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ from typing import Dict, Optional, Tuple
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from transformers import AutoConfig, AutoModel, PretrainedConfig
7
+ from transformers.modeling_outputs import (
8
+ BaseModelOutput,
9
+ BaseModelOutputWithPooling,
10
+ BaseModelOutputWithPoolingAndCrossAttentions,
11
+ )
12
+
13
+ """
14
+ HF architecture mapping
15
+ """
16
+
17
+ _HF_ARCH_DICT = {
18
+ # https://huggingface.co/docs/transformers/model_doc/roberta#roberta
19
+ 'roberta': {
20
+ 'config_names': {
21
+ 'context_length': 'max_position_embeddings',
22
+ 'vocab_size': 'vocab_size',
23
+ 'width': 'hidden_size',
24
+ 'heads': 'num_attention_heads',
25
+ 'layers': 'num_hidden_layers',
26
+ 'layer_attr': 'layer',
27
+ 'token_embeddings_attr': 'embeddings',
28
+ },
29
+ 'pooler': 'mean_pooler',
30
+ },
31
+ # https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaConfig
32
+ 'xlm-roberta': {
33
+ 'config_names': {
34
+ 'context_length': 'max_position_embeddings',
35
+ 'vocab_size': 'vocab_size',
36
+ 'width': 'hidden_size',
37
+ 'heads': 'num_attention_heads',
38
+ 'layers': 'num_hidden_layers',
39
+ 'layer_attr': 'layer',
40
+ 'token_embeddings_attr': 'embeddings',
41
+ },
42
+ 'pooler': 'mean_pooler',
43
+ },
44
+ # https://huggingface.co/docs/transformers/model_doc/mt5#mt5
45
+ 'mt5': {
46
+ 'config_names': {
47
+ # unlimited seqlen
48
+ # https://github.com/google-research/text-to-text-transfer-transformer/issues/273
49
+ # https://github.com/huggingface/transformers/blob/v4.24.0/src/transformers/models/t5/modeling_t5.py#L374
50
+ 'context_length': '',
51
+ 'vocab_size': 'vocab_size',
52
+ 'width': 'd_model',
53
+ 'heads': 'num_heads',
54
+ 'layers': 'num_layers',
55
+ 'layer_attr': 'block',
56
+ 'token_embeddings_attr': 'embed_tokens',
57
+ },
58
+ 'pooler': 'mean_pooler',
59
+ },
60
+ # https://huggingface.co/docs/transformers/model_doc/bert
61
+ 'bert': {
62
+ 'config_names': {
63
+ 'context_length': 'max_position_embeddings',
64
+ 'vocab_size': 'vocab_size',
65
+ 'width': 'hidden_size',
66
+ 'heads': 'num_attention_heads',
67
+ 'layers': 'num_hidden_layers',
68
+ },
69
+ 'pooler': 'cls_pooler',
70
+ },
71
+ # https://huggingface.co/docs/transformers/model_doc/m2m_100
72
+ 'm2m_100': {
73
+ 'config_names': {
74
+ 'context_length': 'max_position_embeddings',
75
+ 'vocab_size': 'vocab_size',
76
+ 'width': 'd_model',
77
+ 'heads': 'encoder_attention_heads',
78
+ 'layers': 'encoder_layers',
79
+ },
80
+ 'pooler': 'cls_pooler',
81
+ },
82
+ }
83
+
84
+
85
+ """
86
+ Pooling functions
87
+ """
88
+
89
+ _POOLERS = {}
90
+
91
+
92
+ def _camel2snake(s):
93
+ return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()
94
+
95
+
96
+ def register_pooler(cls):
97
+ """Decorator registering pooler class"""
98
+ _POOLERS[_camel2snake(cls.__name__)] = cls
99
+ return cls
100
+
101
+
102
+ @register_pooler
103
+ class MeanPooler(nn.Module):
104
+ """Mean pooling"""
105
+
106
+ @staticmethod
107
+ def forward(x: BaseModelOutput, attention_mask: torch.Tensor):
108
+ masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
109
+ return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
110
+
111
+
112
+ @register_pooler
113
+ class MaxPooler(nn.Module):
114
+ """
115
+ Max pooling
116
+ """
117
+
118
+ @staticmethod
119
+ def forward(x: BaseModelOutput, attention_mask: torch.Tensor):
120
+ masked_output = x.last_hidden_state.masked_fill(
121
+ attention_mask.unsqueeze(-1), -torch.inf
122
+ )
123
+ return masked_output.max(1).values
124
+
125
+
126
+ @register_pooler
127
+ class ClsPooler(nn.Module):
128
+ """
129
+ CLS token pooling
130
+ """
131
+
132
+ def __init__(self, use_pooler_output=True):
133
+ super().__init__()
134
+ self.cls_token_position = 0
135
+ self.use_pooler_output = use_pooler_output
136
+
137
+ def forward(self, x: BaseModelOutput, _: torch.Tensor):
138
+ if (
139
+ self.use_pooler_output
140
+ and isinstance(
141
+ x,
142
+ (
143
+ BaseModelOutputWithPooling,
144
+ BaseModelOutputWithPoolingAndCrossAttentions,
145
+ ),
146
+ )
147
+ and (x.pooler_output is not None)
148
+ ):
149
+ return x.pooler_output
150
+
151
+ return x.last_hidden_state[:, self.cls_token_position, :]
152
+
153
+
154
+ """
155
+ HF text model
156
+ """
157
+
158
+
159
+ class HFTextEncoder(nn.Module):
160
+ output_tokens: torch.jit.Final[bool]
161
+
162
+ def __init__(
163
+ self,
164
+ model_name_or_path: str,
165
+ output_dim: int,
166
+ config: PretrainedConfig = None,
167
+ pooler_type: str = None,
168
+ proj_type: str = None,
169
+ proj_bias: bool = False,
170
+ pretrained: bool = True,
171
+ output_tokens: bool = False,
172
+ trust_remote_code: bool = False,
173
+ revision: Optional[str] = None,
174
+ model_config_kwargs: Optional[Dict] = None,
175
+ ):
176
+ super().__init__()
177
+ self.output_tokens = output_tokens
178
+ self.output_dim = output_dim
179
+
180
+ # TODO: find better way to get this information
181
+ uses_transformer_pooler = pooler_type == 'cls_pooler'
182
+ model_config_kwargs = model_config_kwargs or {}
183
+
184
+ if config is None:
185
+ self.config = AutoConfig.from_pretrained(
186
+ model_name_or_path,
187
+ trust_remote_code=trust_remote_code,
188
+ code_revision=revision,
189
+ )
190
+ self.config.update(model_config_kwargs)
191
+ create_func, model_args = (
192
+ (AutoModel.from_pretrained, model_name_or_path)
193
+ if pretrained
194
+ else (AutoModel.from_config, self.config)
195
+ )
196
+ # TODO: do all model configs have this attribute?
197
+ # PretrainedConfig does so yes??
198
+ if (
199
+ hasattr(self.config, 'is_encoder_decoder')
200
+ and self.config.is_encoder_decoder
201
+ ):
202
+ self.transformer = create_func(model_args)
203
+ self.transformer = self.transformer.encoder
204
+ else:
205
+ self.transformer = create_func(
206
+ model_args,
207
+ trust_remote_code=trust_remote_code,
208
+ add_pooling_layer=uses_transformer_pooler,
209
+ code_revision=revision,
210
+ )
211
+ else:
212
+ self.config = config
213
+ self.config.update(model_config_kwargs)
214
+ self.transformer = AutoModel.from_config(self.config)
215
+
216
+ if pooler_type is None: # get default arch pooler
217
+ pooler_type = _HF_ARCH_DICT[self.config.model_type]['pooler']
218
+
219
+ # FIXME downstream users of OpenCLIP models use these attr,
220
+ # need to verify valid across all models
221
+ self.vocab_size = getattr(self.config, 'vocab_size', 0)
222
+ self.context_length = getattr(self.config, 'max_position_embeddings', 0)
223
+
224
+ self.pooler = _POOLERS[pooler_type]()
225
+
226
+ d_model = getattr(
227
+ self.config, _HF_ARCH_DICT[self.config.model_type]['config_names']['width']
228
+ )
229
+ if (d_model == output_dim) and (proj_type is None): # do we always need a proj?
230
+ self.proj = nn.Identity()
231
+ elif proj_type == 'linear':
232
+ self.proj = nn.Linear(d_model, output_dim, bias=proj_bias)
233
+ elif proj_type == 'mlp':
234
+ hidden_size = (d_model + output_dim) // 2
235
+ self.proj = nn.Sequential(
236
+ nn.Linear(d_model, hidden_size, bias=proj_bias),
237
+ nn.GELU(),
238
+ nn.Linear(hidden_size, output_dim, bias=proj_bias),
239
+ )
240
+
241
+ def forward(self, x: torch.Tensor):
242
+ attn_mask = (x != self.config.pad_token_id).long()
243
+ out = self.transformer(input_ids=x, attention_mask=attn_mask)
244
+ pooled_out = self.pooler(out, attn_mask)
245
+ projected = self.proj(pooled_out)
246
+
247
+ seq_len = out.last_hidden_state.shape[1]
248
+ tokens = (
249
+ out.last_hidden_state[
250
+ :, torch.arange(seq_len) != self.pooler.cls_token_position, :
251
+ ]
252
+ if isinstance(self.pooler, ClsPooler)
253
+ else out.last_hidden_state
254
+ )
255
+
256
+ if self.output_tokens:
257
+ return projected, tokens
258
+ return projected
259
+
260
+ def lock(self, unlocked_layers: int = 0, freeze_layer_norm: bool = True):
261
+ if not unlocked_layers: # full freezing
262
+ for n, p in self.transformer.named_parameters():
263
+ p.requires_grad = (
264
+ (not freeze_layer_norm) if 'LayerNorm' in n.split('.') else False
265
+ )
266
+ return
267
+
268
+ encoder = (
269
+ self.transformer.encoder
270
+ if hasattr(self.transformer, 'encoder')
271
+ else self.transformer
272
+ )
273
+ layer_list = getattr(
274
+ encoder, _HF_ARCH_DICT[self.config.model_type]['config_names']['layer_attr']
275
+ )
276
+ print(f'Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model')
277
+ embeddings = getattr(
278
+ self.transformer,
279
+ _HF_ARCH_DICT[self.config.model_type]['config_names'][
280
+ 'token_embeddings_attr'
281
+ ],
282
+ )
283
+ modules = [embeddings, *layer_list][:-unlocked_layers]
284
+ # freeze layers
285
+ for module in modules:
286
+ for n, p in module.named_parameters():
287
+ p.requires_grad = (
288
+ (not freeze_layer_norm) if 'LayerNorm' in n.split('.') else False
289
+ )
290
+
291
+ @torch.jit.ignore
292
+ def set_grad_checkpointing(self, _=True):
293
+ self.transformer.gradient_checkpointing_enable()
294
+
295
+ def init_parameters(self):
296
+ pass
297
+
modeling_clip.py ADDED
@@ -0,0 +1,570 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ #
3
+ # Code mainly copied from:
4
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/modeling_clip.py
5
+ # and adjusted for Jina CLIP
6
+
7
+ from functools import partial
8
+ from typing import List, Optional, Tuple, Union
9
+ from io import BytesIO
10
+ import requests
11
+ import base64
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn.functional as f
15
+ import torch.utils.checkpoint
16
+ from torch import nn
17
+ from transformers import (
18
+ AutoImageProcessor,
19
+ AutoTokenizer,
20
+ BatchEncoding,
21
+ BatchFeature,
22
+ PreTrainedModel,
23
+ logging,
24
+ )
25
+ from transformers.models.clip.modeling_clip import (
26
+ CLIPOutput,
27
+ CLIPTextModelOutput,
28
+ CLIPVisionModelOutput,
29
+ clip_loss,
30
+ )
31
+
32
+ try:
33
+ from tqdm.autonotebook import trange
34
+
35
+ has_tqdm = True
36
+ except ImportError:
37
+ has_tqdm = False
38
+
39
+ from .configuration_clip import JinaCLIPConfig, JinaCLIPTextConfig, JinaCLIPVisionConfig
40
+ from .eva_model import EVAVisionTransformer
41
+ from .hf_model import HFTextEncoder
42
+ # needed for HF to correctly import in cache
43
+ from .rope_embeddings import VisionRotaryEmbeddingFast # noqa: F401
44
+ from .transform import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD, image_transform # noqa: F401
45
+
46
+ logger = logging.get_logger(__name__)
47
+
48
+
49
+ """ Jina CLIP model implementation """
50
+
51
+
52
+ class LayerNorm(nn.LayerNorm):
53
+ """Subclass torch's LayerNorm (with cast back to input dtype)."""
54
+
55
+ def forward(self, x: torch.Tensor):
56
+ origtype = x.dtype
57
+ x = f.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
58
+ return x.to(origtype)
59
+
60
+
61
+ def _build_text_tower(config: JinaCLIPTextConfig) -> HFTextEncoder:
62
+ return HFTextEncoder(
63
+ model_name_or_path=config.hf_model_name_or_path,
64
+ output_dim=config.embed_dim,
65
+ pooler_type=config.pooler_type,
66
+ proj_type=config.proj_type,
67
+ proj_bias=config.proj_bias,
68
+ pretrained=False,
69
+ output_tokens=False,
70
+ trust_remote_code=True,
71
+ revision=None,
72
+ model_config_kwargs=config.hf_model_config_kwargs,
73
+ )
74
+
75
+
76
+ def _build_vision_tower(config: JinaCLIPVisionConfig) -> EVAVisionTransformer:
77
+ norm_layer = partial(LayerNorm, eps=1e-6)
78
+
79
+ if config.fused_layer_norm:
80
+ try:
81
+ from apex.normalization import FusedLayerNorm
82
+
83
+ norm_layer = partial(FusedLayerNorm, eps=1e-6)
84
+ except (ModuleNotFoundError, ImportError):
85
+ logger.warning('Please install apex to use fused layer norm, ignoring')
86
+
87
+ return EVAVisionTransformer(
88
+ img_size=config.image_size,
89
+ patch_size=config.patch_size,
90
+ num_classes=config.embed_dim,
91
+ use_mean_pooling=False,
92
+ init_values=config.ls_init_value,
93
+ patch_dropout=config.patch_dropout,
94
+ embed_dim=config.width,
95
+ depth=config.layers,
96
+ num_heads=config.width // config.head_width,
97
+ mlp_ratio=config.mlp_ratio,
98
+ qkv_bias=config.qkv_bias,
99
+ drop_path_rate=config.drop_path_rate,
100
+ norm_layer=norm_layer,
101
+ xattn=config.x_attention,
102
+ rope=config.rope_embeddings,
103
+ postnorm=config.post_norm,
104
+ pt_hw_seq_len=config.pt_hw_seq_len,
105
+ intp_freq=config.intp_freq,
106
+ naiveswiglu=config.naive_swiglu,
107
+ subln=config.subln,
108
+ proj_type=config.proj_type,
109
+ )
110
+
111
+
112
+ class JinaCLIPPreTrainedModel(PreTrainedModel):
113
+ """
114
+ An abstract class to handle weights initialization and a simple interface for
115
+ downloading and loading pretrained models.
116
+ """
117
+
118
+ config_class = JinaCLIPConfig
119
+ base_model_prefix = 'clip'
120
+ supports_gradient_checkpointing = True
121
+
122
+ def _init_weights(self, module):
123
+ """Initialize the weights"""
124
+ if isinstance(module, JinaCLIPModel):
125
+ if isinstance(module.text_projection, nn.Linear):
126
+ nn.init.normal_(
127
+ module.text_projection.weight,
128
+ std=module.text_embed_dim**-0.5 * self.config.initializer_factor,
129
+ )
130
+ if isinstance(module.text_projection, nn.Linear):
131
+ nn.init.normal_(
132
+ module.visual_projection.weight,
133
+ std=module.vision_embed_dim**-0.5 * self.config.initializer_factor,
134
+ )
135
+ if isinstance(module, nn.LayerNorm):
136
+ module.bias.data.zero_()
137
+ module.weight.data.fill_(1.0)
138
+ if isinstance(module, nn.Linear) and module.bias is not None:
139
+ module.bias.data.zero_()
140
+
141
+
142
+ class JinaCLIPTextModel(JinaCLIPPreTrainedModel):
143
+ config_class = JinaCLIPTextConfig
144
+
145
+ def __init__(self, config: JinaCLIPTextConfig):
146
+ super().__init__(config)
147
+ self.text_model = _build_text_tower(config)
148
+ self.post_init()
149
+
150
+ def forward(
151
+ self,
152
+ input_ids: Union[None, torch.Tensor, BatchEncoding] = None,
153
+ return_dict: Optional[bool] = None,
154
+ *_,
155
+ **__,
156
+ ) -> Union[Tuple[Optional[torch.FloatTensor], ...], CLIPTextModelOutput]:
157
+ return_dict = (
158
+ return_dict if return_dict is not None else self.config.use_return_dict
159
+ )
160
+ x = input_ids.input_ids if isinstance(input_ids, BatchEncoding) else input_ids
161
+ feats = self.text_model(x=x)
162
+ out = CLIPTextModelOutput(text_embeds=feats)
163
+ return out if return_dict else out.to_tuple()
164
+
165
+
166
+ class JinaCLIPVisionModel(JinaCLIPPreTrainedModel):
167
+ config_class = JinaCLIPVisionConfig
168
+ main_input_name = 'pixel_values'
169
+
170
+ def __init__(self, config: JinaCLIPVisionConfig):
171
+ super().__init__(config)
172
+ self.vision_model = _build_vision_tower(config)
173
+ self.post_init()
174
+
175
+ def forward(
176
+ self,
177
+ pixel_values: Union[None, torch.FloatTensor, BatchFeature] = None,
178
+ return_dict: Optional[bool] = None,
179
+ *_,
180
+ **__,
181
+ ) -> Union[Tuple[Optional[torch.FloatTensor], ...], CLIPVisionModelOutput]:
182
+ return_dict = (
183
+ return_dict if return_dict is not None else self.config.use_return_dict
184
+ )
185
+ x = (
186
+ pixel_values.pixel_values
187
+ if isinstance(pixel_values, BatchFeature)
188
+ else pixel_values
189
+ )
190
+ feats = self.vision_model(x=x)
191
+ out = CLIPVisionModelOutput(image_embeds=feats)
192
+ return out if return_dict else out.to_tuple()
193
+
194
+
195
+ class JinaCLIPModel(JinaCLIPPreTrainedModel):
196
+ config_class = JinaCLIPConfig
197
+
198
+ def __init__(self, config: JinaCLIPConfig):
199
+ super().__init__(config)
200
+
201
+ if not isinstance(config.text_config, JinaCLIPTextConfig):
202
+ raise ValueError(
203
+ 'Attribute config.text_config is expected to be of type '
204
+ f'JinaCLIPTextConfig but is of type {type(config.text_config)}.'
205
+ )
206
+
207
+ if not isinstance(config.vision_config, JinaCLIPVisionConfig):
208
+ raise ValueError(
209
+ 'Attribute config.vision_config is expected to be of type '
210
+ f'JinaCLIPVisionConfig but is of type {type(config.vision_config)}.'
211
+ )
212
+
213
+ text_config = config.text_config
214
+ vision_config = config.vision_config
215
+
216
+ if config.use_text_flash_attn is not None:
217
+ text_config.hf_model_config_kwargs['use_flash_attn'] = config.use_text_flash_attn
218
+ if config.use_vision_xformers is not None:
219
+ vision_config.x_attention = config.use_vision_xformers
220
+
221
+ self.add_projections = config.add_projections
222
+ self.projection_dim = config.projection_dim
223
+ self.text_embed_dim = text_config.embed_dim
224
+ self.vision_embed_dim = vision_config.embed_dim
225
+
226
+ self.text_model = _build_text_tower(text_config)
227
+ self.vision_model = _build_vision_tower(vision_config)
228
+ self.logit_scale = nn.Parameter(
229
+ torch.tensor(self.config.logit_scale_init_value)
230
+ )
231
+
232
+ if self.add_projections:
233
+ self.visual_projection = nn.Linear(
234
+ self.vision_embed_dim, self.projection_dim, bias=False
235
+ )
236
+ self.text_projection = nn.Linear(
237
+ self.text_embed_dim, self.projection_dim, bias=False
238
+ )
239
+ else:
240
+ self.visual_projection = nn.Identity()
241
+ self.text_projection = nn.Identity()
242
+
243
+ self.tokenizer = None
244
+ self.preprocess = None
245
+ self.post_init()
246
+
247
+ def get_tokenizer(self):
248
+ if not self.tokenizer:
249
+ self.tokenizer = AutoTokenizer.from_pretrained(
250
+ self.config._name_or_path, trust_remote_code=True
251
+ )
252
+ return self.tokenizer
253
+
254
+ def get_preprocess(self):
255
+ if not self.preprocess:
256
+ self.preprocess = AutoImageProcessor.from_pretrained(
257
+ self.config._name_or_path, trust_remote_code=True
258
+ )
259
+ return self.preprocess
260
+
261
+ def get_text_features(
262
+ self,
263
+ input_ids: Union[None, torch.Tensor, BatchEncoding] = None,
264
+ *_,
265
+ **__,
266
+ ) -> torch.FloatTensor:
267
+ x = input_ids.input_ids if isinstance(input_ids, BatchEncoding) else input_ids
268
+ return self.text_projection(self.text_model(x=x))
269
+
270
+ def get_image_features(
271
+ self,
272
+ pixel_values: Union[None, torch.FloatTensor, BatchFeature] = None,
273
+ *_,
274
+ **__,
275
+ ) -> torch.FloatTensor:
276
+ x = (
277
+ pixel_values.pixel_values
278
+ if isinstance(pixel_values, BatchFeature)
279
+ else pixel_values
280
+ )
281
+ return self.visual_projection(self.vision_model(x=x))
282
+
283
+ @torch.inference_mode()
284
+ def encode_text(
285
+ self,
286
+ sentences: Union[str, List[str]],
287
+ batch_size: int = 32,
288
+ show_progress_bar: Optional[bool] = None,
289
+ convert_to_numpy: bool = True,
290
+ convert_to_tensor: bool = False,
291
+ device: Optional[torch.device] = None,
292
+ normalize_embeddings: bool = True,
293
+ **tokenizer_kwargs,
294
+ ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
295
+ """
296
+ Computes sentence embeddings
297
+ Args:
298
+ sentences(`str` or `List[str]`):
299
+ Sentence or sentences to be encoded
300
+ batch_size(`int`, *optional*, defaults to 32):
301
+ Batch size for the computation
302
+ show_progress_bar(`bool`, *optional*, defaults to None):
303
+ Show a progress bar when encoding sentences.
304
+ If set to None, progress bar is only shown when
305
+ `logger.level == logging.INFO` or `logger.level == logging.DEBUG`.
306
+ convert_to_numpy(`bool`, *optional*, defaults to True):
307
+ If true, the output is a list of numpy vectors.
308
+ Else, it is a list of pytorch tensors.
309
+ convert_to_tensor(`bool`, *optional*, defaults to False):
310
+ If true, you get one large tensor as return.
311
+ Overwrites any setting from convert_to_numpy
312
+ device(`torch.device`, *optional*, defaults to None):
313
+ Which torch.device to use for the computation
314
+ normalize_embeddings(`bool`, *optional*, defaults to False):
315
+ If set to true, returned vectors will have length 1. In that case,
316
+ the faster dot-product (util.dot_score) instead of cosine similarity
317
+ can be used.
318
+ tokenizer_kwargs(`Dict[str, Any]`, *optional*, defaults to {}):
319
+ Keyword arguments for the tokenizer
320
+ Returns:
321
+ By default, a list of tensors is returned.
322
+ If convert_to_tensor, a stacked tensor is returned.
323
+ If convert_to_numpy, a numpy matrix is returned.
324
+ """
325
+ is_training = self.training
326
+ self.eval()
327
+ all_embeddings = []
328
+
329
+ self.tokenizer = self.get_tokenizer()
330
+
331
+ if show_progress_bar is None:
332
+ show_progress_bar = (
333
+ logger.getEffectiveLevel() == logging.INFO
334
+ or logger.getEffectiveLevel() == logging.DEBUG
335
+ )
336
+
337
+ if convert_to_tensor:
338
+ convert_to_numpy = False
339
+
340
+ input_was_string = False
341
+ if isinstance(sentences, str) or not hasattr(sentences, '__len__'):
342
+ sentences = [sentences]
343
+ input_was_string = True
344
+
345
+ if device is not None:
346
+ self.to(device)
347
+
348
+ permutation = np.argsort([-len(i) for i in sentences])
349
+ inverse_permutation = np.argsort(permutation)
350
+ sentences = [sentences[idx] for idx in permutation]
351
+
352
+ tokenizer_kwargs['padding'] = tokenizer_kwargs.get('padding', True)
353
+ tokenizer_kwargs['max_length'] = tokenizer_kwargs.get('max_length', 512)
354
+ tokenizer_kwargs['truncation'] = tokenizer_kwargs.get('truncation', True)
355
+
356
+ if has_tqdm:
357
+ range_iter = trange(
358
+ 0,
359
+ len(sentences),
360
+ batch_size,
361
+ desc='Encoding',
362
+ disable=not show_progress_bar,
363
+ )
364
+ else:
365
+ range_iter = range(0, len(sentences), batch_size)
366
+
367
+ for i in range_iter:
368
+ encoded_input = self.tokenizer(
369
+ sentences[i : i + batch_size],
370
+ return_tensors='pt',
371
+ **tokenizer_kwargs,
372
+ ).to(self.device)
373
+
374
+ embeddings = self.get_text_features(input_ids=encoded_input)
375
+ if normalize_embeddings:
376
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
377
+ if convert_to_numpy:
378
+ embeddings = embeddings.cpu()
379
+ all_embeddings.extend(embeddings)
380
+
381
+ all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
382
+
383
+ if convert_to_tensor:
384
+ all_embeddings = torch.stack(all_embeddings)
385
+ elif convert_to_numpy:
386
+ all_embeddings = np.asarray([emb.to(torch.float32).numpy() for emb in all_embeddings])
387
+
388
+ if input_was_string:
389
+ all_embeddings = all_embeddings[0]
390
+
391
+ self.train(is_training)
392
+ return all_embeddings
393
+
394
+ def decode_data_image(data_image_str):
395
+ header, data = data_image_str.split(',', 1)
396
+ image_data = base64.b64decode(data)
397
+ return Image.open(BytesIO(image_data))
398
+
399
+ @torch.inference_mode()
400
+ def encode_image(
401
+ self,
402
+ images: Union[str, List[Union[str, "Image.Image"]]],
403
+ batch_size: int = 32,
404
+ show_progress_bar: Optional[bool] = None,
405
+ convert_to_numpy: bool = True,
406
+ convert_to_tensor: bool = False,
407
+ device: Optional[torch.device] = None,
408
+ normalize_embeddings: bool = True,
409
+ ) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
410
+ """
411
+ Computes image embeddings.
412
+
413
+ Args:
414
+ images(`str` or `List[Union[str, Image.Image]]`):
415
+ image paths, URLs, PIL images, or data:image/ strings to be encoded
416
+ batch_size(`int`, *optional*, defaults to 32):
417
+ Batch size for the computation
418
+ show_progress_bar(`bool`, *optional*, defaults to None):
419
+ Show a progress bar when encoding images.
420
+ If set to None, progress bar is only shown when
421
+ `logger.level == logging.INFO` or `logger.level == logging.DEBUG`.
422
+ convert_to_numpy(`bool`, *optional*, defaults to True):
423
+ If true, the output is a list of numpy vectors.
424
+ Else, it is a list of pytorch tensors.
425
+ convert_to_tensor(`bool`, *optional*, defaults to False):
426
+ If true, you get one large tensor as return.
427
+ Overwrites any setting from convert_to_numpy
428
+ device(`torch.device`, *optional*, defaults to None):
429
+ Which torch.device to use for the computation
430
+ normalize_embeddings(`bool`, *optional*, defaults to False):
431
+ If set to true, returned vectors will have length 1. In that case,
432
+ the faster dot-product (util.dot_score) instead of cosine similarity
433
+ can be used.
434
+ Returns:
435
+ By default, a list of tensors is returned.
436
+ If convert_to_tensor, a stacked tensor is returned.
437
+ If convert_to_numpy, a numpy matrix is returned.
438
+ """
439
+
440
+ is_training = self.training
441
+ self.eval()
442
+
443
+ self.preprocess = self.get_preprocess()
444
+ all_embeddings = []
445
+
446
+ if show_progress_bar is None:
447
+ show_progress_bar = (
448
+ logger.getEffectiveLevel() == logging.INFO
449
+ or logger.getEffectiveLevel() == logging.DEBUG
450
+ )
451
+
452
+ if convert_to_tensor:
453
+ convert_to_numpy = False
454
+
455
+ input_was_single_img = False
456
+ if isinstance(images, str) or not hasattr(images, '__len__'):
457
+ images = [images]
458
+ input_was_single_img = True
459
+
460
+ if device is not None:
461
+ self.to(device)
462
+
463
+ permutation = np.argsort([-len(str(i)) for i in images])
464
+ inverse_permutation = np.argsort(permutation)
465
+ images = [images[idx] for idx in permutation]
466
+
467
+ if has_tqdm:
468
+ range_iter = trange(
469
+ 0,
470
+ len(images),
471
+ batch_size,
472
+ desc='Encoding',
473
+ disable=not show_progress_bar,
474
+ )
475
+ else:
476
+ range_iter = range(0, len(images), batch_size)
477
+
478
+ from PIL import Image
479
+
480
+ for i in range_iter:
481
+ batch_images = images[i:i+batch_size]
482
+ processed_inputs = []
483
+
484
+ for img in batch_images:
485
+ if isinstance(img, str):
486
+ if img.startswith('http'):
487
+ response = requests.get(img)
488
+ image = Image.open(BytesIO(response.content)).convert('RGB')
489
+ elif img.startswith('data:image/'):
490
+ image = decode_data_image(img).convert('RGB')
491
+ else:
492
+ image = Image.open(img).convert('RGB')
493
+ elif isinstance(img, Image.Image):
494
+ image = img.convert('RGB')
495
+ else:
496
+ raise ValueError("Unsupported image format")
497
+
498
+ processed_inputs.append(image)
499
+
500
+ processed_inputs = self.preprocess(processed_inputs)
501
+ processed_inputs = processed_inputs.to(self.device)
502
+ embeddings = self.get_image_features(processed_inputs)
503
+
504
+ if normalize_embeddings:
505
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
506
+ if convert_to_numpy:
507
+ embeddings = embeddings.cpu()
508
+ all_embeddings.extend(embeddings)
509
+
510
+ all_embeddings = [all_embeddings[idx] for idx in inverse_permutation]
511
+
512
+ if convert_to_tensor:
513
+ all_embeddings = torch.stack(all_embeddings)
514
+ elif convert_to_numpy:
515
+ all_embeddings = np.asarray([emb.to(torch.float32).numpy() for emb in all_embeddings])
516
+
517
+ if input_was_single_img:
518
+ all_embeddings = all_embeddings[0]
519
+
520
+ self.train(is_training)
521
+ return all_embeddings
522
+
523
+ def forward(
524
+ self,
525
+ input_ids: Union[None, torch.Tensor, BatchEncoding] = None,
526
+ pixel_values: Union[None, torch.FloatTensor, BatchFeature] = None,
527
+ return_dict: Optional[bool] = None,
528
+ return_loss: Optional[bool] = None,
529
+ *_,
530
+ **__,
531
+ ) -> Union[Tuple[Optional[torch.FloatTensor], ...], CLIPOutput]:
532
+ return_dict = (
533
+ return_dict if return_dict is not None else self.config.use_return_dict
534
+ )
535
+ image_embeds = self.get_image_features(pixel_values=pixel_values)
536
+ text_embeds = self.get_text_features(input_ids=input_ids)
537
+
538
+ # normalized features
539
+ image_embeds = image_embeds / image_embeds.norm(p=2, dim=-1, keepdim=True)
540
+ text_embeds = text_embeds / text_embeds.norm(p=2, dim=-1, keepdim=True)
541
+
542
+ # cosine similarity as logits
543
+ logit_scale = self.logit_scale.exp()
544
+ logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
545
+ logits_per_image = logits_per_text.t()
546
+
547
+ loss = None
548
+ if return_loss:
549
+ loss = clip_loss(logits_per_text)
550
+
551
+ if not return_dict:
552
+ output = (
553
+ logits_per_image,
554
+ logits_per_text,
555
+ text_embeds,
556
+ image_embeds,
557
+ None,
558
+ None,
559
+ )
560
+ return ((loss,) + output) if loss is not None else output
561
+
562
+ return CLIPOutput(
563
+ loss=loss,
564
+ logits_per_image=logits_per_image,
565
+ logits_per_text=logits_per_text,
566
+ text_embeds=text_embeds,
567
+ image_embeds=image_embeds,
568
+ text_model_output=None,
569
+ vision_model_output=None,
570
+ )
modules.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "idx": 0,
4
+ "name": "0",
5
+ "path": "",
6
+ "type": "jina-clip-implementation-st.6d5aa7d8b428eaba8d7908d86f43f5dd5ad6ad93.custom_st.Transformer"
7
+ }
8
+ ]
preprocessor_config.json ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoImageProcessor": "processing_clip.JinaCLIPImageProcessor",
4
+ "AutoProcessor": "jinaai/jina-clip-implementation--processing_clip.JinaCLIPProcessor"
5
+ },
6
+ "fill_color": 0,
7
+ "image_processor_type": "JinaCLIPImageProcessor",
8
+ "interpolation": "bicubic",
9
+ "mean": [
10
+ 0.48145466,
11
+ 0.4578275,
12
+ 0.40821073
13
+ ],
14
+ "processor_class": "JinaCLIPProcessor",
15
+ "resize_mode": "shortest",
16
+ "size": 224,
17
+ "std": [
18
+ 0.26862954,
19
+ 0.26130258,
20
+ 0.27577711
21
+ ]
22
+ }
processing_clip.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ #
3
+ # Code mainly copied from:
4
+ # https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/image_processing_clip.py
5
+ # and adjusted for Jina CLIP
6
+
7
+ from typing import Tuple, Union
8
+
9
+ import torch
10
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
11
+ from transformers.image_utils import ImageInput, make_list_of_images
12
+ from transformers.models.clip import CLIPProcessor
13
+
14
+ from .transform import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD, image_transform
15
+
16
+ """ Jina CLIP processor implementation """
17
+
18
+
19
+ class JinaCLIPProcessor(CLIPProcessor):
20
+ image_processor_class = 'AutoImageProcessor'
21
+ tokenizer_class = 'AutoTokenizer'
22
+
23
+
24
+ """ Jina CLIP image processor implementation """
25
+
26
+
27
+ class JinaCLIPImageProcessor(BaseImageProcessor):
28
+ model_input_names = ['pixel_values']
29
+ _valid_processor_keys = [
30
+ 'size',
31
+ 'mean',
32
+ 'std',
33
+ 'resize_mode',
34
+ 'interpolation',
35
+ 'fill_color',
36
+ ]
37
+
38
+ def __init__(
39
+ self,
40
+ size: Union[int, Tuple[int, int]] = 224,
41
+ mean: Union[float, Tuple[float]] = OPENAI_DATASET_MEAN,
42
+ std: Union[float, Tuple[float]] = OPENAI_DATASET_STD,
43
+ resize_mode: str = 'shortest',
44
+ interpolation: str = 'bicubic',
45
+ fill_color: int = 0,
46
+ **kwargs,
47
+ ) -> None:
48
+ super().__init__(**kwargs)
49
+ self.size = size
50
+ self.mean = mean
51
+ self.std = std
52
+ self.resize_mode = resize_mode
53
+ self.interpolation = interpolation
54
+ self.fill_color = fill_color
55
+ self.transform = self._build_transform()
56
+
57
+ def _build_transform(self):
58
+ return image_transform(
59
+ image_size=self.size,
60
+ is_train=False,
61
+ mean=self.mean,
62
+ std=self.std,
63
+ resize_mode=self.resize_mode,
64
+ interpolation=self.interpolation,
65
+ fill_color=self.fill_color,
66
+ aug_cfg=None,
67
+ )
68
+
69
+ def to_dict(self):
70
+ output = super().to_dict()
71
+ output.pop('transform')
72
+ return output
73
+
74
+ def preprocess(self, images: ImageInput, **kwargs) -> BatchFeature:
75
+
76
+ _transform_needs_rebuild = False
77
+ for k, v in kwargs.items():
78
+ if k in self._valid_processor_keys:
79
+ if v != getattr(self, k):
80
+ setattr(self, k, v)
81
+ _transform_needs_rebuild = True
82
+
83
+ if _transform_needs_rebuild:
84
+ self.transform = self._build_transform()
85
+
86
+ images = make_list_of_images(images)
87
+ out = torch.stack([self.transform(image) for image in images], dim=0)
88
+ return BatchFeature(data={'pixel_values': out})
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5af329d790c12cf109dabb4e31bf20e24dc07f8aab26509fb39004998cd9674e
3
+ size 890826430
rope_embeddings.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # --------------------------------------------------------
2
+ # Adapted from EVA CLIP
3
+ # https://github.com/baaivision/EVA/tree/master/EVA-CLIP/rei/eva_clip
4
+ # --------------------------------------------------------
5
+
6
+ import logging
7
+ from math import pi
8
+
9
+ import torch
10
+ from einops import rearrange, repeat
11
+ from torch import nn
12
+
13
+
14
+ def broadcast(tensors, dim=-1):
15
+ num_tensors = len(tensors)
16
+ shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
17
+ assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions'
18
+ shape_len = list(shape_lens)[0]
19
+ dim = (dim + shape_len) if dim < 0 else dim
20
+ dims = list(zip(*map(lambda t: list(t.shape), tensors)))
21
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
22
+ assert all(
23
+ [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
24
+ ), 'invalid dimensions for broadcastable concatentation'
25
+ max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
26
+ expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
27
+ expanded_dims.insert(dim, (dim, dims[dim]))
28
+ expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
29
+ tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
30
+ return torch.cat(tensors, dim=dim)
31
+
32
+
33
+ def rotate_half(x):
34
+ x = rearrange(x, '... (d r) -> ... d r', r=2)
35
+ x1, x2 = x.unbind(dim=-1)
36
+ x = torch.stack((-x2, x1), dim=-1)
37
+ return rearrange(x, '... d r -> ... (d r)')
38
+
39
+
40
+ class VisionRotaryEmbedding(nn.Module):
41
+ def __init__(
42
+ self,
43
+ dim,
44
+ pt_seq_len,
45
+ ft_seq_len=None,
46
+ custom_freqs=None,
47
+ freqs_for='lang',
48
+ theta=10000,
49
+ max_freq=10,
50
+ num_freqs=1,
51
+ ):
52
+ super().__init__()
53
+ if custom_freqs:
54
+ freqs = custom_freqs
55
+ elif freqs_for == 'lang':
56
+ freqs = 1.0 / (
57
+ theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
58
+ )
59
+ elif freqs_for == 'pixel':
60
+ freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
61
+ elif freqs_for == 'constant':
62
+ freqs = torch.ones(num_freqs).float()
63
+ else:
64
+ raise ValueError(f'unknown modality {freqs_for}')
65
+
66
+ if ft_seq_len is None:
67
+ ft_seq_len = pt_seq_len
68
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
69
+
70
+ freqs_h = torch.einsum('..., f -> ... f', t, freqs)
71
+ freqs_h = repeat(freqs_h, '... n -> ... (n r)', r=2)
72
+
73
+ freqs_w = torch.einsum('..., f -> ... f', t, freqs)
74
+ freqs_w = repeat(freqs_w, '... n -> ... (n r)', r=2)
75
+
76
+ freqs = broadcast((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1)
77
+
78
+ self.register_buffer('freqs_cos', freqs.cos())
79
+ self.register_buffer('freqs_sin', freqs.sin())
80
+
81
+ logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
82
+
83
+ def forward(self, t, start_index=0):
84
+ rot_dim = self.freqs_cos.shape[-1]
85
+ end_index = start_index + rot_dim
86
+ assert rot_dim <= t.shape[-1], (
87
+ f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in '
88
+ f'all the positions {rot_dim}'
89
+ )
90
+ t_left, t, t_right = (
91
+ t[..., :start_index],
92
+ t[..., start_index:end_index],
93
+ t[..., end_index:],
94
+ )
95
+ t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
96
+
97
+ return torch.cat((t_left, t, t_right), dim=-1)
98
+
99
+
100
+ class VisionRotaryEmbeddingFast(nn.Module):
101
+ def __init__(
102
+ self,
103
+ dim,
104
+ pt_seq_len,
105
+ ft_seq_len=None,
106
+ custom_freqs=None,
107
+ freqs_for='lang',
108
+ theta=10000,
109
+ max_freq=10,
110
+ num_freqs=1,
111
+ patch_dropout=0.0,
112
+ ):
113
+ super().__init__()
114
+ if custom_freqs:
115
+ freqs = custom_freqs
116
+ elif freqs_for == 'lang':
117
+ freqs = 1.0 / (
118
+ theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
119
+ )
120
+ elif freqs_for == 'pixel':
121
+ freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
122
+ elif freqs_for == 'constant':
123
+ freqs = torch.ones(num_freqs).float()
124
+ else:
125
+ raise ValueError(f'unknown modality {freqs_for}')
126
+
127
+ if ft_seq_len is None:
128
+ ft_seq_len = pt_seq_len
129
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
130
+
131
+ freqs = torch.einsum('..., f -> ... f', t, freqs)
132
+ freqs = repeat(freqs, '... n -> ... (n r)', r=2)
133
+ freqs = broadcast((freqs[:, None, :], freqs[None, :, :]), dim=-1)
134
+
135
+ freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
136
+ freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
137
+
138
+ self.patch_dropout = patch_dropout
139
+
140
+ self.register_buffer('freqs_cos', freqs_cos)
141
+ self.register_buffer('freqs_sin', freqs_sin)
142
+
143
+ logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
144
+
145
+ def forward(self, t, patch_indices_keep=None):
146
+ if patch_indices_keep is not None:
147
+ batch = t.size()[0]
148
+ batch_indices = torch.arange(batch)
149
+ batch_indices = batch_indices[..., None]
150
+
151
+ freqs_cos = repeat(
152
+ self.freqs_cos, 'i j -> n i m j', n=t.shape[0], m=t.shape[1]
153
+ )
154
+ freqs_sin = repeat(
155
+ self.freqs_sin, 'i j -> n i m j', n=t.shape[0], m=t.shape[1]
156
+ )
157
+
158
+ freqs_cos = freqs_cos[batch_indices, patch_indices_keep]
159
+ freqs_cos = rearrange(freqs_cos, 'n i m j -> n m i j')
160
+ freqs_sin = freqs_sin[batch_indices, patch_indices_keep]
161
+ freqs_sin = rearrange(freqs_sin, 'n i m j -> n m i j')
162
+
163
+ return t * freqs_cos + rotate_half(t) * freqs_sin
164
+
165
+ return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
special_tokens_map.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": {
3
+ "content": "[CLS]",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "mask_token": {
10
+ "content": "[MASK]",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "[PAD]",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "sep_token": {
24
+ "content": "[SEP]",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "unk_token": {
31
+ "content": "[UNK]",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ }
37
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_config.json ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "100": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "101": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "102": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "103": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "clean_up_tokenization_spaces": true,
45
+ "cls_token": "[CLS]",
46
+ "do_basic_tokenize": true,
47
+ "do_lower_case": true,
48
+ "mask_token": "[MASK]",
49
+ "max_length": 8192,
50
+ "model_max_length": 8192,
51
+ "never_split": null,
52
+ "pad_to_multiple_of": null,
53
+ "pad_token": "[PAD]",
54
+ "pad_token_type_id": 0,
55
+ "padding_side": "right",
56
+ "sep_token": "[SEP]",
57
+ "stride": 0,
58
+ "strip_accents": null,
59
+ "tokenize_chinese_chars": true,
60
+ "tokenizer_class": "BertTokenizer",
61
+ "truncation_side": "right",
62
+ "truncation_strategy": "longest_first",
63
+ "unk_token": "[UNK]"
64
+ }
transform.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numbers
2
+ import random
3
+ import warnings
4
+ from dataclasses import asdict, dataclass
5
+ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
6
+
7
+ import torch
8
+ import torchvision.transforms.functional as F
9
+ from torchvision.transforms import (
10
+ CenterCrop,
11
+ ColorJitter,
12
+ Compose,
13
+ Grayscale,
14
+ InterpolationMode,
15
+ Normalize,
16
+ RandomResizedCrop,
17
+ Resize,
18
+ ToTensor,
19
+ )
20
+ from transformers.image_utils import OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
21
+
22
+ OPENAI_DATASET_MEAN = tuple(OPENAI_CLIP_MEAN)
23
+ OPENAI_DATASET_STD = tuple(OPENAI_CLIP_STD)
24
+
25
+
26
+ @dataclass
27
+ class PreprocessCfg:
28
+ size: Union[int, Tuple[int, int]] = 224
29
+ mode: str = 'RGB'
30
+ mean: Tuple[float, ...] = OPENAI_DATASET_MEAN
31
+ std: Tuple[float, ...] = OPENAI_DATASET_STD
32
+ interpolation: str = 'bicubic'
33
+ resize_mode: str = 'shortest'
34
+ fill_color: int = 0
35
+
36
+ def __post_init__(self):
37
+ assert self.mode in ('RGB',)
38
+
39
+ @property
40
+ def num_channels(self):
41
+ return 3
42
+
43
+ @property
44
+ def input_size(self):
45
+ return (self.num_channels,) + (self.size, self.size)
46
+
47
+
48
+ _PREPROCESS_KEYS = set(asdict(PreprocessCfg()).keys())
49
+
50
+
51
+ def merge_preprocess_dict(
52
+ base: Union[PreprocessCfg, Dict],
53
+ overlay: Dict,
54
+ ):
55
+ """Merge overlay key-value pairs on top of base preprocess cfg or dict.
56
+ Input dicts are filtered based on PreprocessCfg fields.
57
+ """
58
+ if isinstance(base, PreprocessCfg):
59
+ base_clean = asdict(base)
60
+ else:
61
+ base_clean = {k: v for k, v in base.items() if k in _PREPROCESS_KEYS}
62
+ if overlay:
63
+ overlay_clean = {
64
+ k: v for k, v in overlay.items() if k in _PREPROCESS_KEYS and v is not None
65
+ }
66
+ base_clean.update(overlay_clean)
67
+ return base_clean
68
+
69
+
70
+ def merge_preprocess_kwargs(base: Union[PreprocessCfg, Dict], **kwargs):
71
+ return merge_preprocess_dict(base, kwargs)
72
+
73
+
74
+ @dataclass
75
+ class AugmentationCfg:
76
+ scale: Tuple[float, float] = (0.9, 1.0)
77
+ ratio: Optional[Tuple[float, float]] = None
78
+ color_jitter: Optional[
79
+ Union[float, Tuple[float, float, float], Tuple[float, float, float, float]]
80
+ ] = None
81
+ re_prob: Optional[float] = None
82
+ re_count: Optional[int] = None
83
+ use_timm: bool = False
84
+
85
+ # params for simclr_jitter_gray
86
+ color_jitter_prob: float = None
87
+ gray_scale_prob: float = None
88
+
89
+
90
+ def _setup_size(size, error_msg):
91
+ if isinstance(size, numbers.Number):
92
+ return int(size), int(size)
93
+
94
+ if isinstance(size, Sequence) and len(size) == 1:
95
+ return size[0], size[0]
96
+
97
+ if len(size) != 2:
98
+ raise ValueError(error_msg)
99
+
100
+ return size
101
+
102
+
103
+ class ResizeKeepRatio:
104
+ """Resize and Keep Ratio
105
+
106
+ Copy & paste from `timm`
107
+ """
108
+
109
+ def __init__(
110
+ self,
111
+ size,
112
+ longest=0.0,
113
+ interpolation=InterpolationMode.BICUBIC,
114
+ random_scale_prob=0.0,
115
+ random_scale_range=(0.85, 1.05),
116
+ random_aspect_prob=0.0,
117
+ random_aspect_range=(0.9, 1.11),
118
+ ):
119
+ if isinstance(size, (list, tuple)):
120
+ self.size = tuple(size)
121
+ else:
122
+ self.size = (size, size)
123
+ self.interpolation = interpolation
124
+ self.longest = float(longest) # [0, 1] where 0 == shortest edge, 1 == longest
125
+ self.random_scale_prob = random_scale_prob
126
+ self.random_scale_range = random_scale_range
127
+ self.random_aspect_prob = random_aspect_prob
128
+ self.random_aspect_range = random_aspect_range
129
+
130
+ @staticmethod
131
+ def get_params(
132
+ img,
133
+ target_size,
134
+ longest,
135
+ random_scale_prob=0.0,
136
+ random_scale_range=(0.85, 1.05),
137
+ random_aspect_prob=0.0,
138
+ random_aspect_range=(0.9, 1.11),
139
+ ):
140
+ """Get parameters"""
141
+ source_size = img.size[::-1] # h, w
142
+ h, w = source_size
143
+ target_h, target_w = target_size
144
+ ratio_h = h / target_h
145
+ ratio_w = w / target_w
146
+ ratio = max(ratio_h, ratio_w) * longest + min(ratio_h, ratio_w) * (
147
+ 1.0 - longest
148
+ )
149
+ if random_scale_prob > 0 and random.random() < random_scale_prob:
150
+ ratio_factor = random.uniform(random_scale_range[0], random_scale_range[1])
151
+ ratio_factor = (ratio_factor, ratio_factor)
152
+ else:
153
+ ratio_factor = (1.0, 1.0)
154
+ if random_aspect_prob > 0 and random.random() < random_aspect_prob:
155
+ aspect_factor = random.uniform(
156
+ random_aspect_range[0], random_aspect_range[1]
157
+ )
158
+ ratio_factor = (
159
+ ratio_factor[0] / aspect_factor,
160
+ ratio_factor[1] * aspect_factor,
161
+ )
162
+ size = [round(x * f / ratio) for x, f in zip(source_size, ratio_factor)]
163
+ return size
164
+
165
+ def __call__(self, img):
166
+ """
167
+ Args:
168
+ img (PIL Image): Image to be cropped and resized.
169
+
170
+ Returns:
171
+ PIL Image: Resized, padded to at least target size, possibly
172
+ cropped to exactly target size
173
+ """
174
+ size = self.get_params(
175
+ img,
176
+ self.size,
177
+ self.longest,
178
+ self.random_scale_prob,
179
+ self.random_scale_range,
180
+ self.random_aspect_prob,
181
+ self.random_aspect_range,
182
+ )
183
+ img = F.resize(img, size, self.interpolation)
184
+ return img
185
+
186
+ def __repr__(self):
187
+ format_string = self.__class__.__name__ + '(size={0}'.format(self.size)
188
+ format_string += f', interpolation={self.interpolation})'
189
+ format_string += f', longest={self.longest:.3f})'
190
+ return format_string
191
+
192
+
193
+ def center_crop_or_pad(
194
+ img: torch.Tensor, output_size: List[int], fill=0
195
+ ) -> torch.Tensor:
196
+ """Center crops and/or pads the given image.
197
+ If the image is torch Tensor, it is expected
198
+ to have [..., H, W] shape, where ... means an arbitrary number of leading
199
+ dimensions. If image size is smaller than output size along any edge, image is
200
+ padded with 0 and then center cropped.
201
+
202
+ Args:
203
+ img (PIL Image or Tensor): Image to be cropped.
204
+ output_size (sequence or int): (height, width) of the crop box. If int or
205
+ sequence with single int, it is used for both directions.
206
+ fill (int, Tuple[int]): Padding color
207
+
208
+ Returns:
209
+ PIL Image or Tensor: Cropped image.
210
+ """
211
+ if isinstance(output_size, numbers.Number):
212
+ output_size = (int(output_size), int(output_size))
213
+ elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
214
+ output_size = (output_size[0], output_size[0])
215
+
216
+ _, image_height, image_width = F.get_dimensions(img)
217
+ crop_height, crop_width = output_size
218
+
219
+ if crop_width > image_width or crop_height > image_height:
220
+ padding_ltrb = [
221
+ (crop_width - image_width) // 2 if crop_width > image_width else 0,
222
+ (crop_height - image_height) // 2 if crop_height > image_height else 0,
223
+ (crop_width - image_width + 1) // 2 if crop_width > image_width else 0,
224
+ (crop_height - image_height + 1) // 2 if crop_height > image_height else 0,
225
+ ]
226
+ img = F.pad(img, padding_ltrb, fill=fill)
227
+ _, image_height, image_width = F.get_dimensions(img)
228
+ if crop_width == image_width and crop_height == image_height:
229
+ return img
230
+
231
+ crop_top = int(round((image_height - crop_height) / 2.0))
232
+ crop_left = int(round((image_width - crop_width) / 2.0))
233
+ return F.crop(img, crop_top, crop_left, crop_height, crop_width)
234
+
235
+
236
+ class CenterCropOrPad(torch.nn.Module):
237
+ """Crops the given image at the center.
238
+ If the image is torch Tensor, it is expected
239
+ to have [..., H, W] shape, where ... means an arbitrary number of leading
240
+ dimensions. If image size is smaller than output size along any edge, image is
241
+ padded with 0 and then center cropped.
242
+
243
+ Args:
244
+ size (sequence or int): Desired output size of the crop. If size is an
245
+ int instead of sequence like (h, w), a square crop (size, size) is
246
+ made. If provided a sequence of length 1, it will be interpreted as
247
+ (size[0], size[0]).
248
+ """
249
+
250
+ def __init__(self, size, fill=0):
251
+ super().__init__()
252
+ self.size = _setup_size(
253
+ size, error_msg='Please provide only two dimensions (h, w) for size.'
254
+ )
255
+ self.fill = fill
256
+
257
+ def forward(self, img):
258
+ """
259
+ Args:
260
+ img (PIL Image or Tensor): Image to be cropped.
261
+
262
+ Returns:
263
+ PIL Image or Tensor: Cropped image.
264
+ """
265
+ return center_crop_or_pad(img, self.size, fill=self.fill)
266
+
267
+ def __repr__(self) -> str:
268
+ return f'{self.__class__.__name__}(size={self.size})'
269
+
270
+
271
+ def _convert_to_rgb(image):
272
+ return image.convert('RGB')
273
+
274
+
275
+ class _ColorJitter(object):
276
+ """
277
+ Apply Color Jitter to the PIL image with a specified probability.
278
+ """
279
+
280
+ def __init__(self, brightness=0.0, contrast=0.0, saturation=0.0, hue=0.0, p=0.8):
281
+ assert 0.0 <= p <= 1.0
282
+ self.p = p
283
+ self.transf = ColorJitter(
284
+ brightness=brightness, contrast=contrast, saturation=saturation, hue=hue
285
+ )
286
+
287
+ def __call__(self, img):
288
+ if random.random() < self.p:
289
+ return self.transf(img)
290
+ else:
291
+ return img
292
+
293
+
294
+ class _GrayScale(object):
295
+ """
296
+ Apply Gray Scale to the PIL image with a specified probability.
297
+ """
298
+
299
+ def __init__(self, p=0.2):
300
+ assert 0.0 <= p <= 1.0
301
+ self.p = p
302
+ self.transf = Grayscale(num_output_channels=3)
303
+
304
+ def __call__(self, img):
305
+ if random.random() < self.p:
306
+ return self.transf(img)
307
+ else:
308
+ return img
309
+
310
+
311
+ def image_transform(
312
+ image_size: Union[int, Tuple[int, int]],
313
+ is_train: bool,
314
+ mean: Optional[Tuple[float, ...]] = None,
315
+ std: Optional[Tuple[float, ...]] = None,
316
+ resize_mode: Optional[str] = None,
317
+ interpolation: Optional[str] = None,
318
+ fill_color: int = 0,
319
+ aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
320
+ ):
321
+ mean = mean or OPENAI_DATASET_MEAN
322
+ if not isinstance(mean, (list, tuple)):
323
+ mean = (mean,) * 3
324
+
325
+ std = std or OPENAI_DATASET_STD
326
+ if not isinstance(std, (list, tuple)):
327
+ std = (std,) * 3
328
+
329
+ interpolation = interpolation or 'bicubic'
330
+ assert interpolation in ['bicubic', 'bilinear', 'random']
331
+ # NOTE random is ignored for interpolation_mode, so defaults to BICUBIC for
332
+ # inference if set
333
+ interpolation_mode = (
334
+ InterpolationMode.BILINEAR
335
+ if interpolation == 'bilinear'
336
+ else InterpolationMode.BICUBIC
337
+ )
338
+
339
+ resize_mode = resize_mode or 'shortest'
340
+ assert resize_mode in ('shortest', 'longest', 'squash')
341
+
342
+ if isinstance(aug_cfg, dict):
343
+ aug_cfg = AugmentationCfg(**aug_cfg)
344
+ else:
345
+ aug_cfg = aug_cfg or AugmentationCfg()
346
+
347
+ normalize = Normalize(mean=mean, std=std)
348
+
349
+ if is_train:
350
+ aug_cfg_dict = {k: v for k, v in asdict(aug_cfg).items() if v is not None}
351
+ use_timm = aug_cfg_dict.pop('use_timm', False)
352
+ if use_timm:
353
+ from timm.data import create_transform # timm can still be optional
354
+
355
+ if isinstance(image_size, (tuple, list)):
356
+ assert len(image_size) >= 2
357
+ input_size = (3,) + image_size[-2:]
358
+ else:
359
+ input_size = (3, image_size, image_size)
360
+
361
+ aug_cfg_dict.setdefault('color_jitter', None) # disable by default
362
+ # drop extra non-timm items
363
+ aug_cfg_dict.pop('color_jitter_prob', None)
364
+ aug_cfg_dict.pop('gray_scale_prob', None)
365
+
366
+ train_transform = create_transform(
367
+ input_size=input_size,
368
+ is_training=True,
369
+ hflip=0.0,
370
+ mean=mean,
371
+ std=std,
372
+ re_mode='pixel',
373
+ interpolation=interpolation,
374
+ **aug_cfg_dict,
375
+ )
376
+ else:
377
+ train_transform = [
378
+ RandomResizedCrop(
379
+ image_size,
380
+ scale=aug_cfg_dict.pop('scale'),
381
+ interpolation=InterpolationMode.BICUBIC,
382
+ ),
383
+ _convert_to_rgb,
384
+ ]
385
+ if aug_cfg.color_jitter_prob:
386
+ assert (
387
+ aug_cfg.color_jitter is not None and len(aug_cfg.color_jitter) == 4
388
+ )
389
+ train_transform.extend(
390
+ [_ColorJitter(*aug_cfg.color_jitter, p=aug_cfg.color_jitter_prob)]
391
+ )
392
+ if aug_cfg.gray_scale_prob:
393
+ train_transform.extend([_GrayScale(aug_cfg.gray_scale_prob)])
394
+ train_transform.extend(
395
+ [
396
+ ToTensor(),
397
+ normalize,
398
+ ]
399
+ )
400
+ train_transform = Compose(train_transform)
401
+ if aug_cfg_dict:
402
+ warnings.warn(
403
+ f'Unused augmentation cfg items, specify `use_timm` to use '
404
+ f'({list(aug_cfg_dict.keys())}).'
405
+ )
406
+ return train_transform
407
+ else:
408
+ if resize_mode == 'longest':
409
+ transforms = [
410
+ ResizeKeepRatio(
411
+ image_size, interpolation=interpolation_mode, longest=1
412
+ ),
413
+ CenterCropOrPad(image_size, fill=fill_color),
414
+ ]
415
+ elif resize_mode == 'squash':
416
+ if isinstance(image_size, int):
417
+ image_size = (image_size, image_size)
418
+ transforms = [
419
+ Resize(image_size, interpolation=interpolation_mode),
420
+ ]
421
+ else:
422
+ assert resize_mode == 'shortest'
423
+ if not isinstance(image_size, (tuple, list)):
424
+ image_size = (image_size, image_size)
425
+ if image_size[0] == image_size[1]:
426
+ # simple case, use torchvision built-in Resize w/ shortest edge mode
427
+ # (scalar size arg)
428
+ transforms = [Resize(image_size[0], interpolation=interpolation_mode)]
429
+ else:
430
+ # resize shortest edge to matching target dim for non-square target
431
+ transforms = [ResizeKeepRatio(image_size)]
432
+ transforms += [CenterCrop(image_size)]
433
+
434
+ transforms.extend(
435
+ [
436
+ _convert_to_rgb,
437
+ ToTensor(),
438
+ normalize,
439
+ ]
440
+ )
441
+ return Compose(transforms)
442
+
443
+
444
+ def image_transform_v2(
445
+ cfg: PreprocessCfg,
446
+ is_train: bool,
447
+ aug_cfg: Optional[Union[Dict[str, Any], AugmentationCfg]] = None,
448
+ ):
449
+ return image_transform(
450
+ image_size=cfg.size,
451
+ is_train=is_train,
452
+ mean=cfg.mean,
453
+ std=cfg.std,
454
+ interpolation=cfg.interpolation,
455
+ resize_mode=cfg.resize_mode,
456
+ fill_color=cfg.fill_color,
457
+ aug_cfg=aug_cfg,
458
+ )
vocab.txt ADDED
The diff for this file is too large to render. See raw diff