Curious about the methodology of finetuning
Thank you for the great work!
Still a little bit curious about the methodology. May I ask how this was done, to finetune the weight while keeping distribution of the latent unmoved? Was the encoder and the decoder finetuned separately or not? Thanks very much.
Yeah, separate, something like this (the fixed encoder outputs are never sent to the fixed decoder):
def compute_loss(ims, ref_vae, fix_vae, train_encoder=True, train_decoder=True):
ref_latents, _ref_features = run_model_and_capture_features(ref_vae.encoder, ims)
loss = 0
if train_encoder:
fix_latents, fix_features = run_model_and_capture_features(fix_vae.encoder, ims)
loss = loss + F.l1_loss(fix_latents, ref_latents) + 0.0001 * F.relu(th.abs(fix_features) - 100).mean()
if train_decoder:
ref_ims, _ref_features = run_model_and_capture_features(ref_vae.decoder, ref_latents)
fix_ims, fix_features = run_model_and_capture_features(fix_vae.decoder, ref_latents)
loss = loss + F.l1_loss(fix_ims, ref_ims) + 0.0001 * F.relu(th.abs(fix_features) - 100).mean()
return loss
I trained both enc/dec in the same training run initially, but later switched to individual training of enc / dec at a larger batch size.
Got it. Thank you very much for explanation!
As you say "I watched activation-map magnitudes + output deltas on a test image and manually rebalanced the match-original-output and make-activation-maps-smaller losses occasionally.":
Is "fix_features" the concatenation of all intermediate layers' outputs? And did you rebalance two loss items by adjusting the scale 0.0001 above?
Thank you very much!
@tengjiayan Correct.
Thank you!
hey @madebyollin thanks for your great great work.
Could you further explain how you "scaling down weights and biases within the network"? besides the F.relu(th.abs(fix_features) - 100).mean()
(please correct me if anything wrong.)
@xyzhang626
I was only training on a small amount of data (~10k images or something), so to make sure I wasn't breaking anything, I froze quant_conv
and post_quant_conv
as well as all of the Conv2d / Linear weight matrices, then gave each weight matrix a single trainable scale parameter (initialized to 1). So the finetuning process only used 37802 trainable decoder parameters (and a similarly small number in the encoder).
@xyzhang626 I was only training on a small amount of data (~10k images or something), so to make sure I wasn't breaking anything, I froze
quant_conv
andpost_quant_conv
as well as all of the Conv2d / Linear weight matrices, then gave each weight matrix a single trainable scale parameter (initialized to 1). So the finetuning process only used 37802 trainable decoder parameters (and a similarly small number in the encoder).
Thanks for your explanation. Really helpful!
@xyzhang626 I was only training on a small amount of data (~10k images or something), so to make sure I wasn't breaking anything, I froze
quant_conv
andpost_quant_conv
as well as all of the Conv2d / Linear weight matrices, then gave each weight matrix a single trainable scale parameter (initialized to 1). So the finetuning process only used 37802 trainable decoder parameters (and a similarly small number in the encoder).
@madebyollin
thanks for your great great work.
I trained a decoder from scratch with frozen encoder. But the activations too large, can't run in fp16. Luckily found your solution. i want some details, firstly, one conv2d(weight matrices, [Cout, Cin, kw,kh]) with a single trainable scale? why "37802 trainable decoder parameters"? and is groupnorm weight and bias trainable? Secondly, " I was only training on a small amount of data (~10k images or something)", whats the epochs or lr you used
firstly, one conv2d(weight matrices, [Cout, Cin, kw,kh]) with a single trainable scale? why "37802 trainable decoder parameters"? and is groupnorm weight and bias trainable?
Yeah, single scale per conv weight matrix. The norms are all trainable.
Trainable params by layer
Module Trainable Parameters
-------------------------------------------------------------
vae 37802
vae.decoder 37802
vae.decoder.conv_in 513
vae.decoder.mid 9224
vae.decoder.mid.block_1 3074
vae.decoder.mid.block_1.norm1 1024
vae.decoder.mid.block_1.conv1 513
vae.decoder.mid.block_1.norm2 1024
vae.decoder.mid.block_1.conv2 513
vae.decoder.mid.attn_1 3076
vae.decoder.mid.attn_1.norm 1024
vae.decoder.mid.attn_1.q 513
vae.decoder.mid.attn_1.k 513
vae.decoder.mid.attn_1.v 513
vae.decoder.mid.attn_1.proj_out 513
vae.decoder.mid.block_2 3074
vae.decoder.mid.block_2.norm1 1024
vae.decoder.mid.block_2.conv1 513
vae.decoder.mid.block_2.norm2 1024
vae.decoder.mid.block_2.conv2 513
vae.decoder.up 27805
vae.decoder.up.0 2695
vae.decoder.up.0.block 2695
vae.decoder.up.0.block.0 1155
vae.decoder.up.0.block.0.norm1 512
vae.decoder.up.0.block.0.conv1 129
vae.decoder.up.0.block.0.norm2 256
vae.decoder.up.0.block.0.conv2 129
vae.decoder.up.0.block.0.nin_shortcut 129
vae.decoder.up.0.block.1 770
vae.decoder.up.0.block.1.norm1 256
vae.decoder.up.0.block.1.conv1 129
vae.decoder.up.0.block.1.norm2 256
vae.decoder.up.0.block.1.conv2 129
vae.decoder.up.0.block.2 770
vae.decoder.up.0.block.2.norm1 256
vae.decoder.up.0.block.2.conv1 129
vae.decoder.up.0.block.2.norm2 256
vae.decoder.up.0.block.2.conv2 129
vae.decoder.up.1 5640
vae.decoder.up.1.block 5383
vae.decoder.up.1.block.0 2307
vae.decoder.up.1.block.0.norm1 1024
vae.decoder.up.1.block.0.conv1 257
vae.decoder.up.1.block.0.norm2 512
vae.decoder.up.1.block.0.conv2 257
vae.decoder.up.1.block.0.nin_shortcut 257
vae.decoder.up.1.block.1 1538
vae.decoder.up.1.block.1.norm1 512
vae.decoder.up.1.block.1.conv1 257
vae.decoder.up.1.block.1.norm2 512
vae.decoder.up.1.block.1.conv2 257
vae.decoder.up.1.block.2 1538
vae.decoder.up.1.block.2.norm1 512
vae.decoder.up.1.block.2.conv1 257
vae.decoder.up.1.block.2.norm2 512
vae.decoder.up.1.block.2.conv2 257
vae.decoder.up.1.upsample 257
vae.decoder.up.1.upsample.conv 257
vae.decoder.up.2 9735
vae.decoder.up.2.block 9222
vae.decoder.up.2.block.0 3074
vae.decoder.up.2.block.0.norm1 1024
vae.decoder.up.2.block.0.conv1 513
vae.decoder.up.2.block.0.norm2 1024
vae.decoder.up.2.block.0.conv2 513
vae.decoder.up.2.block.1 3074
vae.decoder.up.2.block.1.norm1 1024
vae.decoder.up.2.block.1.conv1 513
vae.decoder.up.2.block.1.norm2 1024
vae.decoder.up.2.block.1.conv2 513
vae.decoder.up.2.block.2 3074
vae.decoder.up.2.block.2.norm1 1024
vae.decoder.up.2.block.2.conv1 513
vae.decoder.up.2.block.2.norm2 1024
vae.decoder.up.2.block.2.conv2 513
vae.decoder.up.2.upsample 513
vae.decoder.up.2.upsample.conv 513
vae.decoder.up.3 9735
vae.decoder.up.3.block 9222
vae.decoder.up.3.block.0 3074
vae.decoder.up.3.block.0.norm1 1024
vae.decoder.up.3.block.0.conv1 513
vae.decoder.up.3.block.0.norm2 1024
vae.decoder.up.3.block.0.conv2 513
vae.decoder.up.3.block.1 3074
vae.decoder.up.3.block.1.norm1 1024
vae.decoder.up.3.block.1.conv1 513
vae.decoder.up.3.block.1.norm2 1024
vae.decoder.up.3.block.1.conv2 513
vae.decoder.up.3.block.2 3074
vae.decoder.up.3.block.2.norm1 1024
vae.decoder.up.3.block.2.conv1 513
vae.decoder.up.3.block.2.norm2 1024
vae.decoder.up.3.block.2.conv2 513
vae.decoder.up.3.upsample 513
vae.decoder.up.3.upsample.conv 513
vae.decoder.norm_out 256
vae.decoder.conv_out 4
You can verify which parameters were changed, and by how much, by comparing the original and fine-tuned weights using a script like this one.
whats the epochs or lr you used
lr 3e-4, not sure about epochs (training was split across a couple separate finetuning runs with different settings - maybe like 100 epochs total?)
firstly, one conv2d(weight matrices, [Cout, Cin, kw,kh]) with a single trainable scale? why "37802 trainable decoder parameters"? and is groupnorm weight and bias trainable?
Yeah, single scale per conv weight matrix. The norms are all trainable.
Trainable params by layer
Module Trainable Parameters ------------------------------------------------------------- vae 37802 vae.decoder 37802 vae.decoder.conv_in 513 vae.decoder.mid 9224 vae.decoder.mid.block_1 3074 vae.decoder.mid.block_1.norm1 1024 vae.decoder.mid.block_1.conv1 513 vae.decoder.mid.block_1.norm2 1024 vae.decoder.mid.block_1.conv2 513 vae.decoder.mid.attn_1 3076 vae.decoder.mid.attn_1.norm 1024 vae.decoder.mid.attn_1.q 513 vae.decoder.mid.attn_1.k 513 vae.decoder.mid.attn_1.v 513 vae.decoder.mid.attn_1.proj_out 513 vae.decoder.mid.block_2 3074 vae.decoder.mid.block_2.norm1 1024 vae.decoder.mid.block_2.conv1 513 vae.decoder.mid.block_2.norm2 1024 vae.decoder.mid.block_2.conv2 513 vae.decoder.up 27805 vae.decoder.up.0 2695 vae.decoder.up.0.block 2695 vae.decoder.up.0.block.0 1155 vae.decoder.up.0.block.0.norm1 512 vae.decoder.up.0.block.0.conv1 129 vae.decoder.up.0.block.0.norm2 256 vae.decoder.up.0.block.0.conv2 129 vae.decoder.up.0.block.0.nin_shortcut 129 vae.decoder.up.0.block.1 770 vae.decoder.up.0.block.1.norm1 256 vae.decoder.up.0.block.1.conv1 129 vae.decoder.up.0.block.1.norm2 256 vae.decoder.up.0.block.1.conv2 129 vae.decoder.up.0.block.2 770 vae.decoder.up.0.block.2.norm1 256 vae.decoder.up.0.block.2.conv1 129 vae.decoder.up.0.block.2.norm2 256 vae.decoder.up.0.block.2.conv2 129 vae.decoder.up.1 5640 vae.decoder.up.1.block 5383 vae.decoder.up.1.block.0 2307 vae.decoder.up.1.block.0.norm1 1024 vae.decoder.up.1.block.0.conv1 257 vae.decoder.up.1.block.0.norm2 512 vae.decoder.up.1.block.0.conv2 257 vae.decoder.up.1.block.0.nin_shortcut 257 vae.decoder.up.1.block.1 1538 vae.decoder.up.1.block.1.norm1 512 vae.decoder.up.1.block.1.conv1 257 vae.decoder.up.1.block.1.norm2 512 vae.decoder.up.1.block.1.conv2 257 vae.decoder.up.1.block.2 1538 vae.decoder.up.1.block.2.norm1 512 vae.decoder.up.1.block.2.conv1 257 vae.decoder.up.1.block.2.norm2 512 vae.decoder.up.1.block.2.conv2 257 vae.decoder.up.1.upsample 257 vae.decoder.up.1.upsample.conv 257 vae.decoder.up.2 9735 vae.decoder.up.2.block 9222 vae.decoder.up.2.block.0 3074 vae.decoder.up.2.block.0.norm1 1024 vae.decoder.up.2.block.0.conv1 513 vae.decoder.up.2.block.0.norm2 1024 vae.decoder.up.2.block.0.conv2 513 vae.decoder.up.2.block.1 3074 vae.decoder.up.2.block.1.norm1 1024 vae.decoder.up.2.block.1.conv1 513 vae.decoder.up.2.block.1.norm2 1024 vae.decoder.up.2.block.1.conv2 513 vae.decoder.up.2.block.2 3074 vae.decoder.up.2.block.2.norm1 1024 vae.decoder.up.2.block.2.conv1 513 vae.decoder.up.2.block.2.norm2 1024 vae.decoder.up.2.block.2.conv2 513 vae.decoder.up.2.upsample 513 vae.decoder.up.2.upsample.conv 513 vae.decoder.up.3 9735 vae.decoder.up.3.block 9222 vae.decoder.up.3.block.0 3074 vae.decoder.up.3.block.0.norm1 1024 vae.decoder.up.3.block.0.conv1 513 vae.decoder.up.3.block.0.norm2 1024 vae.decoder.up.3.block.0.conv2 513 vae.decoder.up.3.block.1 3074 vae.decoder.up.3.block.1.norm1 1024 vae.decoder.up.3.block.1.conv1 513 vae.decoder.up.3.block.1.norm2 1024 vae.decoder.up.3.block.1.conv2 513 vae.decoder.up.3.block.2 3074 vae.decoder.up.3.block.2.norm1 1024 vae.decoder.up.3.block.2.conv1 513 vae.decoder.up.3.block.2.norm2 1024 vae.decoder.up.3.block.2.conv2 513 vae.decoder.up.3.upsample 513 vae.decoder.up.3.upsample.conv 513 vae.decoder.norm_out 256 vae.decoder.conv_out 4
You can verify which parameters were changed, and by how much, by comparing the original and fine-tuned weights using a script like this one.
whats the epochs or lr you used
lr 3e-4, not sure about epochs (training was split across a couple separate finetuning runs with different settings - maybe like 100 epochs total?)
@madebyollin Thanks for your explanation. Really helpful! I have checked "Trainable params by layer". For example, "vae.decoder.mid.block_1.conv1 513" means that conv weight matrix has one trainable scaler, and bias with 512 trainable scaler, total 513? or another explanation,conv weight matrix has Cout_num trainable scaler,and bias with one trainable scaler.
Yeah, 512 trainable bias parameters and 1 trainable scale for the weight matrix