chaojiemao commited on
Commit
4ba6875
·
verified ·
1 Parent(s): 1dce861

Update ace_inference.py

Browse files
Files changed (1) hide show
  1. ace_inference.py +3 -2
ace_inference.py CHANGED
@@ -154,6 +154,7 @@ class ACEInference(DiffusionInference):
154
  self.diffusion_model['model'] = BACKBONES.build(self.diffusion_model['cfg'], logger=self.logger).eval()
155
  # self.dynamic_load(self.diffusion_model, 'diffusion_model')
156
  self.diffusion_model['model'].load_pretrained_model(pretrained_model)
 
157
  self.diffusion_model['device'] = we.device_id
158
 
159
  def upscale_resize(self, image, interpolation=T.InterpolationMode.BILINEAR):
@@ -326,8 +327,8 @@ class ACEInference(DiffusionInference):
326
  self.dynamic_load(self.diffusion_model, 'diffusion_model')
327
  function_name, dtype = self.get_function_info(self.diffusion_model)
328
  with torch.autocast('cuda',
329
- enabled=dtype in ('float16', 'bfloat16'),
330
- dtype=getattr(torch, dtype)):
331
  latent = self.diffusion.sample(
332
  noise=noise,
333
  sampler=sampler,
 
154
  self.diffusion_model['model'] = BACKBONES.build(self.diffusion_model['cfg'], logger=self.logger).eval()
155
  # self.dynamic_load(self.diffusion_model, 'diffusion_model')
156
  self.diffusion_model['model'].load_pretrained_model(pretrained_model)
157
+ self.diffusion_model['model'] = self.diffusion_model['model'].to(torch.bfloat16)
158
  self.diffusion_model['device'] = we.device_id
159
 
160
  def upscale_resize(self, image, interpolation=T.InterpolationMode.BILINEAR):
 
327
  self.dynamic_load(self.diffusion_model, 'diffusion_model')
328
  function_name, dtype = self.get_function_info(self.diffusion_model)
329
  with torch.autocast('cuda',
330
+ enabled=True,
331
+ dtype=torch.bfloat16):
332
  latent = self.diffusion.sample(
333
  noise=noise,
334
  sampler=sampler,