carlosgomes98 commited on
Commit
cc6ba7e
1 Parent(s): 64ad424

Update to use more parseable config

Browse files
Files changed (1) hide show
  1. app.py +13 -33
app.py CHANGED
@@ -261,25 +261,18 @@ def predict_on_images(data_files: list, mask_ratio: float, yaml_file_path: str,
261
  params = yaml.safe_load(f)
262
 
263
  # data related
264
- num_frames = params['num_frames']
265
- img_size = params['img_size']
266
- bands = params['bands']
267
- mean = params['data_mean']
268
- std = params['data_std']
269
-
270
- # model related
271
- depth = params['depth']
272
- patch_size = params['patch_size']
273
- embed_dim = params['embed_dim']
274
- num_heads = params['num_heads']
275
- tubelet_size = params['tubelet_size']
276
- decoder_embed_dim = params['decoder_embed_dim']
277
- decoder_num_heads = params['decoder_num_heads']
278
- decoder_depth = params['decoder_depth']
279
-
280
- batch_size = params['batch_size']
281
-
282
- mask_ratio = params['mask_ratio'] if mask_ratio is None else mask_ratio
283
 
284
  # We must have *num_frames* files to build one example!
285
  assert len(data_files) == num_frames, "File list must be equal to expected number of frames."
@@ -298,20 +291,7 @@ def predict_on_images(data_files: list, mask_ratio: float, yaml_file_path: str,
298
  # Create model and load checkpoint -------------------------------------------------------------
299
 
300
  model = MaskedAutoencoderViT(
301
- img_size=img_size,
302
- patch_size=patch_size,
303
- num_frames=num_frames,
304
- tubelet_size=tubelet_size,
305
- in_chans=len(bands),
306
- embed_dim=embed_dim,
307
- depth=depth,
308
- num_heads=num_heads,
309
- decoder_embed_dim=decoder_embed_dim,
310
- decoder_depth=decoder_depth,
311
- decoder_num_heads=decoder_num_heads,
312
- mlp_ratio=4.,
313
- norm_layer=functools.partial(torch.nn.LayerNorm, eps=1e-6),
314
- norm_pix_loss=False)
315
 
316
  total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
317
  print(f"\n--> Model has {total_params:,} parameters.\n")
 
261
  params = yaml.safe_load(f)
262
 
263
  # data related
264
+ train_params = params["train_params"]
265
+ num_frames = train_params['num_frames']
266
+ img_size = train_params['img_size']
267
+ bands = train_params['bands']
268
+ mean = train_params['data_mean']
269
+ std = train_params['data_std']
270
+
271
+ model_params = params["model_args"]
272
+
273
+ batch_size = 8
274
+
275
+ mask_ratio = train_params['mask_ratio'] if mask_ratio is None else mask_ratio
 
 
 
 
 
 
 
276
 
277
  # We must have *num_frames* files to build one example!
278
  assert len(data_files) == num_frames, "File list must be equal to expected number of frames."
 
291
  # Create model and load checkpoint -------------------------------------------------------------
292
 
293
  model = MaskedAutoencoderViT(
294
+ **model_params)
 
 
 
 
 
 
 
 
 
 
 
 
 
295
 
296
  total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
297
  print(f"\n--> Model has {total_params:,} parameters.\n")