Spaces:
Runtime error
Runtime error
added vae notebook
Browse files- notebooks/test_vae.ipynb +0 -0
- scripts/train_unconditional.py +6 -6
- scripts/train_vae.py +1 -1
notebooks/test_vae.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
scripts/train_unconditional.py
CHANGED
@@ -11,7 +11,7 @@ from accelerate import Accelerator
|
|
11 |
from accelerate.logging import get_logger
|
12 |
from datasets import load_from_disk, load_dataset
|
13 |
from diffusers import (DDPMPipeline, DDPMScheduler, UNet2DModel, LDMPipeline,
|
14 |
-
DDIMScheduler,
|
15 |
from diffusers.hub_utils import init_git_repo, push_to_hub
|
16 |
from diffusers.optimization import get_scheduler
|
17 |
from diffusers.training_utils import EMAModel
|
@@ -46,11 +46,11 @@ def main(args):
|
|
46 |
vqvae = pretrained.vqvae
|
47 |
model = pretrained.unet
|
48 |
else:
|
49 |
-
vqvae =
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
model = UNet2DModel(
|
55 |
sample_size=args.resolution,
|
56 |
in_channels=1,
|
|
|
11 |
from accelerate.logging import get_logger
|
12 |
from datasets import load_from_disk, load_dataset
|
13 |
from diffusers import (DDPMPipeline, DDPMScheduler, UNet2DModel, LDMPipeline,
|
14 |
+
DDIMScheduler, AutoencoderKL)
|
15 |
from diffusers.hub_utils import init_git_repo, push_to_hub
|
16 |
from diffusers.optimization import get_scheduler
|
17 |
from diffusers.training_utils import EMAModel
|
|
|
46 |
vqvae = pretrained.vqvae
|
47 |
model = pretrained.unet
|
48 |
else:
|
49 |
+
vqvae = AutoencoderKL(sample_size=args.resolution,
|
50 |
+
in_channels=1,
|
51 |
+
out_channels=1,
|
52 |
+
latent_channels=1,
|
53 |
+
layers_per_block=2)
|
54 |
model = UNet2DModel(
|
55 |
sample_size=args.resolution,
|
56 |
in_channels=1,
|
scripts/train_vae.py
CHANGED
@@ -152,7 +152,7 @@ if __name__ == "__main__":
|
|
152 |
trainer_opt,
|
153 |
resume_from_checkpoint=args.resume_from_checkpoint,
|
154 |
callbacks=[
|
155 |
-
ImageLogger(),
|
156 |
HFModelCheckpoint(ldm_config=config,
|
157 |
hf_checkpoint=args.hf_checkpoint_dir,
|
158 |
dirpath=args.ldm_checkpoint_dir,
|
|
|
152 |
trainer_opt,
|
153 |
resume_from_checkpoint=args.resume_from_checkpoint,
|
154 |
callbacks=[
|
155 |
+
ImageLogger(every=10),
|
156 |
HFModelCheckpoint(ldm_config=config,
|
157 |
hf_checkpoint=args.hf_checkpoint_dir,
|
158 |
dirpath=args.ldm_checkpoint_dir,
|