frutiemax commited on
Commit
2104644
1 Parent(s): 054faf7

Switch to stabilityai/sd-vae-ft-mse

Browse files
Files changed (1) hide show
  1. train_model.py +7 -3
train_model.py CHANGED
@@ -12,8 +12,8 @@ from tqdm.auto import tqdm
12
  from accelerate import Accelerator
13
  from diffusers import DDPMScheduler, UNet2DConditionModel, AutoencoderKL
14
 
15
- SAMPLE_SIZE = 512
16
- LATENT_SIZE = 64
17
  SAMPLE_NUM_CHANNELS = 3
18
  LATENT_NUM_CHANNELS = 4
19
 
@@ -109,7 +109,7 @@ def train_model(batch_size=4, epochs=100, scheduler_num_timesteps=20, save_model
109
  block_out_channels=(64, 128, 256), norm_num_groups=32)
110
  unet = unet.to(dtype=torch.float16)
111
  scheduler = DDPMScheduler(num_train_timesteps=20)
112
- vae = AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae", use_safetensors=True, variant='fp16')
113
  vae = vae.to(dtype=torch.float16)
114
 
115
  optimizer = torch.optim.Adam(unet.parameters(), lr=start_learning_rate)
@@ -121,6 +121,7 @@ def train_model(batch_size=4, epochs=100, scheduler_num_timesteps=20, save_model
121
  model = RCTDiffusionPipeline(unet, scheduler, vae)
122
  model.load_dictionaries_from_dataset()
123
  labels = convert_labels(dataset, model, num_images)
 
124
 
125
  # lets train for 100 epoch for each sprite in the dataset with a random noise level
126
  progress_bar = tqdm(total=epochs)
@@ -139,6 +140,7 @@ def train_model(batch_size=4, epochs=100, scheduler_num_timesteps=20, save_model
139
  timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_end - batch_index, )).to(device='cuda')
140
  #timesteps = timesteps.to(dtype=torch.int, device='cuda')
141
  noisy_images = scheduler.add_noise(clean_images, noise, timesteps)
 
142
 
143
  # encode through the vae
144
  with accelerator.accumulate(unet):
@@ -153,6 +155,8 @@ def train_model(batch_size=4, epochs=100, scheduler_num_timesteps=20, save_model
153
  result = vae.encode(images).latent_dist.sample()
154
  latent_noises[:, view_index*LATENT_NUM_CHANNELS:(view_index+1)*LATENT_NUM_CHANNELS] = result
155
 
 
 
156
  unet_results = unet(latent_images, timesteps, labels[batch_index:batch_end])[0]
157
  unet_results = unet_results.to(dtype=torch.float16)
158
 
 
12
  from accelerate import Accelerator
13
  from diffusers import DDPMScheduler, UNet2DConditionModel, AutoencoderKL
14
 
15
+ SAMPLE_SIZE = 256
16
+ LATENT_SIZE = 32
17
  SAMPLE_NUM_CHANNELS = 3
18
  LATENT_NUM_CHANNELS = 4
19
 
 
109
  block_out_channels=(64, 128, 256), norm_num_groups=32)
110
  unet = unet.to(dtype=torch.float16)
111
  scheduler = DDPMScheduler(num_train_timesteps=20)
112
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", use_safetensors=True)
113
  vae = vae.to(dtype=torch.float16)
114
 
115
  optimizer = torch.optim.Adam(unet.parameters(), lr=start_learning_rate)
 
121
  model = RCTDiffusionPipeline(unet, scheduler, vae)
122
  model.load_dictionaries_from_dataset()
123
  labels = convert_labels(dataset, model, num_images)
124
+ del model
125
 
126
  # lets train for 100 epoch for each sprite in the dataset with a random noise level
127
  progress_bar = tqdm(total=epochs)
 
140
  timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (batch_end - batch_index, )).to(device='cuda')
141
  #timesteps = timesteps.to(dtype=torch.int, device='cuda')
142
  noisy_images = scheduler.add_noise(clean_images, noise, timesteps)
143
+ del clean_images
144
 
145
  # encode through the vae
146
  with accelerator.accumulate(unet):
 
155
  result = vae.encode(images).latent_dist.sample()
156
  latent_noises[:, view_index*LATENT_NUM_CHANNELS:(view_index+1)*LATENT_NUM_CHANNELS] = result
157
 
158
+ del noise
159
+ del noisy_images
160
  unet_results = unet(latent_images, timesteps, labels[batch_index:batch_end])[0]
161
  unet_results = unet_results.to(dtype=torch.float16)
162