Anton Forsman commited on
Commit
098fc8a
·
1 Parent(s): d5c8a36

separated model files

Browse files
Files changed (3) hide show
  1. diffusion.py +233 -0
  2. inference.py +4 -2
  3. model.py → unet.py +5 -239
diffusion.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ from tqdm import tqdm
5
+ from PIL import Image
6
+ from einops import rearrange
7
+ import math
8
+ class GaussianDiffusion:
9
+ def __init__(self, model, noise_steps, beta_0, beta_T, image_size, channels=3, schedule="linear"):
10
+ """
11
+ suggested betas for:
12
+ * linear schedule: 1e-4, 0.02
13
+
14
+ model: the model to be trained (nn.Module)
15
+ noise_steps: the number of steps to apply noise (int)
16
+ beta_0: the initial value of beta (float)
17
+ beta_T: the final value of beta (float)
18
+ image_size: the size of the image (int, int)
19
+ """
20
+ self.device = 'cpu'
21
+ self.channels = channels
22
+
23
+ self.model = model
24
+ self.noise_steps = noise_steps
25
+ self.beta_0 = beta_0
26
+ self.beta_T = beta_T
27
+ self.image_size = image_size
28
+
29
+ self.betas = self.beta_schedule(schedule=schedule)
30
+ self.alphas = 1.0 - self.betas
31
+ # cumulative product of alphas, so we can optimize forward process calculation
32
+ self.alpha_hat = torch.cumprod(self.alphas, dim=0)
33
+
34
+ def beta_schedule(self, schedule="cosine"):
35
+ if schedule == "linear":
36
+ return torch.linspace(self.beta_0, self.beta_T, self.noise_steps).to(self.device)
37
+ elif schedule == "cosine":
38
+ return self.betas_for_cosine(self.noise_steps)
39
+ elif schedule == "sigmoid":
40
+ return self.betas_for_sigmoid(self.noise_steps)
41
+
42
+ @staticmethod
43
+ def sigmoid(x):
44
+ return 1 / (1 + np.exp(-x))
45
+
46
+ def betas_for_sigmoid(self, num_diffusion_timesteps, start=-3,end=3, tau=1.0, clip_min = 1e-9):
47
+ betas = []
48
+ v_start = self.sigmoid(start/tau)
49
+ v_end = self.sigmoid(end/tau)
50
+ for t in range(num_diffusion_timesteps):
51
+ t_float = float(t/num_diffusion_timesteps)
52
+ output0 = self.sigmoid((t_float* (end-start)+start)/tau)
53
+ output = (v_end-output0) / (v_end-v_start)
54
+ betas.append(np.clip(output*.2, clip_min,.2))
55
+ return torch.flip(torch.tensor(betas).to(self.device),dims=[0]).float()
56
+
57
+ def betas_for_cosine(self,num_steps,start=0,end=1,tau=1,clip_min=1e-9):
58
+ v_start = math.cos(start*math.pi / 2) ** (2 * tau)
59
+ betas = []
60
+ v_end = math.cos(end* math.pi/2) ** 2*tau
61
+ for t in range(num_steps):
62
+ t_float = float(t)/num_steps
63
+ output = math.cos((t_float* (end-start)+start)*math.pi/2)**(2*tau)
64
+ output = (v_end - output) / (v_end-v_start)
65
+ betas.append(np.clip(output*.2,clip_min,.2))
66
+ return torch.flip(torch.tensor(betas).to(self.device),dims=[0]).float()
67
+
68
+
69
+ def sample_time_steps(self, batch_size=1):
70
+ return torch.randint(0, self.noise_steps, (batch_size,)).to(self.device)
71
+
72
+ def to(self,device):
73
+ self.device = device
74
+ self.betas = self.betas.to(device)
75
+ self.alphas = self.alphas.to(device)
76
+ self.alpha_hat = self.alpha_hat.to(device)
77
+
78
+
79
+ def q(self, x, t):
80
+ """
81
+ Forward process
82
+ """
83
+ pass
84
+
85
+ def p(self, x, t):
86
+ """
87
+ Backward process
88
+ """
89
+ pass
90
+
91
+
92
+ def apply_noise(self, x, t):
93
+ # force x to be (batch_size, image_width, image_height, channels)
94
+ if len(x.shape) == 3:
95
+ x = x.unsqueeze(0)
96
+ if type(t) == int:
97
+ t = torch.tensor([t])
98
+ #print(f'Shape -> {x.shape}, len -> {len(x.shape)}')
99
+ sqrt_alpha_hat = torch.sqrt(torch.tensor([self.alpha_hat[t_] for t_ in t]).to(self.device))
100
+ sqrt_one_minus_alpha_hat = torch.sqrt(torch.tensor([1.0 - self.alpha_hat[t_] for t_ in t]).to(self.device))
101
+ # standard normal distribution
102
+ epsilon = torch.randn_like(x).to(self.device)
103
+
104
+ # Eq 2. in DDPM paper
105
+ #noisy_image = sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * epsilon
106
+
107
+ """print(f'''
108
+ Shape of x {x.shape}
109
+ Shape of sqrt {sqrt_one_minus_alpha_hat.shape}''')"""
110
+
111
+ try:
112
+ #print(x.shape)
113
+ #noisy_image = torch.einsum("b,bwhc->bwhc", sqrt_alpha_hat, x.to(self.device)) + torch.einsum("b,bwhc->bwhc", sqrt_one_minus_alpha_hat, epsilon)
114
+ noisy_image = torch.einsum("b,bcwh->bcwh", sqrt_alpha_hat, x.to(self.device)) + torch.einsum("b,bcwh->bcwh", sqrt_one_minus_alpha_hat, epsilon)
115
+ except:
116
+ print(f'Failed image: shape {x.shape}')
117
+
118
+
119
+ #print(f'Noisy image -> {noisy_image.shape}')
120
+ # returning noisy iamge and the noise which was added to the image
121
+ #return noisy_image, epsilon
122
+ #return torch.clip(noisy_image, -1.0, 1.0), epsilon
123
+ return noisy_image, epsilon
124
+
125
+ @staticmethod
126
+ def normalize_image(x):
127
+ # normalize image to [-1, 1]
128
+ return x / 255.0 * 2.0 - 1.0
129
+
130
+ @staticmethod
131
+ def denormalize_image(x):
132
+ # denormalize image to [0, 255]
133
+ return (x + 1.0) / 2.0 * 255.0
134
+
135
+ def sample_step(self, x, t, cond):
136
+ batch_size = x.shape[0]
137
+ device = x.device
138
+ z = torch.randn_like(x) if t >= 1 else torch.zeros_like(x)
139
+ z = z.to(device)
140
+ alpha = self.alphas[t]
141
+ one_over_sqrt_alpha = 1.0 / torch.sqrt(alpha)
142
+ one_minus_alpha = 1.0 - alpha
143
+
144
+ sqrt_one_minus_alpha_hat = torch.sqrt(1.0 - self.alpha_hat[t])
145
+ beta_hat = (1 - self.alpha_hat[t-1]) / (1 - self.alpha_hat[t]) * self.betas[t]
146
+ beta = self.betas[t]
147
+ # should we reshape the params to (batch_size, 1, 1, 1) ?
148
+
149
+
150
+ # we can either use beta_hat or beta_t
151
+ # std = torch.sqrt(beta_hat)
152
+ std = torch.sqrt(beta)
153
+ # mean + variance * z
154
+ if cond is not None:
155
+ predicted_noise = self.model(x, torch.tensor([t]).repeat(batch_size).to(device), cond)
156
+ else:
157
+ predicted_noise = self.model(x, torch.tensor([t]).repeat(batch_size).to(device))
158
+ mean = one_over_sqrt_alpha * (x - one_minus_alpha / sqrt_one_minus_alpha_hat * predicted_noise)
159
+ x_t_minus_1 = mean + std * z
160
+
161
+ return x_t_minus_1
162
+
163
+ def sample(self, num_samples, show_progress=True):
164
+ """
165
+ Sample from the model
166
+ """
167
+ cond = None
168
+ if self.model.is_conditional:
169
+ # cond is arange()
170
+ assert num_samples <= self.model.num_classes, "num_samples must be less than or equal to the number of classes"
171
+ cond = torch.arange(self.model.num_classes)[:num_samples].to(self.device)
172
+ cond = rearrange(cond, 'i -> i ()')
173
+
174
+ self.model.eval()
175
+ image_versions = []
176
+ with torch.no_grad():
177
+ x = torch.randn(num_samples, self.channels, *self.image_size).to(self.device)
178
+ it = reversed(range(1, self.noise_steps))
179
+ if show_progress:
180
+ it = tqdm(it)
181
+ for t in it:
182
+ image_versions.append(self.denormalize_image(torch.clip(x, -1, 1)).clone().squeeze(0))
183
+ x = self.sample_step(x, t, cond)
184
+ self.model.train()
185
+ x = torch.clip(x, -1.0, 1.0)
186
+ return self.denormalize_image(x), image_versions
187
+
188
+ def validate(self, dataloader):
189
+ """
190
+ Calculate the loss on the validation set
191
+ """
192
+ self.model.eval()
193
+ acc_loss = 0
194
+ with torch.no_grad():
195
+ for (image, cond) in dataloader:
196
+ t = self.sample_time_steps(batch_size=image.shape[0])
197
+ noisy_image, added_noise = self.apply_noise(image, t)
198
+ noisy_image = noisy_image.to(self.device)
199
+ added_noise = added_noise.to(self.device)
200
+ cond = cond.to(self.device)
201
+ predicted_noise = self.model(noisy_image, t, cond)
202
+ loss = nn.MSELoss()(predicted_noise, added_noise)
203
+ acc_loss += loss.item()
204
+ self.model.train()
205
+ return acc_loss / len(dataloader)
206
+
207
+ class DiffusionImageAPI:
208
+ def __init__(self, diffusion_model):
209
+ self.diffusion_model = diffusion_model
210
+
211
+ def get_noisy_image(self, image, t):
212
+ x = torch.tensor(np.array(image))
213
+
214
+ x = self.diffusion_model.normalize_image(x)
215
+
216
+ y, _ = self.diffusion_model.apply_noise(x, t)
217
+
218
+ y = self.diffusion_model.denormalize_image(y)
219
+ #print(f"Shape of Image: {y.shape}")
220
+
221
+ return Image.fromarray(y.squeeze(0).numpy().astype(np.uint8))
222
+
223
+
224
+ def get_noisy_images(self, image, time_steps):
225
+ """
226
+ image: the image to be processed PIL.Image
227
+ time_steps: the number of time steps to apply noise (int)
228
+ """
229
+
230
+ return [self.get_noisy_image(image, int(t)) for t in time_steps]
231
+
232
+ def tensor_to_image(self, tensor):
233
+ return Image.fromarray(tensor.cpu().numpy().astype(np.uint8))
inference.py CHANGED
@@ -7,7 +7,9 @@ from PIL import Image
7
  import requests
