Use vae for encoding and decoding for training
Browse files- rct_diffusion_pipeline.py +21 -10
- test_pipeline.py +4 -2
- train_model.py +29 -16
rct_diffusion_pipeline.py
CHANGED
@@ -172,8 +172,12 @@ class RCTDiffusionPipeline(DiffusionPipeline):
|
|
172 |
|
173 |
def generate_noise_batches(self, batch_size):
|
174 |
noise_batches = torch.Tensor(size=(batch_size, self.num_channels, self.latent_size, self.latent_size)).to(dtype=torch.float16, device='cuda')
|
|
|
|
|
|
|
|
|
175 |
for batch_index in range(batch_size):
|
176 |
-
noise = torch.randn(self.num_channels, self.latent_size, self.latent_size).to(dtype=torch.float16, device='cuda')
|
177 |
noise_batches[batch_index] = noise
|
178 |
|
179 |
return torch.reshape(noise_batches, (batch_size, self.num_channels, self.latent_size, self.latent_size)).to(dtype=torch.float16, device='cuda')
|
@@ -237,7 +241,11 @@ class RCTDiffusionPipeline(DiffusionPipeline):
|
|
237 |
|
238 |
def __call__(self, object_description : list[str], color1 : list[str], \
|
239 |
color2 : list[str] = None, color3 : list[str] = None, \
|
240 |
-
batch_size=1, num_inference_steps=
|
|
|
|
|
|
|
|
|
241 |
|
242 |
res, object_description, color1, color2, color3 = self.validate_inputs(object_description, color1, color2, color3, batch_size)
|
243 |
if res == False:
|
@@ -268,20 +276,23 @@ class RCTDiffusionPipeline(DiffusionPipeline):
|
|
268 |
noise_batches = torch.reshape(noise_batches, (batch_size, self.num_channels, self.latent_size, self.latent_size))
|
269 |
noise_batches = noise_batches.to('cuda')
|
270 |
images = torch.Tensor(size=(batch_size, 3, self.sample_size, self.sample_size)).to('cuda')
|
|
|
271 |
|
272 |
-
with torch.no_grad():
|
273 |
-
image = noise_batches
|
274 |
-
result = self.vae.decode(image).sample
|
275 |
-
images = result
|
276 |
|
277 |
# convert those tensors to PIL images
|
|
|
278 |
output_images = []
|
279 |
for batch_index in range(batch_size):
|
280 |
image = images[batch_index]
|
281 |
-
image = (image / 2 + 0.5).clamp(0, 1)
|
282 |
-
image = (image.permute(1, 2, 0) * 255).to(torch.uint8).cpu().numpy()
|
283 |
-
image = (image * 255).round().astype("uint8")
|
284 |
-
image = Image.fromarray(image)
|
|
|
285 |
image.save(f'test{batch_index}.png')
|
286 |
output_images.append(image)
|
287 |
|
|
|
172 |
|
173 |
def generate_noise_batches(self, batch_size):
|
174 |
noise_batches = torch.Tensor(size=(batch_size, self.num_channels, self.latent_size, self.latent_size)).to(dtype=torch.float16, device='cuda')
|
175 |
+
seed = int(0)
|
176 |
+
np.random.seed(seed)
|
177 |
+
torch.manual_seed(seed)
|
178 |
+
torch.cuda.manual_seed(seed)
|
179 |
for batch_index in range(batch_size):
|
180 |
+
noise = torch.randn((self.num_channels, self.latent_size, self.latent_size)).to(dtype=torch.float16, device='cuda')
|
181 |
noise_batches[batch_index] = noise
|
182 |
|
183 |
return torch.reshape(noise_batches, (batch_size, self.num_channels, self.latent_size, self.latent_size)).to(dtype=torch.float16, device='cuda')
|
|
|
241 |
|
242 |
def __call__(self, object_description : list[str], color1 : list[str], \
|
243 |
color2 : list[str] = None, color3 : list[str] = None, \
|
244 |
+
batch_size=1, num_inference_steps=100, generator=torch.manual_seed(torch.random.seed())):
|
245 |
+
|
246 |
+
self.unet.to(device='cuda', dtype=torch.float16)
|
247 |
+
self.vae.to(device='cuda', dtype=torch.float16)
|
248 |
+
self.text_encoder.to(device='cuda', dtype=torch.float16)
|
249 |
|
250 |
res, object_description, color1, color2, color3 = self.validate_inputs(object_description, color1, color2, color3, batch_size)
|
251 |
if res == False:
|
|
|
276 |
noise_batches = torch.reshape(noise_batches, (batch_size, self.num_channels, self.latent_size, self.latent_size))
|
277 |
noise_batches = noise_batches.to('cuda')
|
278 |
images = torch.Tensor(size=(batch_size, 3, self.sample_size, self.sample_size)).to('cuda')
|
279 |
+
images = noise_batches[:, :3]
|
280 |
|
281 |
+
#with torch.no_grad():
|
282 |
+
#image = noise_batches
|
283 |
+
#result = self.vae.decode(image).sample
|
284 |
+
#images = result
|
285 |
|
286 |
# convert those tensors to PIL images
|
287 |
+
tensor_to_pil = T.ToPILImage()
|
288 |
output_images = []
|
289 |
for batch_index in range(batch_size):
|
290 |
image = images[batch_index]
|
291 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
292 |
+
#image = (image.permute(1, 2, 0) * 255).to(torch.uint8).cpu().numpy()
|
293 |
+
#image = (image * 255).round().astype("uint8")
|
294 |
+
#image = Image.fromarray(image)
|
295 |
+
image = tensor_to_pil(image)
|
296 |
image.save(f'test{batch_index}.png')
|
297 |
output_images.append(image)
|
298 |
|
test_pipeline.py
CHANGED
@@ -38,7 +38,9 @@ scheduler = DDPMScheduler(num_train_timesteps=20)
|
|
38 |
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", use_safetensors=True)
|
39 |
vae = vae.to('cuda', dtype=torch.float16)
|
40 |
|
41 |
-
pipeline = RCTDiffusionPipeline(unet, scheduler, vae, tokenizer, text_encoder)
|
42 |
-
|
|
|
|
|
43 |
pipeline.save_pretrained('test')
|
44 |
print('test')
|
|
|
38 |
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", use_safetensors=True)
|
39 |
vae = vae.to('cuda', dtype=torch.float16)
|
40 |
|
41 |
+
#pipeline = RCTDiffusionPipeline(unet, scheduler, vae, tokenizer, text_encoder)
|
42 |
+
pipeline = RCTDiffusionPipeline.from_pretrained('rct_foliage_999')
|
43 |
+
output = pipeline(['pagoda pine tree'], ['green'], ['grey'])
|
44 |
+
output[0].save('out.png')
|
45 |
pipeline.save_pretrained('test')
|
46 |
print('test')
|
train_model.py
CHANGED
@@ -6,6 +6,7 @@ import numpy as np
|
|
6 |
from rct_diffusion_pipeline import RCTDiffusionPipeline
|
7 |
import torch
|
8 |
import torchvision.transforms as T
|
|
|
9 |
import torch.nn.functional as F
|
10 |
from diffusers.optimization import get_cosine_schedule_with_warmup
|
11 |
from tqdm.auto import tqdm
|
@@ -29,16 +30,21 @@ def save_and_test(pipeline, epoch):
|
|
29 |
model_file = f'rct_foliage_{epoch}'
|
30 |
pipeline.save_pretrained(model_file)
|
31 |
|
32 |
-
def
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
)
|
40 |
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
object_descriptions = [obj_desc for obj_desc in dataset["object_description"]]
|
43 |
colors1 = [color1 for color1 in dataset['color1']]
|
44 |
colors2 = [color1 for color1 in dataset['color2']]
|
@@ -87,10 +93,10 @@ def create_embeddings(dataset, model):
|
|
87 |
return model.test_generate_embeddings(object_descriptions, colors1, colors2, colors3)
|
88 |
|
89 |
|
90 |
-
def train_model(batch_size=4, total_images=-1, epochs=100, scheduler_num_timesteps=
|
91 |
dataset = load_dataset('frutiemax/rct_dataset', split=f'train[0:{total_images}]')
|
92 |
dataset.set_transform(convert_images)
|
93 |
-
num_images =
|
94 |
|
95 |
unet = UNet2DConditionModel(sample_size=LATENT_SIZE, in_channels=LATENT_NUM_CHANNELS, out_channels=LATENT_NUM_CHANNELS, \
|
96 |
down_block_types=("CrossAttnDownBlock2D","CrossAttnDownBlock2D","CrossAttnDownBlock2D", "DownBlock2D"),\
|
@@ -109,7 +115,7 @@ def train_model(batch_size=4, total_images=-1, epochs=100, scheduler_num_timeste
|
|
109 |
"CompVis/stable-diffusion-v1-4", subfolder="text_encoder", use_safetensors=True
|
110 |
).to('cuda')
|
111 |
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", use_safetensors=True)
|
112 |
-
vae = vae.to(dtype=torch.
|
113 |
|
114 |
optimizer = torch.optim.AdamW(unet.parameters(), lr=start_learning_rate)
|
115 |
lr_scheduler = get_cosine_schedule_with_warmup(
|
@@ -134,7 +140,7 @@ def train_model(batch_size=4, total_images=-1, epochs=100, scheduler_num_timeste
|
|
134 |
clean_images = batch['image']
|
135 |
batch_size = clean_images.size(0)
|
136 |
embeddings = create_embeddings(batch, model)
|
137 |
-
clean_images = torch.reshape(clean_images, (batch['image'].size(0),
|
138 |
to(device='cuda')
|
139 |
|
140 |
noise = torch.randn(clean_images.shape, dtype=torch.float32, device='cuda')
|
@@ -146,9 +152,16 @@ def train_model(batch_size=4, total_images=-1, epochs=100, scheduler_num_timeste
|
|
146 |
batch_embeddings = embeddings
|
147 |
batch_embeddings = batch_embeddings.to('cuda')
|
148 |
|
|
|
|
|
|
|
149 |
optimizer.zero_grad()
|
150 |
-
unet_results = unet(
|
151 |
-
|
|
|
|
|
|
|
|
|
152 |
loss.backward()
|
153 |
optimizer.step()
|
154 |
lr_scheduler.step()
|
@@ -171,4 +184,4 @@ def train_model(batch_size=4, total_images=-1, epochs=100, scheduler_num_timeste
|
|
171 |
|
172 |
|
173 |
if __name__ == '__main__':
|
174 |
-
train_model(1,
|
|
|
6 |
from rct_diffusion_pipeline import RCTDiffusionPipeline
|
7 |
import torch
|
8 |
import torchvision.transforms as T
|
9 |
+
import torchvision
|
10 |
import torch.nn.functional as F
|
11 |
from diffusers.optimization import get_cosine_schedule_with_warmup
|
12 |
from tqdm.auto import tqdm
|
|
|
30 |
model_file = f'rct_foliage_{epoch}'
|
31 |
pipeline.save_pretrained(model_file)
|
32 |
|
33 |
+
def transform_images(image):
|
34 |
+
res = torch.Tensor(SAMPLE_NUM_CHANNELS, SAMPLE_SIZE, SAMPLE_SIZE)
|
35 |
+
pil_to_tensor = T.PILToTensor()
|
36 |
+
|
37 |
+
res_index = 0
|
38 |
+
scale_factor = np.minimum(SAMPLE_SIZE / image.width, SAMPLE_SIZE / image.height)
|
39 |
+
image = Image.resize(image, size=(int(scale_factor * image.width), int(scale_factor * image.height)), resample=Resampling.NEAREST)
|
|
|
40 |
|
41 |
+
new_image = PIL.Image.new('RGB', (SAMPLE_SIZE, SAMPLE_SIZE))
|
42 |
+
new_image.paste(image, box=(int((SAMPLE_SIZE - image.width)/2), int((SAMPLE_SIZE - image.height)/2)))
|
43 |
+
res = pil_to_tensor(new_image)
|
44 |
+
return res
|
45 |
+
|
46 |
+
def convert_images(dataset):
|
47 |
+
images = [transform_images(image) for image in dataset["image"]]
|
48 |
object_descriptions = [obj_desc for obj_desc in dataset["object_description"]]
|
49 |
colors1 = [color1 for color1 in dataset['color1']]
|
50 |
colors2 = [color1 for color1 in dataset['color2']]
|
|
|
93 |
return model.test_generate_embeddings(object_descriptions, colors1, colors2, colors3)
|
94 |
|
95 |
|
96 |
+
def train_model(batch_size=4, total_images=-1, epochs=100, scheduler_num_timesteps=100, save_model_interval=10, start_learning_rate=1e-4, lr_warmup_steps=500):
|
97 |
dataset = load_dataset('frutiemax/rct_dataset', split=f'train[0:{total_images}]')
|
98 |
dataset.set_transform(convert_images)
|
99 |
+
num_images = dataset.num_rows
|
100 |
|
101 |
unet = UNet2DConditionModel(sample_size=LATENT_SIZE, in_channels=LATENT_NUM_CHANNELS, out_channels=LATENT_NUM_CHANNELS, \
|
102 |
down_block_types=("CrossAttnDownBlock2D","CrossAttnDownBlock2D","CrossAttnDownBlock2D", "DownBlock2D"),\
|
|
|
115 |
"CompVis/stable-diffusion-v1-4", subfolder="text_encoder", use_safetensors=True
|
116 |
).to('cuda')
|
117 |
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", use_safetensors=True)
|
118 |
+
vae = vae.to(dtype=torch.float32, device='cuda')
|
119 |
|
120 |
optimizer = torch.optim.AdamW(unet.parameters(), lr=start_learning_rate)
|
121 |
lr_scheduler = get_cosine_schedule_with_warmup(
|
|
|
140 |
clean_images = batch['image']
|
141 |
batch_size = clean_images.size(0)
|
142 |
embeddings = create_embeddings(batch, model)
|
143 |
+
clean_images = torch.reshape(clean_images, (batch['image'].size(0), SAMPLE_NUM_CHANNELS, SAMPLE_SIZE, SAMPLE_SIZE)).\
|
144 |
to(device='cuda')
|
145 |
|
146 |
noise = torch.randn(clean_images.shape, dtype=torch.float32, device='cuda')
|
|
|
152 |
batch_embeddings = embeddings
|
153 |
batch_embeddings = batch_embeddings.to('cuda')
|
154 |
|
155 |
+
# use the vae to get the latent images
|
156 |
+
latent_images = vae.encode(noisy_images).latent_dist.sample()
|
157 |
+
|
158 |
optimizer.zero_grad()
|
159 |
+
unet_results = unet(latent_images, timesteps, batch_embeddings).sample
|
160 |
+
|
161 |
+
# get back the upscale result
|
162 |
+
noise_pred = vae.decode(unet_results).sample
|
163 |
+
|
164 |
+
loss = loss_fn(noise_pred, noise)
|
165 |
loss.backward()
|
166 |
optimizer.step()
|
167 |
lr_scheduler.step()
|
|
|
184 |
|
185 |
|
186 |
if __name__ == '__main__':
|
187 |
+
train_model(1, save_model_interval=10, epochs=100)
|