debisoft commited on
Commit
84e1ed4
1 Parent(s): 3f89984
Files changed (1) hide show
  1. app.py +16 -0
app.py CHANGED
@@ -122,6 +122,22 @@ ab_t[0] = 1
122
  # construct model
123
  nn_model = ContextUnet(in_channels=3, n_feat=n_feat, n_cfeat=n_cfeat, height=height).to(device)
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  # sample quickly using DDIM
126
  @torch.no_grad()
127
  def sample_ddim(n_sample, n=20):
 
122
  # construct model
123
  nn_model = ContextUnet(in_channels=3, n_feat=n_feat, n_cfeat=n_cfeat, height=height).to(device)
124
 
125
+ # define sampling function for DDIM
126
+ # removes the noise using ddim
127
+ def denoise_ddim(x, t, t_prev, pred_noise):
128
+ ab = ab_t[t]
129
+ ab_prev = ab_t[t_prev]
130
+
131
+ x0_pred = ab_prev.sqrt() / ab.sqrt() * (x - (1 - ab).sqrt() * pred_noise)
132
+ dir_xt = (1 - ab_prev).sqrt() * pred_noise
133
+
134
+ return x0_pred + dir_xt
135
+
136
+ # load in model weights and set to eval mode
137
+ nn_model.load_state_dict(torch.load(f"{save_dir}/model_31.pth", map_location=device))
138
+ nn_model.eval()
139
+ print("Loaded in Model without context")
140
+
141
  # sample quickly using DDIM
142
  @torch.no_grad()
143
  def sample_ddim(n_sample, n=20):