File size: 2,480 Bytes
27486b3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
model:
base_learning_rate: 4.5e-6
target: sgm.models.autoencoder.AutoencodingEngine
params:
input_key: jpg
monitor: val/rec_loss
loss_config:
target: sgm.modules.autoencoding.losses.GeneralLPIPSWithDiscriminator
params:
perceptual_weight: 0.25
disc_start: 20001
disc_weight: 0.5
learn_logvar: True
regularization_weights:
kl_loss: 1.0
regularizer_config:
target: sgm.modules.autoencoding.regularizers.DiagonalGaussianRegularizer
encoder_config:
target: sgm.modules.diffusionmodules.model.Encoder
params:
attn_type: none
double_z: True
z_channels: 4
resolution: 256
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4]
num_res_blocks: 4
attn_resolutions: []
dropout: 0.0
decoder_config:
target: sgm.modules.diffusionmodules.model.Decoder
params: ${model.params.encoder_config.params}
data:
target: sgm.data.dataset.StableDataModuleFromConfig
params:
train:
datapipeline:
urls:
- DATA-PATH
pipeline_config:
shardshuffle: 10000
sample_shuffle: 10000
decoders:
- pil
postprocessors:
- target: sdata.mappers.TorchVisionImageTransforms
params:
key: jpg
transforms:
- target: torchvision.transforms.Resize
params:
size: 256
interpolation: 3
- target: torchvision.transforms.ToTensor
- target: sdata.mappers.Rescaler
- target: sdata.mappers.AddOriginalImageSizeAsTupleAndCropToSquare
params:
h_key: height
w_key: width
loader:
batch_size: 8
num_workers: 4
lightning:
strategy:
target: pytorch_lightning.strategies.DDPStrategy
params:
find_unused_parameters: True
modelcheckpoint:
params:
every_n_train_steps: 5000
callbacks:
metrics_over_trainsteps_checkpoint:
params:
every_n_train_steps: 50000
image_logger:
target: main.ImageLogger
params:
enable_autocast: False
batch_frequency: 1000
max_images: 8
increase_log_steps: True
trainer:
devices: 0,
limit_val_batches: 50
benchmark: True
accumulate_grad_batches: 1
val_check_interval: 10000 |