amildravid4292 commited on
Commit
4ef9ffa
·
verified ·
1 Parent(s): 67ad71c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -1
app.py CHANGED
@@ -14,7 +14,6 @@ from PIL import Image
14
  import numpy as np
15
  from utils import load_models
16
  from editing import get_direction, debias
17
- from sampling import sample_weights
18
  from lora_w2w import LoRAw2w
19
  from huggingface_hub import snapshot_download
20
  import spaces
@@ -83,6 +82,31 @@ thick.value = debias(thick.value, "Brown_Hair", df, pinverse, device.value)
83
  thick.value = debias(thick.value, "Pale_Skin", df, pinverse, device.value)
84
  thick.value = debias(thick.value, "Heavy_Makeup", df, pinverse, device.value)
85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  @torch.no_grad()
87
  @spaces.GPU
88
  def sample_model():
 
14
  import numpy as np
15
  from utils import load_models
16
  from editing import get_direction, debias
 
17
  from lora_w2w import LoRAw2w
18
  from huggingface_hub import snapshot_download
19
  import spaces
 
82
  thick.value = debias(thick.value, "Pale_Skin", df, pinverse, device.value)
83
  thick.value = debias(thick.value, "Heavy_Makeup", df, pinverse, device.value)
84
 
85
+
86
+ @torch.no_grad()
87
+ @spaces.GPU
88
+ def sample_weights(unet, proj, mean, std, v, device, factor = 1.0):
89
+ # get mean and standard deviation for each principal component
90
+ m = torch.mean(proj, 0)
91
+ standev = torch.std(proj, 0)
92
+ del proj
93
+ torch.cuda.empty_cache()
94
+ # sample
95
+ sample = torch.zeros([1, 1000]).to(device)
96
+ for i in range(1000):
97
+ sample[0, i] = torch.normal(m[i], factor*standev[i], (1,1))
98
+
99
+ # load weights into network
100
+ net = LoRAw2w( sample, mean, std, v,
101
+ unet,
102
+ rank=1,
103
+ multiplier=1.0,
104
+ alpha=27.0,
105
+ train_method="xattn-strict"
106
+ ).to(device, torch.bfloat16)
107
+
108
+ return net
109
+
110
  @torch.no_grad()
111
  @spaces.GPU
112
  def sample_model():