8
  import io
9
 
10
- from model import Unet, ConditionalUnet, GaussianDiffusion, DiffusionImageAPI
 
 
11
 
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
 
@@ -22,7 +24,7 @@ def inference():
22
  )
23
  model = ConditionalUnet(
24
  unet=model,
25
- num_classes=14,
26
  )
27
  model.load_state_dict(torch.load("./model_final.pt", map_location=device))
28
 
 
7
  import requests
8
  import io
9
 
10
+ from unet import Unet, ConditionalUnet
11
+
12
+ from diffusion import GaussianDiffusion, DiffusionImageAPI
13
 
14
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
 
 
24
  )
25
  model = ConditionalUnet(
26
  unet=model,
27
+ num_classes=13,
28
  )
29
  model.load_state_dict(torch.load("./model_final.pt", map_location=device))
30
 
model.py → unet.py RENAMED
@@ -7,243 +7,6 @@ from collections import defaultdict
7
  import torch as th
8
  import numpy as np
9
  import math
10
- from tqdm import tqdm
11
- from PIL import Image
12
-
13
- class GaussianDiffusion:
14
- def __init__(self, model, noise_steps, beta_0, beta_T, image_size, channels=3, schedule="linear"):
15
- """
16
- suggested betas for:
17
- * linear schedule: 1e-4, 0.02
18
-
19
- model: the model to be trained (nn.Module)
20
- noise_steps: the number of steps to apply noise (int)
21
- beta_0: the initial value of beta (float)
22
- beta_T: the final value of beta (float)
23
- image_size: the size of the image (int, int)
24
- """
25
- self.device = 'cpu'
26
- self.channels = channels
27
-
28
- self.model = model
29
- self.noise_steps = noise_steps
30
- self.beta_0 = beta_0
31
- self.beta_T = beta_T
32
- self.image_size = image_size
33
-
34
- self.betas = self.beta_schedule(schedule=schedule)
35
- self.alphas = 1.0 - self.betas
36
- # cumulative product of alphas, so we can optimize forward process calculation
37
- self.alpha_hat = torch.cumprod(self.alphas, dim=0)
38
-
39
- def beta_schedule(self, schedule="cosine"):
40
- if schedule == "linear":
41
- return torch.linspace(self.beta_0, self.beta_T, self.noise_steps).to(self.device)
42
- elif schedule == "cosine":
43
- return self.betas_for_cosine(self.noise_steps)
44
- elif schedule == "sigmoid":
45
- return self.betas_for_sigmoid(self.noise_steps)
46
-
47
- @staticmethod
48
- def sigmoid(x):
49
- return 1 / (1 + np.exp(-x))
50
-
51
- def betas_for_sigmoid(self, num_diffusion_timesteps, start=-3,end=3, tau=1.0, clip_min = 1e-9):
52
- betas = []
53
- v_start = self.sigmoid(start/tau)
54
- v_end = self.sigmoid(end/tau)
55
- for t in range(num_diffusion_timesteps):
56
- t_float = float(t/num_diffusion_timesteps)
57
- output0 = self.sigmoid((t_float* (end-start)+start)/tau)
58
- output = (v_end-output0) / (v_end-v_start)
59
- betas.append(np.clip(output*.2, clip_min,.2))
60
- return torch.flip(torch.tensor(betas).to(self.device),dims=[0]).float()
61
-
62
- def betas_for_cosine(self,num_steps,start=0,end=1,tau=1,clip_min=1e-9):
63
- v_start = math.cos(start*math.pi / 2) ** (2 * tau)
64
- betas = []
65
- v_end = math.cos(end* math.pi/2) ** 2*tau
66
- for t in range(num_steps):
67
- t_float = float(t)/num_steps
68
- output = math.cos((t_float* (end-start)+start)*math.pi/2)**(2*tau)
69
- output = (v_end - output) / (v_end-v_start)
70
- betas.append(np.clip(output*.2,clip_min,.2))
71
- return torch.flip(torch.tensor(betas).to(self.device),dims=[0]).float()
72
-
73
-
74
- def sample_time_steps(self, batch_size=1):
75
- return torch.randint(0, self.noise_steps, (batch_size,)).to(self.device)
76
-
77
- def to(self,device):
78
- self.device = device
79
- self.betas = self.betas.to(device)
80
- self.alphas = self.alphas.to(device)
81
- self.alpha_hat = self.alpha_hat.to(device)
82
-
83
-
84
- def q(self, x, t):
85
- """
86
- Forward process
87
- """
88
- pass
89
-
90
- def p(self, x, t):
91
- """
92
- Backward process
93
- """
94
- pass
95
-
96
-
97
- def apply_noise(self, x, t):
98
- # force x to be (batch_size, image_width, image_height, channels)
99
- if len(x.shape) == 3:
100
- x = x.unsqueeze(0)
101
- if type(t) == int:
102
- t = torch.tensor([t])
103
- #print(f'Shape -> {x.shape}, len -> {len(x.shape)}')
104
- sqrt_alpha_hat = torch.sqrt(torch.tensor([self.alpha_hat[t_] for t_ in t]).to(self.device))
105
- sqrt_one_minus_alpha_hat = torch.sqrt(torch.tensor([1.0 - self.alpha_hat[t_] for t_ in t]).to(self.device))
106
- # standard normal distribution
107
- epsilon = torch.randn_like(x).to(self.device)
108
-
109
- # Eq 2. in DDPM paper
110
- #noisy_image = sqrt_alpha_hat * x + sqrt_one_minus_alpha_hat * epsilon
111
-
112
- """print(f'''
113
- Shape of x {x.shape}
114
- Shape of sqrt {sqrt_one_minus_alpha_hat.shape}''')"""
115
-
116
- try:
117
- #print(x.shape)
118
- #noisy_image = torch.einsum("b,bwhc->bwhc", sqrt_alpha_hat, x.to(self.device)) + torch.einsum("b,bwhc->bwhc", sqrt_one_minus_alpha_hat, epsilon)
119
- noisy_image = torch.einsum("b,bcwh->bcwh", sqrt_alpha_hat, x.to(self.device)) + torch.einsum("b,bcwh->bcwh", sqrt_one_minus_alpha_hat, epsilon)
120
- except:
121
- print(f'Failed image: shape {x.shape}')
122
-
123
-
124
- #print(f'Noisy image -> {noisy_image.shape}')
125
- # returning noisy iamge and the noise which was added to the image
126
- #return noisy_image, epsilon
127
- #return torch.clip(noisy_image, -1.0, 1.0), epsilon
128
- return noisy_image, epsilon
129
-
130
- @staticmethod
131
- def normalize_image(x):
132
- # normalize image to [-1, 1]
133
- return x / 255.0 * 2.0 - 1.0
134
-
135
- @staticmethod
136
- def denormalize_image(x):
137
- # denormalize image to [0, 255]
138
- return (x + 1.0) / 2.0 * 255.0
139
-
140
- def sample_step(self, x, t, cond):
141
- batch_size = x.shape[0]
142
- device = x.device
143
- z = torch.randn_like(x) if t >= 1 else torch.zeros_like(x)
144
- z = z.to(device)
145
- alpha = self.alphas[t]
146
- one_over_sqrt_alpha = 1.0 / torch.sqrt(alpha)
147
- one_minus_alpha = 1.0 - alpha
148
-
149
- sqrt_one_minus_alpha_hat = torch.sqrt(1.0 - self.alpha_hat[t])
150
- beta_hat = (1 - self.alpha_hat[t-1]) / (1 - self.alpha_hat[t]) * self.betas[t]
151
- beta = self.betas[t]
152
- # should we reshape the params to (batch_size, 1, 1, 1) ?
153
-
154
-
155
- # we can either use beta_hat or beta_t
156
- # std = torch.sqrt(beta_hat)
157
- std = torch.sqrt(beta)
158
- # mean + variance * z
159
- if cond is not None:
160
- predicted_noise = self.model(x, torch.tensor([t]).repeat(batch_size).to(device), cond)
161
- else:
162
- predicted_noise = self.model(x, torch.tensor([t]).repeat(batch_size).to(device))
163
- mean = one_over_sqrt_alpha * (x - one_minus_alpha / sqrt_one_minus_alpha_hat * predicted_noise)
164
- x_t_minus_1 = mean + std * z
165
-
166
- return x_t_minus_1
167
-
168
- def sample(self, num_samples, show_progress=True):
169
- """
170
- Sample from the model
171
- """
172
- cond = None
173
- if self.model.is_conditional:
174
- # cond is arange()
175
- assert num_samples <= self.model.num_classes, "num_samples must be less than or equal to the number of classes"
176
- cond = torch.arange(self.model.num_classes)[:num_samples].to(self.device)
177
- cond = rearrange(cond, 'i -> i ()')
178
-
179
- self.model.eval()
180
- image_versions = []
181
- with torch.no_grad():
182
- x = torch.randn(num_samples, self.channels, *self.image_size).to(self.device)
183
- it = reversed(range(1, self.noise_steps))
184
- if show_progress:
185
- it = tqdm(it)
186
- for t in it:
187
- image_versions.append(self.denormalize_image(torch.clip(x, -1, 1)).clone().squeeze(0))
188
- x = self.sample_step(x, t, cond)
189
- self.model.train()
190
- x = torch.clip(x, -1.0, 1.0)
191
- return self.denormalize_image(x), image_versions
192
-
193
- def validate(self, dataloader):
194
- """
195
- Calculate the loss on the validation set
196
- """
197
- self.model.eval()
198
- acc_loss = 0
199
- with torch.no_grad():
200
- for (image, cond) in dataloader:
201
- t = self.sample_time_steps(batch_size=image.shape[0])
202
- noisy_image, added_noise = self.apply_noise(image, t)
203
- noisy_image = noisy_image.to(self.device)
204
- added_noise = added_noise.to(self.device)
205
- cond = cond.to(self.device)
206
- predicted_noise = self.model(noisy_image, t, cond)
207
- loss = nn.MSELoss()(predicted_noise, added_noise)
208
- acc_loss += loss.item()
209
- self.model.train()
210
- return acc_loss / len(dataloader)
211
-
212
- class DiffusionImageAPI:
213
- def __init__(self, diffusion_model):
214
- self.diffusion_model = diffusion_model
215
-
216
- def get_noisy_image(self, image, t):
217
- x = torch.tensor(np.array(image))
218
-
219
- x = self.diffusion_model.normalize_image(x)
220
-
221
- y, _ = self.diffusion_model.apply_noise(x, t)
222
-
223
- y = self.diffusion_model.denormalize_image(y)
224
- #print(f"Shape of Image: {y.shape}")
225
-
226
- return Image.fromarray(y.squeeze(0).numpy().astype(np.uint8))
227
-
228
-
229
- def get_noisy_images(self, image, time_steps):
230
- """
231
- image: the image to be processed PIL.Image
232
- time_steps: the number of time steps to apply noise (int)
233
- """
234
-
235
- return [self.get_noisy_image(image, int(t)) for t in time_steps]
236
-
237
- def tensor_to_image(self, tensor):
238
- return Image.fromarray(tensor.cpu().numpy().astype(np.uint8))
239
-
240
-
241
-
242
-
243
-
244
-
245
-
246
-
247
 
