Spaces:
Runtime error
Runtime error
half only if cuda is available
Browse files- climategan_wrapper.py +7 -4
climategan_wrapper.py
CHANGED
@@ -15,6 +15,8 @@ from skimage.transform import resize
|
|
15 |
|
16 |
from climategan.trainer import Trainer
|
17 |
|
|
|
|
|
18 |
|
19 |
def concat_events(output_dict, events, i=None, axis=1):
|
20 |
"""
|
@@ -136,7 +138,8 @@ class ClimateGAN:
|
|
136 |
inference=True,
|
137 |
new_exp=None,
|
138 |
)
|
139 |
-
|
|
|
140 |
|
141 |
def _setup_stable_diffusion(self):
|
142 |
"""
|
@@ -150,8 +153,8 @@ class ClimateGAN:
|
|
150 |
try:
|
151 |
self.sdip_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
152 |
"runwayml/stable-diffusion-inpainting",
|
153 |
-
revision="fp16",
|
154 |
-
torch_dtype=torch.float16,
|
155 |
safety_checker=None,
|
156 |
use_auth_token=os.environ.get("HF_AUTH_TOKEN"),
|
157 |
).to(self.trainer.device)
|
@@ -308,7 +311,7 @@ class ClimateGAN:
|
|
308 |
images,
|
309 |
numpy=True,
|
310 |
bin_value=0.5,
|
311 |
-
half=
|
312 |
ignore_event=ignore_event,
|
313 |
return_masks=True,
|
314 |
)
|
|
|
15 |
|
16 |
from climategan.trainer import Trainer
|
17 |
|
18 |
+
CUDA = torch.cuda.is_available()
|
19 |
+
|
20 |
|
21 |
def concat_events(output_dict, events, i=None, axis=1):
|
22 |
"""
|
|
|
138 |
inference=True,
|
139 |
new_exp=None,
|
140 |
)
|
141 |
+
if CUDA:
|
142 |
+
self.trainer.G.half()
|
143 |
|
144 |
def _setup_stable_diffusion(self):
|
145 |
"""
|
|
|
153 |
try:
|
154 |
self.sdip_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
155 |
"runwayml/stable-diffusion-inpainting",
|
156 |
+
revision="fp16" if CUDA else "main",
|
157 |
+
torch_dtype=torch.float16 if CUDA else torch.float32,
|
158 |
safety_checker=None,
|
159 |
use_auth_token=os.environ.get("HF_AUTH_TOKEN"),
|
160 |
).to(self.trainer.device)
|
|
|
311 |
images,
|
312 |
numpy=True,
|
313 |
bin_value=0.5,
|
314 |
+
half=CUDA,
|
315 |
ignore_event=ignore_event,
|
316 |
return_masks=True,
|
317 |
)
|