aka7774 commited on
Commit
8b7ddb2
1 Parent(s): fff5ee4

Update fn.py

Browse files
Files changed (1) hide show
  1. fn.py +8 -3
fn.py CHANGED
@@ -14,18 +14,23 @@ def load_model(_model = None, _vae = None, loras = []):
14
 
15
  _model = _model or 'cagliostrolab/animagine-xl-3.0'
16
 
 
 
 
 
 
17
  if _vae:
18
  # "stabilityai/sdxl-vae"
19
- vae = AutoencoderKL.from_pretrained(_vae, torch_dtype=torch.float16)
20
  pipe = AutoPipelineForText2Image.from_pretrained(
21
  _model,
22
- torch_dtype=torch.float16,
23
  vae=vae,
24
  )
25
  else:
26
  pipe = AutoPipelineForText2Image.from_pretrained(
27
  _model,
28
- torch_dtype=torch.float16,
29
  )
30
 
31
  # DPM++ 2M Karras
 
14
 
15
  _model = _model or 'cagliostrolab/animagine-xl-3.0'
16
 
17
+ if torch.cuda.is_available():
18
+ torch_dtype = torch.float16
19
+ else:
20
+ torch_dtype = torch.float32
21
+
22
  if _vae:
23
  # "stabilityai/sdxl-vae"
24
+ vae = AutoencoderKL.from_pretrained(_vae, torch_dtype=torch_dtype)
25
  pipe = AutoPipelineForText2Image.from_pretrained(
26
  _model,
27
+ torch_dtype=torch_dtype,
28
  vae=vae,
29
  )
30
  else:
31
  pipe = AutoPipelineForText2Image.from_pretrained(
32
  _model,
33
+ torch_dtype=torch_dtype,
34
  )
35
 
36
  # DPM++ 2M Karras