Mehdi Cherti commited on
Commit
301a4a3
1 Parent(s): e30ec05

make app.py lazy

Browse files
Files changed (1) hide show
  1. app.py +14 -2
app.py CHANGED
@@ -14,13 +14,25 @@ def download(filename):
14
  return "models/" + filename
15
 
16
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
 
 
 
 
 
 
 
 
 
 
17
  models = {
18
- "diffusion_db_128ch_1timesteps_openclip_vith14": load_model(get_model_config('ddgan_ddb_v2'), download('diffusion_db_128ch_1timesteps_openclip_vith14.th'), device=device),
19
- #"diffusion_db_192ch_2timesteps_openclip_vith14": load_model(get_model_config('ddgan_ddb_v3'), download('diffusion_db_192ch_2timesteps_openclip_vith14.th'), device=device),
20
  }
21
  default = "diffusion_db_128ch_1timesteps_openclip_vith14"
22
 
23
  def gen(md, model_name, md2, text, seed, nb_samples, width, height):
 
24
  torch.manual_seed(int(seed))
25
  model = models[model_name]
26
  nb_samples = int(nb_samples)
 
14
  return "models/" + filename
15
 
16
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
17
+ cache = {}
18
+
19
+ def load(name):
20
+ if name in cache:
21
+ return cache[name]
22
+ else:
23
+ model_config, model_path = models[name]
24
+ model = load_model(model_config, model_path, device=device)
25
+ cache[name] = model
26
+ return model
27
+
28
  models = {
29
+ "diffusion_db_128ch_1timesteps_openclip_vith14": (get_model_config('ddgan_ddb_v2'), download('diffusion_db_128ch_1timesteps_openclip_vith14.th')),
30
+ "diffusion_db_192ch_2timesteps_openclip_vith14": (get_model_config('ddgan_ddb_v3'), download('diffusion_db_192ch_2timesteps_openclip_vith14.th')),
31
  }
32
  default = "diffusion_db_128ch_1timesteps_openclip_vith14"
33
 
34
  def gen(md, model_name, md2, text, seed, nb_samples, width, height):
35
+ model = load(model_name)
36
  torch.manual_seed(int(seed))
37
  model = models[model_name]
38
  nb_samples = int(nb_samples)