amildravid4292 commited on
Commit
f5c27d3
·
verified ·
1 Parent(s): 99e4caa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -2
app.py CHANGED
@@ -92,8 +92,39 @@ class main():
92
  self.weight_dimensions = weight_dimensions
93
  self.pinverse = pinverse
94
 
95
- self.unet, self.vae, self.text_encoder, self.tokenizer, self.noise_scheduler = load_models(self.device)
96
- print(self.text_encoder.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  self.network = None
98
 
99
  young = get_direction(df, "Young", pinverse, 1000, device)
 
92
  self.weight_dimensions = weight_dimensions
93
  self.pinverse = pinverse
94
 
95
+ pretrained_model_name_or_path = "stablediffusionapi/realistic-vision-v51"
96
+
97
+ revision = None
98
+ rank = 1
99
+ weight_dtype = torch.bfloat16
100
+
101
+ # Load scheduler, tokenizer and models.
102
+ pipe = StableDiffusionPipeline.from_pretrained("stablediffusionapi/realistic-vision-v51",
103
+ torch_dtype=torch.float16,safety_checker = None,
104
+ requires_safety_checker = False).to(device)
105
+ self.noise_scheduler = pipe.scheduler
106
+ del pipe
107
+ self.tokenizer = AutoTokenizer.from_pretrained(
108
+ pretrained_model_name_or_path, subfolder="tokenizer", revision=revision
109
+ )
110
+ self.text_encoder = CLIPTextModel.from_pretrained(
111
+ pretrained_model_name_or_path, subfolder="text_encoder", revision=revision
112
+ )
113
+ self.vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae", revision=revision)
114
+ self.unet = UNet2DConditionModel.from_pretrained(
115
+ pretrained_model_name_or_path, subfolder="unet", revision=revision
116
+ )
117
+
118
+ self.unet.requires_grad_(False)
119
+ self.unet.to(device, dtype=weight_dtype)
120
+ self.vae.requires_grad_(False)
121
+
122
+ self.text_encoder.requires_grad_(False)
123
+ self.vae.requires_grad_(False)
124
+ self.vae.to(device, dtype=weight_dtype)
125
+ self.text_encoder.to(device, dtype=weight_dtype)
126
+ print("")
127
+
128
  self.network = None
129
 
130
  young = get_direction(df, "Young", pinverse, 1000, device)