Use float16 for inference and float32 for training
Browse files- rct_diffusion_pipeline.py +24 -27
- test_pipeline.py +11 -83
- train_model.py +60 -97
rct_diffusion_pipeline.py
CHANGED
@@ -30,7 +30,7 @@ class RCTDiffusionPipeline(DiffusionPipeline):
|
|
30 |
self.text_tokenizer = text_tokenizer
|
31 |
|
32 |
# channels for 1 image
|
33 |
-
self.num_channels = int(self.unet.config.in_channels
|
34 |
self.load_dictionaries_from_dataset()
|
35 |
self.register_modules(unet=unet, scheduler=scheduler, vae=vae, text_tokenizer=text_tokenizer, text_encoder=text_encoder)
|
36 |
|
@@ -171,13 +171,12 @@ class RCTDiffusionPipeline(DiffusionPipeline):
|
|
171 |
return self.pack_labels_to_tensor(batch_size, object_descriptions, colors1, colors2, colors3).to(device='cuda',dtype=torch.float16)
|
172 |
|
173 |
def generate_noise_batches(self, batch_size):
|
174 |
-
noise_batches = torch.Tensor(size=(batch_size,
|
175 |
for batch_index in range(batch_size):
|
176 |
-
|
177 |
-
|
178 |
-
noise_batches[batch_index, view_index] = noise
|
179 |
|
180 |
-
return torch.reshape(noise_batches, (batch_size,
|
181 |
|
182 |
def test_generate_embeddings(self, object_description, color1, color2, color3) -> torch.Tensor:
|
183 |
batch_size = len(object_description)
|
@@ -190,7 +189,7 @@ class RCTDiffusionPipeline(DiffusionPipeline):
|
|
190 |
with torch.no_grad():
|
191 |
embeddings[batch_index] = self.text_encoder(tokens.input_ids.to('cuda'))[0]
|
192 |
|
193 |
-
return embeddings
|
194 |
|
195 |
def generate_embeddings(self, object_description, color1, color2, color3) -> torch.Tensor:
|
196 |
batch_size = len(object_description)
|
@@ -244,11 +243,11 @@ class RCTDiffusionPipeline(DiffusionPipeline):
|
|
244 |
if res == False:
|
245 |
return None
|
246 |
embeddings = self.test_generate_embeddings(object_description, color1, color2, color3)
|
247 |
-
embeddings = embeddings.to('cuda')
|
248 |
|
249 |
# set the inference steps
|
250 |
self.scheduler.set_timesteps(num_inference_steps)
|
251 |
-
noise_batches = self.generate_noise_batches(batch_size)
|
252 |
|
253 |
# now call the model for the n interations
|
254 |
progress_bar = tqdm(total=num_inference_steps)
|
@@ -257,36 +256,34 @@ class RCTDiffusionPipeline(DiffusionPipeline):
|
|
257 |
progress_bar.set_description(f'Inference step {epoch}')
|
258 |
|
259 |
for batch_index in range(batch_size):
|
260 |
-
|
261 |
with torch.no_grad():
|
262 |
-
noise_residual = self.unet(
|
263 |
-
previous_noisy_sample = self.scheduler.step(noise_residual, t,
|
264 |
noise_batches[batch_index] = previous_noisy_sample
|
265 |
progress_bar.update(1)
|
266 |
epoch = epoch + 1
|
267 |
|
268 |
# reshape the data so we get back 4 RGB images
|
269 |
-
noise_batches = torch.reshape(noise_batches, (batch_size,
|
270 |
-
|
|
|
271 |
|
272 |
with torch.no_grad():
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
images[:, image_index] = result
|
277 |
|
278 |
# convert those tensors to PIL images
|
279 |
output_images = []
|
280 |
for batch_index in range(batch_size):
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
image.save(f'test{image_index}.png')
|
289 |
-
output_images.append(image)
|
290 |
|
291 |
# for now just return the images
|
292 |
return output_images
|
|
|
30 |
self.text_tokenizer = text_tokenizer
|
31 |
|
32 |
# channels for 1 image
|
33 |
+
self.num_channels = int(self.unet.config.in_channels)
|
34 |
self.load_dictionaries_from_dataset()
|
35 |
self.register_modules(unet=unet, scheduler=scheduler, vae=vae, text_tokenizer=text_tokenizer, text_encoder=text_encoder)
|
36 |
|
|
|
171 |
return self.pack_labels_to_tensor(batch_size, object_descriptions, colors1, colors2, colors3).to(device='cuda',dtype=torch.float16)
|
172 |
|
173 |
def generate_noise_batches(self, batch_size):
|
174 |
+
noise_batches = torch.Tensor(size=(batch_size, self.num_channels, self.latent_size, self.latent_size)).to(dtype=torch.float16, device='cuda')
|
175 |
for batch_index in range(batch_size):
|
176 |
+
noise = torch.randn(self.num_channels, self.latent_size, self.latent_size).to(dtype=torch.float16, device='cuda')
|
177 |
+
noise_batches[batch_index] = noise
|
|
|
178 |
|
179 |
+
return torch.reshape(noise_batches, (batch_size, self.num_channels, self.latent_size, self.latent_size)).to(dtype=torch.float16, device='cuda')
|
180 |
|
181 |
def test_generate_embeddings(self, object_description, color1, color2, color3) -> torch.Tensor:
|
182 |
batch_size = len(object_description)
|
|
|
189 |
with torch.no_grad():
|
190 |
embeddings[batch_index] = self.text_encoder(tokens.input_ids.to('cuda'))[0]
|
191 |
|
192 |
+
return embeddings
|
193 |
|
194 |
def generate_embeddings(self, object_description, color1, color2, color3) -> torch.Tensor:
|
195 |
batch_size = len(object_description)
|
|
|
243 |
if res == False:
|
244 |
return None
|
245 |
embeddings = self.test_generate_embeddings(object_description, color1, color2, color3)
|
246 |
+
embeddings = embeddings.to(device='cuda', dtype=torch.float16)
|
247 |
|
248 |
# set the inference steps
|
249 |
self.scheduler.set_timesteps(num_inference_steps)
|
250 |
+
noise_batches = self.generate_noise_batches(batch_size).to(dtype=torch.float16)
|
251 |
|
252 |
# now call the model for the n interations
|
253 |
progress_bar = tqdm(total=num_inference_steps)
|
|
|
256 |
progress_bar.set_description(f'Inference step {epoch}')
|
257 |
|
258 |
for batch_index in range(batch_size):
|
259 |
+
noise_batch = self.scheduler.scale_model_input(noise_batches, timestep=t)
|
260 |
with torch.no_grad():
|
261 |
+
noise_residual = self.unet(noise_batch, t, encoder_hidden_states=embeddings).sample
|
262 |
+
previous_noisy_sample = self.scheduler.step(noise_residual, t, noise_batch).prev_sample
|
263 |
noise_batches[batch_index] = previous_noisy_sample
|
264 |
progress_bar.update(1)
|
265 |
epoch = epoch + 1
|
266 |
|
267 |
# reshape the data so we get back 4 RGB images
|
268 |
+
noise_batches = torch.reshape(noise_batches, (batch_size, self.num_channels, self.latent_size, self.latent_size))
|
269 |
+
noise_batches = noise_batches.to('cuda')
|
270 |
+
images = torch.Tensor(size=(batch_size, 3, self.sample_size, self.sample_size)).to('cuda')
|
271 |
|
272 |
with torch.no_grad():
|
273 |
+
image = noise_batches
|
274 |
+
result = self.vae.decode(image).sample
|
275 |
+
images = result
|
|
|
276 |
|
277 |
# convert those tensors to PIL images
|
278 |
output_images = []
|
279 |
for batch_index in range(batch_size):
|
280 |
+
image = images[batch_index]
|
281 |
+
image = (image / 2 + 0.5).clamp(0, 1).squeeze()
|
282 |
+
image = (image.permute(1, 2, 0) * 255).to(torch.uint8).cpu().numpy()
|
283 |
+
image = (image * 255).round().astype("uint8")
|
284 |
+
image = Image.fromarray(image)
|
285 |
+
image.save(f'test{batch_index}.png')
|
286 |
+
output_images.append(image)
|
|
|
|
|
287 |
|
288 |
# for now just return the images
|
289 |
return output_images
|
test_pipeline.py
CHANGED
@@ -2,6 +2,7 @@ from rct_diffusion_pipeline import RCTDiffusionPipeline
|
|
2 |
from diffusers import UNet2DConditionModel, DDPMScheduler, AutoencoderKL
|
3 |
import torch
|
4 |
from transformers import CLIPTextModel, CLIPTokenizer
|
|
|
5 |
|
6 |
torch_device = "cuda"
|
7 |
|
@@ -22,11 +23,17 @@ test2 = tokenizer('dark green', padding="max_length", max_length=tokenizer.model
|
|
22 |
with torch.no_grad():
|
23 |
test2 = text_encoder(test2.input_ids.to('cuda'))[0]
|
24 |
|
25 |
-
unet = UNet2DConditionModel(sample_size=32, in_channels=
|
26 |
-
down_block_types=(
|
27 |
-
up_block_types=(
|
28 |
-
block_out_channels=(
|
29 |
unet = unet.to('cuda', dtype=torch.float16)
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
scheduler = DDPMScheduler(num_train_timesteps=20)
|
31 |
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", use_safetensors=True)
|
32 |
vae = vae.to('cuda', dtype=torch.float16)
|
@@ -34,83 +41,4 @@ vae = vae.to('cuda', dtype=torch.float16)
|
|
34 |
pipeline = RCTDiffusionPipeline(unet, scheduler, vae, tokenizer, text_encoder)
|
35 |
output = pipeline(['aleppo pine tree'], ['dark green'])
|
36 |
pipeline.save_pretrained('test')
|
37 |
-
|
38 |
-
# from PIL import Image
|
39 |
-
# import torch
|
40 |
-
# from transformers import CLIPTextModel, CLIPTokenizer
|
41 |
-
# from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler
|
42 |
-
|
43 |
-
# vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae", use_safetensors=True)
|
44 |
-
# tokenizer = CLIPTokenizer.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="tokenizer")
|
45 |
-
# text_encoder = CLIPTextModel.from_pretrained(
|
46 |
-
# "CompVis/stable-diffusion-v1-4", subfolder="text_encoder", use_safetensors=True
|
47 |
-
# )
|
48 |
-
# unet = UNet2DConditionModel.from_pretrained(
|
49 |
-
# "CompVis/stable-diffusion-v1-4", subfolder="unet", use_safetensors=True
|
50 |
-
# )
|
51 |
-
|
52 |
-
# from diffusers import UniPCMultistepScheduler
|
53 |
-
|
54 |
-
# scheduler = UniPCMultistepScheduler.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="scheduler")
|
55 |
-
# torch_device = "cuda"
|
56 |
-
# vae.to(torch_device)
|
57 |
-
# text_encoder.to(torch_device)
|
58 |
-
# unet.to(torch_device)
|
59 |
-
|
60 |
-
# prompt = ["a photograph of an astronaut riding a horse"]
|
61 |
-
# height = 512 # default height of Stable Diffusion
|
62 |
-
# width = 512 # default width of Stable Diffusion
|
63 |
-
# num_inference_steps = 25 # Number of denoising steps
|
64 |
-
# guidance_scale = 7.5 # Scale for classifier-free guidance
|
65 |
-
# generator = torch.manual_seed(0) # Seed generator to create the inital latent noise
|
66 |
-
# batch_size = len(prompt)
|
67 |
-
|
68 |
-
# text_input = tokenizer(
|
69 |
-
# prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt"
|
70 |
-
# )
|
71 |
-
|
72 |
-
# with torch.no_grad():
|
73 |
-
# text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
|
74 |
-
|
75 |
-
# text_input = tokenizer(
|
76 |
-
# prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt"
|
77 |
-
# )
|
78 |
-
|
79 |
-
# with torch.no_grad():
|
80 |
-
# text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
|
81 |
-
|
82 |
-
# max_length = text_input.input_ids.shape[-1]
|
83 |
-
# uncond_input = tokenizer([""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt")
|
84 |
-
# uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
|
85 |
-
|
86 |
-
# text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
87 |
-
|
88 |
-
# latents = torch.randn(
|
89 |
-
# (batch_size, unet.in_channels, height // 8, width // 8),
|
90 |
-
# generator=generator,
|
91 |
-
# )
|
92 |
-
# latents = latents.to(torch_device)
|
93 |
-
|
94 |
-
# latents = latents * scheduler.init_noise_sigma
|
95 |
-
|
96 |
-
# from tqdm.auto import tqdm
|
97 |
-
|
98 |
-
# scheduler.set_timesteps(num_inference_steps)
|
99 |
-
|
100 |
-
# for t in tqdm(scheduler.timesteps):
|
101 |
-
# # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
|
102 |
-
# latent_model_input = torch.cat([latents] * 2)
|
103 |
-
|
104 |
-
# latent_model_input = scheduler.scale_model_input(latent_model_input, timestep=t)
|
105 |
-
|
106 |
-
# # predict the noise residual
|
107 |
-
# with torch.no_grad():
|
108 |
-
# noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
|
109 |
-
|
110 |
-
# # perform guidance
|
111 |
-
# noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
112 |
-
# noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
113 |
-
|
114 |
-
# # compute the previous noisy sample x_t -> x_t-1
|
115 |
-
# latents = scheduler.step(noise_pred, t, latents).prev_sample
|
116 |
print('test')
|
|
|
2 |
from diffusers import UNet2DConditionModel, DDPMScheduler, AutoencoderKL
|
3 |
import torch
|
4 |
from transformers import CLIPTextModel, CLIPTokenizer
|
5 |
+
import torch.nn as nn
|
6 |
|
7 |
torch_device = "cuda"
|
8 |
|
|
|
23 |
with torch.no_grad():
|
24 |
test2 = text_encoder(test2.input_ids.to('cuda'))[0]
|
25 |
|
26 |
+
unet = UNet2DConditionModel(sample_size=32, in_channels=4, out_channels=4, \
|
27 |
+
down_block_types=("CrossAttnDownBlock2D","CrossAttnDownBlock2D","CrossAttnDownBlock2D", "DownBlock2D"),\
|
28 |
+
up_block_types=("UpBlock2D","CrossAttnUpBlock2D","CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), cross_attention_dim=768,
|
29 |
+
block_out_channels=(320, 640, 1280, 1280), norm_num_groups=32)
|
30 |
unet = unet.to('cuda', dtype=torch.float16)
|
31 |
+
|
32 |
+
# put float32 for the accumulation
|
33 |
+
for layer in unet.modules():
|
34 |
+
if isinstance(layer, nn.BatchNorm2d):
|
35 |
+
layer.float()
|
36 |
+
|
37 |
scheduler = DDPMScheduler(num_train_timesteps=20)
|
38 |
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", use_safetensors=True)
|
39 |
vae = vae.to('cuda', dtype=torch.float16)
|
|
|
41 |
pipeline = RCTDiffusionPipeline(unet, scheduler, vae, tokenizer, text_encoder)
|
42 |
output = pipeline(['aleppo pine tree'], ['dark green'])
|
43 |
pipeline.save_pretrained('test')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
print('test')
|
train_model.py
CHANGED
@@ -12,11 +12,13 @@ from tqdm.auto import tqdm
|
|
12 |
from accelerate import Accelerator
|
13 |
from diffusers import DDPMScheduler, UNet2DConditionModel, AutoencoderKL
|
14 |
from transformers import CLIPTextModel, CLIPTokenizer
|
|
|
15 |
|
16 |
SAMPLE_SIZE = 256
|
17 |
LATENT_SIZE = 32
|
18 |
SAMPLE_NUM_CHANNELS = 3
|
19 |
LATENT_NUM_CHANNELS = 4
|
|
|
20 |
|
21 |
def save_and_test(pipeline, epoch):
|
22 |
outputs = pipeline(['aleppo pine tree'], ['dark green'])
|
@@ -28,42 +30,22 @@ def save_and_test(pipeline, epoch):
|
|
28 |
pipeline.save_pretrained(model_file)
|
29 |
|
30 |
def convert_images(dataset):
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
# convert those images to 256x256 by cropping and scaling up the image
|
40 |
-
image_views = []
|
41 |
-
for view_index in range(4):
|
42 |
-
images = []
|
43 |
-
for entry in views[view_index]:
|
44 |
-
image = entry['image']
|
45 |
-
|
46 |
-
scale_factor = np.minimum(LATENT_SIZE / image.width, LATENT_SIZE / image.height)
|
47 |
-
image = Image.resize(image, size=(int(scale_factor * image.width), int(scale_factor * image.height)), resample=Resampling.NEAREST)
|
48 |
-
|
49 |
-
new_image = PIL.Image.new('RGBA', (LATENT_SIZE, LATENT_SIZE))
|
50 |
-
new_image.paste(image, box=(int((LATENT_SIZE - image.width)/2), int((LATENT_SIZE - image.height)/2)))
|
51 |
-
images.append(new_image)
|
52 |
-
image_views.append(images)
|
53 |
-
|
54 |
-
del views
|
55 |
|
56 |
-
|
57 |
-
|
58 |
-
|
|
|
|
|
59 |
|
60 |
-
|
61 |
-
|
62 |
-
targets[image_index, view_index] = pillow_to_tensor(image_views[view_index][image_index]).to(dtype=torch.float16)
|
63 |
-
del image_views
|
64 |
-
del entries
|
65 |
-
|
66 |
-
return torch.reshape(targets, (num_images, 4 * LATENT_NUM_CHANNELS, LATENT_SIZE, LATENT_SIZE))
|
67 |
|
68 |
def convert_labels(dataset, model, num_images):
|
69 |
# get the labels
|
@@ -97,115 +79,96 @@ def convert_labels(dataset, model, num_images):
|
|
97 |
del dataset
|
98 |
return class_labels.to(dtype=torch.float16, device='cuda')
|
99 |
|
100 |
-
def
|
101 |
-
|
102 |
-
|
|
|
|
|
|
|
|
|
103 |
|
104 |
-
|
|
|
|
|
105 |
num_images = int(dataset.num_rows / 4) if total_images == None else int(total_images / 4)
|
106 |
|
107 |
-
unet = UNet2DConditionModel(sample_size=LATENT_SIZE, in_channels=LATENT_NUM_CHANNELS
|
108 |
down_block_types=("CrossAttnDownBlock2D","CrossAttnDownBlock2D","CrossAttnDownBlock2D", "DownBlock2D"),\
|
109 |
up_block_types=("UpBlock2D","CrossAttnUpBlock2D","CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), cross_attention_dim=768,
|
110 |
-
block_out_channels=(
|
111 |
-
unet = unet.to(dtype=torch.
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
scheduler = DDPMScheduler(num_train_timesteps=scheduler_num_timesteps)
|
113 |
tokenizer = CLIPTokenizer.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="tokenizer")
|
114 |
text_encoder = CLIPTextModel.from_pretrained(
|
115 |
"CompVis/stable-diffusion-v1-4", subfolder="text_encoder", use_safetensors=True
|
116 |
).to('cuda')
|
117 |
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", use_safetensors=True)
|
118 |
-
vae = vae.to(dtype=torch.float16)
|
119 |
|
120 |
-
optimizer = torch.optim.
|
121 |
lr_scheduler = get_cosine_schedule_with_warmup(
|
122 |
optimizer=optimizer,
|
123 |
num_warmup_steps=lr_warmup_steps,
|
124 |
num_training_steps=num_images * epochs
|
125 |
)
|
126 |
model = RCTDiffusionPipeline(unet, scheduler, vae, tokenizer, text_encoder)
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
colors1 = dataset['color1']
|
131 |
-
colors2 = dataset['color2']
|
132 |
-
colors3 = dataset['color3']
|
133 |
-
|
134 |
-
# we only need 1 of the 4 views
|
135 |
-
object_descriptions = [object_descriptions[desc_index] for desc_index in range(0, len(object_descriptions), 4)]
|
136 |
-
colors1 = [colors1[desc_index] for desc_index in range(0, len(colors1), 4)]
|
137 |
-
colors2 = [colors2[desc_index] for desc_index in range(0, len(colors2), 4)]
|
138 |
-
colors3 = [colors3[desc_index] for desc_index in range(0, len(colors3), 4)]
|
139 |
-
#embeddings = model.generate_embeddings(object_descriptions, colors1, colors2, colors3)
|
140 |
-
embeddings = model.test_generate_embeddings(object_descriptions, colors1, colors2, colors3)
|
141 |
-
|
142 |
-
labels = convert_labels(dataset, model, num_images)
|
143 |
-
del model
|
144 |
-
|
145 |
-
if total_images != None:
|
146 |
-
targets = targets[:int(total_images/4)]
|
147 |
-
label_indices = [index for index in range(0, total_images, 4)]
|
148 |
-
labels = labels[label_indices]
|
149 |
|
150 |
# lets train for 100 epoch for each sprite in the dataset with a random noise level
|
151 |
progress_bar = tqdm(total=epochs)
|
152 |
-
accelerator = Accelerator(mixed_precision='fp16')
|
153 |
-
accelerator.clip_grad_norm_(unet.parameters(), 1.0)
|
154 |
-
unet, scheduler, lr_scheduler, vae = accelerator.prepare(unet, scheduler, lr_scheduler, vae)
|
155 |
|
156 |
loss_fn = torch.nn.MSELoss()
|
157 |
|
158 |
tensor_to_pillow = T.ToPILImage()
|
159 |
for epoch in range(epochs):
|
160 |
# create a noisy version of each sprite
|
161 |
-
for
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
|
|
166 |
|
167 |
-
noise = torch.randn(clean_images.shape, dtype=torch.
|
168 |
-
timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (
|
169 |
|
170 |
#timesteps = timesteps.to(dtype=torch.int, device='cuda')
|
171 |
noisy_images = scheduler.add_noise(clean_images, noise, timesteps)
|
172 |
-
|
173 |
-
# with accelerator.accumulate(unet):
|
174 |
-
# assert not torch.any(torch.isnan(timesteps))
|
175 |
-
|
176 |
-
# batch_embeddings = embeddings[batch_index:batch_end]
|
177 |
-
# batch_embeddings = batch_embeddings.to('cuda')
|
178 |
-
|
179 |
-
# optimizer.zero_grad()
|
180 |
-
# unet_results = unet(noisy_images, timesteps, batch_embeddings).sample
|
181 |
-
# unet_results = unet_results.to(dtype=torch.float16)
|
182 |
-
|
183 |
-
# loss = loss_fn(unet_results, noise)
|
184 |
-
# accelerator.backward(loss)
|
185 |
-
|
186 |
-
# optimizer.step()
|
187 |
-
# lr_scheduler.step()
|
188 |
-
# optimizer.zero_grad()
|
189 |
|
190 |
-
batch_embeddings = embeddings
|
191 |
batch_embeddings = batch_embeddings.to('cuda')
|
192 |
|
193 |
optimizer.zero_grad()
|
194 |
unet_results = unet(noisy_images, timesteps, batch_embeddings).sample
|
195 |
-
unet_results = unet_results.to(dtype=torch.float16)
|
196 |
loss = loss_fn(unet_results, noise)
|
197 |
loss.backward()
|
198 |
optimizer.step()
|
199 |
lr_scheduler.step()
|
200 |
optimizer.zero_grad()
|
201 |
|
202 |
-
progress_bar.set_description(f'epoch={epoch}, batch_index={
|
203 |
|
204 |
if (epoch + 1) % save_model_interval == 0:
|
205 |
-
|
|
|
|
|
206 |
save_and_test(model, epoch)
|
|
|
|
|
|
|
|
|
|
|
|
|
207 |
progress_bar.update(1)
|
208 |
|
209 |
|
210 |
if __name__ == '__main__':
|
211 |
-
train_model(1, total_images=4, save_model_interval=
|
|
|
12 |
from accelerate import Accelerator
|
13 |
from diffusers import DDPMScheduler, UNet2DConditionModel, AutoencoderKL
|
14 |
from transformers import CLIPTextModel, CLIPTokenizer
|
15 |
+
import torch.nn as nn
|
16 |
|
17 |
SAMPLE_SIZE = 256
|
18 |
LATENT_SIZE = 32
|
19 |
SAMPLE_NUM_CHANNELS = 3
|
20 |
LATENT_NUM_CHANNELS = 4
|
21 |
+
from torchvision import transforms
|
22 |
|
23 |
def save_and_test(pipeline, epoch):
|
24 |
outputs = pipeline(['aleppo pine tree'], ['dark green'])
|
|
|
30 |
pipeline.save_pretrained(model_file)
|
31 |
|
32 |
def convert_images(dataset):
|
33 |
+
preprocess = transforms.Compose(
|
34 |
+
[
|
35 |
+
transforms.Resize((LATENT_SIZE, LATENT_SIZE)),
|
36 |
+
transforms.ToTensor(),
|
37 |
+
transforms.Normalize([0.5], [0.5]),
|
38 |
+
]
|
39 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
+
images = [preprocess(image.convert("RGBA")) for image in dataset["image"]]
|
42 |
+
object_descriptions = [obj_desc for obj_desc in dataset["object_description"]]
|
43 |
+
colors1 = [color1 for color1 in dataset['color1']]
|
44 |
+
colors2 = [color1 for color1 in dataset['color2']]
|
45 |
+
colors3 = [color1 for color1 in dataset['color3']]
|
46 |
|
47 |
+
return {"image": images, 'object_description':object_descriptions, 'color1':colors1, \
|
48 |
+
'color2':colors2, 'color3':colors3}
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
def convert_labels(dataset, model, num_images):
|
51 |
# get the labels
|
|
|
79 |
del dataset
|
80 |
return class_labels.to(dtype=torch.float16, device='cuda')
|
81 |
|
82 |
+
def create_embeddings(dataset, model):
|
83 |
+
object_descriptions = dataset['object_description']
|
84 |
+
colors1 = dataset['color1']
|
85 |
+
colors2 = dataset['color2']
|
86 |
+
colors3 = dataset['color3']
|
87 |
+
return model.test_generate_embeddings(object_descriptions, colors1, colors2, colors3)
|
88 |
+
|
89 |
|
90 |
+
def train_model(batch_size=4, total_images=-1, epochs=100, scheduler_num_timesteps=20, save_model_interval=10, start_learning_rate=1e-3, lr_warmup_steps=1):
|
91 |
+
dataset = load_dataset('frutiemax/rct_dataset', split=f'train[0:{total_images}]')
|
92 |
+
dataset.set_transform(convert_images)
|
93 |
num_images = int(dataset.num_rows / 4) if total_images == None else int(total_images / 4)
|
94 |
|
95 |
+
unet = UNet2DConditionModel(sample_size=LATENT_SIZE, in_channels=LATENT_NUM_CHANNELS, out_channels=LATENT_NUM_CHANNELS, \
|
96 |
down_block_types=("CrossAttnDownBlock2D","CrossAttnDownBlock2D","CrossAttnDownBlock2D", "DownBlock2D"),\
|
97 |
up_block_types=("UpBlock2D","CrossAttnUpBlock2D","CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), cross_attention_dim=768,
|
98 |
+
block_out_channels=(128, 256, 512, 512), norm_num_groups=32)
|
99 |
+
unet = unet.to(dtype=torch.float32)
|
100 |
+
|
101 |
+
#https://discuss.pytorch.org/t/training-with-half-precision/11815
|
102 |
+
for layer in unet.modules():
|
103 |
+
if isinstance(layer, nn.BatchNorm2d):
|
104 |
+
layer.float()
|
105 |
+
|
106 |
scheduler = DDPMScheduler(num_train_timesteps=scheduler_num_timesteps)
|
107 |
tokenizer = CLIPTokenizer.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="tokenizer")
|
108 |
text_encoder = CLIPTextModel.from_pretrained(
|
109 |
"CompVis/stable-diffusion-v1-4", subfolder="text_encoder", use_safetensors=True
|
110 |
).to('cuda')
|
111 |
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", use_safetensors=True)
|
112 |
+
vae = vae.to(dtype=torch.float16, device='cuda')
|
113 |
|
114 |
+
optimizer = torch.optim.AdamW(unet.parameters(), lr=start_learning_rate)
|
115 |
lr_scheduler = get_cosine_schedule_with_warmup(
|
116 |
optimizer=optimizer,
|
117 |
num_warmup_steps=lr_warmup_steps,
|
118 |
num_training_steps=num_images * epochs
|
119 |
)
|
120 |
model = RCTDiffusionPipeline(unet, scheduler, vae, tokenizer, text_encoder)
|
121 |
+
unet = unet.to('cuda')
|
122 |
+
|
123 |
+
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
|
125 |
# lets train for 100 epoch for each sprite in the dataset with a random noise level
|
126 |
progress_bar = tqdm(total=epochs)
|
|
|
|
|
|
|
127 |
|
128 |
loss_fn = torch.nn.MSELoss()
|
129 |
|
130 |
tensor_to_pillow = T.ToPILImage()
|
131 |
for epoch in range(epochs):
|
132 |
# create a noisy version of each sprite
|
133 |
+
for step, batch in enumerate(train_dataloader):
|
134 |
+
clean_images = batch['image']
|
135 |
+
batch_size = clean_images.size(0)
|
136 |
+
embeddings = create_embeddings(batch, model)
|
137 |
+
clean_images = torch.reshape(clean_images, (batch['image'].size(0), LATENT_NUM_CHANNELS, LATENT_SIZE, LATENT_SIZE)).\
|
138 |
+
to(device='cuda')
|
139 |
|
140 |
+
noise = torch.randn(clean_images.shape, dtype=torch.float32, device='cuda')
|
141 |
+
timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_size, )).to(device='cuda')
|
142 |
|
143 |
#timesteps = timesteps.to(dtype=torch.int, device='cuda')
|
144 |
noisy_images = scheduler.add_noise(clean_images, noise, timesteps)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
|
146 |
+
batch_embeddings = embeddings
|
147 |
batch_embeddings = batch_embeddings.to('cuda')
|
148 |
|
149 |
optimizer.zero_grad()
|
150 |
unet_results = unet(noisy_images, timesteps, batch_embeddings).sample
|
|
|
151 |
loss = loss_fn(unet_results, noise)
|
152 |
loss.backward()
|
153 |
optimizer.step()
|
154 |
lr_scheduler.step()
|
155 |
optimizer.zero_grad()
|
156 |
|
157 |
+
progress_bar.set_description(f'epoch={epoch}, batch_index={step}, last_loss={loss.item()}')
|
158 |
|
159 |
if (epoch + 1) % save_model_interval == 0:
|
160 |
+
# inference in float16
|
161 |
+
model = RCTDiffusionPipeline(unet.to(dtype=torch.float16), scheduler, \
|
162 |
+
vae.to(dtype=torch.float16), tokenizer, text_encoder.to(dtype=torch.float16))
|
163 |
save_and_test(model, epoch)
|
164 |
+
|
165 |
+
# training in float32
|
166 |
+
unet.to(dtype=torch.float32)
|
167 |
+
vae.to(dtype=torch.float32)
|
168 |
+
text_encoder.to(dtype=torch.float32)
|
169 |
+
|
170 |
progress_bar.update(1)
|
171 |
|
172 |
|
173 |
if __name__ == '__main__':
|
174 |
+
train_model(1, total_images=4, save_model_interval=100, epochs=1000)
|