Use accelerate
Browse files- rct_diffusion_pipeline.py +1 -1
- train_model.py +34 -16
rct_diffusion_pipeline.py
CHANGED
@@ -30,7 +30,7 @@ class RCTDiffusionPipeline(DiffusionPipeline):
|
|
30 |
up_block_types=('UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D'), cross_attention_dim=160,
|
31 |
block_out_channels=(64, 128, 256), norm_num_groups=32)
|
32 |
|
33 |
-
self.unet.to(
|
34 |
|
35 |
def load_dictionaries_from_dataset(self):
|
36 |
dataset = load_dataset('frutiemax/rct_dataset')
|
|
|
30 |
up_block_types=('UpBlock2D', 'CrossAttnUpBlock2D', 'CrossAttnUpBlock2D'), cross_attention_dim=160,
|
31 |
block_out_channels=(64, 128, 256), norm_num_groups=32)
|
32 |
|
33 |
+
self.unet.to(dtype=torch.float16)
|
34 |
|
35 |
def load_dictionaries_from_dataset(self):
|
36 |
dataset = load_dataset('frutiemax/rct_dataset')
|
train_model.py
CHANGED
@@ -9,6 +9,7 @@ import torchvision.transforms as T
|
|
9 |
import torch.nn.functional as F
|
10 |
from diffusers.optimization import get_cosine_schedule_with_warmup
|
11 |
from tqdm.auto import tqdm
|
|
|
12 |
|
13 |
def save_and_test(pipeline, epoch):
|
14 |
outputs = pipeline([[('aleppo pine tree', 1.0)]], [[('dark green', 1.0)]])
|
@@ -19,7 +20,7 @@ def save_and_test(pipeline, epoch):
|
|
19 |
model_file = f'rct_foliage_{epoch}.pth'
|
20 |
pipeline.save_pretrained(model_file)
|
21 |
|
22 |
-
def train_model(batch_size=4, epochs=100, save_model_interval=10, start_learning_rate=1e-3, lr_warmup_steps=500):
|
23 |
dataset = load_dataset('frutiemax/rct_dataset')
|
24 |
dataset = dataset['train']
|
25 |
|
@@ -50,12 +51,12 @@ def train_model(batch_size=4, epochs=100, save_model_interval=10, start_learning
|
|
50 |
del views
|
51 |
|
52 |
# convert those views in tensors
|
53 |
-
targets = torch.Tensor(size=(num_images, 4, 3, 256, 256))
|
54 |
pillow_to_tensor = T.ToTensor()
|
55 |
|
56 |
for image_index in range(num_images):
|
57 |
for view_index in range(4):
|
58 |
-
targets[image_index, view_index] = pillow_to_tensor(image_views[view_index][image_index])
|
59 |
del image_views
|
60 |
del entries
|
61 |
|
@@ -100,33 +101,50 @@ def train_model(batch_size=4, epochs=100, save_model_interval=10, start_learning
|
|
100 |
|
101 |
# lets train for 100 epoch for each sprite in the dataset with a random noise level
|
102 |
progress_bar = tqdm(total=epochs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
for epoch in range(epochs):
|
105 |
# create a noisy version of each sprite
|
106 |
for batch_index in range(0, num_images, batch_size):
|
107 |
progress_bar.set_description(f'epoch={epoch}, batch_index={batch_index}')
|
108 |
batch_end = np.minimum(num_images, batch_index + batch_size)
|
109 |
-
clean_images = targets[batch_index:batch_end]
|
110 |
clean_images = torch.reshape(clean_images, ((batch_end - batch_index), 12, 256, 256))
|
111 |
|
112 |
-
noise = torch.randn(clean_images.shape
|
113 |
-
timesteps = torch.randint(0,
|
114 |
-
timesteps = timesteps.to(dtype=torch.int, device='cuda')
|
115 |
-
noisy_images =
|
116 |
-
|
|
|
|
|
117 |
|
118 |
-
|
119 |
-
|
120 |
-
|
|
|
121 |
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
|
126 |
if (epoch + 1) % save_model_interval == 0:
|
|
|
|
|
127 |
save_and_test(model, epoch)
|
128 |
progress_bar.update(1)
|
129 |
|
130 |
|
131 |
if __name__ == '__main__':
|
132 |
-
train_model(
|
|
|
9 |
import torch.nn.functional as F
|
10 |
from diffusers.optimization import get_cosine_schedule_with_warmup
|
11 |
from tqdm.auto import tqdm
|
12 |
+
from accelerate import Accelerator
|
13 |
|
14 |
def save_and_test(pipeline, epoch):
|
15 |
outputs = pipeline([[('aleppo pine tree', 1.0)]], [[('dark green', 1.0)]])
|
|
|
20 |
model_file = f'rct_foliage_{epoch}.pth'
|
21 |
pipeline.save_pretrained(model_file)
|
22 |
|
23 |
+
def train_model(batch_size=4, epochs=100, scheduler_num_timesteps=20, save_model_interval=10, start_learning_rate=1e-3, lr_warmup_steps=500):
|
24 |
dataset = load_dataset('frutiemax/rct_dataset')
|
25 |
dataset = dataset['train']
|
26 |
|
|
|
51 |
del views
|
52 |
|
53 |
# convert those views in tensors
|
54 |
+
targets = torch.Tensor(size=(num_images, 4, 3, 256, 256)).to(dtype=torch.float16)
|
55 |
pillow_to_tensor = T.ToTensor()
|
56 |
|
57 |
for image_index in range(num_images):
|
58 |
for view_index in range(4):
|
59 |
+
targets[image_index, view_index] = pillow_to_tensor(image_views[view_index][image_index]).to(dtype=torch.float16)
|
60 |
del image_views
|
61 |
del entries
|
62 |
|
|
|
101 |
|
102 |
# lets train for 100 epoch for each sprite in the dataset with a random noise level
|
103 |
progress_bar = tqdm(total=epochs)
|
104 |
+
accelerator = Accelerator(
|
105 |
+
mixed_precision='fp16',
|
106 |
+
gradient_accumulation_steps=1,
|
107 |
+
log_with="tensorboard",
|
108 |
+
project_dir='logs',
|
109 |
+
)
|
110 |
+
|
111 |
+
unet, scheduler, optimizer, lr_scheduler = accelerator.prepare(model.unet, model.scheduler, \
|
112 |
+
optimizer, lr_scheduler)
|
113 |
+
|
114 |
+
del model
|
115 |
+
scheduler.set_timesteps(scheduler_num_timesteps)
|
116 |
|
117 |
for epoch in range(epochs):
|
118 |
# create a noisy version of each sprite
|
119 |
for batch_index in range(0, num_images, batch_size):
|
120 |
progress_bar.set_description(f'epoch={epoch}, batch_index={batch_index}')
|
121 |
batch_end = np.minimum(num_images, batch_index + batch_size)
|
122 |
+
clean_images = targets[batch_index:batch_end]
|
123 |
clean_images = torch.reshape(clean_images, ((batch_end - batch_index), 12, 256, 256))
|
124 |
|
125 |
+
noise = torch.randn(clean_images.shape, dtype=torch.float16)
|
126 |
+
timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_end - batch_index, ))
|
127 |
+
#timesteps = timesteps.to(dtype=torch.int, device='cuda')
|
128 |
+
noisy_images = scheduler.add_noise(clean_images, noise, timesteps).to(device='cuda', dtype=torch.float16)
|
129 |
+
|
130 |
+
with accelerator.accumulate(unet):
|
131 |
+
noise_pred = unet(noisy_images, timesteps.to(device='cuda', dtype=torch.float16), class_labels[batch_index:batch_end].to(device='cuda',dtype=torch.float16), return_dict=False)[0]
|
132 |
|
133 |
+
#noise_pred = noise_pred.to(device='cuda', dtype=torch.float16)
|
134 |
+
loss = F.mse_loss(noise_pred, noise.to('cuda', dtype=torch.float16))
|
135 |
+
accelerator.backward(loss)
|
136 |
+
accelerator.clip_grad_norm_(unet.parameters(), 1.0)
|
137 |
|
138 |
+
optimizer.step()
|
139 |
+
lr_scheduler.step()
|
140 |
+
optimizer.zero_grad()
|
141 |
|
142 |
if (epoch + 1) % save_model_interval == 0:
|
143 |
+
model.unet = accelerator.unwrap_model(unet)
|
144 |
+
model.scheduler = scheduler
|
145 |
save_and_test(model, epoch)
|
146 |
progress_bar.update(1)
|
147 |
|
148 |
|
149 |
if __name__ == '__main__':
|
150 |
+
train_model(4)
|