Spaces:
Runtime error
Runtime error
add option to return intermediates as tensors
Browse files- climategan/trainer.py +18 -13
climategan/trainer.py
CHANGED
@@ -334,19 +334,24 @@ class Trainer:
|
|
334 |
output_data["smog"] = smog
|
335 |
|
336 |
if return_intermediates:
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
|
|
|
|
|
|
|
|
|
|
350 |
|
351 |
return output_data
|
352 |
|
|
|
334 |
output_data["smog"] = smog
|
335 |
|
336 |
if return_intermediates:
|
337 |
+
if numpy:
|
338 |
+
output_data["mask"] = (
|
339 |
+
((mask > bin_value) * 255).cpu().numpy().astype(np.uint8)
|
340 |
+
)
|
341 |
+
output_data["depth"] = (
|
342 |
+
normalize_tensor(depth).cpu().squeeze(1).numpy().astype(np.uint8) * 255
|
343 |
+
)
|
344 |
+
output_data["segmentation"] = (
|
345 |
+
decode_segmap_merged_labels(segmentation, "r", False)
|
346 |
+
.cpu()
|
347 |
+
.permute(0, 2, 3, 1)
|
348 |
+
.numpy()
|
349 |
+
.astype(np.uint8)
|
350 |
+
)
|
351 |
+
else:
|
352 |
+
output_data["mask"] = mask
|
353 |
+
output_data["depth"] = depth
|
354 |
+
output_data["segmentation"] = segmentation
|
355 |
|
356 |
return output_data
|
357 |
|