nightfury commited on
Commit
60dcdf1
·
1 Parent(s): 46a79dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -71,7 +71,8 @@ transform = transforms.Compose([
71
 
72
  def predict(radio, dict, word_mask, prompt=""):
73
  if(radio == "draw a mask above"):
74
- with autocast(device): #"cuda"
 
75
  init_image = dict["image"].convert("RGB").resize((imgRes, imgRes))
76
  mask = dict["mask"].convert("RGB").resize((imgRes, imgRes))
77
  elif(radio == "type what to keep"):
@@ -137,7 +138,8 @@ def predict(radio, dict, word_mask, prompt=""):
137
  os.remove(filename)
138
 
139
  #with autocast(device): #"cuda"
140
- with autocast(device_type="cpu", dtype=torch.bfloat16):
 
141
  images = pipe(prompt = prompt, init_image=init_image, mask_image=mask, strength=0.8)["sample"]
142
  return images[0]
143
 
 
71
 
72
  def predict(radio, dict, word_mask, prompt=""):
73
  if(radio == "draw a mask above"):
74
+ #with autocast(device): #"cuda"
75
+ with autocast(enable=(False if device=='cpu' else True)):
76
  init_image = dict["image"].convert("RGB").resize((imgRes, imgRes))
77
  mask = dict["mask"].convert("RGB").resize((imgRes, imgRes))
78
  elif(radio == "type what to keep"):
 
138
  os.remove(filename)
139
 
140
  #with autocast(device): #"cuda"
141
+ with autocast(enable=(False if device=='cpu' else True)):
142
+ #with autocast(device_type="cpu", dtype=torch.bfloat16):
143
  images = pipe(prompt = prompt, init_image=init_image, mask_image=mask, strength=0.8)["sample"]
144
  return images[0]
145