amildravid4292 commited on
Commit
9f713c2
·
verified ·
1 Parent(s): fc54821

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +134 -197
app.py CHANGED
@@ -6,210 +6,97 @@ from torch.utils.data import Dataset, DataLoader
6
  import gradio as gr
7
  import sys
8
  import tqdm
9
- import uuid
10
  sys.path.append(os.path.abspath(os.path.join("", "..")))
11
  import gc
12
  import warnings
13
  warnings.filterwarnings("ignore")
14
  from PIL import Image
15
  import numpy as np
 
16
  from editing import get_direction, debias
 
17
  from lora_w2w import LoRAw2w
18
  from huggingface_hub import snapshot_download
19
  import spaces
20
- from transformers import CLIPTextModel
21
- from lora_w2w import LoRAw2w
22
- from diffusers import AutoencoderKL, DDPMScheduler, DiffusionPipeline, UNet2DConditionModel, LMSDiscreteScheduler
23
- from transformers import AutoTokenizer, PretrainedConfig
24
- import warnings
25
- warnings.filterwarnings("ignore")
26
- from diffusers import (
27
- AutoencoderKL,
28
- DDPMScheduler,
29
- DiffusionPipeline,
30
- DPMSolverMultistepScheduler,
31
- UNet2DConditionModel,
32
- PNDMScheduler,
33
- StableDiffusionPipeline
34
- )
35
-
36
-
37
- device = gr.State("cuda")
38
- unet = gr.State()
39
- vae = gr.State()
40
- text_encoder = gr.State()
41
- tokenizer = gr.State()
42
- noise_scheduler = gr.State()
43
- network = gr.State()
44
-
45
- pretrained_model_name_or_path = "stablediffusionapi/realistic-vision-v51"
46
- revision = None
47
- rank = 1
48
- weight_dtype = torch.bfloat16
49
- # Load scheduler, tokenizer and models.
50
- pipe = StableDiffusionPipeline.from_pretrained("stablediffusionapi/realistic-vision-v51",
51
- torch_dtype=torch.float16,safety_checker = None,
52
- requires_safety_checker = False).to(device.value)
53
- noise_scheduler.value = pipe.scheduler
54
- del pipe
55
- tokenizer.value = AutoTokenizer.from_pretrained(
56
- pretrained_model_name_or_path, subfolder="tokenizer", revision=revision
57
- )
58
- text_encoder.value = CLIPTextModel.from_pretrained(
59
- pretrained_model_name_or_path, subfolder="text_encoder", revision=revision
60
- )
61
- vae.value = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae", revision=revision)
62
- unet.value = UNet2DConditionModel.from_pretrained(
63
- pretrained_model_name_or_path, subfolder="unet", revision=revision
64
- )
65
-
66
- unet.value.requires_grad_(False)
67
- unet.value.to(device.value, dtype=weight_dtype)
68
- vae.value.requires_grad_(False)
69
-
70
- text_encoder.value.requires_grad_(False)
71
- vae.value.requires_grad_(False)
72
- vae.value.to(device.value, dtype=weight_dtype)
73
- text_encoder.value.to(device.value, dtype=weight_dtype)
74
- print("")
75
-
76
-
77
-
78
-
79
-
80
-
81
-
82
-
83
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  models_path = snapshot_download(repo_id="Snapchat/w2w")
86
 
87
- mean = torch.load(f"{models_path}/files/mean.pt", map_location=torch.device('cpu')).bfloat16().to(device.value)
88
- std = torch.load(f"{models_path}/files/std.pt", map_location=torch.device('cpu')).bfloat16().to(device.value)
89
- v = torch.load(f"{models_path}/files/V.pt", map_location=torch.device('cpu')).bfloat16().to(device.value)
90
- proj = torch.load(f"{models_path}/files/proj_1000pc.pt", map_location=torch.device('cpu')).bfloat16().to(device.value)
91
  df = torch.load(f"{models_path}/files/identity_df.pt")
92
  weight_dimensions = torch.load(f"{models_path}/files/weight_dimensions.pt")
