frutiemax commited on
Commit
054faf7
·
1 Parent(s): 21640c8

Use VAE to speed up inference

Browse files
Files changed (3) hide show
  1. rct_diffusion_pipeline.py +43 -37
  2. test_pipeline.py +11 -3
  3. train_model.py +63 -33
rct_diffusion_pipeline.py CHANGED
@@ -4,6 +4,7 @@ from diffusers import DDPMScheduler, UNet2DConditionModel
4
  import torch
5
  import torchvision.transforms as T
6
  from PIL import Image
 
7
  from transformers import AutoTokenizer
8
  from datasets import load_dataset
9
  import numpy as np
@@ -11,13 +12,7 @@ import pandas as pd
11
  from tqdm.auto import tqdm
12
 
13
  class RCTDiffusionPipeline(DiffusionPipeline):
14
- def get_default_unet(hidden_dim):
15
- return UNet2DConditionModel(sample_size=256, in_channels=12, out_channels=12, \
16
- down_block_types=('CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D'),\
17
- up_block_types=('UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D'), cross_attention_dim=hidden_dim,
18
- block_out_channels=(64, 128, 256), norm_num_groups=32)
19
-
20
- def __init__(self):
21
  super().__init__()
22
 
23
  # dictionnary that keeps the different classes of object description, color1, color2 and color3
@@ -25,16 +20,13 @@ class RCTDiffusionPipeline(DiffusionPipeline):
25
  self.color1_dict = {}
26
  self.color2_dict = {}
27
  self.color3_dict = {}
28
- self.load_dictionaries_from_dataset()
29
 
30
- self.scheduler = None
31
- self.unet = None
32
-
33
- def set_unet(self, unet):
34
- self.unet = unet
35
-
36
- def set_scheduler(self, scheduler):
37
  self.scheduler = scheduler
 
 
 
 
 
38
 
39
  def load_dictionaries_from_dataset(self):
40
  dataset = load_dataset('frutiemax/rct_dataset')
@@ -127,11 +119,10 @@ class RCTDiffusionPipeline(DiffusionPipeline):
127
 
128
  class_labels = torch.reshape(class_labels, (num_images, 1, self.get_class_labels_size()))
129
  return class_labels
130
-
131
- def __call__(self, object_description : list[list[tuple[str, float]]], color1 : list[list[tuple[str, float]]], \
132
  color2 : list[list[tuple[str, float]]] = None, color3 : list[list[tuple[str, float]]] = None, \
133
- batch_size=1, num_inference_steps=20, generator=torch.manual_seed(torch.random.seed())):
134
-
135
  # check if the labels are the correct size
136
  if len(object_description) != batch_size:
137
  return None
@@ -171,25 +162,29 @@ class RCTDiffusionPipeline(DiffusionPipeline):
171
  colors3.append(c3)
172
 
173
  # now put those weights into a tensor
174
- class_labels = self.pack_labels_to_tensor(batch_size, object_descriptions, colors1, colors2, colors3).to(device='cuda',dtype=torch.float16)
175
 
176
- # we need those class labels for the 12 channels
177
- #new_class_labels = torch.Tensor(size=(batch_size, 12, self.get_class_labels_size()))
178
- #new_class_labels[:, :] = class_labels
179
- #class_labels = new_class_labels.to(device='cuda', dtype=torch.float16)
180
- #del new_class_labels
181
-
182
- # set the inference steps
183
- self.scheduler.set_timesteps(num_inference_steps)
184
-
185
- noise_batches = torch.Tensor(size=(batch_size, 4, 3, 256, 256)).to(dtype=torch.float16, device='cuda')
186
  for batch_index in range(batch_size):
187
  for view_index in range(4):
188
- noise = torch.randn(3, 256, 256).to(dtype=torch.float16, device='cuda')
189
  noise_batches[batch_index, view_index] = noise
190
 
191
- # reshape the data so it's (batch_size, 12, 256, 256)
192
- noise_batches = torch.reshape(noise_batches, (batch_size, 1, 12, 256, 256)).to(dtype=torch.float16, device='cuda')
 
 
 
 
 
 
 
 
 
 
 
193
 
194
  # now call the model for the n interations
