frutiemax commited on
Commit
f6f5f48
1 Parent(s): 1618027

Use accelerate

Browse files
Files changed (2) hide show
  1. rct_diffusion_pipeline.py +1 -1
  2. train_model.py +34 -16
rct_diffusion_pipeline.py CHANGED
@@ -30,7 +30,7 @@ class RCTDiffusionPipeline(DiffusionPipeline):
30
  up_block_types=('UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D'), cross_attention_dim=160,
31
  block_out_channels=(64, 128, 256), norm_num_groups=32)
32
 
33
- self.unet.to(device='cuda', dtype=torch.float16)
34
 
35
  def load_dictionaries_from_dataset(self):
36
  dataset = load_dataset('frutiemax/rct_dataset')
 
30
  up_block_types=('UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D'), cross_attention_dim=160,
31
  block_out_channels=(64, 128, 256), norm_num_groups=32)
32
 
33
+ self.unet.to(dtype=torch.float16)
34
 
35
  def load_dictionaries_from_dataset(self):
36
  dataset = load_dataset('frutiemax/rct_dataset')
train_model.py CHANGED
@@ -9,6 +9,7 @@ import torchvision.transforms as T
9
  import torch.nn.functional as F
10
  from diffusers.optimization import get_cosine_schedule_with_warmup
11
  from tqdm.auto import tqdm
 
12
 
13
  def save_and_test(pipeline, epoch):
14
  outputs = pipeline([[('aleppo pine tree', 1.0)]], [[('dark green', 1.0)]])
@@ -19,7 +20,7 @@ def save_and_test(pipeline, epoch):
19
  model_file = f'rct_foliage_{epoch}.pth'
20
  pipeline.save_pretrained(model_file)
21
 
22
- def train_model(batch_size=4, epochs=100, save_model_interval=10, start_learning_rate=1e-3, lr_warmup_steps=500):
23
  dataset = load_dataset('frutiemax/rct_dataset')
24
  dataset = dataset['train']
25
 
