Switch to stabilityai/sd-vae-ft-mse
Browse files- train_model.py +7 -3
train_model.py
CHANGED
@@ -12,8 +12,8 @@ from tqdm.auto import tqdm
|
|
12 |
from accelerate import Accelerator
|
13 |
from diffusers import DDPMScheduler, UNet2DConditionModel, AutoencoderKL
|
14 |
|
15 |
-
SAMPLE_SIZE =
|
16 |
-
LATENT_SIZE =
|
17 |
SAMPLE_NUM_CHANNELS = 3
|
18 |
LATENT_NUM_CHANNELS = 4
|
19 |
|
@@ -109,7 +109,7 @@ def train_model(batch_size=4, epochs=100, scheduler_num_timesteps=20, save_model
|
|
109 |
block_out_channels=(64, 128, 256), norm_num_groups=32)
|
110 |
unet = unet.to(dtype=torch.float16)
|
111 |
scheduler = DDPMScheduler(num_train_timesteps=20)
|
112 |
-
vae = AutoencoderKL.from_pretrained("
|
113 |
vae = vae.to(dtype=torch.float16)
|
114 |
|
115 |
optimizer = torch.optim.Adam(unet.parameters(), lr=start_learning_rate)
|
@@ -121,6 +121,7 @@ def train_model(batch_size=4, epochs=100, scheduler_num_timesteps=20, save_model
|
|
121 |
model = RCTDiffusionPipeline(unet, scheduler, vae)
|
122 |
model.load_dictionaries_from_dataset()
|
123 |
labels = convert_labels(dataset, model, num_images)
|
|
|
124 |
|
125 |
# lets train for 100 epoch for each sprite in the dataset with a random noise level
|
126 |
progress_bar = tqdm(total=epochs)
|
@@ -139,6 +140,7 @@ def train_model(batch_size=4, epochs=100, scheduler_num_timesteps=20, save_model
|
|
139 |
timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_end - batch_index, )).to(device='cuda')
|
140 |
#timesteps = timesteps.to(dtype=torch.int, device='cuda')
|
141 |
noisy_images = scheduler.add_noise(clean_images, noise, timesteps)
|
|
|
142 |
|
143 |
# encode through the vae
|
144 |
with accelerator.accumulate(unet):
|
@@ -153,6 +155,8 @@ def train_model(batch_size=4, epochs=100, scheduler_num_timesteps=20, save_model
|
|
153 |
result = vae.encode(images).latent_dist.sample()
|
154 |
latent_noises[:, view_index*LATENT_NUM_CHANNELS:(view_index+1)*LATENT_NUM_CHANNELS] = result
|
155 |
|
|
|
|
|
156 |
unet_results = unet(latent_images, timesteps, labels[batch_index:batch_end])[0]
|
157 |
unet_results = unet_results.to(dtype=torch.float16)
|
158 |
|
|
|
12 |
from accelerate import Accelerator
|
13 |
from diffusers import DDPMScheduler, UNet2DConditionModel, AutoencoderKL
|
14 |
|
15 |
+
SAMPLE_SIZE = 256
|
16 |
+
LATENT_SIZE = 32
|
17 |
SAMPLE_NUM_CHANNELS = 3
|
18 |
LATENT_NUM_CHANNELS = 4
|
19 |
|
|
|
109 |
block_out_channels=(64, 128, 256), norm_num_groups=32)
|
110 |
unet = unet.to(dtype=torch.float16)
|
111 |
scheduler = DDPMScheduler(num_train_timesteps=20)
|
112 |
+
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", use_safetensors=True)
|
113 |
vae = vae.to(dtype=torch.float16)
|
114 |
|
115 |
optimizer = torch.optim.Adam(unet.parameters(), lr=start_learning_rate)
|
|
|
121 |
model = RCTDiffusionPipeline(unet, scheduler, vae)
|
122 |
model.load_dictionaries_from_dataset()
|
123 |
labels = convert_labels(dataset, model, num_images)
|
124 |
+
del model
|
125 |
|
126 |
# lets train for 100 epoch for each sprite in the dataset with a random noise level
|
127 |
progress_bar = tqdm(total=epochs)
|
|
|
140 |
timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_end - batch_index, )).to(device='cuda')
|
141 |
#timesteps = timesteps.to(dtype=torch.int, device='cuda')
|
142 |
noisy_images = scheduler.add_noise(clean_images, noise, timesteps)
|
143 |
+
del clean_images
|
144 |
|
145 |
# encode through the vae
|
146 |
with accelerator.accumulate(unet):
|
|
|
155 |
result = vae.encode(images).latent_dist.sample()
|
156 |
latent_noises[:, view_index*LATENT_NUM_CHANNELS:(view_index+1)*LATENT_NUM_CHANNELS] = result
|
157 |
|
158 |
+
del noise
|
159 |
+
del noisy_images
|
160 |
unet_results = unet(latent_images, timesteps, labels[batch_index:batch_end])[0]
|
161 |
unet_results = unet_results.to(dtype=torch.float16)
|
162 |
|