195
  progress_bar = tqdm(total=num_inference_steps)
@@ -206,15 +201,26 @@ class RCTDiffusionPipeline(DiffusionPipeline):
206
  epoch = epoch + 1
207
 
208
  # reshape the data so we get back 4 RGB images
209
- noise_batches = torch.reshape(noise_batches, (batch_size, 4, 3, 256, 256)).to('cpu')
 
 
 
 
 
 
 
210
 
211
  # convert those tensors to PIL images
212
  output_images = []
213
- tensor_to_pil = T.ToPILImage('RGB')
214
-
215
  for batch_index in range(batch_size):
216
  for image_index in range(4):
217
- output_images.append(tensor_to_pil(noise_batches[batch_index, image_index]))
 
 
 
 
 
 
218
 
219
  # for now just return the images
220
  return output_images
 
4
  import torch
5
  import torchvision.transforms as T
6
  from PIL import Image
7
+ import PIL.Image
8
  from transformers import AutoTokenizer
9
  from datasets import load_dataset
10
  import numpy as np
 
12
  from tqdm.auto import tqdm
13
 
14
  class RCTDiffusionPipeline(DiffusionPipeline):
15
+ def __init__(self, unet, scheduler, vae):
 
 
 
 
 
 
16
  super().__init__()
17
 
18
  # dictionnary that keeps the different classes of object description, color1, color2 and color3
 
20
  self.color1_dict = {}
21
  self.color2_dict = {}
22
  self.color3_dict = {}
 
23
 
 
 
 
 
 
 
 
24
  self.scheduler = scheduler
25
+ self.unet = unet
26
+ self.vae = vae
27
+
28
+ # channels for 1 image
29
+ self.num_channels = int(self.unet.config.in_channels / 4)
30
 
31
  def load_dictionaries_from_dataset(self):
32
  dataset = load_dataset('frutiemax/rct_dataset')
 
119
 
120
  class_labels = torch.reshape(class_labels, (num_images, 1, self.get_class_labels_size()))
121
  return class_labels
122
+
123
+ def get_class_labels(self, object_description : list[list[tuple[str, float]]], color1 : list[list[tuple[str, float]]], \
124
  color2 : list[list[tuple[str, float]]] = None, color3 : list[list[tuple[str, float]]] = None, \
125
+ batch_size=1):
 
126
  # check if the labels are the correct size
127
  if len(object_description) != batch_size:
128
  return None
 
162
  colors3.append(c3)
163
 
164
  # now put those weights into a tensor
165
+ return self.pack_labels_to_tensor(batch_size, object_descriptions, colors1, colors2, colors3).to(device='cuda',dtype=torch.float16)
166
 
167
+ # generate 64x64 latents
168
+ def generate_noise_batches(self, batch_size):
169
+ noise_batches = torch.Tensor(size=(batch_size, 4, self.num_channels, 64, 64)).to(dtype=torch.float16, device='cuda')
 
 
 
 
 
 
 
170
  for batch_index in range(batch_size):
171
  for view_index in range(4):
172
+ noise = torch.randn(self.num_channels, 64, 64).to(dtype=torch.float16, device='cuda')
173
  noise_batches[batch_index, view_index] = noise
174
 
175
+ return torch.reshape(noise_batches, (batch_size, 1, self.num_channels*4, 64, 64)).to(dtype=torch.float16, device='cuda')
176
+
177
+ def __call__(self, object_description : list[list[tuple[str, float]]], color1 : list[list[tuple[str, float]]], \
178
+ color2 : list[list[tuple[str, float]]] = None, color3 : list[list[tuple[str, float]]] = None, \
179
+ batch_size=1, num_inference_steps=20, generator=torch.manual_seed(torch.random.seed())):
180
+
181
+ class_labels = self.get_class_labels(object_description, color1, color2, color3, batch_size).to(device='cuda', dtype=torch.float16)
182
+ if class_labels == None:
183
+ return None
184
+
185
+ # set the inference steps
186
+ self.scheduler.set_timesteps(num_inference_steps)
187
+ noise_batches = self.generate_noise_batches(batch_size)
188
 
189
  # now call the model for the n interations
190
  progress_bar = tqdm(total=num_inference_steps)
 
