andsteing commited on
Commit
10ff50b
·
2 Parent(s): a6ee350 3cfc2e7

Merge remote-tracking branch 'bv/main'

Browse files
Files changed (3) hide show
  1. app.py +51 -36
  2. big_vision_contrastive_models.py +32 -17
  3. gradio_helpers.py +7 -3
app.py CHANGED
@@ -20,6 +20,7 @@ import urllib.request
20
  import gradio as gr
21
  import PIL.Image
22
 
 
23
  import big_vision_contrastive_models as models
24
  import gradio_helpers
25
 
@@ -38,26 +39,26 @@ LOADING_SECS = {'B/16': 5, 'L/16': 10, 'So400m/14': 10}
38
  MODEL_MAP = {
39
  'lit': {
40
  'B/16': {
41
- 224: 'lit_b16b',
42
  },
43
  'L/16': {
44
- 224: 'lit_l16l',
45
  },
46
  },
47
  'siglip': {
48
  'B/16': {
49
- 224: 'siglip_b16b_224',
50
- 256: 'siglip_b16b_256',
51
- 384: 'siglip_b16b_384',
52
- 512: 'siglip_b16b_512',
53
  },
54
  'L/16': {
55
- 256: 'siglip_l16l_256',
56
- 384: 'siglip_l16l_384',
57
  },
58
  'So400m/14': {
59
- 224: 'siglip_so400m14so440m_224',
60
- 384: 'siglip_so400m14so440m_384',
61
  },
62
  },
63
  }
@@ -73,7 +74,9 @@ def get_cache_status():
73
  )
74
 
75
 
76
- def compute(image_path, prompts, family, variant, res, bias, progress=gr.Progress()):
 
 
77
  """Loads model and computes answers."""
78
 
79
  if image_path is None:
