Spaces:
Sleeping
Sleeping
denoise
Browse files
app.py
CHANGED
@@ -122,6 +122,22 @@ ab_t[0] = 1
|
|
122 |
# construct model
|
123 |
nn_model = ContextUnet(in_channels=3, n_feat=n_feat, n_cfeat=n_cfeat, height=height).to(device)
|
124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
# sample quickly using DDIM
|
126 |
@torch.no_grad()
|
127 |
def sample_ddim(n_sample, n=20):
|
|
|
122 |
# construct model
|
123 |
nn_model = ContextUnet(in_channels=3, n_feat=n_feat, n_cfeat=n_cfeat, height=height).to(device)
|
124 |
|
125 |
+
# define sampling function for DDIM
|
126 |
+
# removes the noise using ddim
|
127 |
+
def denoise_ddim(x, t, t_prev, pred_noise):
|
128 |
+
ab = ab_t[t]
|
129 |
+
ab_prev = ab_t[t_prev]
|
130 |
+
|
131 |
+
x0_pred = ab_prev.sqrt() / ab.sqrt() * (x - (1 - ab).sqrt() * pred_noise)
|
132 |
+
dir_xt = (1 - ab_prev).sqrt() * pred_noise
|
133 |
+
|
134 |
+
return x0_pred + dir_xt
|
135 |
+
|
136 |
+
# load in model weights and set to eval mode
|
137 |
+
nn_model.load_state_dict(torch.load(f"{save_dir}/model_31.pth", map_location=device))
|
138 |
+
nn_model.eval()
|
139 |
+
print("Loaded in Model without context")
|
140 |
+
|
141 |
# sample quickly using DDIM
|
142 |
@torch.no_grad()
|
143 |
def sample_ddim(n_sample, n=20):
|