cocktailpeanut commited on
Commit
c6dd45b
1 Parent(s): c655f3a

MPS support

Browse files
Files changed (1) hide show
  1. app.py +16 -7
app.py CHANGED
@@ -17,18 +17,27 @@ from share_btn import community_icon_html, loading_icon_html, share_js
17
  # load pipelines
18
  # sd_model_id = "runwayml/stable-diffusion-v1-5"
19
  sd_model_id = "stabilityai/stable-diffusion-2-1-base"
20
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
- vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16)
22
- pipe = SemanticStableDiffusionImg2ImgPipeline_DPMSolver.from_pretrained(sd_model_id,vae=vae,torch_dtype=torch.float16, safety_checker=None, requires_safety_checker=False).to(device)
 
 
 
 
 
 
 
 
 
23
  pipe.scheduler = DPMSolverMultistepSchedulerInject.from_pretrained(sd_model_id, subfolder="scheduler"
24
  , algorithm_type="sde-dpmsolver++", solver_order=2)
25
 
26
  blip_processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
27
- blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base",torch_dtype=torch.float16).to(device)
28
 
29
  ## IMAGE CPATIONING ##
30
  def caption_image(input_image):
31
- inputs = blip_processor(images=input_image, return_tensors="pt").to(device, torch.float16)
32
  pixel_values = inputs.pixel_values
33
 
34
  generated_ids = blip_model.generate(pixel_values=pixel_values, max_length=50)
@@ -228,7 +237,7 @@ def randomize_seed_fn(seed, is_random):
228
 
229
  def seed_everything(seed):
230
  torch.manual_seed(seed)
231
- torch.cuda.manual_seed(seed)
232
  random.seed(seed)
233
  np.random.seed(seed)
234
 
@@ -902,4 +911,4 @@ with gr.Blocks(css="style.css") as demo:
902
  )
903
 
904
  demo.queue(default_concurrency_limit=1)
905
- demo.launch()
 
17
  # load pipelines
18
  # sd_model_id = "runwayml/stable-diffusion-v1-5"
19
  sd_model_id = "stabilityai/stable-diffusion-2-1-base"
20
+ #device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
+
22
+ torch_dtype = torch.float16
23
+ if torch.cuda.is_available():
24
+ device = "cuda"
25
+ elif torch.backends.mps.is_available():
26
+ device = "mps"
27
+ torch_dtype = torch.float32
28
+ else:
29
+ device = "cpu"
30
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch_dtype)
31
+ pipe = SemanticStableDiffusionImg2ImgPipeline_DPMSolver.from_pretrained(sd_model_id,vae=vae,torch_dtype=torch_dtype, safety_checker=None, requires_safety_checker=False).to(device)
32
  pipe.scheduler = DPMSolverMultistepSchedulerInject.from_pretrained(sd_model_id, subfolder="scheduler"
33
  , algorithm_type="sde-dpmsolver++", solver_order=2)
34
 
35
  blip_processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
36
+ blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base",torch_dtype=torch_dtype).to(device)
37
 
38
  ## IMAGE CPATIONING ##
39
  def caption_image(input_image):
40
+ inputs = blip_processor(images=input_image, return_tensors="pt").to(device, torch_dtype)
41
  pixel_values = inputs.pixel_values
42
 
43
  generated_ids = blip_model.generate(pixel_values=pixel_values, max_length=50)
 
237
 
238
  def seed_everything(seed):
239
  torch.manual_seed(seed)
240
+ # torch.cuda.manual_seed(seed)
241
  random.seed(seed)
242
  np.random.seed(seed)
243
 
 
911
  )
912
 
913
  demo.queue(default_concurrency_limit=1)
914
+ demo.launch()