@@ -50,12 +51,12 @@ def train_model(batch_size=4, epochs=100, save_model_interval=10, start_learning
50
  del views
51
 
52
  # convert those views in tensors
53
- targets = torch.Tensor(size=(num_images, 4, 3, 256, 256))
54
  pillow_to_tensor = T.ToTensor()
55
 
56
  for image_index in range(num_images):
57
  for view_index in range(4):
58
- targets[image_index, view_index] = pillow_to_tensor(image_views[view_index][image_index])
59
  del image_views
60
  del entries
61
 
@@ -100,33 +101,50 @@ def train_model(batch_size=4, epochs=100, save_model_interval=10, start_learning
100
 
101
  # lets train for 100 epoch for each sprite in the dataset with a random noise level
102
  progress_bar = tqdm(total=epochs)
 
 
 
 
 
 
 
 
 
 
 
 
103
 
104
  for epoch in range(epochs):
105
  # create a noisy version of each sprite
106
  for batch_index in range(0, num_images, batch_size):
107
  progress_bar.set_description(f'epoch={epoch}, batch_index={batch_index}')
108
  batch_end = np.minimum(num_images, batch_index + batch_size)
109
- clean_images = targets[batch_index:batch_end].to(device='cuda', dtype=torch.float16)
110
  clean_images = torch.reshape(clean_images, ((batch_end - batch_index), 12, 256, 256))
111
 
112
- noise = torch.randn(clean_images.shape).to(device='cuda', dtype=torch.float16)
113
- timesteps = torch.randint(0, model.scheduler.config.num_train_timesteps, (batch_end - batch_index, ))
114
- timesteps = timesteps.to(dtype=torch.int, device='cuda')
115
- noisy_images = model.scheduler.add_noise(clean_images, noise, timesteps).to(device='cuda', dtype=torch.float16)
116
- noise_pred = model.unet(noisy_images, timesteps, class_labels[batch_index:batch_end].to(device='cuda',dtype=torch.float16), return_dict=False)[0]
 
 
117
 
118
- noise_pred = noise_pred.to(device='cuda', dtype=torch.float16)
119
- loss = F.mse_loss(noise_pred, noise).to(device='cuda', dtype=torch.float16)
120
- loss.backward()
 
121
 
122
- optimizer.step()
123
- lr_scheduler.step()
124
- optimizer.zero_grad()
125
 
126
  if (epoch + 1) % save_model_interval == 0:
 
 
127
  save_and_test(model, epoch)
128
  progress_bar.update(1)
129
 
130
 
131
  if __name__ == '__main__':
132
- train_model(8)
 
9
  import torch.nn.functional as F
10
  from diffusers.optimization import get_cosine_schedule_with_warmup
11
  from tqdm.auto import tqdm
12
+ from accelerate import Accelerator
13
 
14
  def save_and_test(pipeline, epoch):
15
  outputs = pipeline([[('aleppo pine tree', 1.0)]], [[('dark green', 1.0)]])
 
20
  model_file = f'rct_foliage_{epoch}.pth'
21
  pipeline.save_pretrained(model_file)
22
 
23
+ def train_model(batch_size=4, epochs=100, scheduler_num_timesteps=20, save_model_interval=10, start_learning_rate=1e-3, lr_warmup_steps=500):
24
  dataset = load_dataset('frutiemax/rct_dataset')
25
  dataset = dataset['train']
26
 
 
51
  del views
52
 
53
  # convert those views in tensors
54
+ targets = torch.Tensor(size=(num_images, 4, 3, 256, 256)).to(dtype=torch.float16)
55
  pillow_to_tensor = T.ToTensor()
56
 
57
  for image_index in range(num_images):
58
  for view_index in range(4):
59
+ targets[image_index, view_index] = pillow_to_tensor(image_views[view_index][image_index]).to(dtype=torch.float16)
60
  del image_views
61
  del entries
62
 
 
101
 
102
  # lets train for 100 epoch for each sprite in the dataset with a random noise level
103
  progress_bar = tqdm(total=epochs)
104
+ accelerator = Accelerator(
105
+ mixed_precision='fp16',
106
+ gradient_accumulation_steps=1,
107
+ log_with="tensorboard",
108
+ project_dir='logs',
109
+ )
110
+
111
+ unet, scheduler, optimizer, lr_scheduler = accelerator.prepare(model.unet, model.scheduler, \
112
+ optimizer, lr_scheduler)
113
+
114
+ del model
115
+ scheduler.set_timesteps(scheduler_num_timesteps)
116
 
117
  for epoch in range(epochs):
118
  # create a noisy version of each sprite
119
  for batch_index in range(0, num_images, batch_size):
120
  progress_bar.set_description(f'epoch={epoch}, batch_index={batch_index}')
121
  batch_end = np.minimum(num_images, batch_index + batch_size)
122
+ clean_images = targets[batch_index:batch_end]
123
  clean_images = torch.reshape(clean_images, ((batch_end - batch_index), 12, 256, 256))
124
 
125
+ noise = torch.randn(clean_images.shape, dtype=torch.float16)
126
+ timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_end - batch_index, ))
127
+ #timesteps = timesteps.to(dtype=torch.int, device='cuda')
128
+ noisy_images = scheduler.add_noise(clean_images, noise, timesteps).to(device='cuda', dtype=torch.float16)
129
+
130
+ with accelerator.accumulate(unet):
131
+ noise_pred = unet(noisy_images, timesteps.to(device='cuda', dtype=torch.float16), class_labels[batch_index:batch_end].to(device='cuda',dtype=torch.float16), return_dict=False)[0]
132
 
133
+ #noise_pred = noise_pred.to(device='cuda', dtype=torch.float16)
134
+ loss = F.mse_loss(noise_pred, noise.to('cuda', dtype=torch.float16))
135
+ accelerator.backward(loss)
136
+ accelerator.clip_grad_norm_(unet.parameters(), 1.0)
137
 
138
+ optimizer.step()
139
+ lr_scheduler.step()
140
+ optimizer.zero_grad()
141
 
142
  if (epoch + 1) % save_model_interval == 0:
143
+ model.unet = accelerator.unwrap_model(unet)
144
+ model.scheduler = scheduler
145
  save_and_test(model, epoch)
146
  progress_bar.update(1)
147
 
148
 
149
  if __name__ == '__main__':
150
+ train_model(4)