93
- pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt", map_location=torch.device('cpu')).bfloat16().to(device.value)
94
-
95
-
96
- young = gr.State()
97
- young.value = get_direction(df, "Young", pinverse, 1000, device.value)
98
- young.value = debias(young.value, "Male", df, pinverse, device.value)
99
- young.value = debias(young.value, "Pointy_Nose", df, pinverse, device.value)
100
- young.value = debias(young.value, "Wavy_Hair", df, pinverse, device.value)
101
- young.value = debias(young.value, "Chubby", df, pinverse, device.value)
102
- young.value = debias(young.value, "No_Beard", df, pinverse, device.value)
103
- young.value = debias(young.value, "Mustache", df, pinverse, device.value)
104
-
105
- pointy = gr.State()
106
- pointy.value = get_direction(df, "Pointy_Nose", pinverse, 1000, device.value)
107
- pointy.value = debias(pointy.value, "Young", df, pinverse, device.value)
108
- pointy.value = debias(pointy.value, "Male", df, pinverse, device.value)
109
- pointy.value = debias(pointy.value, "Wavy_Hair", df, pinverse, device.value)
110
- pointy.value = debias(pointy.value, "Chubby", df, pinverse, device.value)
111
- pointy.value = debias(pointy.value, "Heavy_Makeup", df, pinverse, device.value)
112
-
113
- wavy = gr.State()
114
- wavy.value = get_direction(df, "Wavy_Hair", pinverse, 1000, device.value)
115
- wavy.value = debias(wavy.value, "Young", df, pinverse, device.value)
116
- wavy.value = debias(wavy.value, "Male", df, pinverse, device.value)
117
- wavy.value = debias(wavy.value, "Pointy_Nose", df, pinverse, device.value)
118
- wavy.value = debias(wavy.value, "Chubby", df, pinverse, device.value)
119
- wavy.value = debias(wavy.value, "Heavy_Makeup", df, pinverse, device.value)
120
-
121
- thick = gr.State()
122
- thick.value = get_direction(df, "Bushy_Eyebrows", pinverse, 1000, device.value)
123
- thick.value = debias(thick.value, "Male", df, pinverse, device.value)
124
- thick.value = debias(thick.value, "Young", df, pinverse, device.value)
125
- thick.value = debias(thick.value, "Pointy_Nose", df, pinverse, device.value)
126
- thick.value = debias(thick.value, "Wavy_Hair", df, pinverse, device.value)
127
- thick.value = debias(thick.value, "Mustache", df, pinverse, device.value)
128
- thick.value = debias(thick.value, "No_Beard", df, pinverse, device.value)
129
- thick.value = debias(thick.value, "Sideburns", df, pinverse, device.value)
130
- thick.value = debias(thick.value, "Big_Nose", df, pinverse, device.value)
131
- thick.value = debias(thick.value, "Big_Lips", df, pinverse, device.value)
132
- thick.value = debias(thick.value, "Black_Hair", df, pinverse, device.value)
133
- thick.value = debias(thick.value, "Brown_Hair", df, pinverse, device.value)
134
- thick.value = debias(thick.value, "Pale_Skin", df, pinverse, device.value)
135
- thick.value = debias(thick.value, "Heavy_Makeup", df, pinverse, device.value)
136
 
 
137
 
138
-
139
-
140
-
141
-
142
- @torch.no_grad()
143
- @spaces.GPU
144
- def sample_weights(unet, proj, mean, std, v, device, factor = 1.0):
145
- # get mean and standard deviation for each principal component
146
- m = torch.mean(proj, 0)
147
- standev = torch.std(proj, 0)
148
- del proj
149
- torch.cuda.empty_cache()
150
- # sample
151
- sample = torch.zeros([1, 1000]).to(device)
152
- for i in range(1000):
153
- sample[0, i] = torch.normal(m[i], factor*standev[i], (1,1))
154
-
155
- # load weights into network
156
- net = LoRAw2w( sample, mean, std, v,
157
- unet,
158
- rank=1,
159
- multiplier=1.0,
160
- alpha=27.0,
161
- train_method="xattn-strict"
162
- ).to(device, torch.bfloat16)
163
-
164
- return net
165
-
166
- @torch.no_grad()
167
- @spaces.GPU
168
  def sample_model():