201
  epoch = epoch + 1
202
 
203
  # reshape the data so we get back 4 RGB images
204
+ noise_batches = torch.reshape(noise_batches, (batch_size, 4, self.num_channels, 64, 64))
205
+ images = torch.Tensor(size=(batch_size, 4, 3, 512, 512))
206
+
207
+ with torch.no_grad():
208
+ for image_index in range(4):
209
+ image = noise_batches[:, image_index]
210
+ result = self.vae.decode(image).sample
211
+ images[:, image_index] = result
212
 
213
  # convert those tensors to PIL images
214
  output_images = []
 
 
215
  for batch_index in range(batch_size):
216
  for image_index in range(4):
217
+ # run these into the vae decoder
218
+ image = images[batch_index, image_index]
219
+ image = (image / 2 + 0.5).clamp(0, 1).squeeze()
220
+ image = (image.permute(1, 2, 0) * 255).to(torch.uint8).cpu().numpy()
221
+ image = (image * 255).round().astype("uint8")
222
+ image = Image.fromarray(image)
223
+ output_images.append(image)
224
 
225
  # for now just return the images
226
  return output_images
test_pipeline.py CHANGED
@@ -1,11 +1,19 @@
1
  from rct_diffusion_pipeline import RCTDiffusionPipeline
2
- from diffusers import UNet2DConditionModel
3
 
4
 
5
  torch_device = "cuda"
6
 
7
- pipeline = RCTDiffusionPipeline()
8
- pipeline.print_class_tokens_to_csv()
 
 
 
 
 
 
 
 
9
  output = pipeline([[('aleppo pine tree', 1.0)]], [[('dark green', 1.0)]])
10
 
11
  # from PIL import Image
 
1
  from rct_diffusion_pipeline import RCTDiffusionPipeline
2
+ from diffusers import UNet2DConditionModel, DDPMScheduler, AutoencoderKL
3
 
4
 
5
  torch_device = "cuda"
6
 
7
+ unet = UNet2DConditionModel(sample_size=64, in_channels=16, out_channels=16, \
8
+ down_block_types=('CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D'),\
9
+ up_block_types=('UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D'), cross_attention_dim=160,
10
+ block_out_channels=(64, 128, 256), norm_num_groups=32)
11
+ scheduler = DDPMScheduler(num_train_timesteps=20)
12
+ vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae", use_safetensors=True)
13
+ vae.tile_sample_min_size = 256
14
+
15
+
16
+ pipeline = RCTDiffusionPipeline(unet, scheduler, vae)
17
  output = pipeline([[('aleppo pine tree', 1.0)]], [[('dark green', 1.0)]])
18
 
19
  # from PIL import Image
train_model.py CHANGED
@@ -10,7 +10,12 @@ import torch.nn.functional as F
10
  from diffusers.optimization import get_cosine_schedule_with_warmup
11
  from tqdm.auto import tqdm
12
  from accelerate import Accelerator
13
- from diffusers import DDPMScheduler, UNet2DConditionModel
 
 
 
 
 
14
 
15
  def save_and_test(pipeline, epoch):
16
  outputs = pipeline([[('aleppo pine tree', 1.0)]], [[('dark green', 1.0)]])
@@ -21,14 +26,10 @@ def save_and_test(pipeline, epoch):
21
  model_file = f'rct_foliage_{epoch}.pth'
22
  pipeline.save_pretrained(model_file)
23
 
24
- def train_model(batch_size=4, epochs=100, scheduler_num_timesteps=20, save_model_interval=10, start_learning_rate=1e-3, lr_warmup_steps=500):
25
- dataset = load_dataset('frutiemax/rct_dataset')
26
- dataset = dataset['train']
27
-
28
- num_images = int(dataset.num_rows / 4)
29
-
30
- # let's get all the entries for the 4 views split in four lists
31
  views = []
 
32
 
33
  for view_index in range(4):
34
  entries = [entry for entry in dataset if entry['view'] == view_index]
