Spaces:
Sleeping
Sleeping
Merge remote-tracking branch 'bv/main'
Browse files- app.py +51 -36
- big_vision_contrastive_models.py +32 -17
- gradio_helpers.py +7 -3
app.py
CHANGED
@@ -20,6 +20,7 @@ import urllib.request
|
|
20 |
import gradio as gr
|
21 |
import PIL.Image
|
22 |
|
|
|
23 |
import big_vision_contrastive_models as models
|
24 |
import gradio_helpers
|
25 |
|
@@ -38,26 +39,26 @@ LOADING_SECS = {'B/16': 5, 'L/16': 10, 'So400m/14': 10}
|
|
38 |
MODEL_MAP = {
|
39 |
'lit': {
|
40 |
'B/16': {
|
41 |
-
|
42 |
},
|
43 |
'L/16': {
|
44 |
-
|
45 |
},
|
46 |
},
|
47 |
'siglip': {
|
48 |
'B/16': {
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
},
|
54 |
'L/16': {
|
55 |
-
|
56 |
-
|
57 |
},
|
58 |
'So400m/14': {
|
59 |
-
|
60 |
-
|
61 |
},
|
62 |
},
|
63 |
}
|
@@ -73,7 +74,9 @@ def get_cache_status():
|
|
73 |
)
|
74 |
|
75 |
|
76 |
-
def compute(
|
|
|
|
|
77 |
"""Loads model and computes answers."""
|
78 |
|
79 |
if image_path is None:
|
@@ -84,7 +87,7 @@ def compute(image_path, prompts, family, variant, res, bias, progress=gr.Progres
|
|
84 |
model_name = MODEL_MAP[family][variant][res]
|
85 |
config = models.MODEL_CONFIGS[model_name]
|
86 |
local_ckpt = gradio_helpers.get_disk_cache(
|
87 |
-
|
88 |
config = dataclasses.replace(config, ckpt=local_ckpt)
|
89 |
params, model = gradio_helpers.get_memory_cache(
|
90 |
config,
|
@@ -92,11 +95,11 @@ def compute(image_path, prompts, family, variant, res, bias, progress=gr.Progres
|
|
92 |
max_cache_size_bytes=MAX_RAM_CACHE,
|
93 |
progress=progress,
|
94 |
estimated_secs={
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
}.get((family, variant))
|
101 |
)
|
102 |
model: models.ContrastiveModel = model
|
@@ -108,18 +111,19 @@ def compute(image_path, prompts, family, variant, res, bias, progress=gr.Progres
|
|
108 |
image = PIL.Image.open(image_path)
|
109 |
next(it)
|
110 |
with gradio_helpers.timed('image features'):
|
111 |
-
zimg,
|
112 |
params, model.preprocess_images([image])
|
113 |
)
|
114 |
next(it)
|
115 |
with gradio_helpers.timed('text features'):
|
116 |
prompts = prompts.split('\n')
|
117 |
ztxt, out = model.embed_texts(
|
118 |
-
|
119 |
)
|
120 |
next(it)
|
121 |
|
122 |
t = model.get_temperature(out)
|
|
|
123 |
if family == 'lit':
|
124 |
text_probs = list(model.get_probabilities(zimg, ztxt, t, axis=-1)[0])
|
125 |
elif family == 'siglip':
|
@@ -141,7 +145,8 @@ def update_answers(state):
|
|
141 |
"""Generates visible sliders for answers."""
|
142 |
answers = []
|
143 |
for prompt, prob in state[:MAX_ANSWERS]:
|
144 |
-
answers.append(
|
|
|
145 |
while len(answers) < MAX_ANSWERS:
|
146 |
answers.append(gr.Slider(visible=False))
|
147 |
return answers
|
@@ -160,7 +165,10 @@ def create_app():
|
|
160 |
|
161 |
with gr.Blocks(css=css) as demo:
|
162 |
|
163 |
-
gr.Markdown(
|
|
|
|
|
|
|
164 |
|
165 |
status = gr.Markdown(f'Ready ({get_cache_status()})')
|
166 |
|
@@ -169,15 +177,12 @@ def create_app():
|
|
169 |
source = gr.Markdown('', visible=False)
|
170 |
state = gr.State([])
|
171 |
with gr.Column():
|
172 |
-
prompts = gr.Textbox(
|
|
|
173 |
with gr.Row():
|
174 |
|
175 |
family = gr.Dropdown(value='lit', choices=list(MODEL_MAP), label='Model family')
|
176 |
|
177 |
-
def make_variant(family_value, value=None):
|
178 |
-
choices = list(MODEL_MAP[family_value])
|
179 |
-
if value is None:
|
180 |
-
value = choices
|
181 |
make_variant = functools.partial(gr.Dropdown, label='Variant')
|
182 |
variant = make_variant(list(MODEL_MAP['lit']), value='B/16')
|
183 |
|
@@ -187,12 +192,19 @@ def create_app():
|
|
187 |
def make_bias(family, variant, res):
|
188 |
visible = family == 'siglip'
|
189 |
value = {
|
190 |
-
|
191 |
-
|
192 |
-
|
193 |
-
|
194 |
}.get((family, variant, res), -10.0)
|
195 |
-
return gr.Slider(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
bias = make_bias(family.value, variant.value, res.value)
|
197 |
|
198 |
def update_inputs(family, variant, res):
|
@@ -228,7 +240,10 @@ def create_app():
|
|
228 |
# a single `status` widget here, and store the computed information in
|
229 |
# `state`...
|
230 |
run.click(
|
231 |
-
fn=compute,
|
|
|
|
|
|
|
232 |
# ... then we use `state` to update UI components without showing a
|
233 |
# progress bar in their place.
|
234 |
status.change(fn=update_answers, inputs=state, outputs=answers)
|
@@ -238,9 +253,9 @@ def create_app():
|
|
238 |
gr.Examples(
|
239 |
examples=[
|
240 |
[
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
]
|
245 |
for ex in info
|
246 |
],
|
@@ -252,7 +267,7 @@ def create_app():
|
|
252 |
return demo
|
253 |
|
254 |
|
255 |
-
if __name__ ==
|
256 |
|
257 |
logging.basicConfig(level=logging.INFO,
|
258 |
format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
20 |
import gradio as gr
|
21 |
import PIL.Image
|
22 |
|
23 |
+
# pylint: disable=g-bad-import-order
|
24 |
import big_vision_contrastive_models as models
|
25 |
import gradio_helpers
|
26 |
|
|
|
39 |
MODEL_MAP = {
|
40 |
'lit': {
|
41 |
'B/16': {
|
42 |
+
224: 'lit_b16b',
|
43 |
},
|
44 |
'L/16': {
|
45 |
+
224: 'lit_l16l',
|
46 |
},
|
47 |
},
|
48 |
'siglip': {
|
49 |
'B/16': {
|
50 |
+
224: 'siglip_b16b_224',
|
51 |
+
256: 'siglip_b16b_256',
|
52 |
+
384: 'siglip_b16b_384',
|
53 |
+
512: 'siglip_b16b_512',
|
54 |
},
|
55 |
'L/16': {
|
56 |
+
256: 'siglip_l16l_256',
|
57 |
+
384: 'siglip_l16l_384',
|
58 |
},
|
59 |
'So400m/14': {
|
60 |
+
224: 'siglip_so400m14so440m_224',
|
61 |
+
384: 'siglip_so400m14so440m_384',
|
62 |
},
|
63 |
},
|
64 |
}
|
|
|
74 |
)
|
75 |
|
76 |
|
77 |
+
def compute(
|
78 |
+
image_path, prompts, family, variant, res, bias, progress=gr.Progress()
|
79 |
+
):
|
80 |
"""Loads model and computes answers."""
|
81 |
|
82 |
if image_path is None:
|
|
|
87 |
model_name = MODEL_MAP[family][variant][res]
|
88 |
config = models.MODEL_CONFIGS[model_name]
|
89 |
local_ckpt = gradio_helpers.get_disk_cache(
|
90 |
+
config.ckpt, progress=progress, max_cache_size_bytes=MAX_DISK_CACHE)
|
91 |
config = dataclasses.replace(config, ckpt=local_ckpt)
|
92 |
params, model = gradio_helpers.get_memory_cache(
|
93 |
config,
|
|
|
95 |
max_cache_size_bytes=MAX_RAM_CACHE,
|
96 |
progress=progress,
|
97 |
estimated_secs={
|
98 |
+
('lit', 'B/16'): 1,
|
99 |
+
('lit', 'L/16'): 2.5,
|
100 |
+
('siglip', 'B/16'): 9,
|
101 |
+
('siglip', 'L/16'): 28,
|
102 |
+
('siglip', 'So400m/14'): 36,
|
103 |
}.get((family, variant))
|
104 |
)
|
105 |
model: models.ContrastiveModel = model
|
|
|
111 |
image = PIL.Image.open(image_path)
|
112 |
next(it)
|
113 |
with gradio_helpers.timed('image features'):
|
114 |
+
zimg, unused_out = model.embed_images(
|
115 |
params, model.preprocess_images([image])
|
116 |
)
|
117 |
next(it)
|
118 |
with gradio_helpers.timed('text features'):
|
119 |
prompts = prompts.split('\n')
|
120 |
ztxt, out = model.embed_texts(
|
121 |
+
params, model.preprocess_texts(prompts)
|
122 |
)
|
123 |
next(it)
|
124 |
|
125 |
t = model.get_temperature(out)
|
126 |
+
text_probs = []
|
127 |
if family == 'lit':
|
128 |
text_probs = list(model.get_probabilities(zimg, ztxt, t, axis=-1)[0])
|
129 |
elif family == 'siglip':
|
|
|
145 |
"""Generates visible sliders for answers."""
|
146 |
answers = []
|
147 |
for prompt, prob in state[:MAX_ANSWERS]:
|
148 |
+
answers.append(
|
149 |
+
gr.Slider(value=round(100*prob, 2), label=prompt, visible=True))
|
150 |
while len(answers) < MAX_ANSWERS:
|
151 |
answers.append(gr.Slider(visible=False))
|
152 |
return answers
|
|
|
165 |
|
166 |
with gr.Blocks(css=css) as demo:
|
167 |
|
168 |
+
gr.Markdown(
|
169 |
+
'Gradio clone of the original '
|
170 |
+
'[LiT demo](https://google-research.github.io/vision_transformer/lit/).'
|
171 |
+
)
|
172 |
|
173 |
status = gr.Markdown(f'Ready ({get_cache_status()})')
|
174 |
|
|
|
177 |
source = gr.Markdown('', visible=False)
|
178 |
state = gr.State([])
|
179 |
with gr.Column():
|
180 |
+
prompts = gr.Textbox(
|
181 |
+
label='Prompts (press Shift-ENTER to add a prompt)')
|
182 |
with gr.Row():
|
183 |
|
184 |
family = gr.Dropdown(value='lit', choices=list(MODEL_MAP), label='Model family')
|
185 |
|
|
|
|
|
|
|
|
|
186 |
make_variant = functools.partial(gr.Dropdown, label='Variant')
|
187 |
variant = make_variant(list(MODEL_MAP['lit']), value='B/16')
|
188 |
|
|
|
192 |
def make_bias(family, variant, res):
|
193 |
visible = family == 'siglip'
|
194 |
value = {
|
195 |
+
('siglip', 'B/16', 224): -12.9,
|
196 |
+
('siglip', 'L/16', 256): -12.7,
|
197 |
+
('siglip', 'L/16', 256): -16.5,
|
198 |
+
# ...
|
199 |
}.get((family, variant, res), -10.0)
|
200 |
+
return gr.Slider(
|
201 |
+
value=value,
|
202 |
+
minimum=-20,
|
203 |
+
maximum=0,
|
204 |
+
step=0.05,
|
205 |
+
label='Bias',
|
206 |
+
visible=visible,
|
207 |
+
)
|
208 |
bias = make_bias(family.value, variant.value, res.value)
|
209 |
|
210 |
def update_inputs(family, variant, res):
|
|
|
240 |
# a single `status` widget here, and store the computed information in
|
241 |
# `state`...
|
242 |
run.click(
|
243 |
+
fn=compute,
|
244 |
+
inputs=[image, prompts, family, variant, res, bias],
|
245 |
+
outputs=[status, state],
|
246 |
+
)
|
247 |
# ... then we use `state` to update UI components without showing a
|
248 |
# progress bar in their place.
|
249 |
status.change(fn=update_answers, inputs=state, outputs=answers)
|
|
|
253 |
gr.Examples(
|
254 |
examples=[
|
255 |
[
|
256 |
+
IMG_URL_FMT.format(ex['id']),
|
257 |
+
ex['prompts'].replace(', ', '\n'),
|
258 |
+
'[source](%s)' % ex['source'],
|
259 |
]
|
260 |
for ex in info
|
261 |
],
|
|
|
267 |
return demo
|
268 |
|
269 |
|
270 |
+
if __name__ == '__main__':
|
271 |
|
272 |
logging.basicConfig(level=logging.INFO,
|
273 |
format='%(asctime)s - %(levelname)s - %(message)s')
|
big_vision_contrastive_models.py
CHANGED
@@ -27,15 +27,17 @@ import transformers
|
|
27 |
|
28 |
|
29 |
def _clone_git(url, destination_folder, commit_hash=None):
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
|
|
36 |
|
37 |
|
38 |
def setup(commit_hash=None):
|
|
|
39 |
for url, dst_name in (
|
40 |
('https://github.com/google-research/big_vision', 'big_vision_repo'),
|
41 |
('https://github.com/google/flaxformer', 'flaxformer_repo'),
|
@@ -43,11 +45,12 @@ def setup(commit_hash=None):
|
|
43 |
dst_path = os.path.join(tempfile.gettempdir(), dst_name)
|
44 |
if not os.path.exists(dst_path):
|
45 |
_clone_git(url, dst_path, commit_hash)
|
46 |
-
if not
|
47 |
sys.path.insert(0, dst_path)
|
48 |
|
49 |
|
50 |
class ContrastiveModelFamily(enum.Enum):
|
|
|
51 |
LIT = 'lit'
|
52 |
SIGLIP = 'siglip'
|
53 |
|
@@ -96,18 +99,21 @@ class ContrastiveModel:
|
|
96 |
return ztxt, out
|
97 |
|
98 |
def preprocess_texts(self, texts):
|
|
|
99 |
|
100 |
def tokenize_pad(text, seqlen=self.config.seqlen):
|
101 |
|
102 |
if self.config.family == ContrastiveModelFamily.LIT:
|
103 |
-
tokens = self.tokenizer_bert.encode(text, add_special_tokens=True)
|
|
|
104 |
tokens = tokens[:seqlen]
|
105 |
return tokens + [0] * (seqlen - len(tokens))
|
106 |
|
107 |
if self.config.family == ContrastiveModelFamily.SIGLIP:
|
108 |
tokens = self.tokenizer_sp.tokenize(text, add_eos=True)
|
109 |
if len(tokens) >= seqlen:
|
110 |
-
|
|
|
111 |
return tokens + [0] * (seqlen - len(tokens))
|
112 |
|
113 |
return np.array([tokenize_pad(text) for text in texts])
|
@@ -125,7 +131,9 @@ class ContrastiveModel:
|
|
125 |
]) / 127.5 - 1.0
|
126 |
|
127 |
def get_bias(self, out):
|
128 |
-
assert
|
|
|
|
|
129 |
return out['b'].item()
|
130 |
|
131 |
def get_temperature(self, out):
|
@@ -145,7 +153,9 @@ class ContrastiveModel:
|
|
145 |
return jax.nn.sigmoid(zimg @ ztxt.T * temperature + bias)
|
146 |
|
147 |
|
148 |
-
def _make_config(
|
|
|
|
|
149 |
if family == 'lit':
|
150 |
tokenizer = ckpt.replace('.npz', '.txt')
|
151 |
else:
|
@@ -153,11 +163,12 @@ def _make_config(family, variant, res, textvariant, ckpt, embdim, seqlen, vocab_
|
|
153 |
return ContrastiveModelConfig(
|
154 |
family=ContrastiveModelFamily(family), variant=variant, res=res,
|
155 |
textvariant=textvariant, embdim=embdim, seqlen=seqlen,
|
156 |
-
tokenizer=tokenizer, vocab_size=
|
157 |
ckpt=ckpt,
|
158 |
)
|
159 |
|
160 |
|
|
|
161 |
MODEL_CONFIGS = dict(
|
162 |
lit_b16b=_make_config('lit', 'B/16', 224, 'B', 'gs://vit_models/lit/LiT-B16B.npz', 768, 16, 32_000),
|
163 |
lit_l16l=_make_config('lit', 'L/16', 224, 'L', 'gs://vit_models/lit/LiT-L16L.npz', 1024, 16, 32_000),
|
@@ -173,6 +184,7 @@ MODEL_CONFIGS = dict(
|
|
173 |
siglip_so400m14so440m_224=_make_config('siglip', 'So400m/14', 224, 'So400m', 'gs://big_vision/siglip/webli_en_so400m_224_57633886.npz', 1152, 16, 32_000),
|
174 |
siglip_so400m14so400m_384=_make_config('siglip', 'So400m/14', 384, 'So400m', 'gs://big_vision/siglip/webli_en_so400m_384_58765454.npz', 1152, 64, 32_000),
|
175 |
)
|
|
|
176 |
|
177 |
|
178 |
@functools.cache
|
@@ -187,7 +199,6 @@ def load_tokenizer_sp(name_or_path):
|
|
187 |
|
188 |
@functools.cache
|
189 |
def load_tokenizer_bert(path):
|
190 |
-
tok = sentencepiece.SentencePieceProcessor()
|
191 |
if path.startswith('gs://'):
|
192 |
dst = tempfile.mktemp()
|
193 |
gfile.copy(path, dst)
|
@@ -203,7 +214,9 @@ def load_model(config, check_params=False):
|
|
203 |
cfg.image_model = 'vit' # TODO(lbeyer): remove later, default
|
204 |
if config.family == ContrastiveModelFamily.LIT:
|
205 |
cfg.text_model = 'proj.flaxformer.bert'
|
206 |
-
cfg.image = dict(
|
|
|
|
|
207 |
bert_config = {'B': 'base', 'L': 'large'}[config.textvariant]
|
208 |
cfg.text = dict(config=bert_config, head_zeroinit=False)
|
209 |
tokenizer_bert = load_tokenizer_bert(config.tokenizer)
|
@@ -211,10 +224,12 @@ def load_model(config, check_params=False):
|
|
211 |
if config.variant == 'L/16':
|
212 |
cfg.out_dim = (None, config.embdim) # (image_out_dim, text_out_dim)
|
213 |
else:
|
214 |
-
|
|
|
215 |
else:
|
216 |
cfg.image = dict(variant=config.variant, pool_type='map')
|
217 |
-
|
|
|
218 |
cfg.text = dict(variant=config.textvariant, vocab_size=config.vocab_size)
|
219 |
cfg.bias_init = -10.0
|
220 |
tokenizer_sp = load_tokenizer_sp(config.tokenizer)
|
@@ -223,7 +238,7 @@ def load_model(config, check_params=False):
|
|
223 |
cfg.temperature_init = 10.0
|
224 |
|
225 |
model_mod = importlib.import_module(
|
226 |
-
|
227 |
model = model_mod.Model(**cfg)
|
228 |
|
229 |
init_params = None # Faster but bypasses loading sanity-checks.
|
|
|
27 |
|
28 |
|
29 |
def _clone_git(url, destination_folder, commit_hash=None):
|
30 |
+
subprocess.run(
|
31 |
+
['git', 'clone', '--depth=1', url, destination_folder], check=True
|
32 |
+
)
|
33 |
+
if commit_hash:
|
34 |
+
subprocess.run(
|
35 |
+
['git', '-C', destination_folder, 'checkout', commit_hash], check=True
|
36 |
+
)
|
37 |
|
38 |
|
39 |
def setup(commit_hash=None):
|
40 |
+
"""Checks out required non-pypi code from Github."""
|
41 |
for url, dst_name in (
|
42 |
('https://github.com/google-research/big_vision', 'big_vision_repo'),
|
43 |
('https://github.com/google/flaxformer', 'flaxformer_repo'),
|
|
|
45 |
dst_path = os.path.join(tempfile.gettempdir(), dst_name)
|
46 |
if not os.path.exists(dst_path):
|
47 |
_clone_git(url, dst_path, commit_hash)
|
48 |
+
if dst_path not in sys.path:
|
49 |
sys.path.insert(0, dst_path)
|
50 |
|
51 |
|
52 |
class ContrastiveModelFamily(enum.Enum):
|
53 |
+
"""Defines a contrastive model family."""
|
54 |
LIT = 'lit'
|
55 |
SIGLIP = 'siglip'
|
56 |
|
|
|
99 |
return ztxt, out
|
100 |
|
101 |
def preprocess_texts(self, texts):
|
102 |
+
"""Converts texts to padded tokens."""
|
103 |
|
104 |
def tokenize_pad(text, seqlen=self.config.seqlen):
|
105 |
|
106 |
if self.config.family == ContrastiveModelFamily.LIT:
|
107 |
+
tokens = self.tokenizer_bert.encode(text, add_special_tokens=True)
|
108 |
+
tokens = tokens[:-1] # removes [SEP]
|
109 |
tokens = tokens[:seqlen]
|
110 |
return tokens + [0] * (seqlen - len(tokens))
|
111 |
|
112 |
if self.config.family == ContrastiveModelFamily.SIGLIP:
|
113 |
tokens = self.tokenizer_sp.tokenize(text, add_eos=True)
|
114 |
if len(tokens) >= seqlen:
|
115 |
+
eos_id = self.tokenizer_sp.eos_id()
|
116 |
+
return tokens[:seqlen - 1] + [eos_id] # "sticky" eos
|
117 |
return tokens + [0] * (seqlen - len(tokens))
|
118 |
|
119 |
return np.array([tokenize_pad(text) for text in texts])
|
|
|
131 |
]) / 127.5 - 1.0
|
132 |
|
133 |
def get_bias(self, out):
|
134 |
+
assert (
|
135 |
+
self.config.family == ContrastiveModelFamily.SIGLIP
|
136 |
+
), self.config.family
|
137 |
return out['b'].item()
|
138 |
|
139 |
def get_temperature(self, out):
|
|
|
153 |
return jax.nn.sigmoid(zimg @ ztxt.T * temperature + bias)
|
154 |
|
155 |
|
156 |
+
def _make_config(
|
157 |
+
family, variant, res, textvariant, ckpt, embdim, seqlen, vocab_size
|
158 |
+
):
|
159 |
if family == 'lit':
|
160 |
tokenizer = ckpt.replace('.npz', '.txt')
|
161 |
else:
|
|
|
163 |
return ContrastiveModelConfig(
|
164 |
family=ContrastiveModelFamily(family), variant=variant, res=res,
|
165 |
textvariant=textvariant, embdim=embdim, seqlen=seqlen,
|
166 |
+
tokenizer=tokenizer, vocab_size=vocab_size,
|
167 |
ckpt=ckpt,
|
168 |
)
|
169 |
|
170 |
|
171 |
+
# pylint: disable=line-too-long
|
172 |
MODEL_CONFIGS = dict(
|
173 |
lit_b16b=_make_config('lit', 'B/16', 224, 'B', 'gs://vit_models/lit/LiT-B16B.npz', 768, 16, 32_000),
|
174 |
lit_l16l=_make_config('lit', 'L/16', 224, 'L', 'gs://vit_models/lit/LiT-L16L.npz', 1024, 16, 32_000),
|
|
|
184 |
siglip_so400m14so440m_224=_make_config('siglip', 'So400m/14', 224, 'So400m', 'gs://big_vision/siglip/webli_en_so400m_224_57633886.npz', 1152, 16, 32_000),
|
185 |
siglip_so400m14so400m_384=_make_config('siglip', 'So400m/14', 384, 'So400m', 'gs://big_vision/siglip/webli_en_so400m_384_58765454.npz', 1152, 64, 32_000),
|
186 |
)
|
187 |
+
# pylint: enable=line-too-long
|
188 |
|
189 |
|
190 |
@functools.cache
|
|
|
199 |
|
200 |
@functools.cache
|
201 |
def load_tokenizer_bert(path):
|
|
|
202 |
if path.startswith('gs://'):
|
203 |
dst = tempfile.mktemp()
|
204 |
gfile.copy(path, dst)
|
|
|
214 |
cfg.image_model = 'vit' # TODO(lbeyer): remove later, default
|
215 |
if config.family == ContrastiveModelFamily.LIT:
|
216 |
cfg.text_model = 'proj.flaxformer.bert'
|
217 |
+
cfg.image = dict(
|
218 |
+
variant=config.variant, pool_type='tok', head_zeroinit=False
|
219 |
+
)
|
220 |
bert_config = {'B': 'base', 'L': 'large'}[config.textvariant]
|
221 |
cfg.text = dict(config=bert_config, head_zeroinit=False)
|
222 |
tokenizer_bert = load_tokenizer_bert(config.tokenizer)
|
|
|
224 |
if config.variant == 'L/16':
|
225 |
cfg.out_dim = (None, config.embdim) # (image_out_dim, text_out_dim)
|
226 |
else:
|
227 |
+
# (image_out_dim, text_out_dim)
|
228 |
+
cfg.out_dim = (config.embdim, config.embdim)
|
229 |
else:
|
230 |
cfg.image = dict(variant=config.variant, pool_type='map')
|
231 |
+
# TODO(lbeyer): remove later, default
|
232 |
+
cfg.text_model = 'proj.image_text.text_transformer'
|
233 |
cfg.text = dict(variant=config.textvariant, vocab_size=config.vocab_size)
|
234 |
cfg.bias_init = -10.0
|
235 |
tokenizer_sp = load_tokenizer_sp(config.tokenizer)
|
|
|
238 |
cfg.temperature_init = 10.0
|
239 |
|
240 |
model_mod = importlib.import_module(
|
241 |
+
'big_vision.models.proj.image_text.two_towers')
|
242 |
model = model_mod.Model(**cfg)
|
243 |
|
244 |
init_params = None # Faster but bypasses loading sanity-checks.
|
gradio_helpers.py
CHANGED
@@ -30,8 +30,9 @@ def timed(name):
|
|
30 |
logging.info('Timed %s: %.1f secs', name, timing['secs'])
|
31 |
|
32 |
|
33 |
-
|
34 |
-
|
|
|
35 |
"""Copies a file with progress bar.
|
36 |
|
37 |
Args:
|
@@ -39,6 +40,7 @@ def copy_file(src, dst, *, progress=None, block_size=1024 * 1024 * 10, overwrite
|
|
39 |
dst: Destination file. Path must be readable by `tf.io.gfile`.
|
40 |
progress: An object with a `.tqdm` attribute, or `None`.
|
41 |
block_size: Size of individual blocks to be read/written.
|
|
|
42 |
"""
|
43 |
if os.path.dirname(dst):
|
44 |
os.makedirs(os.path.dirname(dst), exist_ok=True)
|
@@ -87,7 +89,9 @@ def _get_array_sizes(tree):
|
|
87 |
return [getattr(x, 'nbytes', 0) for x in jax.tree_leaves(tree)]
|
88 |
|
89 |
|
90 |
-
def get_memory_cache(
|
|
|
|
|
91 |
"""Keeps cache below specified size by removing elements not last accessed."""
|
92 |
if key in _memory_cache:
|
93 |
_memory_cache[key] = _memory_cache.pop(key) # updated "last accessed" order
|
|
|
30 |
logging.info('Timed %s: %.1f secs', name, timing['secs'])
|
31 |
|
32 |
|
33 |
+
def copy_file(
|
34 |
+
src, dst, *, progress=None, block_size=1024 * 1024 * 10, overwrite=False
|
35 |
+
):
|
36 |
"""Copies a file with progress bar.
|
37 |
|
38 |
Args:
|
|
|
40 |
dst: Destination file. Path must be readable by `tf.io.gfile`.
|
41 |
progress: An object with a `.tqdm` attribute, or `None`.
|
42 |
block_size: Size of individual blocks to be read/written.
|
43 |
+
overwrite: If `True`, overwrite `dst` if it exists.
|
44 |
"""
|
45 |
if os.path.dirname(dst):
|
46 |
os.makedirs(os.path.dirname(dst), exist_ok=True)
|
|
|
89 |
return [getattr(x, 'nbytes', 0) for x in jax.tree_leaves(tree)]
|
90 |
|
91 |
|
92 |
+
def get_memory_cache(
|
93 |
+
key, getter, max_cache_size_bytes, progress=None, estimated_secs=None
|
94 |
+
):
|
95 |
"""Keeps cache below specified size by removing elements not last accessed."""
|
96 |
if key in _memory_cache:
|
97 |
_memory_cache[key] = _memory_cache.pop(key) # updated "last accessed" order
|