169
- unet.value = UNet2DConditionModel.from_pretrained(
170
- pretrained_model_name_or_path, subfolder="unet", revision=revision
171
- )
172
- unet.value.requires_grad_(False)
173
- unet.value.to(device.value, dtype=weight_dtype)
174
- network.value = sample_weights(unet.value, proj, mean, std, v[:, :1000], device.value, factor = 1.00)
 
 
 
175
 
176
  @torch.no_grad()
177
  @spaces.GPU
178
  def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):
179
-
180
- generator = torch.Generator(device=device.value).manual_seed(seed)
 
 
 
 
 
 
181
  latents = torch.randn(
182
- (1, unet.value.in_channels, 512 // 8, 512 // 8),
183
  generator = generator,
184
- device = device.value
185
  ).bfloat16()
186
 
187
 
188
- text_input = tokenizer.value(prompt, padding="max_length", max_length=tokenizer.value.model_max_length, truncation=True, return_tensors="pt")
189
 
190
- text_embeddings = text_encoder.value(text_input.input_ids.to(device.value))[0]
191
 
192
  max_length = text_input.input_ids.shape[-1]
193
- uncond_input = tokenizer.value(
194
  [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt"
195
  )
196
- uncond_embeddings = text_encoder.value(uncond_input.input_ids.to(device.value))[0]
197
  text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
198
- noise_scheduler.value.set_timesteps(ddim_steps)
199
- latents = latents * noise_scheduler.value.init_noise_sigma
200
 
201
- for i,t in enumerate(tqdm.tqdm(noise_scheduler.value.timesteps)):
202
  latent_model_input = torch.cat([latents] * 2)
203
- latent_model_input = noise_scheduler.value.scale_model_input(latent_model_input, timestep=t)
204
- with network.value:
205
- noise_pred = unet.value(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
206
  #guidance
207
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
208
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
209
- latents = noise_scheduler.value.step(noise_pred, t, latents).prev_sample
210
 
211
  latents = 1 / 0.18215 * latents
212
- image = vae.value.decode(latents).sample
213
  image = (image / 2 + 0.5).clamp(0, 1)
214
  image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0]
215
 
@@ -221,67 +108,78 @@ def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):
221
  @torch.no_grad()
222
  @spaces.GPU
223
  def edit_inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed, start_noise, a1, a2, a3, a4):
 
 
 
 
 
 
 
 
 
 
 
 
224
 
225
-
226
- original_weights = network.value.proj.clone()
227
 
228
  #pad to same number of PCs
229
  pcs_original = original_weights.shape[1]
230
- pcs_edits = young.value.shape[1]
231
- padding = torch.zeros((1,pcs_original-pcs_edits)).to(device.value)
232
- young_pad = torch.cat((young.value, padding), 1)
233
- pointy_pad = torch.cat((pointy.value, padding), 1)
234
- wavy_pad = torch.cat((wavy.value, padding), 1)
235
- thick_pad = torch.cat((thick.value, padding), 1)
236
 
237
 
238
  edited_weights = original_weights+a1*1e6*young_pad+a2*1e6*pointy_pad+a3*1e6*wavy_pad+a4*2e6*thick_pad
239
 
