frutiemax commited on
Commit
a4c8091
·
1 Parent(s): f6f5f48

Push unet and scheduler outside of pipeline

Browse files
Files changed (2) hide show
  1. rct_diffusion_pipeline.py +14 -10
  2. train_model.py +10 -5
rct_diffusion_pipeline.py CHANGED
@@ -11,6 +11,12 @@ import pandas as pd
11
  from tqdm.auto import tqdm
12
 
13
  class RCTDiffusionPipeline(DiffusionPipeline):
 
 
 
 
 
 
14
  def __init__(self):
15
  super().__init__()
16
 
@@ -21,16 +27,14 @@ class RCTDiffusionPipeline(DiffusionPipeline):
21
  self.color3_dict = {}
22
  self.load_dictionaries_from_dataset()
23
 
24
- self.scheduler = DDPMScheduler()
25
-
26
- # the number of hidden features is dependant on the loaded dictionaries!
27
- hidden_dim = self.get_class_labels_size()
28
- self.unet = UNet2DConditionModel(sample_size=256, in_channels=12, out_channels=12, \
29
- down_block_types=('CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D'),\
30
- up_block_types=('UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D'), cross_attention_dim=160,
31
- block_out_channels=(64, 128, 256), norm_num_groups=32)
32
-
33
- self.unet.to(dtype=torch.float16)
34
 
35
  def load_dictionaries_from_dataset(self):
36
  dataset = load_dataset('frutiemax/rct_dataset')
 
11
  from tqdm.auto import tqdm
12
 
13
  class RCTDiffusionPipeline(DiffusionPipeline):
14
+ def get_default_unet(hidden_dim):
15
+ return UNet2DConditionModel(sample_size=256, in_channels=12, out_channels=12, \
16
+ down_block_types=('CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D'),\
17
+ up_block_types=('UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D'), cross_attention_dim=hidden_dim,
18
+ block_out_channels=(64, 128, 256), norm_num_groups=32)
19
+
20
  def __init__(self):
21
  super().__init__()
22
 
 
27
  self.color3_dict = {}
28
  self.load_dictionaries_from_dataset()
29
 
30
+ self.scheduler = None
31
+ self.unet = None
32
+
33
+ def set_unet(self, unet):
34
+ self.unet = unet
35
+
36
+ def set_scheduler(self, scheduler):
37
+ self.scheduler = scheduler
 
 
38
 
39
  def load_dictionaries_from_dataset(self):
40
  dataset = load_dataset('frutiemax/rct_dataset')
train_model.py CHANGED
@@ -10,6 +10,7 @@ import torch.nn.functional as F
10
  from diffusers.optimization import get_cosine_schedule_with_warmup
11
  from tqdm.auto import tqdm
12
  from accelerate import Accelerator
 
13
 
14
  def save_and_test(pipeline, epoch):
15
  outputs = pipeline([[('aleppo pine tree', 1.0)]], [[('dark green', 1.0)]])
@@ -92,7 +93,8 @@ def train_model(batch_size=4, epochs=100, scheduler_num_timesteps=20, save_model
92
  del colors3
93
  del dataset
94
 
95
- optimizer = torch.optim.Adam(model.unet.parameters(), lr=start_learning_rate)
 
96
  lr_scheduler = get_cosine_schedule_with_warmup(
97
  optimizer=optimizer,
98
  num_warmup_steps=lr_warmup_steps,
@@ -108,10 +110,11 @@ def train_model(batch_size=4, epochs=100, scheduler_num_timesteps=20, save_model
108
  project_dir='logs',
109
  )
110
 
111
- unet, scheduler, optimizer, lr_scheduler = accelerator.prepare(model.unet, model.scheduler, \
 
112
  optimizer, lr_scheduler)
113
-
114
- del model
115
  scheduler.set_timesteps(scheduler_num_timesteps)
116
 
117
  for epoch in range(epochs):
@@ -128,7 +131,7 @@ def train_model(batch_size=4, epochs=100, scheduler_num_timesteps=20, save_model
128
  noisy_images = scheduler.add_noise(clean_images, noise, timesteps).to(device='cuda', dtype=torch.float16)
129
 
130
  with accelerator.accumulate(unet):
131
- noise_pred = unet(noisy_images, timesteps.to(device='cuda', dtype=torch.float16), class_labels[batch_index:batch_end].to(device='cuda',dtype=torch.float16), return_dict=False)[0]
132
 
133
  #noise_pred = noise_pred.to(device='cuda', dtype=torch.float16)
134
  loss = F.mse_loss(noise_pred, noise.to('cuda', dtype=torch.float16))
@@ -143,6 +146,8 @@ def train_model(batch_size=4, epochs=100, scheduler_num_timesteps=20, save_model
143
  model.unet = accelerator.unwrap_model(unet)
144
  model.scheduler = scheduler
145
  save_and_test(model, epoch)
 
 
146
  progress_bar.update(1)
147
 
148
 
 
10
  from diffusers.optimization import get_cosine_schedule_with_warmup
11
  from tqdm.auto import tqdm
12
  from accelerate import Accelerator
13
+ from diffusers import DDPMScheduler, UNet2DConditionModel
14
 
15
  def save_and_test(pipeline, epoch):
16
  outputs = pipeline([[('aleppo pine tree', 1.0)]], [[('dark green', 1.0)]])
 
93
  del colors3
94
  del dataset
95
 
96
+ unet = RCTDiffusionPipeline.get_default_unet(160)
97
+ optimizer = torch.optim.Adam(unet.parameters(), lr=start_learning_rate)
98
  lr_scheduler = get_cosine_schedule_with_warmup(
99
  optimizer=optimizer,
100
  num_warmup_steps=lr_warmup_steps,
 
110
  project_dir='logs',
111
  )
112
 
113
+ scheduler = DDPMScheduler(scheduler_num_timesteps)
114
+ unet, scheduler, optimizer, lr_scheduler = accelerator.prepare(unet, scheduler, \
115
  optimizer, lr_scheduler)
116
+
117
+ unet = unet.to(dtype=torch.float16)
118
  scheduler.set_timesteps(scheduler_num_timesteps)
119
 
120
  for epoch in range(epochs):
 
131
  noisy_images = scheduler.add_noise(clean_images, noise, timesteps).to(device='cuda', dtype=torch.float16)
132
 
133
  with accelerator.accumulate(unet):
134
+ noise_pred = unet(noisy_images, timesteps.to(device='cuda'), class_labels[batch_index:batch_end].to(device='cuda',dtype=torch.float16), return_dict=False)[0]
135
 
136
  #noise_pred = noise_pred.to(device='cuda', dtype=torch.float16)
137
  loss = F.mse_loss(noise_pred, noise.to('cuda', dtype=torch.float16))
 
146
  model.unet = accelerator.unwrap_model(unet)
147
  model.scheduler = scheduler
148
  save_and_test(model, epoch)
149
+ del model.unet
150
+ del model.scheduler
151
  progress_bar.update(1)
152
 
153