@@ -41,18 +42,18 @@ def train_model(batch_size=4, epochs=100, scheduler_num_timesteps=20, save_model
41
  for entry in views[view_index]:
42
  image = entry['image']
43
 
44
- scale_factor = int(np.minimum(256 / image.width, 256 / image.height))
45
  image = Image.resize(image, size=(scale_factor * image.width, scale_factor * image.height), resample=Resampling.NEAREST)
46
 
47
- new_image = PIL.Image.new('RGB', (256, 256))
48
- new_image.paste(image, box=(int((256 - image.width)/2), int((256 - image.height)/2)))
49
  images.append(new_image)
50
  image_views.append(images)
51
 
52
  del views
53
 
54
  # convert those views in tensors
55
- targets = torch.Tensor(size=(num_images, 4, 3, 256, 256)).to(dtype=torch.float16)
56
  pillow_to_tensor = T.ToTensor()
57
 
58
  for image_index in range(num_images):
@@ -61,8 +62,9 @@ def train_model(batch_size=4, epochs=100, scheduler_num_timesteps=20, save_model
61
  del image_views
62
  del entries
63
 
64
- targets = torch.reshape(targets, (num_images, 12, 256, 256))
65
 
 
66
  # get the labels
67
  view0_entries = [row for row in dataset if row['view'] == 0]
68
  obj_descriptions = [row['object_description'] for row in view0_entries]
@@ -79,7 +81,7 @@ def train_model(batch_size=4, epochs=100, scheduler_num_timesteps=20, save_model
79
  colors3 = [[(color3, 1.0)] for color3 in colors3]
80
 
81
  # convert those tuples in numpy arrays using the helper function of the model
82
- model = RCTDiffusionPipeline()
83
  obj_descriptions = [model.get_object_description_weights(obj_desc) for obj_desc in obj_descriptions]
84
  colors1 = [model.get_color1_weights(color1) for color1 in colors1]
85
  colors2 = [model.get_color2_weights(color2) for color2 in colors2]
@@ -92,51 +94,79 @@ def train_model(batch_size=4, epochs=100, scheduler_num_timesteps=20, save_model
92
  del colors2
93
  del colors3
94
  del dataset
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
- unet = RCTDiffusionPipeline.get_default_unet(160)
97
  optimizer = torch.optim.Adam(unet.parameters(), lr=start_learning_rate)
98
  lr_scheduler = get_cosine_schedule_with_warmup(
99
  optimizer=optimizer,
100
  num_warmup_steps=lr_warmup_steps,
101
  num_training_steps=num_images * epochs
102
  )
 
 
 
103
 
104
  # lets train for 100 epoch for each sprite in the dataset with a random noise level
105
  progress_bar = tqdm(total=epochs)
 
 
106
 
107
- scheduler = DDPMScheduler(scheduler_num_timesteps)
108
- unet = unet.to(device='cuda', dtype=torch.float16)
109
- scheduler.set_timesteps(scheduler_num_timesteps)
110
-
111
  for epoch in range(epochs):
112
  # create a noisy version of each sprite
113
  for batch_index in range(0, num_images, batch_size):
114
  progress_bar.set_description(f'epoch={epoch}, batch_index={batch_index}')
115
  batch_end = np.minimum(num_images, batch_index + batch_size)
116
  clean_images = targets[batch_index:batch_end]
117
- clean_images = torch.reshape(clean_images, ((batch_end - batch_index), 12, 256, 256)).to(device='cuda', dtype=torch.float16)
118
 
119
  noise = torch.randn(clean_images.shape, dtype=torch.float16, device='cuda')
120
  timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_end - batch_index, )).to(device='cuda')
121
  #timesteps = timesteps.to(dtype=torch.int, device='cuda')
122
  noisy_images = scheduler.add_noise(clean_images, noise, timesteps)
123
- noise_pred = unet(noisy_images, timesteps, class_labels[batch_index:batch_end].to(device='cuda',dtype=torch.float16), return_dict=False)[0]
124
-
125
- #noise_pred = noise_pred.to(device='cuda', dtype=torch.float16)
126
- loss = F.mse_loss(noise_pred, noise)
127
- loss.backward()
128
- optimizer.step()
129
- lr_scheduler.step()
130
- optimizer.zero_grad()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
  if (epoch + 1) % save_model_interval == 0:
133
- model.unet = unet
134
- model.scheduler = scheduler
135
  save_and_test(model, epoch)
136
- del model.unet
137
- del model.scheduler
138
  progress_bar.update(1)
139
 
140
 
141
  if __name__ == '__main__':
142
- train_model(8, save_model_interval=1)
 
10
  from diffusers.optimization import get_cosine_schedule_with_warmup
11
  from tqdm.auto import tqdm
12
  from accelerate import Accelerator
13
+ from diffusers import DDPMScheduler, UNet2DConditionModel, AutoencoderKL
14
+
15
+ SAMPLE_SIZE = 512
16
+ LATENT_SIZE = 64
17
+ SAMPLE_NUM_CHANNELS = 3
18
+ LATENT_NUM_CHANNELS = 4
19
 
20
  def save_and_test(pipeline, epoch):
21
  outputs = pipeline([[('aleppo pine tree', 1.0)]], [[('dark green', 1.0)]])
 
26
  model_file = f'rct_foliage_{epoch}.pth'
27
  pipeline.save_pretrained(model_file)
28
 
29
+ def convert_images(dataset):
30
+ # let's get all the entries for the 4 views split in four lists
 
 
 
 
 
31
  views = []
32
+ num_images = int(dataset.num_rows / 4)
33
 
34
  for view_index in range(4):
35
  entries = [entry for entry in dataset if entry['view'] == view_index]
 
42
  for entry in views[view_index]:
43
  image = entry['image']
44
 
45
+ scale_factor = int(np.minimum(SAMPLE_SIZE / image.width, SAMPLE_SIZE / image.height))
46
  image = Image.resize(image, size=(scale_factor * image.width, scale_factor * image.height), resample=Resampling.NEAREST)
47
 
48
+ new_image = PIL.Image.new('RGB', (SAMPLE_SIZE, SAMPLE_SIZE))
49
+ new_image.paste(image, box=(int((SAMPLE_SIZE - image.width)/2), int((SAMPLE_SIZE - image.height)/2)))
50
  images.append(new_image)
51
  image_views.append(images)
52
 
53
  del views
54
 
55
  # convert those views in tensors
56
+ targets = torch.Tensor(size=(num_images, 4, SAMPLE_NUM_CHANNELS, SAMPLE_SIZE, SAMPLE_SIZE)).to(dtype=torch.float16)
57
  pillow_to_tensor = T.ToTensor()
58
 
59
  for image_index in range(num_images):
 
62
  del image_views
63
  del entries
64
 
65
+ return torch.reshape(targets, (num_images, 4 * SAMPLE_NUM_CHANNELS, SAMPLE_SIZE, SAMPLE_SIZE))
66
 
67
+ def convert_labels(dataset, model, num_images):
68
  # get the labels
69
  view0_entries = [row for row in dataset if row['view'] == 0]
70
  obj_descriptions = [row['object_description'] for row in view0_entries]
 
81
  colors3 = [[(color3, 1.0)] for color3 in colors3]
82
 
83
  # convert those tuples in numpy arrays using the helper function of the model
84
+
85
  obj_descriptions = [model.get_object_description_weights(obj_desc) for obj_desc in obj_descriptions]
86
  colors1 = [model.get_color1_weights(color1) for color1 in colors1]
87
  colors2 = [model.get_color2_weights(color2) for color2 in colors2]
 
94
  del colors2
95
  del colors3
96
  del dataset
97
+ return class_labels.to(dtype=torch.float16, device='cuda')
98
+
99
+ def train_model(batch_size=4, epochs=100, scheduler_num_timesteps=20, save_model_interval=10, start_learning_rate=1e-3, lr_warmup_steps=500):
100
+ dataset = load_dataset('frutiemax/rct_dataset')
101
+ dataset = dataset['train']
102
+
103
+ targets = convert_images(dataset)
104
+ num_images = int(dataset.num_rows / 4)
105
+
106
+ unet = UNet2DConditionModel(sample_size=LATENT_SIZE, in_channels=LATENT_NUM_CHANNELS * 4, out_channels=LATENT_NUM_CHANNELS * 4, \
107
+ down_block_types=('CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D'),\
108
+ up_block_types=('UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D'), cross_attention_dim=160,
109
+ block_out_channels=(64, 128, 256), norm_num_groups=32)
110
+ unet = unet.to(dtype=torch.float16)
111
+ scheduler = DDPMScheduler(num_train_timesteps=20)
112
+ vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae", use_safetensors=True, variant='fp16')
113
+ vae = vae.to(dtype=torch.float16)
114
 
 
115
  optimizer = torch.optim.Adam(unet.parameters(), lr=start_learning_rate)
116
  lr_scheduler = get_cosine_schedule_with_warmup(
117
  optimizer=optimizer,
118
  num_warmup_steps=lr_warmup_steps,
119
  num_training_steps=num_images * epochs
120
  )
121
+ model = RCTDiffusionPipeline(unet, scheduler, vae)
122
+ model.load_dictionaries_from_dataset()
123
+ labels = convert_labels(dataset, model, num_images)
124
 
125
  # lets train for 100 epoch for each sprite in the dataset with a random noise level
126
  progress_bar = tqdm(total=epochs)
127
+ accelerator = Accelerator(mixed_precision='fp16')
128
+ unet, scheduler, lr_scheduler, vae = accelerator.prepare(unet, scheduler, lr_scheduler, vae)
129
 
 
 
 
 
130
  for epoch in range(epochs):
131
  # create a noisy version of each sprite
132
  for batch_index in range(0, num_images, batch_size):
133
  progress_bar.set_description(f'epoch={epoch}, batch_index={batch_index}')
134
  batch_end = np.minimum(num_images, batch_index + batch_size)
135
  clean_images = targets[batch_index:batch_end]
136
+ clean_images = torch.reshape(clean_images, ((batch_end - batch_index), SAMPLE_NUM_CHANNELS * 4, SAMPLE_SIZE, SAMPLE_SIZE)).to(device='cuda', dtype=torch.float16)
137
 
138
  noise = torch.randn(clean_images.shape, dtype=torch.float16, device='cuda')
139
  timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_end - batch_index, )).to(device='cuda')