240
- generator = torch.Generator(device=device.value).manual_seed(seed)
241
  latents = torch.randn(
242
- (1, unet.value.in_channels, 512 // 8, 512 // 8),
243
  generator = generator,
244
- device = device.value
245
  ).bfloat16()
246
 
247
 
248
- text_input = tokenizer.value(prompt, padding="max_length", max_length=tokenizer.value.model_max_length, truncation=True, return_tensors="pt")
249
 
250
- text_embeddings = text_encoder.value(text_input.input_ids.to(device.value))[0]
251
 
252
  max_length = text_input.input_ids.shape[-1]
253
- uncond_input = tokenizer.value(
254
  [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt"
255
  )
256
- uncond_embeddings = text_encoder.value(uncond_input.input_ids.to(device.value))[0]
257
  text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
258
- noise_scheduler.value.set_timesteps(ddim_steps)
259
- latents = latents * noise_scheduler.value.init_noise_sigma
260
 
261
 
262
 
263
- for i,t in enumerate(tqdm.tqdm(noise_scheduler.value.timesteps)):
264
  latent_model_input = torch.cat([latents] * 2)
265
- latent_model_input = noise_scheduler.value.scale_model_input(latent_model_input, timestep=t)
266
 
267
  if t>start_noise:
268
  pass
269
  elif t<=start_noise:
270
- network.value.proj = torch.nn.Parameter(edited_weights)
271
- network.value.reset()
272
 
273
 
274
  with network:
275
- noise_pred = unet.value(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
276
 
277
 
278
  #guidance
279
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
280
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
281
- latents = noise_scheduler.value.step(noise_pred, t, latents).prev_sample
282
 
283
  latents = 1 / 0.18215 * latents
284
- image = vae.value.decode(latents).sample
285
  image = (image / 2 + 0.5).clamp(0, 1)
286
 
287
  image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0]
@@ -289,12 +187,11 @@ def edit_inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed, st
289
  image = Image.fromarray((image * 255).round().astype("uint8"))
290
 
291
  #reset weights back to original
292
- network.value.proj = torch.nn.Parameter(original_weights)
293
- network.value.reset()
294
 
295
  return image
296
-
297
- @torch.no_grad()
298
  @spaces.GPU
299
  def sample_then_run():
300
  sample_model()
@@ -304,12 +201,52 @@ def sample_then_run():
304
  cfg = 3.0
305
  steps = 25
306
  image = inference( prompt, negative_prompt, cfg, steps, seed)
307
- torch.save(network.value.proj.detach(), "model.pt" )
308
- # net = torch.load("model.pt").cpu()
309
- network.value.proj.detach().cpu()
310
-
311
- return image, "model.pt", network.value #net #, network.value.cpu()
312
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
 
314
  class CustomImageDataset(Dataset):
315
  def __init__(self, images, transform=None):
@@ -542,7 +479,7 @@ with gr.Blocks(css="style.css") as demo:
542
  outputs = [input_image, file_output])
543
 
544
 
545
- sample.click(fn=sample_then_run, outputs=[input_image, file_output, network])
546
 
547
  submit.click(
548
  fn=edit_inference, inputs=[prompt, negative_prompt, cfg, steps, seed, injection_step, a1, a2, a3, a4], outputs=[gallery]
 
6
  import gradio as gr
7
  import sys
8
  import tqdm
 
9
  sys.path.append(os.path.abspath(os.path.join("", "..")))
10
  import gc
11
  import warnings
12
  warnings.filterwarnings("ignore")
13
  from PIL import Image
14
  import numpy as np
15
+ from utils import load_models
16
  from editing import get_direction, debias
17
+ from sampling import sample_weights
18
  from lora_w2w import LoRAw2w
19
  from huggingface_hub import snapshot_download
20
  import spaces
21
+ import uuid
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ global device
24
+ global generator
25
+ global unet
26
+ global vae
27
+ global text_encoder
28
+ global tokenizer
29
+ global noise_scheduler
30
+ global network
31
+ device = "cuda"
32
+ #generator = torch.Generator(device=device)
33
 
34
  models_path = snapshot_download(repo_id="Snapchat/w2w")
35
 
36
+ mean = torch.load(f"{models_path}/files/mean.pt", map_location=torch.device('cpu')).bfloat16().to(device)
37
+ std = torch.load(f"{models_path}/files/std.pt", map_location=torch.device('cpu')).bfloat16().to(device)
38
+ v = torch.load(f"{models_path}/files/V.pt", map_location=torch.device('cpu')).bfloat16().to(device)
39
+ proj = torch.load(f"{models_path}/files/proj_1000pc.pt", map_location=torch.device('cpu')).bfloat16().to(device)
40
  df = torch.load(f"{models_path}/files/identity_df.pt")
41
  weight_dimensions = torch.load(f"{models_path}/files/weight_dimensions.pt")
42
+ pinverse = torch.load(f"{models_path}/files/pinverse_1000pc.pt", map_location=torch.device('cpu')).bfloat16().to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
+ unet, vae, text_encoder, tokenizer, noise_scheduler = load_models(device)
45
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  def sample_model():
47
+ global unet
48
+ del unet
49
+ global network
50
+ mean.to(device)
51
+ std.to(device)
52
+ v.to(device)
53
+ proj.to(device)
54
+ unet, _, _, _, _ = load_models(device)
55
+ network = sample_weights(unet, proj, mean, std, v[:, :1000], device, factor = 1.00)
56
 
57
  @torch.no_grad()
58
  @spaces.GPU
59
  def inference( prompt, negative_prompt, guidance_scale, ddim_steps, seed):
60
+ global device
61
+ #global generator
62
+ global unet
63
+ global vae
64
+ global text_encoder
65
+ global tokenizer
66
+ global noise_scheduler
67
+ generator = torch.Generator(device=device).manual_seed(seed)
68
  latents = torch.randn(
69
+ (1, unet.in_channels, 512 // 8, 512 // 8),
70
  generator = generator,
71
+ device = device
72
  ).bfloat16()
73
 
74
 
75
+ text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
76
 
77
+ text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
78
 
79
  max_length = text_input.input_ids.shape[-1]
80
+ uncond_input = tokenizer(
81
  [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt"
82
  )
83
+ uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
84
  text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
85
+ noise_scheduler.set_timesteps(ddim_steps)
86
+ latents = latents * noise_scheduler.init_noise_sigma
87
 
88
+ for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)):
89
  latent_model_input = torch.cat([latents] * 2)
90
+ latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t)
91
+ with network:
92
+ noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
93
  #guidance
94
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
95
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
96
+ latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
97
 
98
  latents = 1 / 0.18215 * latents
99
+ image = vae.decode(latents).sample
100
  image = (image / 2 + 0.5).clamp(0, 1)
101
  image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0]
