Avijit Ghosh commited on
Commit
de81f33
·
1 Parent(s): ab041ea

add gpu wrapper

Browse files
Files changed (1) hide show
  1. app.py +13 -15
app.py CHANGED
@@ -13,6 +13,7 @@ import os
13
  import spaces
14
 
15
  # Define model initialization functions
 
16
  def load_model(model_name):
17
  if model_name == "stabilityai/sdxl-turbo":
18
  pipeline = DiffusionPipeline.from_pretrained(
@@ -48,24 +49,21 @@ def load_model(model_name):
48
  raise ValueError("Unknown model name")
49
  return pipeline
50
 
51
- choices=[
52
- "stabilityai/sdxl-turbo",
53
- "runwayml/stable-diffusion-v1-5",
54
- "ByteDance/SDXL-Lightning",
55
- "segmind/SSD-1B"
56
- ]
57
-
58
- for model_name in choices:
59
- load_model(model_name)
60
-
61
  # Initialize the default model
62
  default_model = "stabilityai/sdxl-turbo"
63
-
64
  pipeline_text2image = load_model(default_model)
65
 
66
  @spaces.GPU
67
- def getimgen(prompt):
68
- return pipeline_text2image(prompt=prompt, guidance_scale=0.0, num_inference_steps=2).images[0]
 
 
 
 
 
 
 
 
69
 
70
  blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
71
  blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to("cuda")
@@ -115,7 +113,7 @@ def generate_images_plots(prompt, model_name):
115
  pipeline_text2image = load_model(model_name)
116
  foldername = "temp"
117
  Path(foldername).mkdir(parents=True, exist_ok=True)
118
- images = [getimgen(prompt) for _ in range(10)]
119
  genders = []
120
  skintones = []
121
  for image, i in zip(images, range(10)):
@@ -159,4 +157,4 @@ with gr.Blocks(title="Skin Tone and Gender bias in Text to Image Models") as dem
159
  genplot = gr.Plot(label="Gender")
160
  btn.click(generate_images_plots, inputs=[prompt, model_dropdown], outputs=[gallery, skinplot, genplot])
161
 
162
- demo.launch(debug=True)
 
13
  import spaces
14
 
15
  # Define model initialization functions
16
+ @spaces.GPU
17
  def load_model(model_name):
18
  if model_name == "stabilityai/sdxl-turbo":
19
  pipeline = DiffusionPipeline.from_pretrained(
 
49
  raise ValueError("Unknown model name")
50
  return pipeline
51
 
 
 
 
 
 
 
 
 
 
 
52
  # Initialize the default model
53
  default_model = "stabilityai/sdxl-turbo"
 
54
  pipeline_text2image = load_model(default_model)
55
 
56
  @spaces.GPU
57
+ def getimgen(prompt, model_name):
58
+ if model_name == "stabilityai/sdxl-turbo":
59
+ return pipeline_text2image(prompt=prompt, guidance_scale=0.0, num_inference_steps=2).images[0]
60
+ elif model_name == "runwayml/stable-diffusion-v1-5":
61
+ return pipeline_text2image(prompt).images[0]
62
+ elif model_name == "ByteDance/SDXL-Lightning":
63
+ return pipeline_text2image(prompt, num_inference_steps=4, guidance_scale=0).images[0]
64
+ elif model_name == "segmind/SSD-1B":
65
+ neg_prompt = "ugly, blurry, poor quality"
66
+ return pipeline_text2image(prompt=prompt, negative_prompt=neg_prompt).images[0]
67
 
68
  blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
69
  blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to("cuda")
 
113
  pipeline_text2image = load_model(model_name)
114
  foldername = "temp"
115
  Path(foldername).mkdir(parents=True, exist_ok=True)
116
+ images = [getimgen(prompt, model_name) for _ in range(10)]
117
  genders = []
118
  skintones = []
119
  for image, i in zip(images, range(10)):
 
157
  genplot = gr.Plot(label="Gender")
158
  btn.click(generate_images_plots, inputs=[prompt, model_dropdown], outputs=[gallery, skinplot, genplot])
159
 
160
+ demo.launch(debug=True)