@@ -84,7 +87,7 @@ def compute(image_path, prompts, family, variant, res, bias, progress=gr.Progres
84
  model_name = MODEL_MAP[family][variant][res]
85
  config = models.MODEL_CONFIGS[model_name]
86
  local_ckpt = gradio_helpers.get_disk_cache(
87
- config.ckpt, progress=progress, max_cache_size_bytes=MAX_DISK_CACHE)
88
  config = dataclasses.replace(config, ckpt=local_ckpt)
89
  params, model = gradio_helpers.get_memory_cache(
90
  config,
@@ -92,11 +95,11 @@ def compute(image_path, prompts, family, variant, res, bias, progress=gr.Progres
92
  max_cache_size_bytes=MAX_RAM_CACHE,
93
  progress=progress,
94
  estimated_secs={
95
- ('lit', 'B/16'): 1,
96
- ('lit', 'L/16'): 2.5,
97
- ('siglip', 'B/16'): 9,
98
- ('siglip', 'L/16'): 28,
99
- ('siglip', 'So400m/14'): 36,
100
  }.get((family, variant))
101
  )
102
  model: models.ContrastiveModel = model
@@ -108,18 +111,19 @@ def compute(image_path, prompts, family, variant, res, bias, progress=gr.Progres
108
  image = PIL.Image.open(image_path)
109
  next(it)
110
  with gradio_helpers.timed('image features'):
111
- zimg, out = model.embed_images(
112
  params, model.preprocess_images([image])
113
  )
114
  next(it)
115
  with gradio_helpers.timed('text features'):
116
  prompts = prompts.split('\n')
117
  ztxt, out = model.embed_texts(
118
- params, model.preprocess_texts(prompts)
119
  )
120
  next(it)
121
 
122
  t = model.get_temperature(out)
 
123
  if family == 'lit':
124
  text_probs = list(model.get_probabilities(zimg, ztxt, t, axis=-1)[0])
125
  elif family == 'siglip':
@@ -141,7 +145,8 @@ def update_answers(state):
141
  """Generates visible sliders for answers."""
142
  answers = []
143
  for prompt, prob in state[:MAX_ANSWERS]:
144
- answers.append(gr.Slider(value=round(100*prob, 2), label=prompt, visible=True))
 
145
  while len(answers) < MAX_ANSWERS:
146
  answers.append(gr.Slider(visible=False))
147
  return answers
@@ -160,7 +165,10 @@ def create_app():
160
 
161
  with gr.Blocks(css=css) as demo:
162
 
163
- gr.Markdown('Gradio clone of the original [LiT demo](https://google-research.github.io/vision_transformer/lit/).')
 
 
 
164
 
165
  status = gr.Markdown(f'Ready ({get_cache_status()})')
166
 
@@ -169,15 +177,12 @@ def create_app():
169
  source = gr.Markdown('', visible=False)
170
  state = gr.State([])
171
  with gr.Column():
172
- prompts = gr.Textbox(label='Prompts (press Shift-ENTER to add a prompt)')
 
173
  with gr.Row():
174
 
175
  family = gr.Dropdown(value='lit', choices=list(MODEL_MAP), label='Model family')
176
 
177
- def make_variant(family_value, value=None):
178
- choices = list(MODEL_MAP[family_value])
179
- if value is None:
180
- value = choices
181
  make_variant = functools.partial(gr.Dropdown, label='Variant')
182
  variant = make_variant(list(MODEL_MAP['lit']), value='B/16')
183
 
@@ -187,12 +192,19 @@ def create_app():
187
  def make_bias(family, variant, res):
188
  visible = family == 'siglip'
189
  value = {
190
- ('siglip', 'B/16', 224): -12.9,
191
- ('siglip', 'L/16', 256): -12.7,
192
- ('siglip', 'L/16', 256): -16.5,
193
- # ...
194
  }.get((family, variant, res), -10.0)
195
- return gr.Slider(value=value, minimum=-20, maximum=0, step=0.05, label='Bias', visible=visible)
 
 
 
 
 
 
 
196
  bias = make_bias(family.value, variant.value, res.value)
197
 
198
  def update_inputs(family, variant, res):
@@ -228,7 +240,10 @@ def create_app():
228
  # a single `status` widget here, and store the computed information in
229
  # `state`...
230
  run.click(
231
- fn=compute, inputs=[image, prompts, family, variant, res, bias], outputs=[status, state])
 
 
 
232
  # ... then we use `state` to update UI components without showing a
233
  # progress bar in their place.
234
  status.change(fn=update_answers, inputs=state, outputs=answers)
@@ -238,9 +253,9 @@ def create_app():
238
  gr.Examples(
239
  examples=[
240
  [
241
- IMG_URL_FMT.format(ex['id']),
242
- ex['prompts'].replace(', ', '\n'),
243
- '[source](%s)' % ex['source'],
244
  ]
245
  for ex in info
246
  ],
@@ -252,7 +267,7 @@ def create_app():
252
  return demo
253
 
254
 
255
- if __name__ == "__main__":
256
 
257
  logging.basicConfig(level=logging.INFO,
258
  format='%(asctime)s - %(levelname)s - %(message)s')
 
20
  import gradio as gr
21
  import PIL.Image
22
 
23
+ # pylint: disable=g-bad-import-order
24
  import big_vision_contrastive_models as models
25
  import gradio_helpers
26
 
 
39
  MODEL_MAP = {
40
  'lit': {
41
  'B/16': {
42
+ 224: 'lit_b16b',
43
  },
44
  'L/16': {
45
+ 224: 'lit_l16l',
46
  },
47
  },
48
  'siglip': {
49
  'B/16': {
50
+ 224: 'siglip_b16b_224',
51
+ 256: 'siglip_b16b_256',
52
+ 384: 'siglip_b16b_384',
53
+ 512: 'siglip_b16b_512',
54
  },
55
  'L/16': {
56
+ 256: 'siglip_l16l_256',
57
+ 384: 'siglip_l16l_384',
58
  },
59
  'So400m/14': {
60
+ 224: 'siglip_so400m14so440m_224',
61
+ 384: 'siglip_so400m14so440m_384',
62
  },
63
  },
64
  }
 
74
  )
75
 
76
 
77
+ def compute(
78
+ image_path, prompts, family, variant, res, bias, progress=gr.Progress()
79
+ ):
80
  """Loads model and computes answers."""
81
 
82
  if image_path is None:
 
87
  model_name = MODEL_MAP[family][variant][res]
88
  config = models.MODEL_CONFIGS[model_name]
89
  local_ckpt = gradio_helpers.get_disk_cache(
90
+ config.ckpt, progress=progress, max_cache_size_bytes=MAX_DISK_CACHE)
91
  config = dataclasses.replace(config, ckpt=local_ckpt)
92
  params, model = gradio_helpers.get_memory_cache(
93
  config,
 
95
  max_cache_size_bytes=MAX_RAM_CACHE,
96
  progress=progress,
97
  estimated_secs={
98
+ ('lit', 'B/16'): 1,
99
+ ('lit', 'L/16'): 2.5,
100
+ ('siglip', 'B/16'): 9,
101
+ ('siglip', 'L/16'): 28,
102
+ ('siglip', 'So400m/14'): 36,
103
  }.get((family, variant))
104
  )
105
  model: models.ContrastiveModel = model
 
111
  image = PIL.Image.open(image_path)
112
  next(it)
113
  with gradio_helpers.timed('image features'):
114
+ zimg, unused_out = model.embed_images(
115
  params, model.preprocess_images([image])
116
  )
117
  next(it)
118
  with gradio_helpers.timed('text features'):
119
  prompts = prompts.split('\n')
120
  ztxt, out = model.embed_texts(
121
+ params, model.preprocess_texts(prompts)
122
  )
123
  next(it)
124
 
125
  t = model.get_temperature(out)
126
+ text_probs = []
127
  if family == 'lit':
128
  text_probs = list(model.get_probabilities(zimg, ztxt, t, axis=-1)[0])
129
  elif family == 'siglip':
 
145
  """Generates visible sliders for answers."""
146
  answers = []
147
  for prompt, prob in state[:MAX_ANSWERS]:
148
+ answers.append(
149
+ gr.Slider(value=round(100*prob, 2), label=prompt, visible=True))
150
  while len(answers) < MAX_ANSWERS:
151
  answers.append(gr.Slider(visible=False))
152
  return answers
 
165
 
166
  with gr.Blocks(css=css) as demo:
167
 
168
+ gr.Markdown(
169
+ 'Gradio clone of the original '
170
+ '[LiT demo](https://google-research.github.io/vision_transformer/lit/).'
171
+ )
172
 
173
  status = gr.Markdown(f'Ready ({get_cache_status()})')
174
 
 
177
  source = gr.Markdown('', visible=False)
178
  state = gr.State([])
179
  with gr.Column():
180
+ prompts = gr.Textbox(
181
+ label='Prompts (press Shift-ENTER to add a prompt)')
182
  with gr.Row():
183
 
184
  family = gr.Dropdown(value='lit', choices=list(MODEL_MAP), label='Model family')
185
 
 
 
 
 
186
  make_variant = functools.partial(gr.Dropdown, label='Variant')
187
  variant = make_variant(list(MODEL_MAP['lit']), value='B/16')
188
 
 
192
  def make_bias(family, variant, res):
193
  visible = family == 'siglip'
194
  value = {
195
+ ('siglip', 'B/16', 224): -12.9,
196
+ ('siglip', 'L/16', 256): -12.7,
197
+ ('siglip', 'L/16', 256): -16.5,
198
+ # ...
199
  }.get((family, variant, res), -10.0)
200
+ return gr.Slider(
201
+ value=value,
202
+ minimum=-20,
203
+ maximum=0,
204
+ step=0.05,
205
+ label='Bias',
206
+ visible=visible,
207
+ )
208
  bias = make_bias(family.value, variant.value, res.value)
209
 
210
  def update_inputs(family, variant, res):
 
240
  # a single `status` widget here, and store the computed information in
241
  # `state`...
242
  run.click(
243
+ fn=compute,
244
+ inputs=[image, prompts, family, variant, res, bias],
245
+ outputs=[status, state],
246
+ )
247
  # ... then we use `state` to update UI components without showing a
248
  # progress bar in their place.
249
  status.change(fn=update_answers, inputs=state, outputs=answers)
 
253
  gr.Examples(
254
  examples=[
255
  [
256
+ IMG_URL_FMT.format(ex['id']),
257
+ ex['prompts'].replace(', ', '\n'),
258
+ '[source](%s)' % ex['source'],
259
  ]
260
  for ex in info
261
  ],
 
267
  return demo
268
 
269
 
270
+ if __name__ == '__main__':
271
 
272
  logging.basicConfig(level=logging.INFO,
273
  format='%(asctime)s - %(levelname)s - %(message)s')
big_vision_contrastive_models.py CHANGED
@@ -27,15 +27,17 @@ import transformers
27
 
28
 
29
  def _clone_git(url, destination_folder, commit_hash=None):
30
- subprocess.run([
31
- 'git', 'clone', '--depth=1',
32
- url, destination_folder
33
- ], check=True)
34
- if commit_hash:
35
- subprocess.run(['git', '-C', destination_folder, 'checkout', commit_hash], check=True)
 
36
 
37
 
38
  def setup(commit_hash=None):
 
39
  for url, dst_name in (
40
  ('https://github.com/google-research/big_vision', 'big_vision_repo'),
41
  ('https://github.com/google/flaxformer', 'flaxformer_repo'),
@@ -43,11 +45,12 @@ def setup(commit_hash=None):
43
  dst_path = os.path.join(tempfile.gettempdir(), dst_name)
44
  if not os.path.exists(dst_path):
45
  _clone_git(url, dst_path, commit_hash)
46
- if not dst_path in sys.path:
47
  sys.path.insert(0, dst_path)
48
 
49
 
50
  class ContrastiveModelFamily(enum.Enum):
 
51
  LIT = 'lit'
52
  SIGLIP = 'siglip'
53
 
@@ -96,18 +99,21 @@ class ContrastiveModel:
96
  return ztxt, out
97
 
98
  def preprocess_texts(self, texts):
 
99
 
100
  def tokenize_pad(text, seqlen=self.config.seqlen):
101
 
102
  if self.config.family == ContrastiveModelFamily.LIT:
103
- tokens = self.tokenizer_bert.encode(text, add_special_tokens=True)[:-1] # removes [SEP]
 
104
  tokens = tokens[:seqlen]
105
  return tokens + [0] * (seqlen - len(tokens))
106
 
107
  if self.config.family == ContrastiveModelFamily.SIGLIP:
108
  tokens = self.tokenizer_sp.tokenize(text, add_eos=True)
109
  if len(tokens) >= seqlen:
110
- return tokens[:seqlen - 1] + [tok.eos_id()] # "sticky" eos
 
111
  return tokens + [0] * (seqlen - len(tokens))
112
 
113
  return np.array([tokenize_pad(text) for text in texts])
@@ -125,7 +131,9 @@ class ContrastiveModel:
125
  ]) / 127.5 - 1.0
126
 
127
  def get_bias(self, out):
128
- assert self.config.family == ContrastiveModelFamily.SIGLIP, self.config.family
 
 
129
  return out['b'].item()
130
 
131
  def get_temperature(self, out):
@@ -145,7 +153,9 @@ class ContrastiveModel:
145
  return jax.nn.sigmoid(zimg @ ztxt.T * temperature + bias)
146
 
147
 
148
- def _make_config(family, variant, res, textvariant, ckpt, embdim, seqlen, vocab_size):
 
 
149
  if family == 'lit':
150
  tokenizer = ckpt.replace('.npz', '.txt')
151
  else:
@@ -153,11 +163,12 @@ def _make_config(family, variant, res, textvariant, ckpt, embdim, seqlen, vocab_
153
  return ContrastiveModelConfig(
154
  family=ContrastiveModelFamily(family), variant=variant, res=res,
155
  textvariant=textvariant, embdim=embdim, seqlen=seqlen,
156
- tokenizer=tokenizer, vocab_size=32_000,
157
  ckpt=ckpt,
158
  )
159
 
160
 
 
161
  MODEL_CONFIGS = dict(
162
  lit_b16b=_make_config('lit', 'B/16', 224, 'B', 'gs://vit_models/lit/LiT-B16B.npz', 768, 16, 32_000),
163
  lit_l16l=_make_config('lit', 'L/16', 224, 'L', 'gs://vit_models/lit/LiT-L16L.npz', 1024, 16, 32_000),
@@ -173,6 +184,7 @@ MODEL_CONFIGS = dict(
173
  siglip_so400m14so440m_224=_make_config('siglip', 'So400m/14', 224, 'So400m', 'gs://big_vision/siglip/webli_en_so400m_224_57633886.npz', 1152, 16, 32_000),
174
  siglip_so400m14so400m_384=_make_config('siglip', 'So400m/14', 384, 'So400m', 'gs://big_vision/siglip/webli_en_so400m_384_58765454.npz', 1152, 64, 32_000),
175
  )
 
176
 
177
 
178
  @functools.cache
@@ -187,7 +199,6 @@ def load_tokenizer_sp(name_or_path):
187
 
188
  @functools.cache
189
  def load_tokenizer_bert(path):
190
- tok = sentencepiece.SentencePieceProcessor()
191
  if path.startswith('gs://'):
192
  dst = tempfile.mktemp()
193
  gfile.copy(path, dst)
@@ -203,7 +214,9 @@ def load_model(config, check_params=False):
203
  cfg.image_model = 'vit' # TODO(lbeyer): remove later, default
204
  if config.family == ContrastiveModelFamily.LIT:
205
  cfg.text_model = 'proj.flaxformer.bert'
206
- cfg.image = dict(variant=config.variant, pool_type='tok', head_zeroinit=False)
 
 
207
  bert_config = {'B': 'base', 'L': 'large'}[config.textvariant]
208
  cfg.text = dict(config=bert_config, head_zeroinit=False)
209
  tokenizer_bert = load_tokenizer_bert(config.tokenizer)
@@ -211,10 +224,12 @@ def load_model(config, check_params=False):
211
  if config.variant == 'L/16':
212
  cfg.out_dim = (None, config.embdim) # (image_out_dim, text_out_dim)
213
  else:
214
- cfg.out_dim = (config.embdim, config.embdim) # (image_out_dim, text_out_dim)
 
215
  else:
216
  cfg.image = dict(variant=config.variant, pool_type='map')
217
- cfg.text_model = 'proj.image_text.text_transformer' # TODO(lbeyer): remove later, default
 
218
  cfg.text = dict(variant=config.textvariant, vocab_size=config.vocab_size)
219
  cfg.bias_init = -10.0
220
  tokenizer_sp = load_tokenizer_sp(config.tokenizer)
@@ -223,7 +238,7 @@ def load_model(config, check_params=False):
223
  cfg.temperature_init = 10.0
224
 
225
  model_mod = importlib.import_module(
226
- 'big_vision.models.proj.image_text.two_towers')
227
  model = model_mod.Model(**cfg)
228
 
229
  init_params = None # Faster but bypasses loading sanity-checks.
 
27
 
28
 
29
  def _clone_git(url, destination_folder, commit_hash=None):
30
+ subprocess.run(
31
+ ['git', 'clone', '--depth=1', url, destination_folder], check=True
32
+ )
33
+ if commit_hash:
34
+ subprocess.run(
35
+ ['git', '-C', destination_folder, 'checkout', commit_hash], check=True
36
+ )
37
 
38
 
39
  def setup(commit_hash=None):
40
+ """Checks out required non-pypi code from Github."""
41
  for url, dst_name in (
42
  ('https://github.com/google-research/big_vision', 'big_vision_repo'),
43
  ('https://github.com/google/flaxformer', 'flaxformer_repo'),
 
45
  dst_path = os.path.join(tempfile.gettempdir(), dst_name)
46
  if not os.path.exists(dst_path):
47
  _clone_git(url, dst_path, commit_hash)
48
+ if dst_path not in sys.path:
49
  sys.path.insert(0, dst_path)
50
 
51
 
52
  class ContrastiveModelFamily(enum.Enum):
53
+ """Defines a contrastive model family."""
54
  LIT = 'lit'
55
  SIGLIP = 'siglip'
56
 
 
99
  return ztxt, out
100
 
101
  def preprocess_texts(self, texts):
102
+ """Converts texts to padded tokens."""
103
 
104
  def tokenize_pad(text, seqlen=self.config.seqlen):
105
 
106
  if self.config.family == ContrastiveModelFamily.LIT:
107
+ tokens = self.tokenizer_bert.encode(text, add_special_tokens=True)
108
+ tokens = tokens[:-1] # removes [SEP]
109
  tokens = tokens[:seqlen]
110
  return tokens + [0] * (seqlen - len(tokens))
111
 
112
  if self.config.family == ContrastiveModelFamily.SIGLIP:
113
  tokens = self.tokenizer_sp.tokenize(text, add_eos=True)
114
  if len(tokens) >= seqlen:
115
+ eos_id = self.tokenizer_sp.eos_id()
116
+ return tokens[:seqlen - 1] + [eos_id] # "sticky" eos
117
  return tokens + [0] * (seqlen - len(tokens))
118
 
119
  return np.array([tokenize_pad(text) for text in texts])
 
131
  ]) / 127.5 - 1.0
132
 
133
  def get_bias(self, out):
134
+ assert (
135
+ self.config.family == ContrastiveModelFamily.SIGLIP
136
+ ), self.config.family
137
  return out['b'].item()
138
 
139
  def get_temperature(self, out):
 
153
  return jax.nn.sigmoid(zimg @ ztxt.T * temperature + bias)
154
 
155
 
156
+ def _make_config(
157
+ family, variant, res, textvariant, ckpt, embdim, seqlen, vocab_size
158
+ ):
159
  if family == 'lit':
160
  tokenizer = ckpt.replace('.npz', '.txt')
161
  else:
 
163
  return ContrastiveModelConfig(
164
  family=ContrastiveModelFamily(family), variant=variant, res=res,
165
  textvariant=textvariant, embdim=embdim, seqlen=seqlen,
166
+ tokenizer=tokenizer, vocab_size=vocab_size,
167
  ckpt=ckpt,
168
  )
169
 
170
 
171
+ # pylint: disable=line-too-long
172
  MODEL_CONFIGS = dict(
173
  lit_b16b=_make_config('lit', 'B/16', 224, 'B', 'gs://vit_models/lit/LiT-B16B.npz', 768, 16, 32_000),
174
  lit_l16l=_make_config('lit', 'L/16', 224, 'L', 'gs://vit_models/lit/LiT-L16L.npz', 1024, 16, 32_000),
 
184
  siglip_so400m14so440m_224=_make_config('siglip', 'So400m/14', 224, 'So400m', 'gs://big_vision/siglip/webli_en_so400m_224_57633886.npz', 1152, 16, 32_000),
185
  siglip_so400m14so400m_384=_make_config('siglip', 'So400m/14', 384, 'So400m', 'gs://big_vision/siglip/webli_en_so400m_384_58765454.npz', 1152, 64, 32_000),
186
  )
187
+ # pylint: enable=line-too-long
188
 
189
 
190
  @functools.cache
 
199
 
200
  @functools.cache
201
  def load_tokenizer_bert(path):
 
202
  if path.startswith('gs://'):
203
  dst = tempfile.mktemp()
204
  gfile.copy(path, dst)
 
214
  cfg.image_model = 'vit' # TODO(lbeyer): remove later, default
215
  if config.family == ContrastiveModelFamily.LIT:
216
  cfg.text_model = 'proj.flaxformer.bert'
217
+ cfg.image = dict(
218
+ variant=config.variant, pool_type='tok', head_zeroinit=False
219
+ )
220
  bert_config = {'B': 'base', 'L': 'large'}[config.textvariant]
221
  cfg.text = dict(config=bert_config, head_zeroinit=False)
222
  tokenizer_bert = load_tokenizer_bert(config.tokenizer)
 
224
  if config.variant == 'L/16':
225
  cfg.out_dim = (None, config.embdim) # (image_out_dim, text_out_dim)
226
  else:
227
+ # (image_out_dim, text_out_dim)
228
+ cfg.out_dim = (config.embdim, config.embdim)
229
  else:
230
  cfg.image = dict(variant=config.variant, pool_type='map')
231
+ # TODO(lbeyer): remove later, default
232
+ cfg.text_model = 'proj.image_text.text_transformer'
233
  cfg.text = dict(variant=config.textvariant, vocab_size=config.vocab_size)
234
  cfg.bias_init = -10.0
235
  tokenizer_sp = load_tokenizer_sp(config.tokenizer)
 
238
  cfg.temperature_init = 10.0
239
 
240
  model_mod = importlib.import_module(
241
+ 'big_vision.models.proj.image_text.two_towers')
242
  model = model_mod.Model(**cfg)
243
 
244
  init_params = None # Faster but bypasses loading sanity-checks.
gradio_helpers.py CHANGED
@@ -30,8 +30,9 @@ def timed(name):
30
  logging.info('Timed %s: %.1f secs', name, timing['secs'])
31
 
32
 
33
-
34
- def copy_file(src, dst, *, progress=None, block_size=1024 * 1024 * 10, overwrite=False):
 
35
  """Copies a file with progress bar.
36
 
37
  Args:
@@ -39,6 +40,7 @@ def copy_file(src, dst, *, progress=None, block_size=1024 * 1024 * 10, overwrite
39
  dst: Destination file. Path must be readable by `tf.io.gfile`.
40
  progress: An object with a `.tqdm` attribute, or `None`.
41
  block_size: Size of individual blocks to be read/written.
 
42
  """
43
  if os.path.dirname(dst):
44
  os.makedirs(os.path.dirname(dst), exist_ok=True)
@@ -87,7 +89,9 @@ def _get_array_sizes(tree):
87
  return [getattr(x, 'nbytes', 0) for x in jax.tree_leaves(tree)]
88
 
89
 
90
- def get_memory_cache(key, getter, max_cache_size_bytes, progress=None, estimated_secs=None):
 
 
91
  """Keeps cache below specified size by removing elements not last accessed."""
92
  if key in _memory_cache:
93
  _memory_cache[key] = _memory_cache.pop(key) # updated "last accessed" order
 
30
  logging.info('Timed %s: %.1f secs', name, timing['secs'])
31
 
32
 
33
+ def copy_file(
34
+ src, dst, *, progress=None, block_size=1024 * 1024 * 10, overwrite=False
35
+ ):
36
  """Copies a file with progress bar.
37
 
38
  Args:
 
40
  dst: Destination file. Path must be readable by `tf.io.gfile`.
41
  progress: An object with a `.tqdm` attribute, or `None`.
42
  block_size: Size of individual blocks to be read/written.
43
+ overwrite: If `True`, overwrite `dst` if it exists.
44
  """
45
  if os.path.dirname(dst):
46
  os.makedirs(os.path.dirname(dst), exist_ok=True)
 
89
  return [getattr(x, 'nbytes', 0) for x in jax.tree_leaves(tree)]
90
 
91
 
92
+ def get_memory_cache(
93
+ key, getter, max_cache_size_bytes, progress=None, estimated_secs=None
94
+ ):
95
  """Keeps cache below specified size by removing elements not last accessed."""
96
  if key in _memory_cache:
97
  _memory_cache[key] = _memory_cache.pop(key) # updated "last accessed" order