102
 
 
108
  @torch.no_grad()
109
  @spaces.GPU
110
  def edit_inference(prompt, negative_prompt, guidance_scale, ddim_steps, seed, start_noise, a1, a2, a3, a4):
111
+ start_items()
112
+ global device
113
+ #global generator
114
+ global unet
115
+ global vae
116
+ global text_encoder
117
+ global tokenizer
118
+ global noise_scheduler
119
+ global young
120
+ global pointy
121
+ global wavy
122
+ global thick
123
 
124
+ original_weights = network.proj.clone()
 
125
 
126
  #pad to same number of PCs
127
  pcs_original = original_weights.shape[1]
128
+ pcs_edits = young.shape[1]
129
+ padding = torch.zeros((1,pcs_original-pcs_edits)).to(device)
130
+ young_pad = torch.cat((young, padding), 1)
131
+ pointy_pad = torch.cat((pointy, padding), 1)
132
+ wavy_pad = torch.cat((wavy, padding), 1)
133
+ thick_pad = torch.cat((thick, padding), 1)
134
 
135
 
136
  edited_weights = original_weights+a1*1e6*young_pad+a2*1e6*pointy_pad+a3*1e6*wavy_pad+a4*2e6*thick_pad
137
 
138
+ generator = torch.Generator(device=device).manual_seed(seed)
139
  latents = torch.randn(
140
+ (1, unet.in_channels, 512 // 8, 512 // 8),
141
  generator = generator,
142
+ device = device
143
  ).bfloat16()
144
 
145
 
146
+ text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
147
 
148
+ text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
149
 
150
  max_length = text_input.input_ids.shape[-1]
151
+ uncond_input = tokenizer(
152
  [negative_prompt], padding="max_length", max_length=max_length, return_tensors="pt"
153
  )
154
+ uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
155
  text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
156
+ noise_scheduler.set_timesteps(ddim_steps)
157
+ latents = latents * noise_scheduler.init_noise_sigma
158
 
159
 
160
 
161
+ for i,t in enumerate(tqdm.tqdm(noise_scheduler.timesteps)):
162
  latent_model_input = torch.cat([latents] * 2)
163
+ latent_model_input = noise_scheduler.scale_model_input(latent_model_input, timestep=t)
164
 
165
  if t>start_noise:
166
  pass
167
  elif t<=start_noise:
168
+ network.proj = torch.nn.Parameter(edited_weights)
169
+ network.reset()
170
 
171
 
172
  with network:
173
+ noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings, timestep_cond= None).sample
174
 
175
 
