Linoy Tsaban commited on
Commit
e5a71e9
1 Parent(s): 773bed3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -4
app.py CHANGED
@@ -8,9 +8,16 @@ from preprocess_utils import *
8
  from tokenflow_utils import *
9
  # load sd model
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
- # model_id = "stabilityai/stable-diffusion-2-1-base"
12
- # inv_pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to(device)
13
- # inv_pipe.scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
 
 
 
 
 
 
 
14
 
15
  def randomize_seed_fn():
16
  seed = random.randint(0, np.iinfo(np.int32).max)
@@ -71,7 +78,12 @@ def prep(config):
71
  else:
72
  save_path = None
73
 
74
- model = Preprocess(device, config)
 
 
 
 
 
75
  print(type(model.config["batch_size"]))
76
  frames, latents, total_inverted_latents, rgb_reconstruction = model.extract_latents(
77
  num_steps=model.config["steps"],
 
8
  from tokenflow_utils import *
9
  # load sd model
10
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ model_id = "stabilityai/stable-diffusion-2-1-base"
12
+
13
+ scheduler = DDIMScheduler.from_pretrained(model_id, subfolder="scheduler")
14
+ vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", revision="fp16",
15
+ torch_dtype=torch.float16).to(device)
16
+ tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
17
+ text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", revision="fp16",
18
+ torch_dtype=torch.float16).to(device)
19
+ unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet", revision="fp16",
20
+ torch_dtype=torch.float16).to(device)
21
 
22
  def randomize_seed_fn():
23
  seed = random.randint(0, np.iinfo(np.int32).max)
 
78
  else:
79
  save_path = None
80
 
81
+ model = Preprocess(device, config,
82
+ vae=vae,
83
+ text_encoder=text_encoder,
84
+ scheduler=scheduler,
85
+ tokenizer=tokenizer,
86
+ unet=unet)
87
  print(type(model.config["batch_size"]))
88
  frames, latents, total_inverted_latents, rgb_reconstruction = model.extract_latents(
89
  num_steps=model.config["steps"],