ZeyuXie commited on
Commit
361d70a
1 Parent(s): 7d00b29

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -20,11 +20,10 @@ class dotdict(dict):
20
 
21
 
22
  class InferRunner:
23
- @spaces.GPU()
24
  def __init__(self, device):
25
  vae_config = json.load(open("ckpts/ldm/vae_config.json"))
26
  self.vae = AutoencoderKL(**vae_config).to(device)
27
- vae_weights = torch.load("ckpts/ldm/pytorch_model_vae.bin", map_location=device)
28
  self.vae.load_state_dict(vae_weights)
29
 
30
  train_args = dotdict(json.loads(open("ckpts/pico_model/summary.jsonl").readlines()[0]))
@@ -37,7 +36,8 @@ class InferRunner:
37
  ).eval().to(device)
38
  self.scheduler = DDPMScheduler.from_pretrained(train_args.scheduler_name, subfolder="scheduler")
39
 
40
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
41
  runner = InferRunner(device)
42
  event_list = get_event()
43
 
 
20
 
21
 
22
  class InferRunner:
 
23
  def __init__(self, device):
24
  vae_config = json.load(open("ckpts/ldm/vae_config.json"))
25
  self.vae = AutoencoderKL(**vae_config).to(device)
26
+ vae_weights = torch.load("ckpts/ldm/pytorch_model_vae.bin")
27
  self.vae.load_state_dict(vae_weights)
28
 
29
  train_args = dotdict(json.loads(open("ckpts/pico_model/summary.jsonl").readlines()[0]))
 
36
  ).eval().to(device)
37
  self.scheduler = DDPMScheduler.from_pretrained(train_args.scheduler_name, subfolder="scheduler")
38
 
39
+ #device = "cuda" if torch.cuda.is_available() else "cpu"
40
+ device = "cuda"
41
  runner = InferRunner(device)
42
  event_list = get_event()
43