Push unet and scheduler outside of pipeline
Browse files- rct_diffusion_pipeline.py +14 -10
- train_model.py +10 -5
rct_diffusion_pipeline.py
CHANGED
@@ -11,6 +11,12 @@ import pandas as pd
|
|
11 |
from tqdm.auto import tqdm
|
12 |
|
13 |
class RCTDiffusionPipeline(DiffusionPipeline):
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
def __init__(self):
|
15 |
super().__init__()
|
16 |
|
@@ -21,16 +27,14 @@ class RCTDiffusionPipeline(DiffusionPipeline):
|
|
21 |
self.color3_dict = {}
|
22 |
self.load_dictionaries_from_dataset()
|
23 |
|
24 |
-
self.scheduler =
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
self.unet =
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
self.unet.to(dtype=torch.float16)
|
34 |
|
35 |
def load_dictionaries_from_dataset(self):
|
36 |
dataset = load_dataset('frutiemax/rct_dataset')
|
|
|
11 |
from tqdm.auto import tqdm
|
12 |
|
13 |
class RCTDiffusionPipeline(DiffusionPipeline):
|
14 |
+
def get_default_unet(hidden_dim):
|
15 |
+
return UNet2DConditionModel(sample_size=256, in_channels=12, out_channels=12, \
|
16 |
+
down_block_types=('CrossAttnDownBlock2D', 'CrossAttnDownBlock2D', 'DownBlock2D'),\
|
17 |
+
up_block_types=('UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D'), cross_attention_dim=hidden_dim,
|
18 |
+
block_out_channels=(64, 128, 256), norm_num_groups=32)
|
19 |
+
|
20 |
def __init__(self):
|
21 |
super().__init__()
|
22 |
|
|
|
27 |
self.color3_dict = {}
|
28 |
self.load_dictionaries_from_dataset()
|
29 |
|
30 |
+
self.scheduler = None
|
31 |
+
self.unet = None
|
32 |
+
|
33 |
+
def set_unet(self, unet):
|
34 |
+
self.unet = unet
|
35 |
+
|
36 |
+
def set_scheduler(self, scheduler):
|
37 |
+
self.scheduler = scheduler
|
|
|
|
|
38 |
|
39 |
def load_dictionaries_from_dataset(self):
|
40 |
dataset = load_dataset('frutiemax/rct_dataset')
|
train_model.py
CHANGED
@@ -10,6 +10,7 @@ import torch.nn.functional as F
|
|
10 |
from diffusers.optimization import get_cosine_schedule_with_warmup
|
11 |
from tqdm.auto import tqdm
|
12 |
from accelerate import Accelerator
|
|
|
13 |
|
14 |
def save_and_test(pipeline, epoch):
|
15 |
outputs = pipeline([[('aleppo pine tree', 1.0)]], [[('dark green', 1.0)]])
|
@@ -92,7 +93,8 @@ def train_model(batch_size=4, epochs=100, scheduler_num_timesteps=20, save_model
|
|
92 |
del colors3
|
93 |
del dataset
|
94 |
|
95 |
-
|
|
|
96 |
lr_scheduler = get_cosine_schedule_with_warmup(
|
97 |
optimizer=optimizer,
|
98 |
num_warmup_steps=lr_warmup_steps,
|
@@ -108,10 +110,11 @@ def train_model(batch_size=4, epochs=100, scheduler_num_timesteps=20, save_model
|
|
108 |
project_dir='logs',
|
109 |
)
|
110 |
|
111 |
-
|
|
|
112 |
optimizer, lr_scheduler)
|
113 |
-
|
114 |
-
|
115 |
scheduler.set_timesteps(scheduler_num_timesteps)
|
116 |
|
117 |
for epoch in range(epochs):
|
@@ -128,7 +131,7 @@ def train_model(batch_size=4, epochs=100, scheduler_num_timesteps=20, save_model
|
|
128 |
noisy_images = scheduler.add_noise(clean_images, noise, timesteps).to(device='cuda', dtype=torch.float16)
|
129 |
|
130 |
with accelerator.accumulate(unet):
|
131 |
-
noise_pred = unet(noisy_images, timesteps.to(device='cuda'
|
132 |
|
133 |
#noise_pred = noise_pred.to(device='cuda', dtype=torch.float16)
|
134 |
loss = F.mse_loss(noise_pred, noise.to('cuda', dtype=torch.float16))
|
@@ -143,6 +146,8 @@ def train_model(batch_size=4, epochs=100, scheduler_num_timesteps=20, save_model
|
|
143 |
model.unet = accelerator.unwrap_model(unet)
|
144 |
model.scheduler = scheduler
|
145 |
save_and_test(model, epoch)
|
|
|
|
|
146 |
progress_bar.update(1)
|
147 |
|
148 |
|
|
|
10 |
from diffusers.optimization import get_cosine_schedule_with_warmup
|
11 |
from tqdm.auto import tqdm
|
12 |
from accelerate import Accelerator
|
13 |
+
from diffusers import DDPMScheduler, UNet2DConditionModel
|
14 |
|
15 |
def save_and_test(pipeline, epoch):
|
16 |
outputs = pipeline([[('aleppo pine tree', 1.0)]], [[('dark green', 1.0)]])
|
|
|
93 |
del colors3
|
94 |
del dataset
|
95 |
|
96 |
+
unet = RCTDiffusionPipeline.get_default_unet(160)
|
97 |
+
optimizer = torch.optim.Adam(unet.parameters(), lr=start_learning_rate)
|
98 |
lr_scheduler = get_cosine_schedule_with_warmup(
|
99 |
optimizer=optimizer,
|
100 |
num_warmup_steps=lr_warmup_steps,
|
|
|
110 |
project_dir='logs',
|
111 |
)
|
112 |
|
113 |
+
scheduler = DDPMScheduler(scheduler_num_timesteps)
|
114 |
+
unet, scheduler, optimizer, lr_scheduler = accelerator.prepare(unet, scheduler, \
|
115 |
optimizer, lr_scheduler)
|
116 |
+
|
117 |
+
unet = unet.to(dtype=torch.float16)
|
118 |
scheduler.set_timesteps(scheduler_num_timesteps)
|
119 |
|
120 |
for epoch in range(epochs):
|
|
|
131 |
noisy_images = scheduler.add_noise(clean_images, noise, timesteps).to(device='cuda', dtype=torch.float16)
|
132 |
|
133 |
with accelerator.accumulate(unet):
|
134 |
+
noise_pred = unet(noisy_images, timesteps.to(device='cuda'), class_labels[batch_index:batch_end].to(device='cuda',dtype=torch.float16), return_dict=False)[0]
|
135 |
|
136 |
#noise_pred = noise_pred.to(device='cuda', dtype=torch.float16)
|
137 |
loss = F.mse_loss(noise_pred, noise.to('cuda', dtype=torch.float16))
|
|
|
146 |
model.unet = accelerator.unwrap_model(unet)
|
147 |
model.scheduler = scheduler
|
148 |
save_and_test(model, epoch)
|
149 |
+
del model.unet
|
150 |
+
del model.scheduler
|
151 |
progress_bar.update(1)
|
152 |
|
153 |
|