176
  #guidance
177
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
178
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
179
+ latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
180
 
181
  latents = 1 / 0.18215 * latents
182
+ image = vae.decode(latents).sample
183
  image = (image / 2 + 0.5).clamp(0, 1)
184
 
185
  image = image.detach().cpu().float().permute(0, 2, 3, 1).numpy()[0]
 
187
  image = Image.fromarray((image * 255).round().astype("uint8"))
188
 
189
  #reset weights back to original
190
+ network.proj = torch.nn.Parameter(original_weights)
191
+ network.reset()
192
 
193
  return image
194
+
 
195
  @spaces.GPU
196
  def sample_then_run():
197
  sample_model()
 
201
  cfg = 3.0
202
  steps = 25
203
  image = inference( prompt, negative_prompt, cfg, steps, seed)
204
+ torch.save(network.proj, "model.pt" )
205
+ return image, "model.pt"
 
 
 
206
 
207
+ #@spaces.GPU
208
+ def start_items():
209
+ print("Starting items")
210
+ global young
211
+ global pointy
212
+ global wavy
213
+ global thick
214
+ young = get_direction(df, "Young", pinverse, 1000, device)
215
+ young = debias(young, "Male", df, pinverse, device)
216
+ young = debias(young, "Pointy_Nose", df, pinverse, device)
217
+ young = debias(young, "Wavy_Hair", df, pinverse, device)
218
+ young = debias(young, "Chubby", df, pinverse, device)
219
+ young = debias(young, "No_Beard", df, pinverse, device)
220
+ young = debias(young, "Mustache", df, pinverse, device)
221
+
222
+ pointy = get_direction(df, "Pointy_Nose", pinverse, 1000, device)
223
+ pointy = debias(pointy, "Young", df, pinverse, device)
224
+ pointy = debias(pointy, "Male", df, pinverse, device)
225
+ pointy = debias(pointy, "Wavy_Hair", df, pinverse, device)
226
+ pointy = debias(pointy, "Chubby", df, pinverse, device)
227
+ pointy = debias(pointy, "Heavy_Makeup", df, pinverse, device)
228
+
229
+ wavy = get_direction(df, "Wavy_Hair", pinverse, 1000, device)
230
+ wavy = debias(wavy, "Young", df, pinverse, device)
231
+ wavy = debias(wavy, "Male", df, pinverse, device)
232
+ wavy = debias(wavy, "Pointy_Nose", df, pinverse, device)
233
+ wavy = debias(wavy, "Chubby", df, pinverse, device)
234
+ wavy = debias(wavy, "Heavy_Makeup", df, pinverse, device)
235
+
236
+ thick = get_direction(df, "Bushy_Eyebrows", pinverse, 1000, device)
237
+ thick = debias(thick, "Male", df, pinverse, device)
238
+ thick = debias(thick, "Young", df, pinverse, device)
239
+ thick = debias(thick, "Pointy_Nose", df, pinverse, device)
240
+ thick = debias(thick, "Wavy_Hair", df, pinverse, device)
241
+ thick = debias(thick, "Mustache", df, pinverse, device)
242
+ thick = debias(thick, "No_Beard", df, pinverse, device)
243
+ thick = debias(thick, "Sideburns", df, pinverse, device)
244
+ thick = debias(thick, "Big_Nose", df, pinverse, device)
245
+ thick = debias(thick, "Big_Lips", df, pinverse, device)
246
+ thick = debias(thick, "Black_Hair", df, pinverse, device)
247
+ thick = debias(thick, "Brown_Hair", df, pinverse, device)
248
+ thick = debias(thick, "Pale_Skin", df, pinverse, device)
249
+ thick = debias(thick, "Heavy_Makeup", df, pinverse, device)
250
 
251
  class CustomImageDataset(Dataset):
252
  def __init__(self, images, transform=None):
 
479
  outputs = [input_image, file_output])
480
 
481
 
482
+ sample.click(fn=sample_then_run, outputs=[input_image, file_output])
483
 
484
  submit.click(
485
  fn=edit_inference, inputs=[prompt, negative_prompt, cfg, steps, seed, injection_step, a1, a2, a3, a4], outputs=[gallery]