rynmurdock commited on
Commit
2da41ae
1 Parent(s): 37458c6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -3
app.py CHANGED
@@ -36,11 +36,16 @@ sdxl_lightening = "ByteDance/SDXL-Lightning"
36
  ckpt = "sdxl_lightning_2step_unet.safetensors"
37
  unet = UNet2DConditionModel.from_config(model_id, subfolder="unet").to("cuda", torch.float16)
38
  unet.load_state_dict(load_file(hf_hub_download(sdxl_lightening, ckpt), device="cuda"))
39
- pipe = StableDiffusionXLPipeline.from_pretrained(model_id, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")
40
  pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.float16)
41
  pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
42
  pipe.to(device='cuda')
43
- pipe.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
 
 
 
 
 
44
 
45
  output_hidden_state = False
46
  #######################
@@ -57,7 +62,7 @@ def predict(
57
  im_emb = torch.zeros(1, 1280, dtype=torch.float16, device='cuda')
58
  image = pipe(
59
  prompt=prompt,
60
- image_embeds=[im_emb.to('cuda')],
61
  height=1024,
62
  width=1024,
63
  num_inference_steps=2,
 
36
  ckpt = "sdxl_lightning_2step_unet.safetensors"
37
  unet = UNet2DConditionModel.from_config(model_id, subfolder="unet").to("cuda", torch.float16)
38
  unet.load_state_dict(load_file(hf_hub_download(sdxl_lightening, ckpt), device="cuda"))
39
+ pipe = SDEmb.from_pretrained(model_id, unet=unet, torch_dtype=torch.float16, variant="fp16").to("cuda")
40
  pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.float16)
41
  pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
42
  pipe.to(device='cuda')
43
+
44
+ image_encoder = CLIPimg.from_pretrained('h94/IP-Adapter', subfolder='sdxl_models/image_encoder', torch_dtype=torch.float16).to(device)
45
+ pipe.unet._load_ip_adapter_weights(torch.load(hf_hub_download('h94/IP-Adapter', 'sdxl_models/ip-adapter_sdxl.bin'), map_location="cpu"))
46
+ pipe.register_modules(image_encoder = image_encoder)
47
+
48
+ # pipe.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl.bin")
49
 
50
  output_hidden_state = False
51
  #######################
 
62
  im_emb = torch.zeros(1, 1280, dtype=torch.float16, device='cuda')
63
  image = pipe(
64
  prompt=prompt,
65
+ ip_adapter_emb=[im_emb.to('cuda')],
66
  height=1024,
67
  width=1024,
68
  num_inference_steps=2,