Spaces:
Running
on
A10G
Running
on
A10G
rynmurdock
commited on
Commit
•
385fb5f
1
Parent(s):
93b9a94
device changes
Browse files
app.py
CHANGED
@@ -34,18 +34,18 @@ start_time = time.time()
|
|
34 |
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
35 |
sdxl_lightening = "ByteDance/SDXL-Lightning"
|
36 |
ckpt = "sdxl_lightning_2step_unet.safetensors"
|
37 |
-
unet = UNet2DConditionModel.from_config(model_id, subfolder="unet").to(
|
38 |
-
unet.load_state_dict(load_file(hf_hub_download(sdxl_lightening, ckpt), device=
|
39 |
|
40 |
-
image_encoder = CLIPVisionModelWithProjection.from_pretrained("h94/IP-Adapter", subfolder="models/image_encoder", torch_dtype=torch.float16,).to(
|
41 |
-
pipe = AutoPipelineForText2Image.from_pretrained(model_id, unet=unet, torch_dtype=torch.float16, variant="fp16", image_encoder=image_encoder).to(
|
42 |
pipe.unet._load_ip_adapter_weights(torch.load(hf_hub_download('h94/IP-Adapter', 'sdxl_models/ip-adapter_sdxl_vit-h.bin')))
|
43 |
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl_vit-h.bin")
|
44 |
pipe.register_modules(image_encoder = image_encoder)
|
45 |
|
46 |
pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.float16)
|
47 |
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
|
48 |
-
pipe.to(device=
|
49 |
|
50 |
|
51 |
output_hidden_state = False
|
@@ -60,13 +60,13 @@ def predict(
|
|
60 |
"""Run a single prediction on the model"""
|
61 |
with torch.no_grad():
|
62 |
if im_emb == None:
|
63 |
-
im_emb = torch.zeros(1, 1024, dtype=torch.float16, device=
|
64 |
|
65 |
-
im_emb = [im_emb.to(
|
66 |
if prompt == '':
|
67 |
image = pipe(
|
68 |
-
prompt_embeds=torch.zeros(1, 1, 2048, dtype=torch.float16, device=
|
69 |
-
pooled_prompt_embeds=torch.zeros(1, 1280, dtype=torch.float16, device=
|
70 |
ip_adapter_image_embeds=im_emb,
|
71 |
height=1024,
|
72 |
width=1024,
|
@@ -83,9 +83,9 @@ def predict(
|
|
83 |
guidance_scale=0,
|
84 |
).images[0]
|
85 |
im_emb, _ = pipe.encode_image(
|
86 |
-
image,
|
87 |
)
|
88 |
-
return image, im_emb.to(
|
89 |
|
90 |
# TODO add to state instead of shared across all
|
91 |
glob_idx = 0
|
@@ -128,7 +128,7 @@ def next_image(embs, ys, calibrate_prompts):
|
|
128 |
if has_0 and has_1:
|
129 |
break
|
130 |
|
131 |
-
feature_embs = np.array(torch.cat([embs[i] for i in indices]).to('cpu'))
|
132 |
scaler = preprocessing.StandardScaler().fit(feature_embs)
|
133 |
feature_embs = scaler.transform(feature_embs)
|
134 |
|
@@ -138,7 +138,7 @@ def next_image(embs, ys, calibrate_prompts):
|
|
138 |
|
139 |
rng_prompt = random.choice(prompt_list)
|
140 |
w = 1# if len(embs) % 2 == 0 else 0
|
141 |
-
im_emb = w * lin_class.coef_.to(
|
142 |
prompt= 'an image' if glob_idx % 2 == 0 else rng_prompt
|
143 |
print(prompt, len(ys))
|
144 |
image, im_emb = predict(prompt, im_emb)
|
|
|
34 |
model_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
35 |
sdxl_lightening = "ByteDance/SDXL-Lightning"
|
36 |
ckpt = "sdxl_lightning_2step_unet.safetensors"
|
37 |
+
unet = UNet2DConditionModel.from_config(model_id, subfolder="unet").to(DEVICE, torch.float16)
|
38 |
+
unet.load_state_dict(load_file(hf_hub_download(sdxl_lightening, ckpt), device=DEVICE))
|
39 |
|
40 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained("h94/IP-Adapter", subfolder="models/image_encoder", torch_dtype=torch.float16,).to(DEVICE)
|
41 |
+
pipe = AutoPipelineForText2Image.from_pretrained(model_id, unet=unet, torch_dtype=torch.float16, variant="fp16", image_encoder=image_encoder).to(DEVICE)
|
42 |
pipe.unet._load_ip_adapter_weights(torch.load(hf_hub_download('h94/IP-Adapter', 'sdxl_models/ip-adapter_sdxl_vit-h.bin')))
|
43 |
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter_sdxl_vit-h.bin")
|
44 |
pipe.register_modules(image_encoder = image_encoder)
|
45 |
|
46 |
pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesdxl", torch_dtype=torch.float16)
|
47 |
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config, timestep_spacing="trailing")
|
48 |
+
pipe.to(device=DEVICE)
|
49 |
|
50 |
|
51 |
output_hidden_state = False
|
|
|
60 |
"""Run a single prediction on the model"""
|
61 |
with torch.no_grad():
|
62 |
if im_emb == None:
|
63 |
+
im_emb = torch.zeros(1, 1024, dtype=torch.float16, device=DEVICE)
|
64 |
|
65 |
+
im_emb = [im_emb.to(DEVICE).unsqueeze(0)]
|
66 |
if prompt == '':
|
67 |
image = pipe(
|
68 |
+
prompt_embeds=torch.zeros(1, 1, 2048, dtype=torch.float16, device=DEVICE),
|
69 |
+
pooled_prompt_embeds=torch.zeros(1, 1280, dtype=torch.float16, device=DEVICE),
|
70 |
ip_adapter_image_embeds=im_emb,
|
71 |
height=1024,
|
72 |
width=1024,
|
|
|
83 |
guidance_scale=0,
|
84 |
).images[0]
|
85 |
im_emb, _ = pipe.encode_image(
|
86 |
+
image, DEVICE, 1, output_hidden_state
|
87 |
)
|
88 |
+
return image, im_emb.to('cpu')
|
89 |
|
90 |
# TODO add to state instead of shared across all
|
91 |
glob_idx = 0
|
|
|
128 |
if has_0 and has_1:
|
129 |
break
|
130 |
|
131 |
+
feature_embs = np.array(torch.cat([embs[i].to('cpu') for i in indices]).to('cpu'))
|
132 |
scaler = preprocessing.StandardScaler().fit(feature_embs)
|
133 |
feature_embs = scaler.transform(feature_embs)
|
134 |
|
|
|
138 |
|
139 |
rng_prompt = random.choice(prompt_list)
|
140 |
w = 1# if len(embs) % 2 == 0 else 0
|
141 |
+
im_emb = w * lin_class.coef_.to(dtype=torch.float16)
|
142 |
prompt= 'an image' if glob_idx % 2 == 0 else rng_prompt
|
143 |
print(prompt, len(ys))
|
144 |
image, im_emb = predict(prompt, im_emb)
|