pfzhu commited on
Commit
071945c
1 Parent(s): a0fd4bb

Upload folder using huggingface_hub

Browse files
LICENSE ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Apache License
3
+ Version 2.0, January 2004
4
+ http://www.apache.org/licenses/
5
+
6
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
+
8
+ 1. Definitions.
9
+
10
+ "License" shall mean the terms and conditions for use, reproduction,
11
+ and distribution as defined by Sections 1 through 9 of this document.
12
+
13
+ "Licensor" shall mean the copyright owner or entity authorized by
14
+ the copyright owner that is granting the License.
15
+
16
+ "Legal Entity" shall mean the union of the acting entity and all
17
+ other entities that control, are controlled by, or are under common
18
+ control with that entity. For the purposes of this definition,
19
+ "control" means (i) the power, direct or indirect, to cause the
20
+ direction or management of such entity, whether by contract or
21
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or (iii) beneficial ownership of such entity.
23
+
24
+ "You" (or "Your") shall mean an individual or Legal Entity
25
+ exercising permissions granted by this License.
26
+
27
+ "Source" form shall mean the preferred form for making modifications,
28
+ including but not limited to software source code, documentation
29
+ source, and configuration files.
30
+
31
+ "Object" form shall mean any form resulting from mechanical
32
+ transformation or translation of a Source form, including but
33
+ not limited to compiled object code, generated documentation,
34
+ and conversions to other media types.
35
+
36
+ "Work" shall mean the work of authorship, whether in Source or
37
+ Object form, made available under the License, as indicated by a
38
+ copyright notice that is included in or attached to the work
39
+ (an example is provided in the Appendix below).
40
+
41
+ "Derivative Works" shall mean any work, whether in Source or Object
42
+ form, that is based on (or derived from) the Work and for which the
43
+ editorial revisions, annotations, elaborations, or other modifications
44
+ represent, as a whole, an original work of authorship. For the purposes
45
+ of this License, Derivative Works shall not include works that remain
46
+ separable from, or merely link (or bind by name) to the interfaces of,
47
+ the Work and Derivative Works thereof.
48
+
49
+ "Contribution" shall mean any work of authorship, including
50
+ the original version of the Work and any modifications or additions
51
+ to that Work or Derivative Works thereof, that is intentionally
52
+ submitted to Licensor for inclusion in the Work by the copyright owner
53
+ or by an individual or Legal Entity authorized to submit on behalf of
54
+ the copyright owner. For the purposes of this definition, "submitted"
55
+ means any form of electronic, verbal, or written communication sent
56
+ to the Licensor or its representatives, including but not limited to
57
+ communication on electronic mailing lists, source code control systems,
58
+ and issue tracking systems that are managed by, or on behalf of, the
59
+ Licensor for the purpose of discussing and improving the Work, but
60
+ excluding communication that is conspicuously marked or otherwise
61
+ designated in writing by the copyright owner as "Not a Contribution."
62
+
63
+ "Contributor" shall mean Licensor and any individual or Legal Entity
64
+ on behalf of whom a Contribution has been received by Licensor and
65
+ subsequently incorporated within the Work.
66
+
67
+ 2. Grant of Copyright License. Subject to the terms and conditions of
68
+ this License, each Contributor hereby grants to You a perpetual,
69
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
+ copyright license to reproduce, prepare Derivative Works of,
71
+ publicly display, publicly perform, sublicense, and distribute the
72
+ Work and such Derivative Works in Source or Object form.
73
+
74
+ 3. Grant of Patent License. Subject to the terms and conditions of
75
+ this License, each Contributor hereby grants to You a perpetual,
76
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
+ (except as stated in this section) patent license to make, have made,
78
+ use, offer to sell, sell, import, and otherwise transfer the Work,
79
+ where such license applies only to those patent claims licensable
80
+ by such Contributor that are necessarily infringed by their
81
+ Contribution(s) alone or by combination of their Contribution(s)
82
+ with the Work to which such Contribution(s) was submitted. If You
83
+ institute patent litigation against any entity (including a
84
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
85
+ or a Contribution incorporated within the Work constitutes direct
86
+ or contributory patent infringement, then any patent licenses
87
+ granted to You under this License for that Work shall terminate
88
+ as of the date such litigation is filed.
89
+
90
+ 4. Redistribution. You may reproduce and distribute copies of the
91
+ Work or Derivative Works thereof in any medium, with or without
92
+ modifications, and in Source or Object form, provided that You
93
+ meet the following conditions:
94
+
95
+ (a) You must give any other recipients of the Work or
96
+ Derivative Works a copy of this License; and
97
+
98
+ (b) You must cause any modified files to carry prominent notices
99
+ stating that You changed the files; and
100
+
101
+ (c) You must retain, in the Source form of any Derivative Works
102
+ that You distribute, all copyright, patent, trademark, and
103
+ attribution notices from the Source form of the Work,
104
+ excluding those notices that do not pertain to any part of
105
+ the Derivative Works; and
106
+
107
+ (d) If the Work includes a "NOTICE" text file as part of its
108
+ distribution, then any Derivative Works that You distribute must
109
+ include a readable copy of the attribution notices contained
110
+ within such NOTICE file, excluding those notices that do not
111
+ pertain to any part of the Derivative Works, in at least one
112
+ of the following places: within a NOTICE text file distributed
113
+ as part of the Derivative Works; within the Source form or
114
+ documentation, if provided along with the Derivative Works; or,
115
+ within a display generated by the Derivative Works, if and
116
+ wherever such third-party notices normally appear. The contents
117
+ of the NOTICE file are for informational purposes only and
118
+ do not modify the License. You may add Your own attribution
119
+ notices within Derivative Works that You distribute, alongside
120
+ or as an addendum to the NOTICE text from the Work, provided
121
+ that such additional attribution notices cannot be construed
122
+ as modifying the License.
123
+
124
+ You may add Your own copyright statement to Your modifications and
125
+ may provide additional or different license terms and conditions
126
+ for use, reproduction, or distribution of Your modifications, or
127
+ for any such Derivative Works as a whole, provided Your use,
128
+ reproduction, and distribution of the Work otherwise complies with
129
+ the conditions stated in this License.
130
+
131
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
132
+ any Contribution intentionally submitted for inclusion in the Work
133
+ by You to the Licensor shall be under the terms and conditions of
134
+ this License, without any additional terms or conditions.
135
+ Notwithstanding the above, nothing herein shall supersede or modify
136
+ the terms of any separate license agreement you may have executed
137
+ with Licensor regarding such Contributions.
138
+
139
+ 6. Trademarks. This License does not grant permission to use the trade
140
+ names, trademarks, service marks, or product names of the Licensor,
141
+ except as required for reasonable and customary use in describing the
142
+ origin of the Work and reproducing the content of the NOTICE file.
143
+
144
+ 7. Disclaimer of Warranty. Unless required by applicable law or
145
+ agreed to in writing, Licensor provides the Work (and each
146
+ Contributor provides its Contributions) on an "AS IS" BASIS,
147
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
+ implied, including, without limitation, any warranties or conditions
149
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
+ PARTICULAR PURPOSE. You are solely responsible for determining the
151
+ appropriateness of using or redistributing the Work and assume any
152
+ risks associated with Your exercise of permissions under this License.
153
+
154
+ 8. Limitation of Liability. In no event and under no legal theory,
155
+ whether in tort (including negligence), contract, or otherwise,
156
+ unless required by applicable law (such as deliberate and grossly
157
+ negligent acts) or agreed to in writing, shall any Contributor be
158
+ liable to You for damages, including any direct, indirect, special,
159
+ incidental, or consequential damages of any character arising as a
160
+ result of this License or out of the use or inability to use the
161
+ Work (including but not limited to damages for loss of goodwill,
162
+ work stoppage, computer failure or malfunction, or any and all
163
+ other commercial damages or losses), even if such Contributor
164
+ has been advised of the possibility of such damages.
165
+
166
+ 9. Accepting Warranty or Additional Liability. While redistributing
167
+ the Work or Derivative Works thereof, You may choose to offer,
168
+ and charge a fee for, acceptance of support, warranty, indemnity,
169
+ or other liability obligations and/or rights consistent with this
170
+ License. However, in accepting such obligations, You may act only
171
+ on Your own behalf and on Your sole responsibility, not on behalf
172
+ of any other Contributor, and only if You agree to indemnify,
173
+ defend, and hold each Contributor harmless for any liability
174
+ incurred by, or claims asserted against, such Contributor by reason
175
+ of your accepting any such warranty or additional liability.
176
+
177
+ END OF TERMS AND CONDITIONS
178
+
179
+ APPENDIX: How to apply the Apache License to your work.
180
+
181
+ To apply the Apache License to your work, attach the following
182
+ boilerplate notice, with the fields enclosed by brackets "[]"
183
+ replaced with your own identifying information. (Don't include
184
+ the brackets!) The text should be enclosed in the appropriate
185
+ comment syntax for the file format. We also recommend that a
186
+ file or class name and description of purpose be included on the
187
+ same "printed page" as the copyright notice for easier
188
+ identification within third-party archives.
189
+
190
+ Copyright 2024 LY Corporation
191
+
192
+ Licensed under the Apache License, Version 2.0 (the "License");
193
+ you may not use this file except in compliance with the License.
194
+ You may obtain a copy of the License at
195
+
196
+ http://www.apache.org/licenses/LICENSE-2.0
197
+
198
+ Unless required by applicable law or agreed to in writing, software
199
+ distributed under the License is distributed on an "AS IS" BASIS,
200
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201
+ See the License for the specific language governing permissions and
202
+ limitations under the License.
README.md CHANGED
@@ -1,3 +1,76 @@
1
  ---
 
2
  license: apache-2.0
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ language: ja
3
  license: apache-2.0
4
+ tags:
5
+ - clip
6
+ - japanese-clip
7
+ pipeline_tag: feature-extraction
8
  ---
9
+
10
+ # clip-japanese-base
11
+
12
+ This is a Japanese [CLIP (Contrastive Language-Image Pre-training)](https://arxiv.org/abs/2103.00020) model developed by [LY Corporation](https://www.lycorp.co.jp/en/). This model was trained on ~1B web-collected image-text pairs, and it is applicable to various visual tasks including zero-shot image classification, text-to-image or image-to-text retrieval.
13
+
14
+ ## How to use
15
+ 1. Install packages
16
+ ```
17
+ pip install pillow requests sentencepiece transformers torch timm
18
+ ```
19
+ 2. Run
20
+ ```python
21
+ import io
22
+ import requests
23
+ from PIL import Image
24
+ import torch
25
+ from transformers import AutoImageProcessor, AutoModel, AutoTokenizer
26
+
27
+ HF_MODEL_PATH = 'line-corporation/clip-japanese-base'
28
+ tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_PATH, trust_remote_code=True)
29
+ processor = AutoImageProcessor.from_pretrained(HF_MODEL_PATH, trust_remote_code=True)
30
+ model = AutoModel.from_pretrained(HF_MODEL_PATH, trust_remote_code=True)
31
+ device = "cuda" if torch.cuda.is_available() else "cpu"
32
+
33
+ image = Image.open(io.BytesIO(requests.get('https://images.pexels.com/photos/2253275/pexels-photo-2253275.jpeg?auto=compress&cs=tinysrgb&dpr=3&h=750&w=1260').content))
34
+ image = processor(image, return_tensors="pt")
35
+ text = tokenizer(["犬", "猫", "象"])
36
+
37
+ with torch.no_grad():
38
+ image_features = model.get_image_features(**image)
39
+ text_features = model.get_text_features(**text)
40
+ text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
41
+
42
+ print("Label probs:", text_probs)
43
+ # [[1., 0., 0.]]
44
+ ```
45
+
46
+ ## Model architecture
47
+
48
+ The model uses an [Eva02-B](https://huggingface.co/timm/eva02_base_patch16_clip_224.merged2b_s8b_b131k) Transformer architecture as the image encoder and a 12-layer BERT as the text encoder. The text encoder was initialized from [rinna/japanese-clip-vit-b-16](https://huggingface.co/rinna/japanese-clip-vit-b-16).
49
+
50
+ ## Evaluation
51
+ ### Dataset
52
+ - [STAIR Captions](http://captions.stair.center/) (v2014 val set of MSCOCO) for image-to-text (i2t) and text-to-image (t2i) retrieval. We measure performance using R@1, which is the average recall of i2t and t2i retrieval.
53
+ - [Recruit Datasets](https://huggingface.co/datasets/recruit-jp/japanese-image-classification-evaluation-dataset) for image classification.
54
+ - [ImageNet-1K](https://www.image-net.org/download.php) for image classification. We translated all classnames into Japanese. The classnames and templates can be found in `ja-imagenet-1k-classnames.txt` and `ja-imagenet-1k-templates.txt`.
55
+
56
+ ### Result
57
+ | Model | Image Encoder Params | Text Encoder params | STAIR Captions (R@1) | Recruit Datasets (acc@1) | ImageNet-1K (acc@1) |
58
+ |-------------------|----------------------|---------------------|----------------|------------------|-----------------|
59
+ | [Ours](https://huggingface.co/line-corporation/clip-japanese-base) | 86M(Eva02-B) | 100M(BERT) | **0.30** | **0.89** | 0.58 |
60
+ | [Stable-ja-clip](https://huggingface.co/stabilityai/japanese-stable-clip-vit-l-16) | 307M(ViT-L) | 100M(BERT) | 0.24 | 0.77 | **0.68** |
61
+ | [Rinna-ja-clip](https://huggingface.co/rinna/japanese-clip-vit-b-16) | 86M(ViT-B) | 100M(BERT) | 0.13 | 0.54 | 0.56 |
62
+ | [Laion-clip](https://huggingface.co/laion/CLIP-ViT-H-14-frozen-xlm-roberta-large-laion5B-s13B-b90k) | 632M(ViT-H) | 561M(XLM-RoBERTa) | **0.30** | 0.83 | 0.58 |
63
+ | [Hakuhodo-ja-clip](https://huggingface.co/hakuhodo-tech/japanese-clip-vit-h-14-bert-wider) | 632M(ViT-H) | 100M(BERT) | 0.21 | 0.82 | 0.46 |
64
+
65
+ ## Licenses
66
+
67
+ [The Apache License, Version 2.0](https://www.apache.org/licenses/LICENSE-2.0)
68
+
69
+ ## Citation
70
+ ```
71
+ @misc{clip-japanese-base,
72
+ title = {CLIP Japanese Base},
73
+ author={Shuhei Yokoo, Shuntaro Okada, Peifei Zhu, Shuhei Nishimura and Naoki Takayama}
74
+ url = {https://huggingface.co/line-corporation/clip-japanese-base},
75
+ }
76
+ ```
config.json ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "./lycorp/clyp-eva02-b-16",
3
+ "architectures": [
4
+ "CLYPModel"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_clyp.CLYPConfig",
8
+ "AutoModel": "modeling_clyp.CLYPModel"
9
+ },
10
+ "itc_loss_config": null,
11
+ "learn_temperature": true,
12
+ "model_type": "clyp",
13
+ "temperature_init": 0.07,
14
+ "temperature_max": 1000.0,
15
+ "temperature_min": 0.01,
16
+ "text_encoder_config": {
17
+ "backbone_config": {
18
+ "model_name": "rinna/japanese-clip-vit-b-16"
19
+ },
20
+ "neck_config": {
21
+ "bias": false,
22
+ "in_channels": 768,
23
+ "out_channels": 512
24
+ },
25
+ "pooler_config": {
26
+ "input_type": "huggingface",
27
+ "return_patch_features": false
28
+ }
29
+ },
30
+ "torch_dtype": "float32",
31
+ "transformers_version": "4.39.1",
32
+ "vision_encoder_config": {
33
+ "backbone_config": {
34
+ "extra_kwargs": {},
35
+ "model_name": "eva02_base_patch16_clip_224.merged2b",
36
+ "pretrained": true
37
+ },
38
+ "neck_config": {
39
+ "bias": false,
40
+ "in_channels": 768,
41
+ "out_channels": 512
42
+ },
43
+ "pooler_config": {
44
+ "input_type": "timm",
45
+ "return_patch_features": false
46
+ }
47
+ }
48
+ }
configuration_clyp.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+
3
+ # Copyright 2024 LY Corporation.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ from __future__ import annotations
17
+
18
+ from typing import Any, Literal, Optional
19
+
20
+ from transformers import PretrainedConfig
21
+
22
+
23
+ class CLYPConfig(PretrainedConfig):
24
+ model_type = "clyp"
25
+
26
+ def __init__(
27
+ self,
28
+ vision_encoder_config: Optional[dict] = None,
29
+ text_encoder_config: Optional[dict] = None,
30
+ itc_loss_config: Optional[dict] = None,
31
+ learn_temperature: bool = True,
32
+ temperature_init: float = 0.07,
33
+ temperature_min: float = 0.01,
34
+ temperature_max: float = 1000.0,
35
+ **kwargs,
36
+ ) -> None:
37
+ super().__init__(**kwargs)
38
+ vision_encoder_config = vision_encoder_config or {}
39
+ text_encoder_config = text_encoder_config or {}
40
+ self.vision_encoder_config = CLYPVisionEncoderConfig(**vision_encoder_config)
41
+ self.text_encoder_config = CLYPTextEncoderConfig(**text_encoder_config)
42
+ self.itc_loss_config = (
43
+ CLYPLossConfig(**itc_loss_config) if itc_loss_config else None
44
+ )
45
+ self.learn_temperature = learn_temperature
46
+ self.temperature_init = temperature_init
47
+ self.temperature_min = temperature_min
48
+ self.temperature_max = temperature_max
49
+
50
+ def to_diff_dict(self) -> dict[str, Any]:
51
+ serializable_config_dict = super().to_diff_dict()
52
+ sub_serializable_config_dict = {
53
+ "vision_encoder_config": _to_diff_dict(self.vision_encoder_config),
54
+ "text_encoder_config": _to_diff_dict(self.text_encoder_config),
55
+ }
56
+ self.dict_torch_dtype_to_str(sub_serializable_config_dict)
57
+ serializable_config_dict.update(sub_serializable_config_dict)
58
+ return serializable_config_dict
59
+
60
+
61
+ class CLYPVisionEncoderConfig(PretrainedConfig):
62
+ def __init__(
63
+ self,
64
+ backbone_config: Optional[dict] = None,
65
+ pooler_config: Optional[dict] = None,
66
+ neck_config: Optional[dict] = None,
67
+ **kwargs,
68
+ ) -> None:
69
+ super().__init__(**kwargs)
70
+ backbone_config = backbone_config or {}
71
+ pooler_config = pooler_config or {"input_type": "timm"}
72
+ neck_config = neck_config or {}
73
+ self.backbone_config = CLYPVisionBackboneConfig(**backbone_config)
74
+ self.pooler_config = CLYPPoolerConfig(**pooler_config)
75
+ self.neck_config = CLYPNeckConfig(**neck_config)
76
+
77
+ def to_diff_dict(self) -> dict[str, Any]:
78
+ serializable_config_dict = {
79
+ "backbone_config": _to_diff_dict(self.backbone_config),
80
+ "pooler_config": _to_diff_dict(self.pooler_config),
81
+ "neck_config": _to_diff_dict(self.neck_config),
82
+ }
83
+ self.dict_torch_dtype_to_str(serializable_config_dict)
84
+ return serializable_config_dict
85
+
86
+
87
+ class CLYPTextEncoderConfig(PretrainedConfig):
88
+ def __init__(
89
+ self,
90
+ backbone_config: Optional[dict] = None,
91
+ pooler_config: Optional[dict] = None,
92
+ neck_config: Optional[dict] = None,
93
+ **kwargs,
94
+ ) -> None:
95
+ super().__init__(**kwargs)
96
+ backbone_config = backbone_config or {}
97
+ pooler_config = pooler_config or {"input_type": "huggingface"}
98
+ neck_config = neck_config or {}
99
+ self.backbone_config = CLYPTextBackboneConfig(**backbone_config)
100
+ self.pooler_config = CLYPPoolerConfig(**pooler_config)
101
+ self.neck_config = CLYPNeckConfig(**neck_config)
102
+
103
+ def to_diff_dict(self) -> dict[str, Any]:
104
+ serializable_config_dict = {
105
+ "backbone_config": _to_diff_dict(self.backbone_config),
106
+ "pooler_config": _to_diff_dict(self.pooler_config),
107
+ "neck_config": _to_diff_dict(self.neck_config),
108
+ }
109
+ self.dict_torch_dtype_to_str(serializable_config_dict)
110
+ return serializable_config_dict
111
+
112
+
113
+ class CLYPVisionBackboneConfig(PretrainedConfig):
114
+ def __init__(
115
+ self,
116
+ model_name: str = "eva02_base_patch16_clip_224.merged2b",
117
+ pretrained: bool = True,
118
+ extra_kwargs: Optional[dict] = None,
119
+ **kwargs,
120
+ ) -> None:
121
+ super().__init__(**kwargs)
122
+ self.model_name = model_name
123
+ self.pretrained = pretrained
124
+ self.extra_kwargs = extra_kwargs or {}
125
+
126
+
127
+ class CLYPTextBackboneConfig(PretrainedConfig):
128
+ def __init__(
129
+ self,
130
+ model_name: str = "rinna/japanese-clip-vit-b-16",
131
+ **kwargs,
132
+ ) -> None:
133
+ super().__init__(**kwargs)
134
+ self.model_name = model_name
135
+
136
+
137
+ class CLYPPoolerConfig(PretrainedConfig):
138
+ def __init__(
139
+ self,
140
+ input_type: Literal["timm", "huggingface"] | None = None,
141
+ return_patch_features: bool = False,
142
+ **kwargs,
143
+ ) -> None:
144
+ super().__init__(**kwargs)
145
+ self.input_type = input_type
146
+ self.return_patch_features = return_patch_features
147
+
148
+
149
+ class CLYPNeckConfig(PretrainedConfig):
150
+ def __init__(
151
+ self,
152
+ in_channels: int = 768,
153
+ out_channels: int = 512,
154
+ bias: bool = False,
155
+ **kwargs,
156
+ ) -> None:
157
+ super().__init__(**kwargs)
158
+ self.in_channels = in_channels
159
+ self.out_channels = out_channels
160
+ self.bias = bias
161
+
162
+
163
+ class CLYPLossConfig(PretrainedConfig):
164
+ def __init__(
165
+ self,
166
+ learn_temperature: bool = True,
167
+ init_temperature: float = 0.07,
168
+ max_temperature: Optional[float] = None,
169
+ min_temperature: Optional[float] = None,
170
+ label_smoothing: float = 0.0,
171
+ gather_with_grad: bool = True,
172
+ **kwargs,
173
+ ) -> None:
174
+ super().__init__(**kwargs)
175
+ self.learn_temperature = learn_temperature
176
+ self.init_temperature = init_temperature
177
+ self.max_temperature = max_temperature
178
+ self.min_temperature = min_temperature
179
+ self.label_smoothing = label_smoothing
180
+ self.gather_with_grad = gather_with_grad
181
+
182
+
183
+ def _to_diff_dict(c: PretrainedConfig) -> dict:
184
+ """Function to override PretrainedConfig.to_diff_dict()
185
+
186
+ NOTE
187
+ ----
188
+ In transformers==4.38.1,
189
+ PretrainedConfig.__repr__ may not be able to show configs that has some sub-configs
190
+ """
191
+ d = c.to_diff_dict()
192
+ if "transformers_version" in d:
193
+ d.pop("transformers_version")
194
+ return d
195
+
196
+
197
+ if __name__ == "__main__":
198
+ conf = CLYPConfig.from_pretrained("config.json")
199
+ print(conf)
image_processing_clyp.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+
3
+ # Copyright 2024 LY Corporation.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ from __future__ import annotations
17
+
18
+ from typing import Literal, Optional
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ import torchvision.transforms as T
23
+ import torchvision.transforms.functional as F
24
+ from PIL import Image
25
+ from timm.data import (
26
+ IMAGENET_INCEPTION_MEAN,
27
+ IMAGENET_INCEPTION_STD,
28
+ OPENAI_CLIP_MEAN,
29
+ OPENAI_CLIP_STD,
30
+ )
31
+ from timm.data.transforms_factory import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
32
+ from transformers.image_processing_utils import BaseImageProcessor, BatchFeature
33
+ from transformers.image_utils import ImageInput, make_list_of_images
34
+ from transformers.utils import TensorType
35
+
36
+ NormalizationType = Literal["imagenet", "imagenet_inception", "openai_clip"]
37
+
38
+
39
+ class CLYPImageProcessor(BaseImageProcessor):
40
+ def __init__(
41
+ self,
42
+ image_size: int = 224,
43
+ normalization_type: NormalizationType = "imagenet",
44
+ **kwargs,
45
+ ):
46
+ super().__init__(**kwargs)
47
+ self.image_size = image_size
48
+ self.normalization_type: NormalizationType = normalization_type
49
+
50
+ def preprocess(
51
+ self,
52
+ images: ImageInput | list[ImageInput],
53
+ return_tensors: Optional[str | TensorType] = None,
54
+ **kwargs,
55
+ ) -> BatchFeature:
56
+ images = make_list_of_images(images, expected_ndims=3)
57
+ # TODO: Support train
58
+ transforms = TestTransform(
59
+ self.image_size, normalization_type=self.normalization_type
60
+ )
61
+ images = [transforms(image).numpy() for image in images]
62
+ return BatchFeature(data={"pixel_values": images}, tensor_type=return_tensors)
63
+
64
+
65
+ class TrainTransform:
66
+ def __init__(
67
+ self,
68
+ image_size: int,
69
+ scale_range_min: float,
70
+ scale_range_max: float,
71
+ normalization_type: NormalizationType = "imagenet",
72
+ ) -> None:
73
+ """
74
+ Args:
75
+ image_size (int): output-image size.
76
+ scale_range_min (float): minimum value of the scale to crop an input image.
77
+ scale_range_max (float): maximum value of the scale to crop an input image.
78
+ normalization_type (str): select mean and std for normalization (see get_mean_and_std).
79
+ """
80
+ scale = (scale_range_min, scale_range_max)
81
+ mean_and_std = get_mean_and_std(normalization_type)
82
+
83
+ self.transform = T.Compose(
84
+ [
85
+ T.RandomResizedCrop(
86
+ image_size, scale=scale, interpolation=T.InterpolationMode.BICUBIC
87
+ ),
88
+ _convert_to_rgb,
89
+ T.ToTensor(),
90
+ T.Normalize(**mean_and_std),
91
+ ]
92
+ )
93
+
94
+ def __call__(self, img):
95
+ return self.transform(img)
96
+
97
+
98
+ class TestTransform:
99
+ def __init__(
100
+ self, image_size: int, normalization_type: NormalizationType = "imagenet"
101
+ ) -> None:
102
+ """
103
+ Args:
104
+ image_size (int): output-image size.
105
+ normalization_type (str): select mean and std for normalization (see get_mean_and_std).
106
+ """
107
+ mean_and_std = get_mean_and_std(normalization_type)
108
+
109
+ self.transform = T.Compose(
110
+ [
111
+ ResizeMaxSize(image_size, fill=0),
112
+ T.CenterCrop(image_size),
113
+ _convert_to_rgb,
114
+ T.ToTensor(),
115
+ T.Normalize(**mean_and_std),
116
+ ]
117
+ )
118
+
119
+ def __call__(self, img):
120
+ return self.transform(img)
121
+
122
+
123
+ class SmallestMaxSize(T.Resize):
124
+ """Resize shorter side of an input image.
125
+
126
+ The shorter side of an input image is resized to the max_size.
127
+ Note that an large part of the input image is discarded when an aspect-ratio value of the input image is extremely small or large.
128
+ """
129
+
130
+ def __init__(self, max_size: int, **kwargs):
131
+ super().__init__(max_size, **kwargs)
132
+
133
+ @staticmethod
134
+ def target_size(w: int, h: int, size: int) -> tuple[int, int]:
135
+ if h < w:
136
+ w, h = int(size * w / h), size
137
+ else:
138
+ w, h = size, int(size * h / w)
139
+ return (h, w)
140
+
141
+ def __call__(self, img):
142
+ size = self.size
143
+ assert isinstance(size, int)
144
+ w, h = img.size
145
+ target_size = self.target_size(w, h, size)
146
+ return F.resize(img, list(target_size), self.interpolation)
147
+
148
+
149
+ class ResizeMaxSize(nn.Module):
150
+ """Resize longer side of an input image.
151
+
152
+ The longer side of an input image is resized to the max_size.
153
+ Note that an large part of the output image is padded when an aspect-ration value of the input image is extremely small or large.
154
+ Adapted from https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/transform.py
155
+ """
156
+
157
+ def __init__(
158
+ self,
159
+ max_size: int,
160
+ interpolation: T.InterpolationMode = T.InterpolationMode.BICUBIC,
161
+ fn: str = "max",
162
+ fill: int = 0,
163
+ ):
164
+ super().__init__()
165
+ if not isinstance(max_size, int):
166
+ raise TypeError(f"Size should be int. Got {type(max_size)}")
167
+ self.max_size = max_size
168
+ self.interpolation = interpolation
169
+ self.fn = min if fn == "min" else min
170
+ self.fill = fill
171
+
172
+ def forward(self, img):
173
+ if isinstance(img, torch.Tensor):
174
+ height, width = img.shape[:2]
175
+ else:
176
+ width, height = img.size
177
+ scale = self.max_size / float(max(height, width))
178
+ if scale != 1.0:
179
+ new_size = tuple(round(dim * scale) for dim in (height, width))
180
+ img = F.resize(img, new_size, self.interpolation) # type: ignore
181
+ pad_h = self.max_size - new_size[0]
182
+ pad_w = self.max_size - new_size[1]
183
+ img = F.pad(
184
+ img,
185
+ padding=[
186
+ pad_w // 2,
187
+ pad_h // 2,
188
+ pad_w - pad_w // 2,
189
+ pad_h - pad_h // 2,
190
+ ],
191
+ fill=self.fill,
192
+ )
193
+ return img
194
+
195
+
196
+ def get_mean_and_std(normalization_type: NormalizationType) -> dict:
197
+ """Return mean and std tensors for T.Normalize()
198
+ NOTE:
199
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
200
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
201
+ IMAGENET_INCEPTION_MEAN = (0.5, 0.5, 0.5)
202
+ IMAGENET_INCEPTION_STD = (0.5, 0.5, 0.5)
203
+ OPENAI_CLIP_MEAN = (0.48145466, 0.4578275, 0.40821073)
204
+ OPENAI_CLIP_STD = (0.26862954, 0.26130258, 0.27577711)
205
+ """
206
+ if normalization_type == "imagenet":
207
+ return {
208
+ "mean": torch.tensor(IMAGENET_DEFAULT_MEAN),
209
+ "std": torch.tensor(IMAGENET_DEFAULT_STD),
210
+ }
211
+ elif normalization_type == "imagenet_inception":
212
+ return {
213
+ "mean": torch.tensor(IMAGENET_INCEPTION_MEAN),
214
+ "std": torch.tensor(IMAGENET_INCEPTION_STD),
215
+ }
216
+ elif normalization_type == "openai_clip":
217
+ return {
218
+ "mean": torch.tensor(OPENAI_CLIP_MEAN),
219
+ "std": torch.tensor(OPENAI_CLIP_STD),
220
+ }
221
+ else:
222
+ raise ValueError(normalization_type)
223
+
224
+
225
+ def _convert_to_rgb(image: Image.Image) -> Image.Image:
226
+ return image.convert("RGB")
ja-imagenet-1k-classnames.txt ADDED
@@ -0,0 +1,1000 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ テンチ
2
+ 金魚
3
+ ホホジロザメ
4
+ イタチザメ
5
+ シュモクザメ
6
+ シビレエイ
7
+ アカエイ
8
+ 雄鶏,おんどり
9
+ 雌鶏,めんどり
10
+ ダチョウ
11
+ アトリ
12
+ ゴシキヒワ
13
+ メキシコマシコ
14
+ ユキヒメドリ
15
+ ルリノジコ
16
+ コマツグミ
17
+ 目黒
18
+ カケス
19
+ カササギ
20
+ 四十雀,シジュウカラ
21
+ カワガラス
22
+ トビ
23
+ ハクトウワシ,白頭鷲
24
+ ハゲワシ
25
+ カラフトフクロウ
26
+ ファイアサラマンダー
27
+ スベイモリ,オビイモリ
28
+ イモリ
29
+ スポテッドサラマンダー,キボシサンショウウオ
30
+ アホロートル
31
+ ウシガエル
32
+ アマガエル
33
+ オガエル
34
+ アカウミガメ
35
+ オサガメ
36
+ 鼈,ドロガメ
37
+ スッポン
38
+ ハコガメ
39
+ バンドトカゲモドキ
40
+ イグアナ
41
+ グリーンアノール
42
+ ハシリトカゲ
43
+ アガマトカゲ
44
+ エリマキトカゲ
45
+ アシナシトカゲ
46
+ アメリカドクトカゲ
47
+ ミドリカナヘビ
48
+ カメレオン
49
+ コモドオオトカゲ
50
+ ナイルワニ
51
+ ミシシッピワニ
52
+ トリケラトプス
53
+ 盲蛇,ミミズヘビ
54
+ リングネックスネーク
55
+ トウブシシバナヘビ
56
+ 緑のヘビ
57
+ キングスネーク
58
+ ガータースネーク
59
+ ミズヘビ
60
+ ツルヘビ
61
+ 夜行性のヘビ
62
+ ボアコンストリクター
63
+ アフリカニシキヘビ
64
+ インドコブラ
65
+ グリーンマンバ
66
+ ウミヘビ
67
+ サハラツノクサリヘビ
68
+ ダイヤガラガラヘビ
69
+ ヨコバイガラガラヘビ
70
+ 三葉虫
71
+ ザトウムシ
72
+ サソリ
73
+ コガネグモ
74
+ 納屋クモ
75
+ オニグモ
76
+ クロゴケグモ
77
+ タランチュラ
78
+ ドクグモ
79
+ ダニ
80
+ ムカデ
81
+ クロライチョウ
82
+ ライチョウ,雷鳥
83
+ エリマキライチョウ
84
+ 茶色の斑紋のあるライチョウ
85
+ クジャク
86
+ ウズラ
87
+ ヤマウズラ
88
+ ヨウム
89
+ コンゴウインコ
90
+ キバタン
91
+ ヒインコ
92
+ バンケン
93
+ ハチクイ
94
+ サイチョウ
95
+ ハチドリ
96
+ キリハシ,錐嘴
97
+ オオハシ
98
+ アヒル
99
+ ウミアイサ
100
+ ガチョウ
101
+ コクチョウ,黒鳥
102
+ 牙を持つ動物
103
+ ハリモグラ
104
+ カモノハシ
105
+ ワラビー
106
+ コアラ
107
+ ウォンバット
108
+ クラゲ
109
+ イソギンチャク
110
+ 脳珊瑚
111
+ 扁形動物
112
+ 線虫
113
+ ホラガイ,巻き貝
114
+ カタツムリ
115
+ ナメクジ
116
+ ウミウシ
117
+ ヒザラガイ,多板綱
118
+ オウムガイ
119
+ アメリカイチョウガニ
120
+ イワガニ
121
+ シオマネキ
122
+ タラバガニ
123
+ アメリカンロブスター
124
+ 伊勢エビ
125
+ ザリガニ
126
+ ヤドカリ
127
+ ワラジムシ,等脚類
128
+ コウノトリ
129
+ ナベコウ
130
+ ヘラサギ
131
+ フラミンゴ
132
+ ヒメアカクロサギ
133
+ ダイサギ
134
+ ヨシゴイ
135
+ ツル
136
+ ツルモドキ
137
+ バン,鷭
138
+ アメリカオオバン
139
+ ノガン
140
+ キョウジョシギ
141
+ ハマシギ
142
+ アカアシシギ
143
+ オオハシシギ
144
+ ミヤコドリ
145
+ ペリカン
146
+ キングペンギン
147
+ アホウドリ,アルバトロス
148
+ コククジラ
149
+ シャチ,鯱
150
+ ジュゴン
151
+ アシカ
152
+ チワワ
153
+
154
+ マルチーズ
155
+ ペキニーズ
156
+ シーズー
157
+ キングチャールズスパニエル
158
+ パピヨン
159
+ トイテリア
160
+ ローデシアン・リッジバック
161
+ アフガンハウンド
162
+ バセットハウンド
163
+ ビーグル
164
+ ブラッドハウンド
165
+ ブルーティッククーンハウンド
166
+ ブラック・アンド・タン・クーンハウンド
167
+ ツリーイング・ウォーカー・クーンハウンド
168
+ イングリッシュ・フォックスハウンド
169
+ レッドボーン・クーンハウンド
170
+ ボルゾイ
171
+ アイリッシュウルフハウンド
172
+ イタリアン・グレーハウンド
173
+ ウィペット
174
+ イビサン・ハウンド
175
+ ノルウェージャン・エルクハウンド
176
+ オッターハウンド
177
+ サルーキ
178
+ スコティッシュ・ディアハウンド
179
+ ワイマラナー
180
+ スタッフォードシャーブルテリア
181
+ アメリカンスタッフォードシャーテリア
182
+ ベドリントンテリア
183
+ ボーダーテリア
184
+ ケリーブルーテリア
185
+ アイリッシュテリア
186
+ ノーフォークテリア
187
+ ノーリッチテリア
188
+ ヨークシャーテリア
189
+ ワイヤーフォックステリア
190
+ レークランドテリア
191
+ シーリーハムテリア
192
+ エアデールテリア
193
+ ケアーン・テリア
194
+ オーストラリアン・テリア
195
+ ダンディ・ディンモント・テリア
196
+ ボストンテリア
197
+ ミニチュア・シュナウザー
198
+ ジャイアント・シュナウザー
199
+ スタンダード・シュナウザー
200
+ スコッチテリア
201
+ チベタンテリア
202
+ オーストラリアン・シルキー・テリア
203
+ ソフトコーテッド・ウィートン・テリア
204
+ ウエスト・ハイランド・ホワイト・テリア
205
+ ラサ・アプソ
206
+ フラットコーテッド・レトリーバー
207
+ カーリーコーテッド・レトリーバー
208
+ ゴールデン・レトリバー
209
+ ラブラドール・レトリバー
210
+ チェサピーク・ベイ・レトリーバー
211
+ ジャーマン・ショートヘア・ポインタ
212
+ ビズラ
213
+ イングリッシュ・セッター
214
+ アイリッシュ・セッター
215
+ ゴードン・セッター
216
+ ブリタニー・スパニエル
217
+ クラムバー・スパニエル
218
+ イングリッシュ・スプリンガー・スパニ��ル
219
+ ウェルシュ・スプリンガー・スパニエル
220
+ コッカー・スパニエル
221
+ サセックス・スパニエル
222
+ アイリッシュ・ウォーター・スパニエル
223
+ クバース犬
224
+ スキッパーキー
225
+ ベルジアン・シェパード・ドッグ・グローネンダール
226
+ マリノア
227
+ ブリアール
228
+ オーストラリアン・ケルピー
229
+ コモンドール
230
+ オールドイングリッシュシープドッグ
231
+ シェットランド・シープドッグ
232
+ コリー
233
+ ボーダー・コリー
234
+ ブーヴィエ・デ・フランドル
235
+ ロットワイラー
236
+ ジャーマンシェパード
237
+ ドーベルマン
238
+ ミニチュア・ピンシャー
239
+ グレータースイス・マウンテンドッグ
240
+ バーニーズ・マウンテン・ドッグ
241
+ アッペンツェラー・キャトル・ドッグ
242
+ エントレブッハー・キャトル・ドッグ
243
+ ボクサー犬
244
+ ブルマスティフ
245
+ チベタンマスティフ
246
+ フレンチブルドッグ
247
+ グレートデン
248
+ セントバーナード
249
+ エスキモー犬
250
+ アラスカン・マラミュート
251
+ シベリアンハスキー
252
+ ダルメシアン
253
+ アーフェンピンシャー
254
+ バセンジー
255
+ パグ
256
+ レオンベルガー
257
+ ニューファンドランド犬
258
+ グレートピレニーズ
259
+ サモエド
260
+ ポメラニアン
261
+ チャウチャウ
262
+ キースホンド
263
+ ブラバンソングリフォン
264
+ ペンブローク
265
+ ウェルシュコーギーカーディガン
266
+ トイプードル
267
+ ミニチュアプードル
268
+ スタンダードプードル
269
+ メキシカン・ヘアレス・ドッグ
270
+ ハイイロオオカミ
271
+ 白いオオカミ
272
+ レッドウルフ
273
+ コヨーテ
274
+ ディンゴ
275
+ ドール,豺
276
+ リカオン
277
+ ハイエナ
278
+ アカギツネ
279
+ キットギツネ
280
+ ホッキョクギツネ
281
+ ハイイロギツネ
282
+ トラネコ
283
+ ジャガーネコ
284
+ ペルシャ猫
285
+ シャム猫
286
+ エジプシャンマウ
287
+ ピューマ,クーガー
288
+ オオヤマネコ
289
+ ヒョウ
290
+ ユキヒョウ
291
+ ジャガー
292
+ ライオン
293
+
294
+ チーター
295
+ ヒグマ
296
+ アメリカグマ
297
+ ホッキョクグマ
298
+ ナマケグマ
299
+ マングース
300
+ ミーアキャット
301
+ ハンミョウ
302
+ てんとう虫
303
+ オサムシ
304
+ カミキリムシ
305
+ ハムシ
306
+ スカラベ,フンコロガシ
307
+ カブトムシ
308
+ ゾウムシ
309
+ ハエ
310
+
311
+
312
+ バッタ
313
+ コオロギ
314
+ ナナフシ
315
+ ゴキブリ
316
+ カマキリ
317
+
318
+ ヨコバイ
319
+ クサカゲロウ
320
+ トンボ
321
+ イトトンボ
322
+ ヨーロッパアカタテハ
323
+ ジャノメチョウ
324
+ オオカバマダラ
325
+ モンシロチョウ
326
+ キチョウ,黄色の蝶
327
+ ゴイシシジミ,シジミチョウ
328
+ ヒトデ
329
+ ウニ,海胆,雲丹
330
+ ナマコ
331
+ ワタオウサギ
332
+ 野ウサギ
333
+ アンゴラウサギ
334
+ ハムスター
335
+ ヤマアラシ
336
+ キツネリス
337
+ マーモット
338
+ ビーバー
339
+ モルモット
340
+ 栗毛の馬
341
+ シマウマ
342
+
343
+ イノシシ
344
+ イボイノシシ
345
+ カバ
346
+ 雄牛
347
+ 水牛
348
+ バイソン
349
+ 牡羊,雄羊
350
+ ビッグホーン
351
+ アイベックス
352
+ ハーテビースト
353
+ インパラ
354
+ ガゼル
355
+ アラビアラクダ
356
+ ラマ
357
+ イタチ
358
+ ミンク
359
+ ヨーロッパケナガイタチ
360
+ クロアシイタチ
361
+ カワウソ
362
+ スカンク
363
+ アナグマ
364
+ アルマジロ
365
+ ミユビナマケモノ
366
+ オランウータン
367
+ ゴリラ
368
+ チンパンジー
369
+ テナガザル
370
+ フクロテナガザル
371
+ オナガザル
372
+ パタスモンキー
373
+ ヒヒ
374
+ マカク
375
+ ラングール,ヤセザル
376
+ コロブス
377
+ テングザル
378
+ マーモセット
379
+ オマキザル
380
+ ハウラ,ホエザル
381
+ ティティ
382
+ クモザル
383
+ リスザル
384
+ ワオキツネザル
385
+ インドリ
386
+ インドゾウ
387
+ アフリカゾウ
388
+ レッサーパンダ
389
+ ジャイアントパンダ
390
+ スヌーク
391
+ ウナギ
392
+ ギンザケ,銀鮭
393
+ ロックビューティーエンゼルフィッシュ
394
+ クマノミ
395
+ チョウザメ
396
+ ガーフィッシュ
397
+ ミノカサゴ
398
+ フグ
399
+ そろばん
400
+ アバヤ,アラブの民族衣装
401
+ アカデミックガウン,法服
402
+ アコーディオン
403
+ アコースティックギター
404
+ 空母
405
+ 旅客機
406
+ 飛行船
407
+ 祭壇
408
+ 救急車
409
+ 水陸両用車
410
+ アナログ時計
411
+ 養蜂場
412
+ エプロン
413
+ ごみ箱
414
+ アサルトライフル
415
+ リュック,バックパック
416
+ パン屋,ベーカリー
417
+ 平均台
418
+ バルーン,気球,風船
419
+ ボールペン
420
+ 絆創膏
421
+ バンジョー
422
+ 手すり
423
+ バーベル
424
+ 理髪店のいす
425
+ 理髪店
426
+ 納屋
427
+ バロメーター,気圧計
428
+
429
+ 手押し車
430
+ 野球ボール
431
+ バスケットボール
432
+ バシネット
433
+ バスーン,ファゴット
434
+ 水泳帽
435
+ バスタオル
436
+ 浴槽
437
+ ステーションワゴン
438
+ 灯台
439
+ ビーカー
440
+ シャコー帽
441
+ ビール瓶
442
+ ビールグラス
443
+ 鐘塔,鐘楼
444
+ よだれ掛け
445
+ タンデム自転車
446
+ ビキニ
447
+ バインダー
448
+ 双眼鏡
449
+ 巣箱,鳥小屋
450
+ ボートハウス
451
+ ボブスレー
452
+ ループタイ
453
+ ボンネット
454
+ 本棚
455
+ 書店
456
+ 瓶の蓋
457
+ 狩猟弓
458
+ 蝶ネクタイ
459
+ 真鍮記念プレート
460
+ ブラジャー
461
+ 防波堤
462
+ 鎧の胸当て
463
+ ほうき
464
+ バケツ
465
+ バックル
466
+ 防弾チョッキ
467
+ 新幹線
468
+ 精肉店
469
+ タクシー
470
+ 大釜
471
+ キャンドル
472
+ 大砲
473
+ カヌー
474
+ 缶切り
475
+ カーディガン
476
+ 車のミラー
477
+ メリーゴーランド,回転���馬
478
+ 工具セット
479
+ 段ボール箱
480
+ 車輪
481
+ ATM
482
+ カセットテープ
483
+ カセットプレーヤー
484
+
485
+ カタマラン
486
+ CDプレーヤー
487
+ チェロ
488
+ 携帯電話
489
+
490
+ 金網フェンス
491
+ 鎖帷子,鎖かたびら
492
+ チェーンソー
493
+ チェスト,収納
494
+ 西洋だんす,シフォニア
495
+ チャイム,ベル,鐘
496
+ 食器棚
497
+ クリスマスストッキング
498
+ 教会
499
+ 映画館
500
+ チョッパー,肉包丁,クリーバー
501
+ 崖の住居
502
+ マント
503
+ サボ,下駄
504
+ カクテルシェーカー
505
+ コーヒーマグ
506
+ コーヒーポット
507
+ コイル
508
+ 組み合わせ錠,ダイヤル錠
509
+ コンピュータキーボード
510
+ 菓子屋
511
+ コンテナ船
512
+ オープンカー,コンバーチブル
513
+ コルク抜き
514
+ コルネット
515
+ カウボーイブーツ
516
+ カウボーイハット
517
+ ゆりかご
518
+ クレーン
519
+ クラッシュヘルメット
520
+ 木箱
521
+ ベビーベッド
522
+ スロークッカー
523
+ クロケットボール
524
+ 松葉杖
525
+ キュイラス,胸当て
526
+ ダム
527
+
528
+ デスクトップコンピューター
529
+ ダイヤル電話
530
+ おむつ
531
+ デジタル時計
532
+ デジタル腕時計
533
+ ダイニングテーブル
534
+ 布巾
535
+ 食器洗い機
536
+ ディスクブレーキ
537
+ ドック,船着き場
538
+ 犬ぞり
539
+ ドーム
540
+ 玄関マット
541
+ 掘削リグ
542
+ ドラム
543
+ ドラムスティック
544
+ ダンベル
545
+ ダッチオーブン
546
+ 扇風機
547
+ エレキギター
548
+ 電気機関車
549
+ 娯楽施設
550
+ 封筒
551
+ エスプレッソマシーン
552
+ フェースパウダー
553
+ フェザーボア
554
+ バインダー,書類キャビネット
555
+ 消防艇
556
+ 消防車
557
+ 防火用スクリーン
558
+ 旗竿
559
+ フルート
560
+ 折りたたみ椅子
561
+ アメリカンフットボールのヘルメット
562
+ フォークリフト
563
+ 噴水
564
+ 万年筆
565
+ 四柱ベッド
566
+ 貨車
567
+ フレンチホルン
568
+ フライパン
569
+ 毛皮のコート
570
+ ごみ収集車
571
+ ガスマスク
572
+ ガソリンポンプ
573
+ ゴブレット
574
+ ゴーカート
575
+ ゴルフボール
576
+ ゴルフカート
577
+ ゴンドラ
578
+ ゴング
579
+ ガウン
580
+ グランドピアノ
581
+ 植木室,温室
582
+ ラジエーターグリル
583
+ 食料品店
584
+ ギロチン
585
+ ヘアクリップ
586
+ ヘアスプレー
587
+ ハーフトラック
588
+ ハンマー
589
+ 洗濯かご
590
+ ヘアドライヤー
591
+ 携帯コンピュータ
592
+ ハンカチ
593
+ ハードディスクドライブ,HDD
594
+ ハーモニカ
595
+ ハープ,竪琴
596
+ 刈り取り機,コンバイン
597
+
598
+ ホルスター
599
+ ホームシアター
600
+ ハニカム
601
+ フック
602
+ フープスカート
603
+ 鉄棒
604
+ 馬車
605
+ 砂時計
606
+ iPod,アイポッド
607
+ 衣類用アイロン
608
+ ジャックオーランタン
609
+ ジーンズ
610
+ ジープ
611
+ Tシャツ
612
+ ジグソーパズル
613
+ 人力車
614
+ ジョイスティック
615
+ 着物
616
+ 膝パッド
617
+ 結び目
618
+ 白衣
619
+ レードル,ひしゃく
620
+ ランプシェード,秉燭
621
+ ノートパソコン
622
+ 芝刈り機
623
+ レンズキャップ
624
+ レターオープナー
625
+ 図書館
626
+ 救命ボート
627
+ ライター
628
+ リムジン
629
+ 定期船
630
+ 口紅
631
+ ローファー
632
+ ローション
633
+ スピーカー
634
+ ルーペ
635
+ 製材所
636
+ 磁気コンパス
637
+ メッセンジャーバッグ
638
+ 郵便受け
639
+ タイツ
640
+ ワンピース水着
641
+ マンホールの蓋
642
+ マラカス
643
+ マリンバ
644
+ マスク,仮面
645
+ マッチ棒
646
+ メイポール,五月柱
647
+ 迷路
648
+ 計量カップ
649
+ 薬箱
650
+ 巨石
651
+ マイク
652
+ 電子レンジ
653
+ 軍服
654
+ ミルク缶
655
+ ミニバス
656
+ ミニスカート
657
+ ミニバン
658
+ ミサイル
659
+ ミトン
660
+ ミキシングボウル
661
+ 移動式住宅
662
+ フォード・モデルT
663
+ モデム
664
+ 修道院
665
+ モニター
666
+ モペット
667
+ 乳鉢と乳棒
668
+ 卒業帽
669
+ モスク
670
+ 蚊帳
671
+ スクーター
672
+ マウンテンバイク
673
+ 山のテント
674
+ コンピュータマウス
675
+ ネズミ捕り
676
+ 引っ越しトラック
677
+ 銃口
678
+ 金属釘
679
+ ネックブレース
680
+ ネックレス
681
+ おしゃぶり
682
+ ノートパソコン
683
+ オベリスク
684
+ オーボエ
685
+ オカリナ
686
+ オドメーター
687
+ オイルフィルター
688
+ パイプオルガン
689
+ オシロスコープ
690
+ オーバースカート
691
+ 牛車
692
+ 酸素マスク
693
+ 小包
694
+ パドル
695
+ パドルホイール
696
+ 南京錠
697
+ 絵筆
698
+ パジャマ
699
+ 宮殿
700
+ パンフルート
701
+ ペーパータオル
702
+ パラシュート
703
+ 平行棒
704
+ 公園のベンチ
705
+ パーキングメーター
706
+ 客車,鉄道車両
707
+ パティオ
708
+ 公衆電話
709
+ 台座
710
+ 筆箱
711
+ 鉛筆削り
712
+ 香水
713
+ ペトリ皿
714
+ コピー機
715
+ ピック
716
+ ピッケルハウベ,スパイク付き鉄かぶと
717
+ ピケットフェンス
718
+ ピックアップトラック
719
+ 桟橋
720
+ 貯金箱
721
+ 錠剤瓶
722
+
723
+ ピンポン球
724
+ 風車
725
+ 海賊船
726
+ ピッチャー,水差し
727
+ 角鉋,かんな
728
+ プラネタリウム
729
+ ビニール袋
730
+ 皿立て
731
+ 農耕用プラウ
732
+ プランジャー
733
+ ポラロイドカメラ
734
+ ポール
735
+ 警察車
736
+ ポンチョ
737
+ ビリヤード台
738
+ ソーダボトル
739
+ 植木鉢
740
+ ろくろ
741
+ 電動ドリル
742
+ 礼拝用敷物
743
+ プリンタ
744
+ 刑務所
745
+ ミサイル
746
+ プロジェクター
747
+ ホッケーパック
748
+ サンドバッグ
749
+ がま口,銭入れ
750
+ 羽ペン
751
+ キルト
752
+ レーシングカー
753
+ ラケット
754
+ ラジエーター
755
+ ラジオ,無線
756
+ 電波望遠鏡
757
+ 天水桶
758
+ キャンピングカー
759
+ 釣りリール
760
+ 一眼レフカメラ
761
+ 冷蔵庫
762
+ リモコン
763
+ レストラン
764
+ リボルバー
765
+ ライフル
766
+ ロッキングチェア
767
+ 焼��料理店
768
+ 消しゴム
769
+ ラグビーボール
770
+ 定規
771
+ スニーカー
772
+ 金庫
773
+ 安全ピン
774
+ 塩入れ
775
+ サンダル
776
+ サロン
777
+ サックス
778
+
779
+ 体重計
780
+ スクールバス
781
+ スクーナー
782
+ スコアボード
783
+ CRTモニター
784
+ ねじ,スクリュー
785
+ ドライバー
786
+ シートベルト
787
+ ミシン
788
+
789
+ 靴屋
790
+ 障子
791
+ 買い物かご
792
+ ショッピングカート
793
+ シャベル
794
+ シャワーキャップ
795
+ シャワーカーテン
796
+ スキー
797
+ スキーマスク
798
+ 寝袋
799
+ 計算尺
800
+ 引戸
801
+ スロットマシン
802
+ シュノーケル
803
+ スノーモービル
804
+ 除雪機
805
+ ソープディスペンサー
806
+ サッカーボール
807
+ 靴下
808
+ 太陽炉
809
+ ソンブレロ
810
+ スープ皿
811
+ スペースキー
812
+ スペースヒーター
813
+ スペースシャトル
814
+ スパチュラ,へら
815
+ レース艇,モーターボート
816
+ クモの巣
817
+ 紡錘
818
+ スポーツカー
819
+ スポットライト
820
+ ステージ
821
+ 蒸気機関車
822
+ 通り抜けアーチ橋
823
+ スチールドラム
824
+ 聴診器
825
+ ストール
826
+ 石垣
827
+ ストップウォッチ
828
+ ストーブ
829
+ ろ過器,ストレーナー
830
+ 路面電車
831
+ 担架,ストレッチャー
832
+ カウチ,ソファ
833
+ 仏舎利塔
834
+ 潜水艦
835
+ スーツ
836
+ 日時計,日晷儀,晷針
837
+ サングラス
838
+ サングラス
839
+ 日焼け止め
840
+ 吊り橋
841
+ モップ
842
+ スウェットシャツ,トレーナー
843
+ 海パン
844
+ ブランコ
845
+ スイッチ
846
+ 注射器
847
+ 電気スタンド
848
+ タンク,戦車
849
+ テーププレーヤー
850
+ ティーポット,急須
851
+ テディベア
852
+ テレビ
853
+ テニスボール
854
+ 茅葺屋根
855
+ 劇場のカーテン
856
+ 指ぬき
857
+ 脱穀機
858
+ 玉座
859
+ 瓦屋根
860
+ トースター
861
+ タバコ屋
862
+ 便座
863
+ たいまつ
864
+ トーテムポール
865
+ レッカー車
866
+ 玩具屋
867
+ トラクター
868
+ トレーラートラック
869
+ お盆,トレイ
870
+ トレンチコート
871
+ 三輪車
872
+ トリマラン,三胴船
873
+ 三脚
874
+ 凱旋門
875
+ トロリーバス
876
+ トロンボーン
877
+ バスタブ
878
+ 回転ドア
879
+ タイプライターのキーボード
880
+
881
+ 一輪車
882
+ アップライトピアノ
883
+ 掃除機
884
+ 花瓶
885
+ 丸天井,円蓋
886
+ ベルベット
887
+ 自動販売機
888
+ 祭服,礼服
889
+ 高架橋
890
+ バイオリン
891
+ バレーボール
892
+ ワッフルメーカー
893
+ 壁掛け時計
894
+ 財布
895
+ ワードローブ
896
+ 軍用機
897
+ シンク,洗面器
898
+ ワッシャー,洗濯機
899
+ 水筒
900
+ 水差し
901
+ ウォータータワー,給水塔
902
+ ウイスキージャグ
903
+ ホイッスル
904
+ かつら
905
+ 窓網戸
906
+ ブラインド
907
+ ウィンザーネクタイ
908
+ ワインボトル
909
+ 飛行機の翼
910
+ 中華鍋
911
+ 木製スプーン
912
+ ウール
913
+ ワームフェンス
914
+ 難破船
915
+ 帆船
916
+ ユルト
917
+ ウェブサイト
918
+ 漫画本
919
+ クロスワードパズル
920
+ 道路標識
921
+ 信号機
922
+ ブックカバー
923
+ メニュー
924
+ お皿
925
+ ワカモレ
926
+ コンソメ
927
+ ホットポット,火鍋
928
+ パフェ,トライフル
929
+ アイスクリーム
930
+ アイスキャンディー
931
+ フランスパン
932
+ ベーグル
933
+ プレッツェル
934
+ チーズバーガー
935
+ ホットドッグ
936
+ マッシュポテト
937
+ キャベツ
938
+ ブロッコリー
939
+ カリフラワー
940
+ ズッキーニ
941
+ そうめんかぼちゃ
942
+ ドングリかぼちゃ
943
+ バターナッツかぼちゃ
944
+ キュウリ
945
+ アーティチョーク
946
+ ピーマン
947
+ カルドン
948
+ キノコ
949
+ 青リンゴ
950
+ イチゴ
951
+ オレンジ,ミカン
952
+ レモン,檸檬
953
+ イチジク
954
+ パイナップル
955
+ バナナ
956
+ ジャックフルーツ,パラミツ
957
+ カスタードアップル
958
+ ザクロ
959
+ 干し草
960
+ カルボナーラ
961
+ チョコレートソース
962
+ 生地
963
+ ミートローフ
964
+ ピザ
965
+ ポットパイ
966
+ ブリトー
967
+ 赤ワイン
968
+ エスプレッソ
969
+ カップ
970
+ エッグノッグ
971
+
972
+
973
+
974
+ サンゴ礁
975
+ 間欠泉
976
+ 湖畔
977
+
978
+ 砂州
979
+ 海岸
980
+
981
+ 火山
982
+ 野球選手
983
+ 婿,新郎
984
+ スキューバダイバー
985
+ 菜種
986
+ デイジー,ヒナギク,雛菊
987
+ パフィオペディラム
988
+ コーン,トウキビ,トウモロコシ
989
+ ドングリ
990
+ ローズヒップ
991
+ セイヨウトチノキ
992
+ ホウキタケ
993
+ ハラタケ
994
+ シャグマアミガサタケ
995
+ スッポンタケ
996
+ ツチグリ
997
+ マイタケ
998
+ ヤマドリタケ
999
+ トウモロコシの穂・芯
1000
+ ちり紙,トイレットペーパー
ja-imagenet-1k-templates.txt ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {c}の悪い写真
2
+ 多くの{c}の写真
3
+ {c}の彫刻
4
+ 見づらい{c}の写真
5
+ {c}の低解像度写真
6
+ {c}のレンダリング
7
+ {c}の落書き
8
+ {c}のトリミング写真
9
+ {c}のタトゥー
10
+ 刺繍された{c}
11
+ {c}の明るい写真
12
+ きれいな{c}の写真
13
+ 汚れた{c}の写真
14
+ {c}の暗い写真
15
+ {c}の絵
16
+ 私の{c}の写真
17
+ プラスチック製の{c}
18
+ かっこいい{c}の写真
19
+ {c}のクローズアップ写真
20
+ {c}の白黒写真
21
+ {c}のピクセル写真
22
+ jpegで加工した{c}の写真
23
+ {c}のぼやけた写真
24
+ {c}の写真
25
+ {c}の良い写真
26
+ ゲームに登場する{c}
27
+ 折り紙で作った{c}
28
+ {c}のスケッチ
29
+ おもちゃの{c}
30
+ {c}の演出
31
+ 大きな{c}の写真
32
+ 素敵な{c}の写真
33
+ 奇妙な{c}の写真
34
+ 漫画の{c}
35
+ {c}の芸術
36
+ {c}のぬいぐるみ
37
+ 小さな{c}の写真
model.py ADDED
@@ -0,0 +1,505 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+
3
+ # Copyright 2024 LY Corporation.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ from __future__ import annotations
17
+
18
+ import logging
19
+ from typing import Optional, Union
20
+
21
+ import timm
22
+ import torch
23
+ import torch.distributed as dist
24
+ import torch.distributed.nn
25
+ import torch.nn as nn
26
+ import torch.nn.functional as F
27
+ from timm.models.swin_transformer import SwinTransformer as TimmSwinTransformer
28
+ from transformers import PreTrainedModel
29
+ from transformers.utils.logging import get_logger
30
+
31
+ from .configuration_clyp import (
32
+ CLYPTextBackboneConfig,
33
+ CLYPTextEncoderConfig,
34
+ CLYPVisionBackboneConfig,
35
+ CLYPVisionEncoderConfig,
36
+ )
37
+ from .model_rinna import RinnaCLIPConfig, RinnaCLIPModel
38
+
39
+ DEFAULT_LOGGER = get_logger(__name__)
40
+
41
+
42
+ class VisionEncoder(nn.Module):
43
+ """Vision encoder to extract image feateurs.
44
+
45
+ Pooler and neck are optional.
46
+ Instead of defining pooler and neck in VisionEncoder, you can define them in algorithm classes.
47
+
48
+ Attributes:
49
+ backbone (nn.Module): backbone loaded from timm, huggingface or registry.
50
+ pooler (nn.Module): module to extract image-level features.
51
+ neck (nn.Module): module to adjust feature dimensions.
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ backbone: nn.Module,
57
+ pooler: Optional[nn.Module] = None,
58
+ neck: Optional[nn.Module] = None,
59
+ ) -> None:
60
+ super().__init__()
61
+ self.backbone = backbone
62
+ self.pooler = pooler
63
+ self.neck = neck
64
+
65
+ def forward(self, imgs: torch.Tensor):
66
+ """A method to extract image features.
67
+
68
+ Args:
69
+ imgs (torch.Tensor): shape=(batch_size, channels, height, width).
70
+
71
+ Returns:
72
+ out (torch.Tensor): the output shape changes depending on pooler, and the following shapes are usually expected.
73
+ - output only image-level features like CLIP: shape=(batch_size, embed_dim)
74
+ - output image-level and local patch features like BLIP2: shape=(batch_size, embed_dim, length)
75
+ """
76
+ out = self.backbone(imgs) # Shape=(batch_size, channels, height, width)
77
+ if self.pooler:
78
+ out = self.pooler(out)
79
+ if self.neck:
80
+ out = self.neck(out)
81
+ return out
82
+
83
+
84
+ class SwinTransformerPerm(nn.Module):
85
+ """Wrapper for SwinTransformer in timm.
86
+
87
+ This wrapper changes the output shape to (batch_size, channels, height, width).
88
+ The original shape of timm SwinTransformer is (batch_size, height, width, channels).
89
+ """
90
+
91
+ def __init__(self, swin: nn.Module) -> None:
92
+ super().__init__()
93
+ self.swin = swin
94
+
95
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
96
+ out = self.swin(x)
97
+ out = out.permute(0, 3, 1, 2)
98
+ return out
99
+
100
+
101
+ def load_from_timm(
102
+ config: CLYPVisionBackboneConfig,
103
+ use_gradient_checkpointing: bool = False,
104
+ path_weights: Optional[str] = None,
105
+ logger: logging.Logger = DEFAULT_LOGGER,
106
+ ):
107
+ """Create a backbone using a method: timm.create_model.
108
+
109
+ Args:
110
+ config (TimmBackboneConfig): config fed to timm.create_model.
111
+ use_gradient_checkpointing (bool): True if use gradient checkpointing.
112
+ path_weights (str): path to weights for backbone initialization.
113
+ """
114
+ # backbone
115
+ assert config is not None
116
+ backbone = timm.create_model(
117
+ model_name=config.model_name,
118
+ pretrained=config.pretrained,
119
+ **config.extra_kwargs,
120
+ )
121
+ backbone.reset_classifier(0, "")
122
+
123
+ logger.info(
124
+ f" - load from timm: model_name={config.model_name}, pretrained={config.pretrained}"
125
+ )
126
+
127
+ # gradient checkpointing
128
+ backbone.set_grad_checkpointing(enable=use_gradient_checkpointing)
129
+ if use_gradient_checkpointing:
130
+ logger.info(" - gradient checkpointing is enebled.")
131
+
132
+ # init weights
133
+ if path_weights:
134
+ state_dict = torch.load(path_weights, map_location="cpu")
135
+ checks = backbone.load_state_dict(state_dict, strict=False)
136
+ logger.info(f" - load weights from {path_weights}")
137
+ logger.info(f" - state dict checks: {checks}")
138
+
139
+ # swin
140
+ if isinstance(backbone, TimmSwinTransformer):
141
+ backbone = SwinTransformerPerm(backbone)
142
+ return backbone
143
+
144
+
145
+ def create_vision_encoder(
146
+ config: CLYPVisionEncoderConfig, logger: logging.Logger = DEFAULT_LOGGER
147
+ ) -> VisionEncoder:
148
+ assert config.pooler_config.input_type
149
+ backbone = load_from_timm(config.backbone_config, logger=logger)
150
+ pooler = CLSTokenPooling(
151
+ config.pooler_config.input_type, config.pooler_config.return_patch_features
152
+ )
153
+ neck = Linear(
154
+ config.neck_config.in_channels,
155
+ config.neck_config.out_channels,
156
+ config.neck_config.bias,
157
+ )
158
+ return VisionEncoder(backbone, pooler=pooler, neck=neck)
159
+
160
+
161
+ class TextEncoder(nn.Module):
162
+ """Text encoder to extract text features.
163
+
164
+ Pooler and neck are optional.
165
+ Instead of defining pooler and neck in TextEncoder, you can define them in algorithm classes.
166
+
167
+ Attributes:
168
+ backbone (nn.Module): backbone loaded from timm, huggingface or registry.
169
+ pooler (nn.Module): module to extract image-level features.
170
+ neck (nn.Module): module to adjust feature dimensions.
171
+
172
+ """
173
+
174
+ def __init__(
175
+ self,
176
+ backbone: nn.Module,
177
+ pooler: Optional[nn.Module] = None,
178
+ neck: Optional[nn.Module] = None,
179
+ ) -> None:
180
+ super().__init__()
181
+ self.backbone = backbone
182
+ self.pooler = pooler
183
+ self.neck = neck
184
+
185
+ def forward(self, inputs: dict) -> torch.Tensor:
186
+ """A method to extract text features.
187
+
188
+ Args:
189
+ inputs (dict): basic keys are shown below:
190
+ - input_ids (torch.Tensor)
191
+ - attention_mask (Optional[torch.Tensor])
192
+ - position_ids (Optional[torch.Tensor])
193
+ - token_type_ids (Optional[torch.Tensor])
194
+ - output_attentions Optional[bool]
195
+ - output_hidden_states Optional[bool]
196
+
197
+ Returns:
198
+ out (torch.Tensor): the output shape changes depending on pooler, and the following shapes are usually expected.
199
+ - output only class token like CLIP: shape=(batch_size, embed_dim)
200
+ - output all token features like BLIP2: shape=(batch_size, embed_dim, length)
201
+ """
202
+ out = self.backbone(**inputs)
203
+ if self.pooler:
204
+ out = self.pooler(out)
205
+ if self.neck:
206
+ out = self.neck(out)
207
+ return out
208
+
209
+
210
+ class TextBackboneModelWrapper(nn.Module):
211
+ def __init__(self, model: nn.Module) -> None:
212
+ super().__init__()
213
+ self.model = model.text_model
214
+
215
+ def forward(
216
+ self,
217
+ input_ids: Optional[torch.Tensor] = None,
218
+ attention_mask: Optional[torch.Tensor] = None,
219
+ position_ids: Optional[torch.Tensor] = None,
220
+ token_type_ids: Optional[torch.Tensor] = None,
221
+ ) -> torch.Tensor:
222
+ out = self.model(
223
+ input_ids=input_ids,
224
+ attention_mask=attention_mask,
225
+ position_ids=position_ids,
226
+ token_type_ids=token_type_ids,
227
+ )
228
+ return out
229
+
230
+ def set_gradient_checkpointing(self, enabled: bool) -> None:
231
+ if enabled:
232
+ self.model.gradient_checkpointing_enable()
233
+
234
+
235
+ def load_from_huggingface(
236
+ config: CLYPTextBackboneConfig,
237
+ use_gradient_checkpointing: bool = False,
238
+ path_weights: Optional[str] = None,
239
+ logger: logging.Logger = DEFAULT_LOGGER,
240
+ ) -> nn.Module:
241
+ """Load a backbone from huggingface.
242
+
243
+ Args:
244
+ config (HuggingfaceBackboneConfig): config fed to AutoModel.from_pretrained.
245
+ use_gradient_checkpointing (bool): True if use gradient checkpointing.
246
+ path_weights (str): path to weights for backbone initialization.
247
+ """
248
+
249
+ # NOTE:
250
+ # Initialize Rinna CLIP without pretrained weights here,
251
+ # because CLYP model loads its whole weights afterward
252
+ auto_config = RinnaCLIPConfig.from_pretrained(config.model_name)
253
+ backbone = RinnaCLIPModel(auto_config)
254
+
255
+ logger.info(f" - load from huggingface: model_name={config.model_name}")
256
+
257
+ # gradient checkpointing
258
+ if isinstance(backbone, PreTrainedModel):
259
+ if use_gradient_checkpointing:
260
+ backbone.gradient_checkpointing_enable()
261
+ logger.info(" - gradient checkpointing is enabled")
262
+ else:
263
+ raise NotImplementedError()
264
+
265
+ # init weights
266
+ if path_weights:
267
+ raise NotImplementedError()
268
+ return backbone
269
+
270
+
271
+ def create_text_encoder(
272
+ config: CLYPTextEncoderConfig, logger: logging.Logger = DEFAULT_LOGGER
273
+ ) -> TextEncoder:
274
+ assert config.pooler_config.input_type
275
+ backbone = TextBackboneModelWrapper(
276
+ load_from_huggingface(config.backbone_config, logger=logger)
277
+ )
278
+ pooler = CLSTokenPooling(
279
+ config.pooler_config.input_type, config.pooler_config.return_patch_features
280
+ )
281
+ neck = Linear(
282
+ config.neck_config.in_channels,
283
+ config.neck_config.out_channels,
284
+ bias=config.neck_config.bias,
285
+ )
286
+ return TextEncoder(backbone, pooler=pooler, neck=neck)
287
+
288
+
289
+ class Linear(nn.Module):
290
+ """Linear layer."""
291
+
292
+ def __init__(self, in_channels: int, out_channels: int, bias: bool) -> None:
293
+ """
294
+ Args:
295
+ in_channels (int): input feature dimension.
296
+ out_channels (out): output feature dimension.
297
+ bias (bool): True if use bias in nn.Linear.
298
+ """
299
+ super().__init__()
300
+ self.linear = nn.Linear(in_channels, out_channels, bias=bias)
301
+
302
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
303
+ """
304
+ Args:
305
+ x (torch.Tensor): shape=(batch_size, ..., in_channels).
306
+
307
+ Returns:
308
+ out (torch.Tensor): shape=(batch_size, ..., out_channels).
309
+ """
310
+ out = self.linear(x)
311
+ return out
312
+
313
+
314
+ class CLSTokenPooling(nn.Module):
315
+ """A module to extract class token."""
316
+
317
+ def __init__(self, input_type: str, return_patch_features: bool) -> None:
318
+ """
319
+ Args:
320
+ input_type (str): timm or huggingface.
321
+ - If input_type is timm, x[:, 0] is extracted as a class token.
322
+ - If input_type is huggingface, x.last_hidden_state[:,0] is extracted as a class token.
323
+ return_patch_features (bool): True if output local features.
324
+ """
325
+ super().__init__()
326
+ assert input_type in ["timm", "huggingface"]
327
+ self.input_type = input_type
328
+ self.return_patch_features = return_patch_features
329
+
330
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
331
+ """
332
+ Args:
333
+ x (torch.Tensor): shape=(batch_size, length, dim).
334
+
335
+ Returns:
336
+ out (torch.Tensor): shape=(batch_size, dim).
337
+ """
338
+ # tensor: shape=(batch_size, length, dim)
339
+ if self.input_type == "timm":
340
+ assert x.ndim == 3, "CLSTokenPooling: dimension of input tensor must be 3."
341
+ if self.return_patch_features:
342
+ return x
343
+ else:
344
+ return x[:, 0]
345
+
346
+ # huggingface
347
+ elif self.input_type == "huggingface":
348
+ out = x.last_hidden_state
349
+ if self.return_patch_features:
350
+ return out
351
+ else:
352
+ return out[:, 0]
353
+
354
+
355
+ class InfoNCELoss(nn.Module):
356
+ def __init__(
357
+ self,
358
+ learn_temperature: bool,
359
+ init_temperature: float,
360
+ max_temperature: Optional[float] = None,
361
+ min_temperature: Optional[float] = None,
362
+ label_smoothing: float = 0.0,
363
+ gather_with_grad: bool = False,
364
+ ):
365
+ super().__init__()
366
+ self.label_smoothing = label_smoothing
367
+ self.gather_with_grad = gather_with_grad
368
+
369
+ # set temperature
370
+ self.learn_temperature = learn_temperature
371
+ self.temperature = torch.ones([]) * init_temperature
372
+ if self.learn_temperature:
373
+ self.temperature = nn.Parameter(self.temperature)
374
+ self.max_temperature = max_temperature
375
+ self.min_temperature = min_temperature
376
+
377
+ # whether clip temperature or not
378
+ self.require_temperature_clipping = self.learn_temperature and (
379
+ self.max_temperature or self.min_temperature
380
+ )
381
+
382
+ def clip_temperature(self):
383
+ if self.require_temperature_clipping:
384
+ self.temperature.data = torch.clamp(
385
+ self.temperature, self.min_temperature, self.max_temperature
386
+ )
387
+
388
+ def forward(
389
+ self,
390
+ image_feats: torch.Tensor,
391
+ text_feats: torch.Tensor,
392
+ return_similarity: bool = False,
393
+ ) -> Union[torch.Tensor, tuple[torch.Tensor]]:
394
+ # gather image and text features
395
+ image_feats_all = concat_all_gather(
396
+ image_feats, with_grad=self.gather_with_grad
397
+ )
398
+ text_feats_all = concat_all_gather(text_feats, with_grad=self.gather_with_grad)
399
+
400
+ # compute cosine similarity
401
+ sim_i2t = image_to_text_similarity(
402
+ image_feats=image_feats,
403
+ text_feats=text_feats_all,
404
+ )
405
+ sim_t2i = text_to_image_similarity(
406
+ text_feats=text_feats,
407
+ image_feats=image_feats_all,
408
+ )
409
+
410
+ # logits, scaled cosine similarity
411
+ logits_i2t = sim_i2t / self.temperature
412
+ logits_t2i = sim_t2i / self.temperature
413
+
414
+ # obtain targets
415
+ rank = dist.get_rank()
416
+ batch_size = image_feats.size(0)
417
+ targets = torch.arange(batch_size) + batch_size * rank
418
+ targets = targets.to(dtype=torch.long, device=image_feats.device)
419
+
420
+ # calculate loss
421
+ loss_i2t = F.cross_entropy(
422
+ logits_i2t, targets, label_smoothing=self.label_smoothing
423
+ )
424
+ loss_t2i = F.cross_entropy(
425
+ logits_t2i, targets, label_smoothing=self.label_smoothing
426
+ )
427
+ loss = (loss_i2t + loss_t2i) / 2.0
428
+
429
+ if not return_similarity:
430
+ return loss
431
+ else:
432
+ return loss, sim_i2t, sim_t2i
433
+
434
+
435
+ def image_to_text_similarity(
436
+ image_feats: torch.Tensor, text_feats: torch.Tensor
437
+ ) -> torch.Tensor:
438
+ """
439
+ Args:
440
+ image_feats (torch.Tensor): shape=(num_imgs, embed_dim) or (num_imgs, num_query_tokens, embed_dim).
441
+ text_feats (torch.Tensor): shape=(num_texts, embed_dim).
442
+
443
+ Returns:
444
+ sim_i2t (torch.Tensor): shape=(num_imgs, num_texts).
445
+ """
446
+ assert image_feats.ndim in [2, 3]
447
+ assert text_feats.ndim == 2
448
+
449
+ # normalize features
450
+ image_feats = F.normalize(image_feats, dim=-1)
451
+ text_feats = F.normalize(text_feats, dim=-1)
452
+
453
+ if image_feats.ndim == 2:
454
+ sim_i2t = image_feats @ text_feats.T
455
+ else:
456
+ # a query token with maximum cosine similarity is selected
457
+ sim_i2t = torch.matmul(
458
+ image_feats.unsqueeze(1), text_feats.unsqueeze(0).unsqueeze(-1)
459
+ ).squeeze() # shape=(num_imgs, num_texts, num_query_tokens)
460
+ sim_i2t, _ = sim_i2t.max(dim=-1) # shape=(num_imgs, num_texts)
461
+ return sim_i2t
462
+
463
+
464
+ def text_to_image_similarity(text_feats: torch.Tensor, image_feats: torch.Tensor):
465
+ """
466
+ Args:
467
+ text_feats (torch.Tensor): shape=(num_texts, embed_dim).
468
+ image_feats (torch.Tensor): shape=(num_imgs, embed_dim) or (num_imgs, num_query_tokens, embed_dim).
469
+
470
+ Returns:
471
+ similarity_maxtrix (torch.Tensor): shape=(num_texts, num_imgs).
472
+ """
473
+ assert image_feats.ndim in [2, 3]
474
+ assert text_feats.ndim == 2
475
+
476
+ # normalize features
477
+ image_feats = F.normalize(image_feats, dim=-1)
478
+ text_feats = F.normalize(text_feats, dim=-1)
479
+
480
+ if image_feats.ndim == 2:
481
+ sim_t2i = text_feats @ image_feats.T
482
+ else:
483
+ # a query token with maximum cosine similarity is selected
484
+ sim_t2i = torch.matmul(
485
+ text_feats.unsqueeze(1).unsqueeze(1),
486
+ image_feats.permute(0, 2, 1).unsqueeze(0),
487
+ ).squeeze()
488
+ sim_t2i, _ = sim_t2i.max(dim=-1)
489
+ return sim_t2i
490
+
491
+
492
+ def concat_all_gather(tensor: torch.Tensor, with_grad: bool):
493
+ """
494
+ Performs all_gather operation on the provided tensors.
495
+ *** Warning ***: torch.distributed.all_gather has no gradient.
496
+
497
+ Another implementation: https://github.com/salesforce/LAVIS/blob/main/lavis/models/base_model.py#L202-L237
498
+ """
499
+ if with_grad:
500
+ output = torch.cat(torch.distributed.nn.all_gather(tensor), dim=0)
501
+ else:
502
+ tensors_gather = [torch.ones_like(tensor) for _ in range(dist.get_world_size())]
503
+ dist.all_gather(tensors_gather, tensor, async_op=False)
504
+ output = torch.cat(tensors_gather, dim=0)
505
+ return output
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f0997d294a0723358c5622fc51caa0b8589de2d36295b1ff40cfa11f9c9f8e9c
3
+ size 786788708
model_rinna.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+
3
+ # Copyright 2024 LY Corporation.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ # Almost copied from https://github.com/rinnakk/japanese-clip/blob/master/src/japanese_clip/clip/modeling_clip.py
18
+ # This code is distributed under the Apache License 2.0.
19
+ from __future__ import annotations
20
+
21
+ import copy
22
+ from typing import Optional
23
+
24
+ import torch
25
+ import torch.distributed.nn
26
+ import torch.nn as nn
27
+ from transformers import AutoConfig, AutoModel, PreTrainedModel
28
+ from transformers.configuration_utils import PretrainedConfig
29
+ from transformers.models.clip import (
30
+ CLIPVisionConfig,
31
+ CLIPVisionModel,
32
+ )
33
+ from transformers.models.clip.modeling_clip import CLIPOutput
34
+ from transformers.utils import logging
35
+
36
+ logger = logging.get_logger(__name__)
37
+
38
+
39
+ # Copied from transformers.models.clip.modeling_clip.contrastive_loss
40
+ def contrastive_loss(logits: torch.Tensor) -> torch.Tensor:
41
+ return nn.functional.cross_entropy(
42
+ logits, torch.arange(len(logits), device=logits.device)
43
+ )
44
+
45
+
46
+ # Copied from transformers.models.clip.modeling_clip.clip_loss
47
+ def clip_loss(similarity: torch.Tensor) -> torch.Tensor:
48
+ caption_loss = contrastive_loss(similarity)
49
+ image_loss = contrastive_loss(similarity.T)
50
+ return (caption_loss + image_loss) / 2.0
51
+
52
+
53
+ class RinnaCLIPConfig(PretrainedConfig):
54
+ model_type = "clip"
55
+ is_composition = True
56
+
57
+ def __init__(self, projection_dim=512, logit_scale_init_value=2.6592, **kwargs):
58
+ super().__init__(**kwargs)
59
+
60
+ if "vision_config" not in kwargs:
61
+ raise ValueError("`vision_config` can not be `None`.")
62
+
63
+ if "text_config" not in kwargs:
64
+ raise ValueError("`text_config` can not be `None`.")
65
+
66
+ vision_config = kwargs.pop("vision_config")
67
+ text_config = kwargs.pop("text_config")
68
+
69
+ vision_model_type = vision_config.pop("model_type")
70
+ text_model_type = text_config.pop("model_type")
71
+
72
+ if vision_model_type == "clip":
73
+ self.vision_config = AutoConfig.for_model(
74
+ vision_model_type, **vision_config
75
+ ).vision_config
76
+ elif vision_model_type == "clip_vision_model":
77
+ self.vision_config = CLIPVisionConfig(**vision_config)
78
+ else:
79
+ self.vision_config = AutoConfig.for_model(
80
+ vision_model_type, **vision_config
81
+ )
82
+
83
+ self.text_config = AutoConfig.for_model(text_model_type, **text_config)
84
+
85
+ self.projection_dim = projection_dim
86
+ self.logit_scale_init_value = logit_scale_init_value
87
+
88
+ @classmethod
89
+ def from_vision_text_configs(
90
+ cls, vision_config: PretrainedConfig, text_config: PretrainedConfig, **kwargs
91
+ ):
92
+ r"""
93
+ Instantiate a [`VisionTextDualEncoderConfig`] (or a derived class) from text model configuration and vision
94
+ model configuration.
95
+
96
+ Returns:
97
+ [`VisionTextDualEncoderConfig`]: An instance of a configuration object
98
+ """
99
+
100
+ return cls(
101
+ vision_config=vision_config.to_dict(),
102
+ text_config=text_config.to_dict(),
103
+ **kwargs,
104
+ )
105
+
106
+ def to_dict(self):
107
+ """
108
+ Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
109
+
110
+ Returns:
111
+ `Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
112
+ """
113
+ output = copy.deepcopy(self.__dict__)
114
+ output["vision_config"] = self.vision_config.to_dict()
115
+ output["text_config"] = self.text_config.to_dict()
116
+ output["model_type"] = self.__class__.model_type
117
+ return output
118
+
119
+
120
+ class RinnaCLIPModel(PreTrainedModel):
121
+ config_class = RinnaCLIPConfig
122
+ base_model_prefix = "clip"
123
+
124
+ def __init__(
125
+ self,
126
+ config: Optional[RinnaCLIPConfig] = None,
127
+ vision_model: Optional[PreTrainedModel] = None,
128
+ text_model: Optional[PreTrainedModel] = None,
129
+ ):
130
+ if config is None and (vision_model is None or text_model is None):
131
+ raise ValueError(
132
+ "Either a configuration or an vision and a text model has to be provided"
133
+ )
134
+
135
+ if config is None:
136
+ config = RinnaCLIPConfig.from_vision_text_configs(
137
+ vision_model.config,
138
+ text_model.config, # type: ignore[union-attr]
139
+ )
140
+ else:
141
+ if not isinstance(config, self.config_class):
142
+ raise ValueError(
143
+ f"config: {config} has to be of type {self.config_class}"
144
+ )
145
+
146
+ # initialize with config
147
+ super().__init__(config)
148
+
149
+ if vision_model is None:
150
+ if isinstance(config.vision_config, CLIPVisionConfig):
151
+ vision_model = CLIPVisionModel(
152
+ config.vision_config, add_pooling_layer=False
153
+ )
154
+ else:
155
+ vision_model = AutoModel.from_config(
156
+ config.vision_config, add_pooling_layer=False
157
+ )
158
+
159
+ if text_model is None:
160
+ text_model = AutoModel.from_config(
161
+ config.text_config, add_pooling_layer=False
162
+ )
163
+
164
+ self.vision_model = vision_model
165
+ self.text_model = text_model
166
+
167
+ # make sure that the individual model's config refers to the shared config
168
+ # so that the updates to the config will be synced
169
+ self.vision_model.config = self.config.vision_config
170
+ self.text_model.config = self.config.text_config
171
+
172
+ self.vision_embed_dim = config.vision_config.hidden_size
173
+ self.text_embed_dim = config.text_config.hidden_size
174
+ self.projection_dim = config.projection_dim
175
+
176
+ self.visual_projection = nn.Linear(
177
+ self.vision_embed_dim, self.projection_dim, bias=False
178
+ )
179
+ self.text_projection = nn.Linear(
180
+ self.text_embed_dim, self.projection_dim, bias=False
181
+ )
182
+ self.logit_scale = nn.Parameter(
183
+ torch.ones([]) * self.config.logit_scale_init_value
184
+ )
185
+
186
+ def get_text_features(
187
+ self,
188
+ input_ids=None,
189
+ attention_mask=None,
190
+ position_ids=None,
191
+ token_type_ids=None,
192
+ output_attentions=None,
193
+ output_hidden_states=None,
194
+ return_dict=None,
195
+ out=False,
196
+ ):
197
+ text_outputs = self.text_model(
198
+ input_ids=input_ids,
199
+ attention_mask=attention_mask,
200
+ position_ids=position_ids,
201
+ token_type_ids=token_type_ids,
202
+ output_attentions=output_attentions,
203
+ output_hidden_states=output_hidden_states,
204
+ return_dict=return_dict,
205
+ )
206
+ pooled_output = text_outputs.last_hidden_state[:, 0, :]
207
+ text_features = self.text_projection(pooled_output)
208
+ if out:
209
+ return text_features, text_outputs
210
+ return text_features
211
+
212
+ def get_image_features(
213
+ self,
214
+ pixel_values=None,
215
+ output_attentions=None,
216
+ output_hidden_states=None,
217
+ return_dict=None,
218
+ ):
219
+ vision_outputs = self.vision_model(
220
+ pixel_values=pixel_values,
221
+ output_attentions=output_attentions,
222
+ output_hidden_states=output_hidden_states,
223
+ return_dict=return_dict,
224
+ )
225
+
226
+ pooled_output = vision_outputs.last_hidden_state[:, 0, :]
227
+ image_features = self.visual_projection(pooled_output)
228
+
229
+ return image_features
230
+
231
+ def forward(
232
+ self,
233
+ input_ids=None,
234
+ pixel_values=None,
235
+ attention_mask=None,
236
+ position_ids=None,
237
+ return_loss=None,
238
+ token_type_ids=None,
239
+ output_attentions=None,
240
+ output_hidden_states=None,
241
+ return_dict=None,
242
+ ):
243
+ return_dict = (
244
+ return_dict if return_dict is not None else self.config.return_dict
245
+ )
246
+
247
+ vision_outputs = self.vision_model(
248
+ pixel_values=pixel_values,
249
+ output_attentions=output_attentions,
250
+ output_hidden_states=output_hidden_states,
251
+ return_dict=return_dict,
252
+ )
253
+
254
+ text_outputs = self.text_model(
255
+ input_ids=input_ids,
256
+ attention_mask=attention_mask,
257
+ token_type_ids=token_type_ids,
258
+ position_ids=position_ids,
259
+ output_attentions=output_attentions,
260
+ output_hidden_states=output_hidden_states,
261
+ return_dict=return_dict,
262
+ )
263
+ image_embeds = vision_outputs.last_hidden_state[:, 0, :]
264
+ image_embeds = self.visual_projection(image_embeds)
265
+
266
+ text_embeds = text_outputs.last_hidden_state[:, 0, :]
267
+ text_embeds = self.text_projection(text_embeds)
268
+
269
+ # normalized features
270
+ image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
271
+ text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
272
+
273
+ # cosine similarity as logits
274
+ logit_scale = self.logit_scale.exp()
275
+ # logit_scale = self.logit_scale
276
+ logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
277
+ logits_per_image = logits_per_text.T
278
+
279
+ loss = None
280
+ if return_loss:
281
+ loss = clip_loss(logits_per_text)
282
+
283
+ if not return_dict:
284
+ output = (
285
+ logits_per_image,
286
+ logits_per_text,
287
+ text_embeds,
288
+ image_embeds,
289
+ text_outputs,
290
+ vision_outputs,
291
+ )
292
+ return ((loss,) + output) if loss is not None else output
293
+
294
+ return CLIPOutput(
295
+ loss=loss,
296
+ logits_per_image=logits_per_image,
297
+ logits_per_text=logits_per_text,
298
+ text_embeds=text_embeds,
299
+ image_embeds=image_embeds,
300
+ text_model_output=text_outputs,
301
+ vision_model_output=vision_outputs,
302
+ )
303
+
304
+ @classmethod
305
+ def from_pretrained(cls, *args, **kwargs):
306
+ # At the moment fast initialization is not supported
307
+ # for composite models
308
+ kwargs["_fast_init"] = False
309
+ return super().from_pretrained(*args, **kwargs)
310
+
311
+ @classmethod
312
+ def from_vision_text_pretrained(
313
+ cls,
314
+ vision_model_name_or_path: Optional[str] = None,
315
+ text_model_name_or_path: Optional[str] = None,
316
+ *model_args,
317
+ **kwargs,
318
+ ) -> PreTrainedModel:
319
+ kwargs_vision = {
320
+ argument[len("vision_") :]: value
321
+ for argument, value in kwargs.items()
322
+ if argument.startswith("vision_")
323
+ }
324
+
325
+ kwargs_text = {
326
+ argument[len("text_") :]: value
327
+ for argument, value in kwargs.items()
328
+ if argument.startswith("text_")
329
+ }
330
+
331
+ # remove vision, text kwargs from kwargs
332
+ for key in kwargs_vision.keys():
333
+ del kwargs["vision_" + key]
334
+ for key in kwargs_text.keys():
335
+ del kwargs["text_" + key]
336
+
337
+ # Load and initialize the vision and text model
338
+ vision_model = kwargs_vision.pop("model", None)
339
+ if vision_model is None:
340
+ if vision_model_name_or_path is None:
341
+ raise ValueError(
342
+ "If `vision_model` is not defined as an argument, a `vision_model_name_or_path` has to be defined"
343
+ )
344
+
345
+ if "config" not in kwargs_vision:
346
+ vision_config = AutoConfig.from_pretrained(vision_model_name_or_path)
347
+
348
+ if vision_config.model_type == "clip":
349
+ kwargs_vision["config"] = vision_config.vision_config
350
+ vision_model = CLIPVisionModel.from_pretrained(
351
+ vision_model_name_or_path,
352
+ add_pooling_layer=False,
353
+ *model_args,
354
+ **kwargs_vision,
355
+ )
356
+ # TODO: Should we use the pre-trained projection as well ?
357
+ else:
358
+ kwargs_vision["config"] = vision_config
359
+ vision_model = AutoModel.from_pretrained(
360
+ vision_model_name_or_path,
361
+ add_pooling_layer=False,
362
+ *model_args,
363
+ **kwargs_vision,
364
+ )
365
+
366
+ text_model = kwargs_text.pop("model", None)
367
+ if text_model is None:
368
+ if text_model_name_or_path is None:
369
+ raise ValueError(
370
+ "If `text_model` is not defined as an argument, a `text_model_name_or_path` has to be defined"
371
+ )
372
+
373
+ if "config" not in kwargs_text:
374
+ text_config = AutoConfig.from_pretrained(text_model_name_or_path)
375
+ kwargs_text["config"] = text_config
376
+
377
+ text_model = AutoModel.from_pretrained(
378
+ text_model_name_or_path,
379
+ add_pooling_layer=False,
380
+ *model_args,
381
+ **kwargs_text,
382
+ )
383
+
384
+ # instantiate config with corresponding kwargs
385
+ config = RinnaCLIPConfig.from_vision_text_configs(
386
+ vision_model.config, text_model.config, **kwargs
387
+ )
388
+
389
+ # init model
390
+ model = cls(config=config, vision_model=vision_model, text_model=text_model)
391
+
392
+ # the projection layers are always newly initialized when loading the model
393
+ # using pre-trained vision and text model.
394
+ # logger.warning(
395
+ # "The projection layer and logit scale weights `['visual_projection.weight', 'text_projection.weight', 'logit_scale']` "
396
+ # "are newly initialized. You should probably TRAIN this model on a down-stream task "
397
+ # "to be able to use it for predictions and inference."
398
+ # )
399
+
400
+ return model
modeling_clyp.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+
3
+ # Copyright 2024 LY Corporation.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ from __future__ import annotations
17
+
18
+ from dataclasses import dataclass
19
+ from typing import Any, Optional
20
+
21
+ import torch
22
+ import torch.nn.functional as F
23
+ from transformers import PreTrainedModel
24
+ from transformers.models.clip.modeling_clip import CLIPOutput
25
+
26
+ from .configuration_clyp import CLYPConfig, CLYPLossConfig
27
+ from .model import InfoNCELoss, create_text_encoder, create_vision_encoder
28
+ from .model_rinna import RinnaCLIPModel # noqa
29
+
30
+
31
+ @dataclass
32
+ class CLYPOutput(CLIPOutput):
33
+ ...
34
+
35
+
36
+ class CLYPPreTrainedModel(PreTrainedModel):
37
+ config_class = CLYPConfig
38
+
39
+ def __init__(self, *args, **kwargs):
40
+ super().__init__(*args, **kwargs)
41
+
42
+ def _init_weights(self, module: Any) -> None:
43
+ pass
44
+
45
+
46
+ class CLYPModel(CLYPPreTrainedModel):
47
+ def __init__(self, config: CLYPConfig):
48
+ super().__init__(config)
49
+ self.vision_encoder = create_vision_encoder(config.vision_encoder_config)
50
+ self.text_encoder = create_text_encoder(config.text_encoder_config)
51
+ self.initialize_clip(
52
+ learn_temperature=config.learn_temperature,
53
+ temperature_init=config.temperature_init,
54
+ temperature_min=config.temperature_min,
55
+ temperature_max=config.temperature_max,
56
+ itc_loss_config=config.itc_loss_config,
57
+ )
58
+
59
+ def initialize_clip(
60
+ self,
61
+ learn_temperature: Optional[bool] = None,
62
+ temperature_init: Optional[float] = None,
63
+ temperature_min: Optional[float] = None,
64
+ temperature_max: Optional[float] = None,
65
+ itc_loss_config: Optional[CLYPLossConfig] = None,
66
+ ) -> None:
67
+ # create contrastive loss function
68
+ if itc_loss_config:
69
+ raise NotImplementedError
70
+ else:
71
+ assert learn_temperature is not None
72
+ assert temperature_init is not None
73
+ self.itc_loss_fn = InfoNCELoss(
74
+ learn_temperature=learn_temperature,
75
+ init_temperature=temperature_init,
76
+ max_temperature=temperature_max,
77
+ min_temperature=temperature_min,
78
+ gather_with_grad=True,
79
+ )
80
+
81
+ def forward(
82
+ self,
83
+ input_ids: Optional[torch.LongTensor] = None,
84
+ pixel_values: Optional[torch.FloatTensor] = None,
85
+ attention_mask: Optional[torch.Tensor] = None,
86
+ position_ids: Optional[torch.LongTensor] = None,
87
+ return_loss: Optional[bool] = None,
88
+ output_attentions: Optional[bool] = None,
89
+ output_hidden_states: Optional[bool] = None,
90
+ return_dict: Optional[bool] = None,
91
+ ) -> tuple | CLYPOutput:
92
+ image_feats = self.vision_encoder(pixel_values)
93
+ text_feats = self.text_encoder(
94
+ {
95
+ "input_ids": input_ids,
96
+ "attention_mask": attention_mask,
97
+ "position_ids": position_ids,
98
+ }
99
+ )
100
+
101
+ loss = None
102
+ if return_loss:
103
+ loss = self.itc_loss_fn(image_feats, text_feats)
104
+
105
+ image_embeds = F.normalize(image_feats, dim=-1)
106
+ text_embeds = F.normalize(text_feats, dim=-1)
107
+
108
+ sim_i2t = image_embeds @ text_embeds.T
109
+ sim_t2i = text_embeds @ image_embeds.T
110
+
111
+ logits_per_image = sim_i2t / self.itc_loss_fn.temperature
112
+ logits_per_text = sim_t2i / self.itc_loss_fn.temperature
113
+
114
+ if not return_dict:
115
+ if loss is None:
116
+ return (logits_per_image, logits_per_text, text_embeds, image_embeds)
117
+ return (loss, logits_per_image, logits_per_text, text_embeds, image_embeds)
118
+
119
+ # TODO:
120
+ # - Support vision_model_output and text_model_output
121
+ # - Improve type: torch.Tensor -> torch.FloatTensor
122
+ return CLYPOutput(
123
+ loss=loss,
124
+ logits_per_image=logits_per_image, # type: ignore
125
+ logits_per_text=logits_per_text, # type: ignore
126
+ text_embeds=text_embeds, # type: ignore
127
+ image_embeds=image_embeds, # type: ignore
128
+ )
129
+
130
+ def get_text_features(
131
+ self,
132
+ input_ids: Optional[torch.Tensor] = None,
133
+ attention_mask: Optional[torch.Tensor] = None,
134
+ position_ids: Optional[torch.Tensor] = None,
135
+ output_attentions: Optional[bool] = None,
136
+ output_hidden_states: Optional[bool] = None,
137
+ return_dict: Optional[bool] = None,
138
+ ) -> torch.FloatTensor:
139
+ text_feats = self.text_encoder(
140
+ {
141
+ "input_ids": input_ids,
142
+ "attention_mask": attention_mask,
143
+ "position_ids": position_ids,
144
+ }
145
+ )
146
+ return text_feats
147
+
148
+ def get_image_features(
149
+ self,
150
+ pixel_values: Optional[torch.FloatTensor] = None,
151
+ output_attentions: Optional[bool] = None,
152
+ output_hidden_states: Optional[bool] = None,
153
+ return_dict: Optional[bool] = None,
154
+ ) -> torch.FloatTensor:
155
+ image_feats = self.vision_encoder(pixel_values)
156
+ return image_feats
157
+
158
+
159
+ if __name__ == "__main__":
160
+ model = CLYPModel.from_pretrained(".")
preprocessor_config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoImageProcessor": "image_processing_clyp.CLYPImageProcessor"
4
+ },
5
+ "image_processor_type": "CLYPImageProcessor",
6
+ "image_size": 224,
7
+ "normalization_type": "imagenet"
8
+ }
special_tokens_map.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
spiece.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b5cbdfa8aa7c54c8c5af85b78c309c54a5f2749a20468bf6f60eee007fe6fec1
3
+ size 805634
tokenization_clyp.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+
3
+ # Copyright 2024 LY Corporation.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ from __future__ import annotations
17
+
18
+ from typing import Optional
19
+
20
+ import torch
21
+ from transformers import BatchEncoding, PreTrainedTokenizer, T5Tokenizer
22
+ from transformers.tokenization_utils_base import (
23
+ PaddingStrategy,
24
+ PreTokenizedInput,
25
+ TextInput,
26
+ TruncationStrategy,
27
+ )
28
+
29
+
30
+ class CLYPTokenizer(PreTrainedTokenizer):
31
+ """CLYPTokenizer based on rinna/japanese-roberta-base
32
+
33
+ This tokenizer is registered as a custom tokenizer to manually add CLS token to each text.
34
+ """
35
+
36
+ def __init__(self, max_length: int, padding: str, truncation: bool, **kwargs):
37
+ # tokenizer
38
+ self.tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-roberta-base")
39
+ self.tokenizer.do_lower_case = True
40
+
41
+ super().__init__(
42
+ max_length=max_length, padding=padding, truncation=truncation, **kwargs
43
+ )
44
+ self.max_length = max_length
45
+ self.padding = padding
46
+ self.truncation = truncation
47
+
48
+ @property
49
+ def vocab_size(self):
50
+ return self.tokenizer.vocab_size
51
+
52
+ def get_vocab(self) -> dict[str, int]:
53
+ return self.tokenizer.get_vocab()
54
+
55
+ def save_vocabulary(
56
+ self, save_directory: str, filename_prefix: Optional[str] = None
57
+ ) -> tuple[str]:
58
+ return self.tokenizer.save_vocabulary(
59
+ save_directory, filename_prefix=filename_prefix
60
+ )
61
+
62
+ def _tokenize(self, text, **kwargs):
63
+ return self.tokenizer._tokenize(text, **kwargs)
64
+
65
+ def _convert_token_to_id(self, token):
66
+ return self.tokenizer._convert_token_to_id(token)
67
+
68
+ def _convert_id_to_token(self, index: int) -> str:
69
+ return self.tokenizer._convert_id_to_token(index)
70
+
71
+ def __call__(
72
+ self,
73
+ text: TextInput | PreTokenizedInput | list[TextInput] | list[PreTokenizedInput],
74
+ add_special_tokens: bool = True,
75
+ padding: bool | str | PaddingStrategy | None = None,
76
+ truncation: bool | str | TruncationStrategy | None = None,
77
+ max_length: Optional[int] = None,
78
+ **kwargs,
79
+ ):
80
+ if max_length is None:
81
+ max_length = self.max_length
82
+ if padding is None:
83
+ padding = self.padding
84
+ if truncation is None:
85
+ truncation = self.truncation
86
+
87
+ if add_special_tokens:
88
+ max_length = max_length - 1
89
+
90
+ if not isinstance(text, list):
91
+ # TODO: Review
92
+ text = [text]
93
+
94
+ out = self.tokenizer(
95
+ text,
96
+ max_length=max_length,
97
+ padding=padding,
98
+ truncation=truncation,
99
+ add_special_tokens=False,
100
+ **kwargs,
101
+ )
102
+
103
+ if add_special_tokens:
104
+ input_ids = [
105
+ [self.tokenizer.cls_token_id] + ids for ids in out["input_ids"]
106
+ ]
107
+ attention_mask = [[1] + am for am in out["attention_mask"]]
108
+ position_ids = [list(range(0, len(input_ids[0])))] * len(input_ids)
109
+ else:
110
+ input_ids = out["input_ids"]
111
+ attention_mask = out["attention_mask"]
112
+ position_ids = [list(range(0, len(input_ids[0])))] * len(input_ids)
113
+
114
+ # tensor
115
+ input_ids = torch.tensor(input_ids, dtype=torch.long)
116
+ attention_mask = torch.tensor(attention_mask, dtype=torch.long)
117
+ position_ids = torch.tensor(position_ids, dtype=torch.long)
118
+
119
+ # retrn
120
+ data = {
121
+ "input_ids": input_ids,
122
+ "attention_mask": attention_mask,
123
+ "position_ids": position_ids,
124
+ }
125
+ return BatchEncoding(data=data)
tokenizer_config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {},
3
+ "auto_map": {
4
+ "AutoTokenizer": [
5
+ "tokenization_clyp.CLYPTokenizer",
6
+ null
7
+ ]
8
+ },
9
+ "clean_up_tokenization_spaces": true,
10
+ "max_length": 77,
11
+ "model_max_length": 1000000000000000019884624838656,
12
+ "padding": "longest",
13
+ "tokenizer_class": "CLYPTokenizer",
14
+ "truncation": true
15
+ }