cocktailpeanut commited on
Commit
003e981
1 Parent(s): 31dfa92
Files changed (1) hide show
  1. app.py +17 -5
app.py CHANGED
@@ -23,7 +23,13 @@ import gradio as gr
23
 
24
  # global variable
25
  MAX_SEED = np.iinfo(np.int32).max
26
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
27
  STYLE_NAMES = list(styles.keys())
28
  DEFAULT_STYLE_NAME = "Watercolor"
29
 
@@ -53,10 +59,16 @@ pipe = StableDiffusionXLInstantIDPipeline.from_pretrained(
53
  safety_checker=None,
54
  feature_extractor=None,
55
  )
56
- pipe.cuda()
 
 
 
57
  pipe.load_ip_adapter_instantid(face_adapter)
58
- pipe.image_proj_model.to('cuda')
59
- pipe.unet.to('cuda')
 
 
 
60
 
61
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
62
  if randomize_seed:
@@ -407,4 +419,4 @@ with gr.Blocks(css=css) as demo:
407
 
408
  gr.Markdown(article)
409
 
410
- demo.launch()
 
23
 
24
  # global variable
25
  MAX_SEED = np.iinfo(np.int32).max
26
+ #device = "cuda" if torch.cuda.is_available() else "cpu"
27
+ if torch.backends.mps.is_available():
28
+ device = "mps"
29
+ elif torch.cuda.is_available():
30
+ device = "cuda"
31
+ else:
32
+ device = "cpu"
33
  STYLE_NAMES = list(styles.keys())
34
  DEFAULT_STYLE_NAME = "Watercolor"
35
 
 
59
  safety_checker=None,
60
  feature_extractor=None,
61
  )
62
+ if device == 'mps':
63
+ pipe.to("mps")
64
+ elif device == 'cuda':
65
+ pipe.cuda()
66
  pipe.load_ip_adapter_instantid(face_adapter)
67
+ #pipe.image_proj_model.to('cuda')
68
+ #pipe.unet.to('cuda')
69
+ if device == 'mps' or device == 'cuda':
70
+ pipe.image_proj_model.to(device)
71
+ pipe.unet.to(device)
72
 
73
  def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
74
  if randomize_seed:
 
419
 
420
  gr.Markdown(article)
421
 
422
+ demo.launch()