vict0rsch commited on
Commit
87ed0da
1 Parent(s): fe58151

half only if cuda is available

Browse files
Files changed (1) hide show
  1. climategan_wrapper.py +7 -4
climategan_wrapper.py CHANGED
@@ -15,6 +15,8 @@ from skimage.transform import resize
15
 
16
  from climategan.trainer import Trainer
17
 
 
 
18
 
19
  def concat_events(output_dict, events, i=None, axis=1):
20
  """
@@ -136,7 +138,8 @@ class ClimateGAN:
136
  inference=True,
137
  new_exp=None,
138
  )
139
- self.trainer.G.half()
 
140
 
141
  def _setup_stable_diffusion(self):
142
  """
@@ -150,8 +153,8 @@ class ClimateGAN:
150
  try:
151
  self.sdip_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
152
  "runwayml/stable-diffusion-inpainting",
153
- revision="fp16",
154
- torch_dtype=torch.float16,
155
  safety_checker=None,
156
  use_auth_token=os.environ.get("HF_AUTH_TOKEN"),
157
  ).to(self.trainer.device)
@@ -308,7 +311,7 @@ class ClimateGAN:
308
  images,
309
  numpy=True,
310
  bin_value=0.5,
311
- half=True,
312
  ignore_event=ignore_event,
313
  return_masks=True,
314
  )
 
15
 
16
  from climategan.trainer import Trainer
17
 
18
+ CUDA = torch.cuda.is_available()
19
+
20
 
21
  def concat_events(output_dict, events, i=None, axis=1):
22
  """
 
138
  inference=True,
139
  new_exp=None,
140
  )
141
+ if CUDA:
142
+ self.trainer.G.half()
143
 
144
  def _setup_stable_diffusion(self):
145
  """
 
153
  try:
154
  self.sdip_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
155
  "runwayml/stable-diffusion-inpainting",
156
+ revision="fp16" if CUDA else "main",
157
+ torch_dtype=torch.float16 if CUDA else torch.float32,
158
  safety_checker=None,
159
  use_auth_token=os.environ.get("HF_AUTH_TOKEN"),
160
  ).to(self.trainer.device)
 
311
  images,
312
  numpy=True,
313
  bin_value=0.5,
314
+ half=CUDA,
315
  ignore_event=ignore_event,
316
  return_masks=True,
317
  )