248
  str_to_act = defaultdict(lambda: nn.SiLU())
249
  str_to_act.update({
@@ -547,6 +310,9 @@ class Unet(nn.Module):
547
  ):
548
  super().__init__()
549
  self.is_conditional = False
 
 
 
550
 
551
  self.image_channels = image_channels
552
  self.starting_channels = starting_channels
@@ -643,7 +409,7 @@ class ConditionalUnet(nn.Module):
643
  self.unet = unet
644
  self.num_classes = num_classes
645
 
646
- self.class_embedding = nn.Embedding(num_classes, unet.starting_channels)
647
 
648
  def forward(self, x, t, cond=None):
649
  # cond: (batch_size, n), where n is the number of classes that we are conditioning on
@@ -655,4 +421,4 @@ class ConditionalUnet(nn.Module):
655
  cond = cond.sum(dim=1)
656
  t += cond
657
 
658
- return self.unet._forward(x, t)
 
7
  import torch as th
8
  import numpy as np
9
  import math
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  str_to_act = defaultdict(lambda: nn.SiLU())
12
  str_to_act.update({
 
310
  ):
311
  super().__init__()
312
  self.is_conditional = False
313
+ #channel_mults = (1, 2, 2, 2)
314
+ #attention_layers = (False, False, True, False)
315
+ #res_block_width=3
316
 
317
  self.image_channels = image_channels
318
  self.starting_channels = starting_channels
 
409
  self.unet = unet
410
  self.num_classes = num_classes
411
 
412
+ self.class_embedding = nn.Embedding(num_classes + 1, unet.starting_channels, padding_idx=0)
413
 
414
  def forward(self, x, t, cond=None):
415
  # cond: (batch_size, n), where n is the number of classes that we are conditioning on
 
421
  cond = cond.sum(dim=1)
422
  t += cond
423
 
424
+ return self.unet._forward(x, t)