vict0rsch commited on
Commit
cc3d7ef
·
1 Parent(s): 9da944e

add option to return intermediates as tensors

Browse files
Files changed (1) hide show
  1. climategan/trainer.py +18 -13
climategan/trainer.py CHANGED
@@ -334,19 +334,24 @@ class Trainer:
334
  output_data["smog"] = smog
335
 
336
  if return_intermediates:
337
- output_data["mask"] = (
338
- ((mask > bin_value) * 255).cpu().numpy().astype(np.uint8)
339
- )
340
- output_data["depth"] = (
341
- normalize_tensor(depth).cpu().squeeze(1).numpy().astype(np.uint8) * 255
342
- )
343
- output_data["segmentation"] = (
344
- decode_segmap_merged_labels(segmentation, "r", False)
345
- .cpu()
346
- .permute(0, 2, 3, 1)
347
- .numpy()
348
- .astype(np.uint8)
349
- )
 
 
 
 
 
350
 
351
  return output_data
352
 
 
334
  output_data["smog"] = smog
335
 
336
  if return_intermediates:
337
+ if numpy:
338
+ output_data["mask"] = (
339
+ ((mask > bin_value) * 255).cpu().numpy().astype(np.uint8)
340
+ )
341
+ output_data["depth"] = (
342
+ normalize_tensor(depth).cpu().squeeze(1).numpy().astype(np.uint8) * 255
343
+ )
344
+ output_data["segmentation"] = (
345
+ decode_segmap_merged_labels(segmentation, "r", False)
346
+ .cpu()
347
+ .permute(0, 2, 3, 1)
348
+ .numpy()
349
+ .astype(np.uint8)
350
+ )
351
+ else:
352
+ output_data["mask"] = mask
353
+ output_data["depth"] = depth
354
+ output_data["segmentation"] = segmentation
355
 
356
  return output_data
357