Format with black and isort and lint with flake8
Browse files- README.md +4 -0
- configuration_hybrid_clip.py +14 -5
- discard_incorrect_files.py +6 -5
- join_datasets_custom_split.py +20 -11
- modeling_hybrid_clip.py +90 -26
- prepare_wit.py +69 -18
- run_hybrid_clip.py +127 -46
- scale_convert.py +8 -8
- test_on_image.py +13 -3
README.md
CHANGED
@@ -7,18 +7,22 @@ tags:
|
|
7 |
- vit
|
8 |
---
|
9 |
# CLIP-Spanish
|
|
|
10 |
CLIP Spanish is a CLIP-like model for Spanish language. It is composed of [BERTIN](https://huggingface.co/bertin-project/bertin-roberta-base-spanish) as a language encoder and the ViT-B/32 image encoder from [CLIP](https://huggingface.co/openai/clip-vit-base-patch32). The model is implemented in [Flax](https://github.com/google/flax), including training scripts (see `training.md`).
|
11 |
This is part of the [Flax/Jax Community Week](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/7104), organised by [HuggingFace](https://huggingface.co/) and TPU usage sponsored by Google.
|
12 |
|
13 |
## Spanish WIT
|
|
|
14 |
We used a subset of 141,230 Spanish captions from the [WIT dataset](https://github.com/google-research-datasets/wit) for training.
|
15 |
|
16 |
## Team members
|
|
|
17 |
- Eduardo González Ponferrada ([edugp](https://huggingface.co/edugp))
|
18 |
- Manu Romero ([mrm8488](https://huggingface.co/))
|
19 |
- María Grandury ([mariagrandury](https://huggingface.co/))
|
20 |
|
21 |
## Useful links
|
|
|
22 |
- [Community Week timeline](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/7104#summary-timeline-calendar-6)
|
23 |
- [Community Week README](https://github.com/huggingface/transformers/blob/master/examples/research_projects/jax-projects/README.md)
|
24 |
- [Community Week thread](https://discuss.huggingface.co/t/bertin-pretrain-roberta-large-from-scratch-in-spanish/7125)
|
|
|
7 |
- vit
|
8 |
---
|
9 |
# CLIP-Spanish
|
10 |
+
|
11 |
CLIP Spanish is a CLIP-like model for Spanish language. It is composed of [BERTIN](https://huggingface.co/bertin-project/bertin-roberta-base-spanish) as a language encoder and the ViT-B/32 image encoder from [CLIP](https://huggingface.co/openai/clip-vit-base-patch32). The model is implemented in [Flax](https://github.com/google/flax), including training scripts (see `training.md`).
|
12 |
This is part of the [Flax/Jax Community Week](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/7104), organised by [HuggingFace](https://huggingface.co/) and TPU usage sponsored by Google.
|
13 |
|
14 |
## Spanish WIT
|
15 |
+
|
16 |
We used a subset of 141,230 Spanish captions from the [WIT dataset](https://github.com/google-research-datasets/wit) for training.
|
17 |
|
18 |
## Team members
|
19 |
+
|
20 |
- Eduardo González Ponferrada ([edugp](https://huggingface.co/edugp))
|
21 |
- Manu Romero ([mrm8488](https://huggingface.co/))
|
22 |
- María Grandury ([mariagrandury](https://huggingface.co/))
|
23 |
|
24 |
## Useful links
|
25 |
+
|
26 |
- [Community Week timeline](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/7104#summary-timeline-calendar-6)
|
27 |
- [Community Week README](https://github.com/huggingface/transformers/blob/master/examples/research_projects/jax-projects/README.md)
|
28 |
- [Community Week thread](https://discuss.huggingface.co/t/bertin-pretrain-roberta-large-from-scratch-in-spanish/7125)
|
configuration_hybrid_clip.py
CHANGED
@@ -3,7 +3,6 @@ import copy
|
|
3 |
from transformers.configuration_utils import PretrainedConfig
|
4 |
from transformers.utils import logging
|
5 |
|
6 |
-
|
7 |
logger = logging.get_logger(__name__)
|
8 |
|
9 |
|
@@ -64,19 +63,25 @@ class HybridCLIPConfig(PretrainedConfig):
|
|
64 |
self.text_config = AutoConfig.for_model(text_model_type, **text_config)
|
65 |
|
66 |
if vision_model_type == "clip":
|
67 |
-
self.vision_config = AutoConfig.for_model(
|
|
|
|
|
68 |
elif vision_model_type == "clip_vision_model":
|
69 |
from transformers import CLIPVisionConfig
|
70 |
|
71 |
self.vision_config = CLIPVisionConfig(**vision_config)
|
72 |
else:
|
73 |
-
self.vision_config = AutoConfig.for_model(
|
|
|
|
|
74 |
|
75 |
self.projection_dim = projection_dim
|
76 |
self.initializer_factor = 1.0
|
77 |
|
78 |
@classmethod
|
79 |
-
def from_text_vision_configs(
|
|
|
|
|
80 |
r"""
|
81 |
Instantiate a :class:`HybridCLIPConfig` (or a derived class) from text model configuration and
|
82 |
vision model configuration.
|
@@ -84,7 +89,11 @@ class HybridCLIPConfig(PretrainedConfig):
|
|
84 |
:class:`HybridCLIPConfig`: An instance of a configuration object
|
85 |
"""
|
86 |
|
87 |
-
return cls(
|
|
|
|
|
|
|
|
|
88 |
|
89 |
def to_dict(self):
|
90 |
"""
|
|
|
3 |
from transformers.configuration_utils import PretrainedConfig
|
4 |
from transformers.utils import logging
|
5 |
|
|
|
6 |
logger = logging.get_logger(__name__)
|
7 |
|
8 |
|
|
|
63 |
self.text_config = AutoConfig.for_model(text_model_type, **text_config)
|
64 |
|
65 |
if vision_model_type == "clip":
|
66 |
+
self.vision_config = AutoConfig.for_model(
|
67 |
+
vision_model_type, **vision_config
|
68 |
+
).vision_config
|
69 |
elif vision_model_type == "clip_vision_model":
|
70 |
from transformers import CLIPVisionConfig
|
71 |
|
72 |
self.vision_config = CLIPVisionConfig(**vision_config)
|
73 |
else:
|
74 |
+
self.vision_config = AutoConfig.for_model(
|
75 |
+
vision_model_type, **vision_config
|
76 |
+
)
|
77 |
|
78 |
self.projection_dim = projection_dim
|
79 |
self.initializer_factor = 1.0
|
80 |
|
81 |
@classmethod
|
82 |
+
def from_text_vision_configs(
|
83 |
+
cls, text_config: PretrainedConfig, vision_config: PretrainedConfig, **kwargs
|
84 |
+
):
|
85 |
r"""
|
86 |
Instantiate a :class:`HybridCLIPConfig` (or a derived class) from text model configuration and
|
87 |
vision model configuration.
|
|
|
89 |
:class:`HybridCLIPConfig`: An instance of a configuration object
|
90 |
"""
|
91 |
|
92 |
+
return cls(
|
93 |
+
text_config=text_config.to_dict(),
|
94 |
+
vision_config=vision_config.to_dict(),
|
95 |
+
**kwargs
|
96 |
+
)
|
97 |
|
98 |
def to_dict(self):
|
99 |
"""
|
discard_incorrect_files.py
CHANGED
@@ -1,9 +1,7 @@
|
|
1 |
import json
|
2 |
import os
|
3 |
-
from tqdm import tqdm
|
4 |
|
5 |
-
import
|
6 |
-
from torchvision.io import ImageReadMode, read_image
|
7 |
|
8 |
JOINT_JSON_DIRECTORY = f"/home/{os.environ['USER']}/data/wit/all_jsons"
|
9 |
SCALE_CONVERTED_DIRECTORY = f"/home/{os.environ['USER']}/data/wit_scale_converted"
|
@@ -16,13 +14,16 @@ for split in ["train", "valid", "test"]:
|
|
16 |
|
17 |
supported_examples = []
|
18 |
for example in tqdm(examples):
|
19 |
-
directory, filename = os.path.split(example[
|
20 |
if filename in valid_files:
|
21 |
example["image_path"] = os.path.join(SCALE_CONVERTED_DIRECTORY, filename)
|
22 |
supported_examples.append(json.dumps(example, ensure_ascii=False))
|
23 |
|
24 |
print(f"Total {split} examples: {len(supported_examples)}")
|
25 |
-
with open(
|
|
|
|
|
|
|
26 |
f.write("\n".join(supported_examples))
|
27 |
|
28 |
print("DONE!")
|
|
|
1 |
import json
|
2 |
import os
|
|
|
3 |
|
4 |
+
from tqdm import tqdm
|
|
|
5 |
|
6 |
JOINT_JSON_DIRECTORY = f"/home/{os.environ['USER']}/data/wit/all_jsons"
|
7 |
SCALE_CONVERTED_DIRECTORY = f"/home/{os.environ['USER']}/data/wit_scale_converted"
|
|
|
14 |
|
15 |
supported_examples = []
|
16 |
for example in tqdm(examples):
|
17 |
+
directory, filename = os.path.split(example["image_path"])
|
18 |
if filename in valid_files:
|
19 |
example["image_path"] = os.path.join(SCALE_CONVERTED_DIRECTORY, filename)
|
20 |
supported_examples.append(json.dumps(example, ensure_ascii=False))
|
21 |
|
22 |
print(f"Total {split} examples: {len(supported_examples)}")
|
23 |
+
with open(
|
24 |
+
f"{SCALE_CONVERTED_DIRECTORY}/{split}_dataset_scale_converted_98_1_1_split.json",
|
25 |
+
"w",
|
26 |
+
) as f:
|
27 |
f.write("\n".join(supported_examples))
|
28 |
|
29 |
print("DONE!")
|
join_datasets_custom_split.py
CHANGED
@@ -1,10 +1,7 @@
|
|
1 |
-
import os
|
2 |
import json
|
|
|
3 |
import random
|
4 |
|
5 |
-
import pandas as pd
|
6 |
-
|
7 |
-
|
8 |
DATA_DIR = f"/home/{os.environ['USER']}/data/wit/all_jsons"
|
9 |
SEED = 0
|
10 |
PROPORTION_TRAIN = 0.98
|
@@ -12,7 +9,9 @@ PROPORTION_VALID = 0.01
|
|
12 |
|
13 |
random.seed(SEED)
|
14 |
|
15 |
-
all_files = [
|
|
|
|
|
16 |
|
17 |
print(all_files)
|
18 |
|
@@ -20,7 +19,9 @@ examples = []
|
|
20 |
for file_ in all_files:
|
21 |
print(file_)
|
22 |
with open(file_) as f:
|
23 |
-
file_examples = [
|
|
|
|
|
24 |
print(len(file_examples))
|
25 |
examples.extend(file_examples)
|
26 |
|
@@ -34,15 +35,23 @@ random.shuffle(examples)
|
|
34 |
print(examples[0])
|
35 |
|
36 |
split_dataset = {}
|
37 |
-
split_dataset["train"] = examples[:int(len(examples) * PROPORTION_TRAIN)]
|
38 |
-
split_dataset["valid"] = examples[
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
|
42 |
for split in ["train", "valid", "test"]:
|
43 |
print("-----")
|
44 |
print(len(split_dataset[split]))
|
45 |
print("-----")
|
46 |
-
with open(
|
|
|
|
|
|
|
47 |
f.write("\n".join(split_dataset[split]))
|
48 |
-
|
|
|
|
|
1 |
import json
|
2 |
+
import os
|
3 |
import random
|
4 |
|
|
|
|
|
|
|
5 |
DATA_DIR = f"/home/{os.environ['USER']}/data/wit/all_jsons"
|
6 |
SEED = 0
|
7 |
PROPORTION_TRAIN = 0.98
|
|
|
9 |
|
10 |
random.seed(SEED)
|
11 |
|
12 |
+
all_files = [
|
13 |
+
f"{DATA_DIR}/{file_}" for file_ in os.listdir(DATA_DIR) if ("all" not in file_)
|
14 |
+
]
|
15 |
|
16 |
print(all_files)
|
17 |
|
|
|
19 |
for file_ in all_files:
|
20 |
print(file_)
|
21 |
with open(file_) as f:
|
22 |
+
file_examples = [
|
23 |
+
json.dumps(json.loads(line), ensure_ascii=False) for line in f.readlines()
|
24 |
+
]
|
25 |
print(len(file_examples))
|
26 |
examples.extend(file_examples)
|
27 |
|
|
|
35 |
print(examples[0])
|
36 |
|
37 |
split_dataset = {}
|
38 |
+
split_dataset["train"] = examples[: int(len(examples) * PROPORTION_TRAIN)]
|
39 |
+
split_dataset["valid"] = examples[
|
40 |
+
int(len(examples) * PROPORTION_TRAIN) : int(
|
41 |
+
len(examples) * (PROPORTION_TRAIN + PROPORTION_VALID)
|
42 |
+
)
|
43 |
+
]
|
44 |
+
split_dataset["test"] = examples[
|
45 |
+
int(len(examples) * (PROPORTION_TRAIN + PROPORTION_VALID)) :
|
46 |
+
]
|
47 |
|
48 |
|
49 |
for split in ["train", "valid", "test"]:
|
50 |
print("-----")
|
51 |
print(len(split_dataset[split]))
|
52 |
print("-----")
|
53 |
+
with open(
|
54 |
+
f"/home/{os.environ['USER']}/data/wit/all_jsons/{split}_dataset_all_98_1_1_split.json",
|
55 |
+
"w",
|
56 |
+
) as f:
|
57 |
f.write("\n".join(split_dataset[split]))
|
|
modeling_hybrid_clip.py
CHANGED
@@ -18,13 +18,13 @@ from typing import Optional, Tuple
|
|
18 |
import flax.linen as nn
|
19 |
import jax
|
20 |
import jax.numpy as jnp
|
21 |
-
from configuration_hybrid_clip import HybridCLIPConfig
|
22 |
from flax.core.frozen_dict import FrozenDict
|
23 |
from transformers import FLAX_MODEL_MAPPING, FlaxCLIPVisionModel
|
24 |
from transformers.modeling_flax_utils import FlaxPreTrainedModel
|
25 |
from transformers.models.clip.modeling_flax_clip import FlaxCLIPOutput
|
26 |
from transformers.utils import logging
|
27 |
|
|
|
28 |
|
29 |
logger = logging.get_logger(__name__)
|
30 |
|
@@ -42,7 +42,9 @@ class FlaxHybridCLIPModule(nn.Module):
|
|
42 |
self.vision_embed_dim = vision_config.hidden_size
|
43 |
|
44 |
text_module = FLAX_MODEL_MAPPING[self.config.text_config.__class__].module_class
|
45 |
-
vision_module = FLAX_MODEL_MAPPING.get(
|
|
|
|
|
46 |
|
47 |
self.text_model = text_module(text_config, dtype=self.dtype)
|
48 |
self.vision_model = vision_module(vision_config, dtype=self.dtype)
|
@@ -73,7 +75,9 @@ class FlaxHybridCLIPModule(nn.Module):
|
|
73 |
output_hidden_states=None,
|
74 |
return_dict=None,
|
75 |
):
|
76 |
-
return_dict =
|
|
|
|
|
77 |
|
78 |
vision_outputs = self.vision_model(
|
79 |
pixel_values=pixel_values,
|
@@ -101,7 +105,9 @@ class FlaxHybridCLIPModule(nn.Module):
|
|
101 |
text_embeds = self.text_projection(text_embeds)
|
102 |
|
103 |
# normalized features
|
104 |
-
image_embeds = image_embeds / jnp.linalg.norm(
|
|
|
|
|
105 |
text_embeds = text_embeds / jnp.linalg.norm(text_embeds, axis=-1, keepdims=True)
|
106 |
|
107 |
# cosine similarity as logits
|
@@ -110,7 +116,14 @@ class FlaxHybridCLIPModule(nn.Module):
|
|
110 |
logits_per_image = logits_per_text.T
|
111 |
|
112 |
if not return_dict:
|
113 |
-
return (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
|
115 |
return FlaxCLIPOutput(
|
116 |
logits_per_image=logits_per_image,
|
@@ -132,18 +145,30 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
|
|
132 |
input_shape: Optional[Tuple] = None,
|
133 |
seed: int = 0,
|
134 |
dtype: jnp.dtype = jnp.float32,
|
135 |
-
**kwargs
|
136 |
):
|
137 |
if input_shape is None:
|
138 |
-
input_shape = (
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
|
140 |
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
141 |
-
super().__init__(
|
|
|
|
|
142 |
|
143 |
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
144 |
# init input tensor
|
145 |
input_ids = jnp.zeros(input_shape[0], dtype="i4")
|
146 |
-
position_ids = jnp.broadcast_to(
|
|
|
|
|
147 |
token_type_ids = jnp.ones_like(input_ids)
|
148 |
attention_mask = jnp.ones_like(input_ids)
|
149 |
|
@@ -152,7 +177,9 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
|
|
152 |
params_rng, dropout_rng = jax.random.split(rng)
|
153 |
rngs = {"params": params_rng, "dropout": dropout_rng}
|
154 |
|
155 |
-
return self.module.init(
|
|
|
|
|
156 |
|
157 |
def __call__(
|
158 |
self,
|
@@ -168,14 +195,24 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
|
|
168 |
output_hidden_states: Optional[bool] = None,
|
169 |
return_dict: Optional[bool] = None,
|
170 |
):
|
171 |
-
output_attentions =
|
|
|
|
|
|
|
|
|
172 |
output_hidden_states = (
|
173 |
-
output_hidden_states
|
|
|
|
|
|
|
|
|
|
|
174 |
)
|
175 |
-
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
176 |
|
177 |
if position_ids is None:
|
178 |
-
position_ids = jnp.broadcast_to(
|
|
|
|
|
179 |
|
180 |
if token_type_ids is None:
|
181 |
token_type_ids = jnp.zeros_like(input_ids)
|
@@ -225,7 +262,9 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
|
|
225 |
obtained by applying the projection layer to the pooled output of text model.
|
226 |
"""
|
227 |
if position_ids is None:
|
228 |
-
position_ids = jnp.broadcast_to(
|
|
|
|
|
229 |
|
230 |
if token_type_ids is None:
|
231 |
token_type_ids = jnp.zeros_like(input_ids)
|
@@ -238,7 +277,14 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
|
|
238 |
if dropout_rng is not None:
|
239 |
rngs["dropout"] = dropout_rng
|
240 |
|
241 |
-
def _get_features(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
242 |
text_outputs = module.text_model(
|
243 |
input_ids=input_ids,
|
244 |
attention_mask=attention_mask,
|
@@ -261,7 +307,9 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
|
|
261 |
rngs=rngs,
|
262 |
)
|
263 |
|
264 |
-
def get_image_features(
|
|
|
|
|
265 |
r"""
|
266 |
Args:
|
267 |
pixel_values (:obj:`numpy.ndarray` of shape :obj:`(batch_size, num_channels, height, width)`):
|
@@ -279,7 +327,9 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
|
|
279 |
rngs["dropout"] = dropout_rng
|
280 |
|
281 |
def _get_features(module, pixel_values, deterministic):
|
282 |
-
vision_outputs = module.vision_model(
|
|
|
|
|
283 |
pooled_output = vision_outputs[1] # pooled_output
|
284 |
image_features = module.visual_projection(pooled_output)
|
285 |
return image_features
|
@@ -345,11 +395,15 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
|
|
345 |
"""
|
346 |
|
347 |
kwargs_text = {
|
348 |
-
argument[len("text_") :]: value
|
|
|
|
|
349 |
}
|
350 |
|
351 |
kwargs_vision = {
|
352 |
-
argument[len("vision_") :]: value
|
|
|
|
|
353 |
}
|
354 |
|
355 |
# remove text, vision kwargs from kwargs
|
@@ -372,7 +426,9 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
|
|
372 |
text_config = AutoConfig.from_pretrained(text_model_name_or_path)
|
373 |
kwargs_text["config"] = text_config
|
374 |
|
375 |
-
text_model = FlaxAutoModel.from_pretrained(
|
|
|
|
|
376 |
|
377 |
vision_model = kwargs_vision.pop("model", None)
|
378 |
if vision_model is None:
|
@@ -387,21 +443,29 @@ class FlaxHybridCLIP(FlaxPreTrainedModel):
|
|
387 |
vision_config = AutoConfig.from_pretrained(vision_model_name_or_path)
|
388 |
kwargs_vision["config"] = vision_config
|
389 |
|
390 |
-
vision_model = FlaxAutoModel.from_pretrained(
|
|
|
|
|
391 |
|
392 |
# instantiate config with corresponding kwargs
|
393 |
dtype = kwargs.pop("dtype", jnp.float32)
|
394 |
-
config = HybridCLIPConfig.from_text_vision_configs(
|
|
|
|
|
395 |
|
396 |
# init model
|
397 |
model = cls(config, *model_args, dtype=dtype, **kwargs)
|
398 |
|
399 |
if vision_config.model_type == "clip":
|
400 |
-
model.params["vision_model"]["vision_model"] = vision_model.params[
|
401 |
-
|
|
|
|
|
|
|
|
|
402 |
else:
|
403 |
model.params["vision_model"] = vision_model.params
|
404 |
|
405 |
model.params["text_model"] = text_model.params
|
406 |
|
407 |
-
return model
|
|
|
18 |
import flax.linen as nn
|
19 |
import jax
|
20 |
import jax.numpy as jnp
|
|
|
21 |
from flax.core.frozen_dict import FrozenDict
|
22 |
from transformers import FLAX_MODEL_MAPPING, FlaxCLIPVisionModel
|
23 |
from transformers.modeling_flax_utils import FlaxPreTrainedModel
|
24 |
from transformers.models.clip.modeling_flax_clip import FlaxCLIPOutput
|
25 |
from transformers.utils import logging
|
26 |
|
27 |
+
from configuration_hybrid_clip import HybridCLIPConfig
|
28 |
|
29 |
logger = logging.get_logger(__name__)
|
30 |
|
|
|
42 |
self.vision_embed_dim = vision_config.hidden_size
|
43 |
|
44 |
text_module = FLAX_MODEL_MAPPING[self.config.text_config.__class__].module_class
|
45 |
+
vision_module = FLAX_MODEL_MAPPING.get(
|
46 |
+
self.config.vision_config.__class__, FlaxCLIPVisionModel
|
47 |
+
).module_class
|
48 |
|
49 |
self.text_model = text_module(text_config, dtype=self.dtype)
|
50 |
self.vision_model = vision_module(vision_config, dtype=self.dtype)
|
|
|
75 |
output_hidden_states=None,
|
76 |
return_dict=None,
|
77 |
):
|
78 |
+
return_dict = (
|
79 |
+
return_dict if return_dict is not None else self.config.return_dict
|
80 |
+
)
|
81 |
|
82 |
vision_outputs = self.vision_model(
|
83 |
pixel_values=pixel_values,
|
|
|
105 |
text_embeds = self.text_projection(text_embeds)
|
106 |
|
107 |
# normalized features
|
108 |
+
image_embeds = image_embeds / jnp.linalg.norm(
|
109 |
+
image_embeds, axis=-1, keepdims=True
|
110 |
+
)
|
111 |
text_embeds = text_embeds / jnp.linalg.norm(text_embeds, axis=-1, keepdims=True)
|
112 |
|
113 |
# cosine similarity as logits
|
|
|
116 |
logits_per_image = logits_per_text.T
|
117 |
|
118 |
if not return_dict:
|
119 |
+
return (
|
120 |
+
logits_per_image,
|
121 |
+
logits_per_text,
|
122 |
+
text_embeds,
|
123 |
+
image_embeds,
|
124 |
+
text_outputs,
|
125 |
+
vision_outputs,
|
126 |
+
)
|
127 |
|
128 |
return FlaxCLIPOutput(
|
129 |
logits_per_image=logits_per_image,
|
|
|
145 |
input_shape: Optional[Tuple] = None,
|
146 |
seed: int = 0,
|
147 |
dtype: jnp.dtype = jnp.float32,
|
148 |
+
**kwargs,
|
149 |
):
|
150 |
if input_shape is None:
|
151 |
+
input_shape = (
|
152 |
+
(1, 1),
|
153 |
+
(
|
154 |
+
1,
|
155 |
+
config.vision_config.image_size,
|
156 |
+
config.vision_config.image_size,
|
157 |
+
3,
|
158 |
+
),
|
159 |
+
)
|
160 |
|
161 |
module = self.module_class(config=config, dtype=dtype, **kwargs)
|
162 |
+
super().__init__(
|
163 |
+
config, module, input_shape=input_shape, seed=seed, dtype=dtype
|
164 |
+
)
|
165 |
|
166 |
def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:
|
167 |
# init input tensor
|
168 |
input_ids = jnp.zeros(input_shape[0], dtype="i4")
|
169 |
+
position_ids = jnp.broadcast_to(
|
170 |
+
jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape[0]
|
171 |
+
)
|
172 |
token_type_ids = jnp.ones_like(input_ids)
|
173 |
attention_mask = jnp.ones_like(input_ids)
|
174 |
|
|
|
177 |
params_rng, dropout_rng = jax.random.split(rng)
|
178 |
rngs = {"params": params_rng, "dropout": dropout_rng}
|
179 |
|
180 |
+
return self.module.init(
|
181 |
+
rngs, input_ids, pixel_values, attention_mask, position_ids, token_type_ids
|
182 |
+
)["params"]
|
183 |
|
184 |
def __call__(
|
185 |
self,
|
|
|
195 |
output_hidden_states: Optional[bool] = None,
|
196 |
return_dict: Optional[bool] = None,
|
197 |
):
|
198 |
+
output_attentions = (
|
199 |
+
output_attentions
|
200 |
+
if output_attentions is not None
|
201 |
+
else self.config.output_attentions
|
202 |
+
)
|
203 |
output_hidden_states = (
|
204 |
+
output_hidden_states
|
205 |
+
if output_hidden_states is not None
|
206 |
+
else self.config.output_hidden_states
|
207 |
+
)
|
208 |
+
return_dict = (
|
209 |
+
return_dict if return_dict is not None else self.config.return_dict
|
210 |
)
|
|
|
211 |
|
212 |
if position_ids is None:
|
213 |
+
position_ids = jnp.broadcast_to(
|
214 |
+
jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape
|
215 |
+
)
|
216 |
|
217 |
if token_type_ids is None:
|
218 |
token_type_ids = jnp.zeros_like(input_ids)
|
|
|
262 |
obtained by applying the projection layer to the pooled output of text model.
|
263 |
"""
|
264 |
if position_ids is None:
|
265 |
+
position_ids = jnp.broadcast_to(
|
266 |
+
jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape
|
267 |
+
)
|
268 |
|
269 |
if token_type_ids is None:
|
270 |
token_type_ids = jnp.zeros_like(input_ids)
|
|
|
277 |
if dropout_rng is not None:
|
278 |
rngs["dropout"] = dropout_rng
|
279 |
|
280 |
+
def _get_features(
|
281 |
+
module,
|
282 |
+
input_ids,
|
283 |
+
attention_mask,
|
284 |
+
position_ids,
|
285 |
+
token_type_ids,
|
286 |
+
deterministic,
|
287 |
+
):
|
288 |
text_outputs = module.text_model(
|
289 |
input_ids=input_ids,
|
290 |
attention_mask=attention_mask,
|
|
|
307 |
rngs=rngs,
|
308 |
)
|
309 |
|
310 |
+
def get_image_features(
|
311 |
+
self, pixel_values, dropout_rng: jax.random.PRNGKey = None, train=False
|
312 |
+
):
|
313 |
r"""
|
314 |
Args:
|
315 |
pixel_values (:obj:`numpy.ndarray` of shape :obj:`(batch_size, num_channels, height, width)`):
|
|
|
327 |
rngs["dropout"] = dropout_rng
|
328 |
|
329 |
def _get_features(module, pixel_values, deterministic):
|
330 |
+
vision_outputs = module.vision_model(
|
331 |
+
pixel_values=pixel_values, deterministic=deterministic
|
332 |
+
)
|
333 |
pooled_output = vision_outputs[1] # pooled_output
|
334 |
image_features = module.visual_projection(pooled_output)
|
335 |
return image_features
|
|
|
395 |
"""
|
396 |
|
397 |
kwargs_text = {
|
398 |
+
argument[len("text_") :]: value
|
399 |
+
for argument, value in kwargs.items()
|
400 |
+
if argument.startswith("text_")
|
401 |
}
|
402 |
|
403 |
kwargs_vision = {
|
404 |
+
argument[len("vision_") :]: value
|
405 |
+
for argument, value in kwargs.items()
|
406 |
+
if argument.startswith("vision_")
|
407 |
}
|
408 |
|
409 |
# remove text, vision kwargs from kwargs
|
|
|
426 |
text_config = AutoConfig.from_pretrained(text_model_name_or_path)
|
427 |
kwargs_text["config"] = text_config
|
428 |
|
429 |
+
text_model = FlaxAutoModel.from_pretrained(
|
430 |
+
text_model_name_or_path, *model_args, **kwargs_text
|
431 |
+
)
|
432 |
|
433 |
vision_model = kwargs_vision.pop("model", None)
|
434 |
if vision_model is None:
|
|
|
443 |
vision_config = AutoConfig.from_pretrained(vision_model_name_or_path)
|
444 |
kwargs_vision["config"] = vision_config
|
445 |
|
446 |
+
vision_model = FlaxAutoModel.from_pretrained(
|
447 |
+
vision_model_name_or_path, *model_args, **kwargs_vision
|
448 |
+
)
|
449 |
|
450 |
# instantiate config with corresponding kwargs
|
451 |
dtype = kwargs.pop("dtype", jnp.float32)
|
452 |
+
config = HybridCLIPConfig.from_text_vision_configs(
|
453 |
+
text_model.config, vision_model.config, **kwargs
|
454 |
+
)
|
455 |
|
456 |
# init model
|
457 |
model = cls(config, *model_args, dtype=dtype, **kwargs)
|
458 |
|
459 |
if vision_config.model_type == "clip":
|
460 |
+
model.params["vision_model"]["vision_model"] = vision_model.params[
|
461 |
+
"vision_model"
|
462 |
+
]
|
463 |
+
model.params["visual_projection"]["kernel"] = vision_model.params[
|
464 |
+
"visual_projection"
|
465 |
+
]["kernel"]
|
466 |
else:
|
467 |
model.params["vision_model"] = vision_model.params
|
468 |
|
469 |
model.params["text_model"] = text_model.params
|
470 |
|
471 |
+
return model
|
prepare_wit.py
CHANGED
@@ -3,14 +3,13 @@ import json
|
|
3 |
import logging
|
4 |
import os
|
5 |
import time
|
6 |
-
from typing import List
|
7 |
-
import urllib.request
|
8 |
import urllib.error
|
|
|
|
|
9 |
|
10 |
import pandas as pd
|
11 |
from tqdm import tqdm
|
12 |
|
13 |
-
|
14 |
logging.basicConfig(
|
15 |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
16 |
datefmt="%m/%d/%Y %H:%M:%S",
|
@@ -18,11 +17,18 @@ logging.basicConfig(
|
|
18 |
)
|
19 |
logger = logging.getLogger(__name__)
|
20 |
|
21 |
-
|
|
|
|
|
|
|
22 |
total_lines = len(lines)
|
23 |
-
train_lines = lines[:int(total_lines * train_proportion)]
|
24 |
-
valid_lines = lines[
|
25 |
-
|
|
|
|
|
|
|
|
|
26 |
|
27 |
with open(f"{output_dir}/train_dataset.json", "w") as f:
|
28 |
f.write("\n".join(train_lines))
|
@@ -33,14 +39,33 @@ def split_and_save_datasets(lines: List[str], output_dir: str, train_proportion:
|
|
33 |
with open(f"{output_dir}/test_dataset.json", "w") as f:
|
34 |
f.write("\n".join(test_lines))
|
35 |
|
|
|
36 |
def prepare_wit(
|
37 |
-
tsv: str,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
os.makedirs(output_dir, exist_ok=True)
|
39 |
logger.info("Loading dataset")
|
40 |
df = pd.read_csv(tsv, sep="\t", engine="python")
|
41 |
existing_files = set(os.listdir(output_dir))
|
42 |
-
not_exists_condition =
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
# Shuffle
|
45 |
df = df.sample(frac=1.0, random_state=seed)
|
46 |
logger.info(f"Trying to downloading {df.shape[0]} files")
|
@@ -58,14 +83,21 @@ def prepare_wit(
|
|
58 |
try:
|
59 |
# Download file
|
60 |
urllib.request.urlretrieve(url, image_path)
|
61 |
-
lines.append(
|
|
|
|
|
|
|
|
|
|
|
62 |
count += 1
|
63 |
break
|
64 |
-
except urllib.error.HTTPError
|
65 |
time.sleep(pause * 10)
|
66 |
if count % backup_period == 0:
|
67 |
logger.info(f"Saving dataset backup: Number of lines {len(lines)}")
|
68 |
-
split_and_save_datasets(
|
|
|
|
|
69 |
if retry == retries - 1:
|
70 |
logger.info(f"Skipping {image_filename}")
|
71 |
pbar.update(1)
|
@@ -73,16 +105,35 @@ def prepare_wit(
|
|
73 |
finally:
|
74 |
split_and_save_datasets(lines, output_dir, train_proportion, valid_proportion)
|
75 |
|
|
|
76 |
if __name__ == "__main__":
|
77 |
-
parser = argparse.ArgumentParser(description
|
78 |
-
parser.add_argument(
|
|
|
|
|
|
|
|
|
79 |
parser.add_argument("--language", type=str, default="es")
|
80 |
-
parser.add_argument(
|
|
|
|
|
|
|
|
|
81 |
parser.add_argument("--random_seed", type=int, default=0)
|
82 |
parser.add_argument("--train_proportion", type=float, default=0.8)
|
83 |
parser.add_argument("--valid_proportion", type=float, default=0.1)
|
84 |
parser.add_argument("--backup_period", type=int, default=1000)
|
85 |
|
86 |
args = parser.parse_args()
|
87 |
-
assert
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
import logging
|
4 |
import os
|
5 |
import time
|
|
|
|
|
6 |
import urllib.error
|
7 |
+
import urllib.request
|
8 |
+
from typing import List
|
9 |
|
10 |
import pandas as pd
|
11 |
from tqdm import tqdm
|
12 |
|
|
|
13 |
logging.basicConfig(
|
14 |
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
15 |
datefmt="%m/%d/%Y %H:%M:%S",
|
|
|
17 |
)
|
18 |
logger = logging.getLogger(__name__)
|
19 |
|
20 |
+
|
21 |
+
def split_and_save_datasets(
|
22 |
+
lines: List[str], output_dir: str, train_proportion: float, valid_proportion: float
|
23 |
+
):
|
24 |
total_lines = len(lines)
|
25 |
+
train_lines = lines[: int(total_lines * train_proportion)]
|
26 |
+
valid_lines = lines[
|
27 |
+
int(total_lines * train_proportion) : int(
|
28 |
+
total_lines * (train_proportion + valid_proportion)
|
29 |
+
)
|
30 |
+
]
|
31 |
+
test_lines = lines[int(total_lines * (train_proportion + valid_proportion)) :]
|
32 |
|
33 |
with open(f"{output_dir}/train_dataset.json", "w") as f:
|
34 |
f.write("\n".join(train_lines))
|
|
|
39 |
with open(f"{output_dir}/test_dataset.json", "w") as f:
|
40 |
f.write("\n".join(test_lines))
|
41 |
|
42 |
+
|
43 |
def prepare_wit(
|
44 |
+
tsv: str,
|
45 |
+
language: str,
|
46 |
+
output_dir: str,
|
47 |
+
seed: int,
|
48 |
+
train_proportion: float,
|
49 |
+
valid_proportion: float,
|
50 |
+
backup_period: int,
|
51 |
+
language_col: str = "language",
|
52 |
+
caption_col: str = "caption_reference_description",
|
53 |
+
url_col: str = "image_url",
|
54 |
+
pause=0.875,
|
55 |
+
retries: int = 10,
|
56 |
+
):
|
57 |
os.makedirs(output_dir, exist_ok=True)
|
58 |
logger.info("Loading dataset")
|
59 |
df = pd.read_csv(tsv, sep="\t", engine="python")
|
60 |
existing_files = set(os.listdir(output_dir))
|
61 |
+
not_exists_condition = ~(
|
62 |
+
df[url_col].map(lambda x: x.split("/")[-1][-100:]).isin(existing_files)
|
63 |
+
)
|
64 |
+
df = df[
|
65 |
+
(df["language"] == language)
|
66 |
+
& (~df["caption_reference_description"].isnull())
|
67 |
+
& not_exists_condition
|
68 |
+
]
|
69 |
# Shuffle
|
70 |
df = df.sample(frac=1.0, random_state=seed)
|
71 |
logger.info(f"Trying to downloading {df.shape[0]} files")
|
|
|
83 |
try:
|
84 |
# Download file
|
85 |
urllib.request.urlretrieve(url, image_path)
|
86 |
+
lines.append(
|
87 |
+
json.dumps(
|
88 |
+
{"image_path": image_path, "captions": [caption]},
|
89 |
+
ensure_ascii=False,
|
90 |
+
)
|
91 |
+
)
|
92 |
count += 1
|
93 |
break
|
94 |
+
except urllib.error.HTTPError:
|
95 |
time.sleep(pause * 10)
|
96 |
if count % backup_period == 0:
|
97 |
logger.info(f"Saving dataset backup: Number of lines {len(lines)}")
|
98 |
+
split_and_save_datasets(
|
99 |
+
lines, output_dir, train_proportion, valid_proportion
|
100 |
+
)
|
101 |
if retry == retries - 1:
|
102 |
logger.info(f"Skipping {image_filename}")
|
103 |
pbar.update(1)
|
|
|
105 |
finally:
|
106 |
split_and_save_datasets(lines, output_dir, train_proportion, valid_proportion)
|
107 |
|
108 |
+
|
109 |
if __name__ == "__main__":
|
110 |
+
parser = argparse.ArgumentParser(description="Download and prepare the WIT dataset")
|
111 |
+
parser.add_argument(
|
112 |
+
"--tsv",
|
113 |
+
type=str,
|
114 |
+
default=f"/home/{os.environ['USER']}/data/wit/wit_v1.train.all-1percent_sample.tsv",
|
115 |
+
)
|
116 |
parser.add_argument("--language", type=str, default="es")
|
117 |
+
parser.add_argument(
|
118 |
+
"--output_dir",
|
119 |
+
type=str,
|
120 |
+
default=f"/home/{os.environ['USER']}/data/wit/prepared_dataset",
|
121 |
+
)
|
122 |
parser.add_argument("--random_seed", type=int, default=0)
|
123 |
parser.add_argument("--train_proportion", type=float, default=0.8)
|
124 |
parser.add_argument("--valid_proportion", type=float, default=0.1)
|
125 |
parser.add_argument("--backup_period", type=int, default=1000)
|
126 |
|
127 |
args = parser.parse_args()
|
128 |
+
assert (
|
129 |
+
args.train_proportion + args.valid_proportion < 1.0
|
130 |
+
), "The sum of train_proportion and valid_proportion has to be < 1.0"
|
131 |
+
prepare_wit(
|
132 |
+
args.tsv,
|
133 |
+
args.language,
|
134 |
+
args.output_dir,
|
135 |
+
args.random_seed,
|
136 |
+
args.train_proportion,
|
137 |
+
args.valid_proportion,
|
138 |
+
args.backup_period,
|
139 |
+
)
|
run_hybrid_clip.py
CHANGED
@@ -32,25 +32,26 @@ from dataclasses import dataclass, field
|
|
32 |
from pathlib import Path
|
33 |
from typing import Callable, Optional
|
34 |
|
35 |
-
import numpy as np
|
36 |
-
import torch
|
37 |
-
from torchvision.datasets import VisionDataset
|
38 |
-
from torchvision.io import ImageReadMode, read_image
|
39 |
-
from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize
|
40 |
-
from torchvision.transforms.functional import InterpolationMode
|
41 |
-
from tqdm import tqdm
|
42 |
-
|
43 |
import jax
|
44 |
import jax.numpy as jnp
|
|
|
45 |
import optax
|
|
|
46 |
import transformers
|
47 |
from flax import jax_utils
|
48 |
from flax.jax_utils import unreplicate
|
49 |
from flax.training import train_state
|
50 |
from flax.training.common_utils import get_metrics, shard, shard_prng_key
|
51 |
-
from
|
52 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
|
|
54 |
|
55 |
logger = logging.getLogger(__name__)
|
56 |
|
@@ -61,7 +62,9 @@ if has_tensorboard:
|
|
61 |
from flax.metrics.tensorboard import SummaryWriter
|
62 |
except ImportError as ie:
|
63 |
has_tensorboard = False
|
64 |
-
print(
|
|
|
|
|
65 |
|
66 |
else:
|
67 |
print(
|
@@ -90,20 +93,33 @@ class ModelArguments:
|
|
90 |
)
|
91 |
from_pt: bool = field(
|
92 |
default=True,
|
93 |
-
metadata={
|
|
|
|
|
94 |
)
|
95 |
config_name: Optional[str] = field(
|
96 |
-
default=None,
|
|
|
|
|
|
|
97 |
)
|
98 |
tokenizer_name: Optional[str] = field(
|
99 |
-
default=None,
|
|
|
|
|
|
|
100 |
)
|
101 |
cache_dir: Optional[str] = field(
|
102 |
-
default=None,
|
|
|
|
|
|
|
103 |
)
|
104 |
use_fast_tokenizer: bool = field(
|
105 |
default=True,
|
106 |
-
metadata={
|
|
|
|
|
107 |
)
|
108 |
dtype: Optional[str] = field(
|
109 |
default="float32",
|
@@ -119,9 +135,12 @@ class DataTrainingArguments:
|
|
119 |
Arguments pertaining to what data we are going to input our model for training and eval.
|
120 |
"""
|
121 |
|
122 |
-
data_dir: Optional[str] = field(
|
|
|
|
|
123 |
train_file: Optional[str] = field(
|
124 |
-
default=None,
|
|
|
125 |
)
|
126 |
validation_file: Optional[str] = field(
|
127 |
default=None,
|
@@ -149,10 +168,12 @@ class DataTrainingArguments:
|
|
149 |
},
|
150 |
)
|
151 |
overwrite_cache: bool = field(
|
152 |
-
default=False,
|
|
|
153 |
)
|
154 |
overwrite_cache: bool = field(
|
155 |
-
default=False,
|
|
|
156 |
)
|
157 |
preprocessing_num_workers: Optional[int] = field(
|
158 |
default=None,
|
@@ -161,7 +182,9 @@ class DataTrainingArguments:
|
|
161 |
|
162 |
def __post_init__(self):
|
163 |
if self.train_file is None and self.validation_file is None:
|
164 |
-
raise ValueError(
|
|
|
|
|
165 |
else:
|
166 |
if self.train_file is not None:
|
167 |
extension = self.train_file.split(".")[-1]
|
@@ -180,7 +203,10 @@ class Transform(torch.nn.Module):
|
|
180 |
Resize([image_size], interpolation=InterpolationMode.BICUBIC),
|
181 |
CenterCrop(image_size),
|
182 |
ConvertImageDtype(torch.float),
|
183 |
-
Normalize(
|
|
|
|
|
|
|
184 |
)
|
185 |
|
186 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
@@ -225,7 +251,7 @@ class ImageTextDataset(VisionDataset):
|
|
225 |
self.image_paths = []
|
226 |
|
227 |
for example in examples:
|
228 |
-
captions_subset =
|
229 |
self.captions.extend(captions_subset)
|
230 |
self.image_paths.extend([example["image_path"]] * len(captions_subset))
|
231 |
|
@@ -253,7 +279,9 @@ class TrainState(train_state.TrainState):
|
|
253 |
dropout_rng: jnp.ndarray
|
254 |
|
255 |
def replicate(self):
|
256 |
-
return jax_utils.replicate(self).replace(
|
|
|
|
|
257 |
|
258 |
|
259 |
def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
|
@@ -270,25 +298,39 @@ def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
|
|
270 |
|
271 |
|
272 |
def create_learning_rate_fn(
|
273 |
-
train_ds_size: int,
|
|
|
|
|
|
|
|
|
274 |
) -> Callable[[int], jnp.array]:
|
275 |
"""Returns a linear warmup, linear_decay learning rate function."""
|
276 |
steps_per_epoch = train_ds_size // train_batch_size
|
277 |
num_train_steps = steps_per_epoch * num_train_epochs
|
278 |
-
warmup_fn = optax.linear_schedule(
|
|
|
|
|
279 |
decay_fn = optax.linear_schedule(
|
280 |
-
init_value=learning_rate,
|
|
|
|
|
|
|
|
|
|
|
281 |
)
|
282 |
-
schedule_fn = optax.join_schedules(schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps])
|
283 |
return schedule_fn
|
284 |
|
285 |
|
286 |
def main():
|
287 |
-
parser = HfArgumentParser(
|
|
|
|
|
288 |
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
289 |
# If we pass only one argument to the script and it's the path to a json file,
|
290 |
# let's parse it to get our arguments.
|
291 |
-
model_args, data_args, training_args = parser.parse_json_file(
|
|
|
|
|
292 |
else:
|
293 |
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
294 |
|
@@ -321,11 +363,15 @@ def main():
|
|
321 |
|
322 |
if model_args.tokenizer_name:
|
323 |
tokenizer = AutoTokenizer.from_pretrained(
|
324 |
-
model_args.tokenizer_name,
|
|
|
|
|
325 |
)
|
326 |
elif model_args.text_model_name_or_path:
|
327 |
tokenizer = AutoTokenizer.from_pretrained(
|
328 |
-
model_args.text_model_name_or_path,
|
|
|
|
|
329 |
)
|
330 |
else:
|
331 |
raise ValueError(
|
@@ -366,16 +412,28 @@ def main():
|
|
366 |
|
367 |
# Store some constant
|
368 |
num_epochs = int(training_args.num_train_epochs)
|
369 |
-
train_batch_size =
|
|
|
|
|
370 |
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
371 |
steps_per_epoch = len(train_dataset) // train_batch_size
|
372 |
total_train_steps = steps_per_epoch * num_epochs
|
373 |
|
374 |
# Use collate function to tokenizer the text and convert the processed images to numpy
|
375 |
def collate_fn(examples):
|
376 |
-
pixel_values =
|
|
|
|
|
|
|
|
|
377 |
captions = [example[1] for example in examples]
|
378 |
-
inputs = tokenizer(
|
|
|
|
|
|
|
|
|
|
|
|
|
379 |
|
380 |
batch = {
|
381 |
"pixel_values": pixel_values,
|
@@ -408,7 +466,9 @@ def main():
|
|
408 |
|
409 |
# Enable tensorboard only on the master node
|
410 |
if has_tensorboard and jax.process_index() == 0:
|
411 |
-
summary_writer = SummaryWriter(
|
|
|
|
|
412 |
|
413 |
# Initialize our training
|
414 |
rng = jax.random.PRNGKey(training_args.seed)
|
@@ -433,7 +493,9 @@ def main():
|
|
433 |
)
|
434 |
|
435 |
# Setup train state
|
436 |
-
state = TrainState.create(
|
|
|
|
|
437 |
|
438 |
def cross_entropy(logits, axis):
|
439 |
logprobs = jax.nn.log_softmax(logits, axis=axis)
|
@@ -442,7 +504,9 @@ def main():
|
|
442 |
return ce
|
443 |
|
444 |
def clip_loss(similarity):
|
445 |
-
loss = (
|
|
|
|
|
446 |
return loss
|
447 |
|
448 |
# Define gradient update step fn
|
@@ -450,7 +514,9 @@ def main():
|
|
450 |
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
451 |
|
452 |
def compute_loss(params):
|
453 |
-
logits = state.apply_fn(
|
|
|
|
|
454 |
loss = clip_loss(logits)
|
455 |
return loss
|
456 |
|
@@ -460,7 +526,10 @@ def main():
|
|
460 |
|
461 |
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
|
462 |
|
463 |
-
metrics = {
|
|
|
|
|
|
|
464 |
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
465 |
|
466 |
return new_state, metrics
|
@@ -485,8 +554,12 @@ def main():
|
|
485 |
logger.info("***** Running training *****")
|
486 |
logger.info(f" Num examples = {len(train_dataset)}")
|
487 |
logger.info(f" Num Epochs = {num_epochs}")
|
488 |
-
logger.info(
|
489 |
-
|
|
|
|
|
|
|
|
|
490 |
logger.info(f" Total optimization steps = {total_train_steps}")
|
491 |
|
492 |
train_time = 0
|
@@ -504,7 +577,9 @@ def main():
|
|
504 |
train_metrics = []
|
505 |
|
506 |
steps_per_epoch = len(train_dataset) // train_batch_size
|
507 |
-
train_step_progress_bar = tqdm(
|
|
|
|
|
508 |
# train
|
509 |
for batch in train_loader:
|
510 |
batch = shard(batch)
|
@@ -525,7 +600,9 @@ def main():
|
|
525 |
# ======================== Evaluating ==============================
|
526 |
eval_metrics = []
|
527 |
eval_steps = len(eval_dataset) // eval_batch_size
|
528 |
-
eval_step_progress_bar = tqdm(
|
|
|
|
|
529 |
for batch in eval_loader:
|
530 |
# Model forward
|
531 |
batch = shard(batch)
|
@@ -541,14 +618,18 @@ def main():
|
|
541 |
|
542 |
# Print metrics and update progress bar
|
543 |
eval_step_progress_bar.close()
|
544 |
-
desc =
|
|
|
|
|
545 |
epochs.write(desc)
|
546 |
epochs.desc = desc
|
547 |
|
548 |
# Save metrics
|
549 |
if has_tensorboard and jax.process_index() == 0:
|
550 |
cur_step = epoch * (len(train_dataset) // train_batch_size)
|
551 |
-
write_metric(
|
|
|
|
|
552 |
|
553 |
# save checkpoint after each epoch and push checkpoint to the hub
|
554 |
if jax.process_index() == 0:
|
|
|
32 |
from pathlib import Path
|
33 |
from typing import Callable, Optional
|
34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
import jax
|
36 |
import jax.numpy as jnp
|
37 |
+
import numpy as np
|
38 |
import optax
|
39 |
+
import torch
|
40 |
import transformers
|
41 |
from flax import jax_utils
|
42 |
from flax.jax_utils import unreplicate
|
43 |
from flax.training import train_state
|
44 |
from flax.training.common_utils import get_metrics, shard, shard_prng_key
|
45 |
+
from torchvision.datasets import VisionDataset
|
46 |
+
from torchvision.io import ImageReadMode, read_image
|
47 |
+
from torchvision.transforms import (CenterCrop, ConvertImageDtype, Normalize,
|
48 |
+
Resize)
|
49 |
+
from torchvision.transforms.functional import InterpolationMode
|
50 |
+
from tqdm import tqdm
|
51 |
+
from transformers import (AutoTokenizer, HfArgumentParser, TrainingArguments,
|
52 |
+
is_tensorboard_available, set_seed)
|
53 |
|
54 |
+
from modeling_hybrid_clip import FlaxHybridCLIP
|
55 |
|
56 |
logger = logging.getLogger(__name__)
|
57 |
|
|
|
62 |
from flax.metrics.tensorboard import SummaryWriter
|
63 |
except ImportError as ie:
|
64 |
has_tensorboard = False
|
65 |
+
print(
|
66 |
+
f"Unable to display metrics through TensorBoard because some package are not installed: {ie}"
|
67 |
+
)
|
68 |
|
69 |
else:
|
70 |
print(
|
|
|
93 |
)
|
94 |
from_pt: bool = field(
|
95 |
default=True,
|
96 |
+
metadata={
|
97 |
+
"help": "whether to load the text and vision model using PyTorch checkpoints."
|
98 |
+
},
|
99 |
)
|
100 |
config_name: Optional[str] = field(
|
101 |
+
default=None,
|
102 |
+
metadata={
|
103 |
+
"help": "Pretrained config name or path if not the same as model_name"
|
104 |
+
},
|
105 |
)
|
106 |
tokenizer_name: Optional[str] = field(
|
107 |
+
default=None,
|
108 |
+
metadata={
|
109 |
+
"help": "Pretrained tokenizer name or path if not the same as model_name"
|
110 |
+
},
|
111 |
)
|
112 |
cache_dir: Optional[str] = field(
|
113 |
+
default=None,
|
114 |
+
metadata={
|
115 |
+
"help": "Where do you want to store the pretrained models downloaded from s3"
|
116 |
+
},
|
117 |
)
|
118 |
use_fast_tokenizer: bool = field(
|
119 |
default=True,
|
120 |
+
metadata={
|
121 |
+
"help": "Whether to use one of the fast tokenizer (backed by the tokenizers library) or not."
|
122 |
+
},
|
123 |
)
|
124 |
dtype: Optional[str] = field(
|
125 |
default="float32",
|
|
|
135 |
Arguments pertaining to what data we are going to input our model for training and eval.
|
136 |
"""
|
137 |
|
138 |
+
data_dir: Optional[str] = field(
|
139 |
+
default=None, metadata={"help": "The data directory containing input files."}
|
140 |
+
)
|
141 |
train_file: Optional[str] = field(
|
142 |
+
default=None,
|
143 |
+
metadata={"help": "The input training data file (a jsonlines file)."},
|
144 |
)
|
145 |
validation_file: Optional[str] = field(
|
146 |
default=None,
|
|
|
168 |
},
|
169 |
)
|
170 |
overwrite_cache: bool = field(
|
171 |
+
default=False,
|
172 |
+
metadata={"help": "Overwrite the cached training and evaluation sets"},
|
173 |
)
|
174 |
overwrite_cache: bool = field(
|
175 |
+
default=False,
|
176 |
+
metadata={"help": "Overwrite the cached training and evaluation sets"},
|
177 |
)
|
178 |
preprocessing_num_workers: Optional[int] = field(
|
179 |
default=None,
|
|
|
182 |
|
183 |
def __post_init__(self):
|
184 |
if self.train_file is None and self.validation_file is None:
|
185 |
+
raise ValueError(
|
186 |
+
"Need either a dataset name or a training/validation file."
|
187 |
+
)
|
188 |
else:
|
189 |
if self.train_file is not None:
|
190 |
extension = self.train_file.split(".")[-1]
|
|
|
203 |
Resize([image_size], interpolation=InterpolationMode.BICUBIC),
|
204 |
CenterCrop(image_size),
|
205 |
ConvertImageDtype(torch.float),
|
206 |
+
Normalize(
|
207 |
+
(0.48145466, 0.4578275, 0.40821073),
|
208 |
+
(0.26862954, 0.26130258, 0.27577711),
|
209 |
+
),
|
210 |
)
|
211 |
|
212 |
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
251 |
self.image_paths = []
|
252 |
|
253 |
for example in examples:
|
254 |
+
captions_subset = example["captions"][:captions_per_image]
|
255 |
self.captions.extend(captions_subset)
|
256 |
self.image_paths.extend([example["image_path"]] * len(captions_subset))
|
257 |
|
|
|
279 |
dropout_rng: jnp.ndarray
|
280 |
|
281 |
def replicate(self):
|
282 |
+
return jax_utils.replicate(self).replace(
|
283 |
+
dropout_rng=shard_prng_key(self.dropout_rng)
|
284 |
+
)
|
285 |
|
286 |
|
287 |
def write_metric(summary_writer, train_metrics, eval_metrics, train_time, step):
|
|
|
298 |
|
299 |
|
300 |
def create_learning_rate_fn(
|
301 |
+
train_ds_size: int,
|
302 |
+
train_batch_size: int,
|
303 |
+
num_train_epochs: int,
|
304 |
+
num_warmup_steps: int,
|
305 |
+
learning_rate: float,
|
306 |
) -> Callable[[int], jnp.array]:
|
307 |
"""Returns a linear warmup, linear_decay learning rate function."""
|
308 |
steps_per_epoch = train_ds_size // train_batch_size
|
309 |
num_train_steps = steps_per_epoch * num_train_epochs
|
310 |
+
warmup_fn = optax.linear_schedule(
|
311 |
+
init_value=0.0, end_value=learning_rate, transition_steps=num_warmup_steps
|
312 |
+
)
|
313 |
decay_fn = optax.linear_schedule(
|
314 |
+
init_value=learning_rate,
|
315 |
+
end_value=0,
|
316 |
+
transition_steps=num_train_steps - num_warmup_steps,
|
317 |
+
)
|
318 |
+
schedule_fn = optax.join_schedules(
|
319 |
+
schedules=[warmup_fn, decay_fn], boundaries=[num_warmup_steps]
|
320 |
)
|
|
|
321 |
return schedule_fn
|
322 |
|
323 |
|
324 |
def main():
|
325 |
+
parser = HfArgumentParser(
|
326 |
+
(ModelArguments, DataTrainingArguments, TrainingArguments)
|
327 |
+
)
|
328 |
if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
|
329 |
# If we pass only one argument to the script and it's the path to a json file,
|
330 |
# let's parse it to get our arguments.
|
331 |
+
model_args, data_args, training_args = parser.parse_json_file(
|
332 |
+
json_file=os.path.abspath(sys.argv[1])
|
333 |
+
)
|
334 |
else:
|
335 |
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
336 |
|
|
|
363 |
|
364 |
if model_args.tokenizer_name:
|
365 |
tokenizer = AutoTokenizer.from_pretrained(
|
366 |
+
model_args.tokenizer_name,
|
367 |
+
cache_dir=model_args.cache_dir,
|
368 |
+
use_fast=model_args.use_fast_tokenizer,
|
369 |
)
|
370 |
elif model_args.text_model_name_or_path:
|
371 |
tokenizer = AutoTokenizer.from_pretrained(
|
372 |
+
model_args.text_model_name_or_path,
|
373 |
+
cache_dir=model_args.cache_dir,
|
374 |
+
use_fast=model_args.use_fast_tokenizer,
|
375 |
)
|
376 |
else:
|
377 |
raise ValueError(
|
|
|
412 |
|
413 |
# Store some constant
|
414 |
num_epochs = int(training_args.num_train_epochs)
|
415 |
+
train_batch_size = (
|
416 |
+
int(training_args.per_device_train_batch_size) * jax.device_count()
|
417 |
+
)
|
418 |
eval_batch_size = int(training_args.per_device_eval_batch_size) * jax.device_count()
|
419 |
steps_per_epoch = len(train_dataset) // train_batch_size
|
420 |
total_train_steps = steps_per_epoch * num_epochs
|
421 |
|
422 |
# Use collate function to tokenizer the text and convert the processed images to numpy
|
423 |
def collate_fn(examples):
|
424 |
+
pixel_values = (
|
425 |
+
torch.stack([example[0] for example in examples])
|
426 |
+
.permute(0, 2, 3, 1)
|
427 |
+
.numpy()
|
428 |
+
)
|
429 |
captions = [example[1] for example in examples]
|
430 |
+
inputs = tokenizer(
|
431 |
+
captions,
|
432 |
+
max_length=data_args.max_seq_length,
|
433 |
+
padding="max_length",
|
434 |
+
truncation=True,
|
435 |
+
return_tensors="np",
|
436 |
+
)
|
437 |
|
438 |
batch = {
|
439 |
"pixel_values": pixel_values,
|
|
|
466 |
|
467 |
# Enable tensorboard only on the master node
|
468 |
if has_tensorboard and jax.process_index() == 0:
|
469 |
+
summary_writer = SummaryWriter(
|
470 |
+
log_dir=Path(training_args.output_dir).joinpath("logs").as_posix()
|
471 |
+
)
|
472 |
|
473 |
# Initialize our training
|
474 |
rng = jax.random.PRNGKey(training_args.seed)
|
|
|
493 |
)
|
494 |
|
495 |
# Setup train state
|
496 |
+
state = TrainState.create(
|
497 |
+
apply_fn=model.__call__, params=model.params, tx=adamw, dropout_rng=dropout_rng
|
498 |
+
)
|
499 |
|
500 |
def cross_entropy(logits, axis):
|
501 |
logprobs = jax.nn.log_softmax(logits, axis=axis)
|
|
|
504 |
return ce
|
505 |
|
506 |
def clip_loss(similarity):
|
507 |
+
loss = (
|
508 |
+
cross_entropy(similarity, axis=0) + cross_entropy(similarity, axis=1)
|
509 |
+
) / 2
|
510 |
return loss
|
511 |
|
512 |
# Define gradient update step fn
|
|
|
514 |
dropout_rng, new_dropout_rng = jax.random.split(state.dropout_rng)
|
515 |
|
516 |
def compute_loss(params):
|
517 |
+
logits = state.apply_fn(
|
518 |
+
**batch, params=params, dropout_rng=dropout_rng, train=True
|
519 |
+
)[0]
|
520 |
loss = clip_loss(logits)
|
521 |
return loss
|
522 |
|
|
|
526 |
|
527 |
new_state = state.apply_gradients(grads=grad, dropout_rng=new_dropout_rng)
|
528 |
|
529 |
+
metrics = {
|
530 |
+
"loss": loss,
|
531 |
+
"learning_rate": linear_decay_lr_schedule_fn(state.step),
|
532 |
+
}
|
533 |
metrics = jax.lax.pmean(metrics, axis_name="batch")
|
534 |
|
535 |
return new_state, metrics
|
|
|
554 |
logger.info("***** Running training *****")
|
555 |
logger.info(f" Num examples = {len(train_dataset)}")
|
556 |
logger.info(f" Num Epochs = {num_epochs}")
|
557 |
+
logger.info(
|
558 |
+
f" Instantaneous batch size per device = {training_args.per_device_train_batch_size}"
|
559 |
+
)
|
560 |
+
logger.info(
|
561 |
+
f" Total train batch size (w. parallel & distributed) = {train_batch_size}"
|
562 |
+
)
|
563 |
logger.info(f" Total optimization steps = {total_train_steps}")
|
564 |
|
565 |
train_time = 0
|
|
|
577 |
train_metrics = []
|
578 |
|
579 |
steps_per_epoch = len(train_dataset) // train_batch_size
|
580 |
+
train_step_progress_bar = tqdm(
|
581 |
+
total=steps_per_epoch, desc="Training...", position=1, leave=False
|
582 |
+
)
|
583 |
# train
|
584 |
for batch in train_loader:
|
585 |
batch = shard(batch)
|
|
|
600 |
# ======================== Evaluating ==============================
|
601 |
eval_metrics = []
|
602 |
eval_steps = len(eval_dataset) // eval_batch_size
|
603 |
+
eval_step_progress_bar = tqdm(
|
604 |
+
total=eval_steps, desc="Evaluating...", position=2, leave=False
|
605 |
+
)
|
606 |
for batch in eval_loader:
|
607 |
# Model forward
|
608 |
batch = shard(batch)
|
|
|
618 |
|
619 |
# Print metrics and update progress bar
|
620 |
eval_step_progress_bar.close()
|
621 |
+
desc = (
|
622 |
+
f"Epoch... ({epoch + 1}/{num_epochs} | Eval Loss: {eval_metrics['loss']})"
|
623 |
+
)
|
624 |
epochs.write(desc)
|
625 |
epochs.desc = desc
|
626 |
|
627 |
# Save metrics
|
628 |
if has_tensorboard and jax.process_index() == 0:
|
629 |
cur_step = epoch * (len(train_dataset) // train_batch_size)
|
630 |
+
write_metric(
|
631 |
+
summary_writer, train_metrics, eval_metrics, train_time, cur_step
|
632 |
+
)
|
633 |
|
634 |
# save checkpoint after each epoch and push checkpoint to the hub
|
635 |
if jax.process_index() == 0:
|
scale_convert.py
CHANGED
@@ -1,12 +1,11 @@
|
|
1 |
import glob
|
2 |
import itertools
|
3 |
-
from argparse import ArgumentParser
|
4 |
-
from joblib import Parallel, delayed
|
5 |
import os
|
6 |
import subprocess
|
|
|
7 |
from collections import Counter
|
8 |
-
import shutil
|
9 |
|
|
|
10 |
|
11 |
parser = ArgumentParser()
|
12 |
parser.add_argument("in_dir")
|
@@ -26,17 +25,16 @@ files = itertools.chain(
|
|
26 |
glob.iglob(f"{args.in_dir}/*/*.SVG"),
|
27 |
)
|
28 |
|
|
|
29 |
def process_file(path):
|
30 |
basename = os.path.basename(path)
|
31 |
-
ext = os.path.splitext(basename)[1]
|
32 |
name = os.path.splitext(basename)[0]
|
33 |
|
34 |
-
dirname = os.path.dirname(path)
|
35 |
try:
|
36 |
r = subprocess.run(
|
37 |
f'convert {path} -resize "224^>" -colorspace RGB -density 1200 {args.out_dir}/{name}.jpg',
|
38 |
shell=True,
|
39 |
-
timeout=10
|
40 |
)
|
41 |
rcode = r.returncode
|
42 |
except subprocess.TimeoutExpired:
|
@@ -48,6 +46,8 @@ def process_file(path):
|
|
48 |
|
49 |
return rcode
|
50 |
|
51 |
-
codes = Parallel(n_jobs=32, prefer="threads", verbose=1)(delayed(process_file)(f) for f in files)
|
52 |
-
print(Counter(codes))
|
53 |
|
|
|
|
|
|
|
|
|
|
1 |
import glob
|
2 |
import itertools
|
|
|
|
|
3 |
import os
|
4 |
import subprocess
|
5 |
+
from argparse import ArgumentParser
|
6 |
from collections import Counter
|
|
|
7 |
|
8 |
+
from joblib import Parallel, delayed
|
9 |
|
10 |
parser = ArgumentParser()
|
11 |
parser.add_argument("in_dir")
|
|
|
25 |
glob.iglob(f"{args.in_dir}/*/*.SVG"),
|
26 |
)
|
27 |
|
28 |
+
|
29 |
def process_file(path):
|
30 |
basename = os.path.basename(path)
|
|
|
31 |
name = os.path.splitext(basename)[0]
|
32 |
|
|
|
33 |
try:
|
34 |
r = subprocess.run(
|
35 |
f'convert {path} -resize "224^>" -colorspace RGB -density 1200 {args.out_dir}/{name}.jpg',
|
36 |
shell=True,
|
37 |
+
timeout=10,
|
38 |
)
|
39 |
rcode = r.returncode
|
40 |
except subprocess.TimeoutExpired:
|
|
|
46 |
|
47 |
return rcode
|
48 |
|
|
|
|
|
49 |
|
50 |
+
codes = Parallel(n_jobs=32, prefer="threads", verbose=1)(
|
51 |
+
delayed(process_file)(f) for f in files
|
52 |
+
)
|
53 |
+
print(Counter(codes))
|
test_on_image.py
CHANGED
@@ -17,13 +17,21 @@ def prepare_image(image_path, model):
|
|
17 |
pixel_values = torch.stack([preprocessed_image]).permute(0, 2, 3, 1).numpy()
|
18 |
return pixel_values
|
19 |
|
|
|
20 |
def prepare_text(text, tokenizer):
|
21 |
return tokenizer(text, return_tensors="np")
|
22 |
|
|
|
23 |
def run_inference(image_path, text, model, tokenizer):
|
24 |
pixel_values = prepare_image(image_path, model)
|
25 |
input_text = prepare_text(text, tokenizer)
|
26 |
-
model_output = model(
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
logits = model_output["logits_per_image"]
|
28 |
score = jax.nn.sigmoid(logits)[0][0]
|
29 |
return score
|
@@ -31,9 +39,11 @@ def run_inference(image_path, text, model, tokenizer):
|
|
31 |
|
32 |
if __name__ == "__main__":
|
33 |
model = FlaxHybridCLIP.from_pretrained("./")
|
34 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
|
|
|
|
35 |
|
36 |
image_path = f"/home/{os.environ['USER']}/data/wit_scale_converted/Santuar.jpg"
|
37 |
text = "Fachada del Santuario"
|
38 |
|
39 |
-
print(run_inference(image_path, text, model, tokenizer))
|
|
|
17 |
pixel_values = torch.stack([preprocessed_image]).permute(0, 2, 3, 1).numpy()
|
18 |
return pixel_values
|
19 |
|
20 |
+
|
21 |
def prepare_text(text, tokenizer):
|
22 |
return tokenizer(text, return_tensors="np")
|
23 |
|
24 |
+
|
25 |
def run_inference(image_path, text, model, tokenizer):
|
26 |
pixel_values = prepare_image(image_path, model)
|
27 |
input_text = prepare_text(text, tokenizer)
|
28 |
+
model_output = model(
|
29 |
+
input_text["input_ids"],
|
30 |
+
pixel_values,
|
31 |
+
attention_mask=input_text["attention_mask"],
|
32 |
+
train=False,
|
33 |
+
return_dict=True,
|
34 |
+
)
|
35 |
logits = model_output["logits_per_image"]
|
36 |
score = jax.nn.sigmoid(logits)[0][0]
|
37 |
return score
|
|
|
39 |
|
40 |
if __name__ == "__main__":
|
41 |
model = FlaxHybridCLIP.from_pretrained("./")
|
42 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
43 |
+
"bertin-project/bertin-roberta-base-spanish"
|
44 |
+
)
|
45 |
|
46 |
image_path = f"/home/{os.environ['USER']}/data/wit_scale_converted/Santuar.jpg"
|
47 |
text = "Fachada del Santuario"
|
48 |
|
49 |
+
print(run_inference(image_path, text, model, tokenizer))
|