ZeyuXie commited on
Commit
f4c1365
1 Parent(s): 9a7456a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -21,10 +21,11 @@ class dotdict(dict):
21
 
22
  class InferRunner:
23
  def __init__(self, device):
 
24
  vae_config = json.load(open("ckpts/ldm/vae_config.json"))
25
  vae_weights = torch.load("ckpts/ldm/pytorch_model_vae.bin", map_location="cpu")
26
  self.vae.load_state_dict(vae_weights)
27
- self.vae = AutoencoderKL(**vae_config).to(device)
28
 
29
  train_args = dotdict(json.loads(open("ckpts/pico_model/summary.jsonl").readlines()[0]))
30
  self.pico_model = PicoDiffusion(
 
21
 
22
  class InferRunner:
23
  def __init__(self, device):
24
+ self.vae = AutoencoderKL(**vae_config)
25
  vae_config = json.load(open("ckpts/ldm/vae_config.json"))
26
  vae_weights = torch.load("ckpts/ldm/pytorch_model_vae.bin", map_location="cpu")
27
  self.vae.load_state_dict(vae_weights)
28
+ self.vae = self.vae.to(device)
29
 
30
  train_args = dotdict(json.loads(open("ckpts/pico_model/summary.jsonl").readlines()[0]))
31
  self.pico_model = PicoDiffusion(