amildravid4292 commited on
Commit
945f95c
·
verified ·
1 Parent(s): 564edc4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -2
app.py CHANGED
@@ -12,7 +12,6 @@ import warnings
12
  warnings.filterwarnings("ignore")
13
  from PIL import Image
14
  import numpy as np
15
- from utils import load_models
16
  from editing import get_direction, debias
17
  from sampling import sample_weights
18
  from lora_w2w import LoRAw2w
@@ -21,6 +20,43 @@ import spaces
21
 
22
  models_path = snapshot_download(repo_id="Snapchat/w2w")
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  class main():
25
  def __init__(self):
26
  super(main, self).__init__()
@@ -43,7 +79,7 @@ class main():
43
  self.weight_dimensions = weight_dimensions
44
  self.pinverse = pinverse
45
 
46
- self.unet, self.vae, self.text_encoder, self.tokenizer, self.noise_scheduler = load_models(device)
47
  print(self.text_encoder.device)
48
  self.network = None
49
 
@@ -89,6 +125,8 @@ class main():
89
  thick = debias(thick, "Heavy_Makeup", df, pinverse, device)
90
  self.thick = thick
91
 
 
 
92
  def sample_model(self):
93
  self.unet, _, _, _, _ = load_models(self.device)
94
  self.network = sample_weights(self.unet, self.proj, self.mean, self.std, self.v[:, :1000], self.device, factor = 1.00)
 
12
  warnings.filterwarnings("ignore")
13
  from PIL import Image
14
  import numpy as np
 
15
  from editing import get_direction, debias
16
  from sampling import sample_weights
17
  from lora_w2w import LoRAw2w
 
20
 
21
  models_path = snapshot_download(repo_id="Snapchat/w2w")
22
 
23
+ @spaces.GPU
24
+ def load_models(device):
25
+ pretrained_model_name_or_path = "stablediffusionapi/realistic-vision-v51"
26
+
27
+ revision = None
28
+ rank = 1
29
+ weight_dtype = torch.bfloat16
30
+
31
+ # Load scheduler, tokenizer and models.
32
+ pipe = StableDiffusionPipeline.from_pretrained("stablediffusionapi/realistic-vision-v51",
33
+ torch_dtype=torch.float16,safety_checker = None,
34
+ requires_safety_checker = False).to(device)
35
+ noise_scheduler = pipe.scheduler
36
+ del pipe
37
+ tokenizer = AutoTokenizer.from_pretrained(
38
+ pretrained_model_name_or_path, subfolder="tokenizer", revision=revision
39
+ )
40
+ text_encoder = CLIPTextModel.from_pretrained(
41
+ pretrained_model_name_or_path, subfolder="text_encoder", revision=revision
42
+ )
43
+ vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae", revision=revision)
44
+ unet = UNet2DConditionModel.from_pretrained(
45
+ pretrained_model_name_or_path, subfolder="unet", revision=revision
46
+ )
47
+
48
+ unet.requires_grad_(False)
49
+ unet.to(device, dtype=weight_dtype)
50
+ vae.requires_grad_(False)
51
+
52
+ text_encoder.requires_grad_(False)
53
+ vae.requires_grad_(False)
54
+ vae.to(device, dtype=weight_dtype)
55
+ text_encoder.to(device, dtype=weight_dtype)
56
+ print("")
57
+
58
+ return unet, vae, text_encoder, tokenizer, noise_scheduler
59
+
60
  class main():
61
  def __init__(self):
62
  super(main, self).__init__()
 
79
  self.weight_dimensions = weight_dimensions
80
  self.pinverse = pinverse
81
 
82
+ self.unet, self.vae, self.text_encoder, self.tokenizer, self.noise_scheduler = self.load_models(self.device)
83
  print(self.text_encoder.device)
84
  self.network = None
85
 
 
125
  thick = debias(thick, "Heavy_Makeup", df, pinverse, device)
126
  self.thick = thick
127
 
128
+
129
+
130
  def sample_model(self):
131
  self.unet, _, _, _, _ = load_models(self.device)
132
  self.network = sample_weights(self.unet, self.proj, self.mean, self.std, self.v[:, :1000], self.device, factor = 1.00)