140
  #timesteps = timesteps.to(dtype=torch.int, device='cuda')
141
  noisy_images = scheduler.add_noise(clean_images, noise, timesteps)
142
+
143
+ # encode through the vae
144
+ with accelerator.accumulate(unet):
145
+ latent_images = torch.Tensor(size=(batch_end - batch_index, LATENT_NUM_CHANNELS * 4, LATENT_SIZE, LATENT_SIZE)).to(device='cuda', dtype=torch.float16)
146
+ latent_noises = torch.Tensor(size=(batch_end - batch_index, LATENT_NUM_CHANNELS * 4, LATENT_SIZE, LATENT_SIZE)).to(device='cuda', dtype=torch.float16)
147
+ for view_index in range(4):
148
+ images = noisy_images[:, view_index*SAMPLE_NUM_CHANNELS:(view_index+1)*SAMPLE_NUM_CHANNELS]
149
+ result = vae.encode(images).latent_dist.sample()
150
+ latent_images[:, view_index*LATENT_NUM_CHANNELS:(view_index+1)*LATENT_NUM_CHANNELS] = result
151
+
152
+ images = noise[:, view_index*SAMPLE_NUM_CHANNELS:(view_index+1)*SAMPLE_NUM_CHANNELS]
153
+ result = vae.encode(images).latent_dist.sample()
154
+ latent_noises[:, view_index*LATENT_NUM_CHANNELS:(view_index+1)*LATENT_NUM_CHANNELS] = result
155
+
156
+ unet_results = unet(latent_images, timesteps, labels[batch_index:batch_end])[0]
157
+ unet_results = unet_results.to(dtype=torch.float16)
158
+
159
+ loss = F.mse_loss(unet_results, latent_noises)
160
+ accelerator.backward(loss)
161
+ optimizer.step()
162
+ lr_scheduler.step()
163
+ optimizer.zero_grad()
164
 
165
  if (epoch + 1) % save_model_interval == 0:
166
+ model = RCTDiffusionPipeline(accelerator.unwrap_model(unet), scheduler, vae)
 
167
  save_and_test(model, epoch)
 
 
168
  progress_bar.update(1)
169
 
170
 
171
  if __name__ == '__main__':
172
+ train_model(1, save_model_interval=1)