multimodalart HF staff commited on
Commit
e3d2366
1 Parent(s): d8035da

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -2
app.py CHANGED
@@ -71,14 +71,18 @@ def run_all(prompt, steps, n_images, weight, clip_guided):
71
  torch.manual_seed(seed)
72
  x = torch.randn([n_images, 3, side_y, side_x], device='cuda')
73
  t = torch.linspace(1, 0, steps + 1, device='cuda')[:-1]
74
- step_list = utils.get_spliced_ddpm_cosine_schedule(t)
 
 
 
 
75
  if(not clip_guided):
76
  outs = sampling.plms_sample(cfg_model_fn, x, step_list, {})#, callback=display_callback)
77
  else:
78
  extra_args = {'clip_embed': clip_embed}
79
  cond_fn_ = cond_fn
80
  model_fn = make_cond_model_fn(model, cond_fn_)
81
- outs = sampling.plms_sample(model_fn, x, steps, extra_args)
82
  images_out = []
83
  for i, out in enumerate(outs):
84
  images_out.append(utils.to_pil_image(out))
 
71
  torch.manual_seed(seed)
72
  x = torch.randn([n_images, 3, side_y, side_x], device='cuda')
73
  t = torch.linspace(1, 0, steps + 1, device='cuda')[:-1]
74
+ #step_list = utils.get_spliced_ddpm_cosine_schedule(t)
75
+ if model.min_t == 0:
76
+ step_list = utils.get_spliced_ddpm_cosine_schedule(t)
77
+ else:
78
+ step_list = utils.get_ddpm_schedule(t)
79
  if(not clip_guided):
80
  outs = sampling.plms_sample(cfg_model_fn, x, step_list, {})#, callback=display_callback)
81
  else:
82
  extra_args = {'clip_embed': clip_embed}
83
  cond_fn_ = cond_fn
84
  model_fn = make_cond_model_fn(model, cond_fn_)
85
+ outs = sampling.plms_sample(model_fn, x, step_list, extra_args)
86
  images_out = []
87
  for i, out in enumerate(outs):
88
  images_out.append(utils.to_pil_image(out))