abreza commited on
Commit
ccc4eeb
1 Parent(s): 7979b4d

trying to add sdxl

Browse files
Files changed (1) hide show
  1. app.py +13 -0
app.py CHANGED
@@ -184,6 +184,19 @@ model.load_state_dict(state_dict, strict=True)
184
 
185
  model = model.to(device)
186
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  print('Loading Finished!')
188
 
189
  # Gradio UI
 
184
 
185
  model = model.to(device)
186
 
187
+ # Load text-to-image model
188
+ print('Loading text-to-image model ...')
189
+
190
+ pipe = StableDiffusionXLPipeline.from_pretrained(
191
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.bfloat16)
192
+ pipe.to(device="cuda", dtype=torch.bfloat16)
193
+
194
+ unet_state = load_file(hf_hub_download(
195
+ "ByteDance/Hyper-SD", "Hyper-SDXL-1step-Unet.safetensors"), device="cuda")
196
+ pipe.unet.load_state_dict(unet_state)
197
+ pipe.scheduler = LCMScheduler.from_config(
198
+ pipe.scheduler.config, timestep_spacing="trailing")
199
+
200
  print('Loading Finished!')
201
 
202
  # Gradio UI