batmangiaicuuthegioi commited on
Commit
dca2470
·
verified ·
1 Parent(s): f631423

Upload 7 files

Browse files
Files changed (7) hide show
  1. .gitattributes +35 -35
  2. README.md +13 -13
  3. app.py +46 -154
  4. config.py +24 -0
  5. model.py +359 -0
  6. requirements.txt +9 -6
  7. utils.py +29 -0
.gitattributes CHANGED
@@ -1,35 +1,35 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -1,14 +1,14 @@
1
- ---
2
- title: StyleTransferDemo
3
- emoji: 🖼
4
- colorFrom: purple
5
- colorTo: red
6
- sdk: gradio
7
- sdk_version: 5.0.1
8
- app_file: app.py
9
- pinned: false
10
- license: mit
11
- short_description: our deep learning project
12
- ---
13
-
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
+ ---
2
+ title: StyleTransferDemo
3
+ emoji: 🖼
4
+ colorFrom: purple
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 5.0.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ short_description: our deep learning project
12
+ ---
13
+
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,154 +1,46 @@
1
- import gradio as gr
2
- import numpy as np
3
- import random
4
-
5
- # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
7
- import torch
8
-
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
11
-
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
16
-
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
19
-
20
- MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
22
-
23
-
24
- # @spaces.GPU #[uncomment to use ZeroGPU]
25
- def infer(
26
- prompt,
27
- negative_prompt,
28
- seed,
29
- randomize_seed,
30
- width,
31
- height,
32
- guidance_scale,
33
- num_inference_steps,
34
- progress=gr.Progress(track_tqdm=True),
35
- ):
36
- if randomize_seed:
37
- seed = random.randint(0, MAX_SEED)
38
-
39
- generator = torch.Generator().manual_seed(seed)
40
-
41
- image = pipe(
42
- prompt=prompt,
43
- negative_prompt=negative_prompt,
44
- guidance_scale=guidance_scale,
45
- num_inference_steps=num_inference_steps,
46
- width=width,
47
- height=height,
48
- generator=generator,
49
- ).images[0]
50
-
51
- return image, seed
52
-
53
-
54
- examples = [
55
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
58
- ]
59
-
60
- css = """
61
- #col-container {
62
- margin: 0 auto;
63
- max-width: 640px;
64
- }
65
- """
66
-
67
- with gr.Blocks(css=css) as demo:
68
- with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
70
-
71
- with gr.Row():
72
- prompt = gr.Text(
73
- label="Prompt",
74
- show_label=False,
75
- max_lines=1,
76
- placeholder="Enter your prompt",
77
- container=False,
78
- )
79
-
80
- run_button = gr.Button("Run", scale=0, variant="primary")
81
-
82
- result = gr.Image(label="Result", show_label=False)
83
-
84
- with gr.Accordion("Advanced Settings", open=False):
85
- negative_prompt = gr.Text(
86
- label="Negative prompt",
87
- max_lines=1,
88
- placeholder="Enter a negative prompt",
89
- visible=False,
90
- )
91
-
92
- seed = gr.Slider(
93
- label="Seed",
94
- minimum=0,
95
- maximum=MAX_SEED,
96
- step=1,
97
- value=0,
98
- )
99
-
100
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
101
-
102
- with gr.Row():
103
- width = gr.Slider(
104
- label="Width",
105
- minimum=256,
106
- maximum=MAX_IMAGE_SIZE,
107
- step=32,
108
- value=1024, # Replace with defaults that work for your model
109
- )
110
-
111
- height = gr.Slider(
112
- label="Height",
113
- minimum=256,
114
- maximum=MAX_IMAGE_SIZE,
115
- step=32,
116
- value=1024, # Replace with defaults that work for your model
117
- )
118
-
119
- with gr.Row():
120
- guidance_scale = gr.Slider(
121
- label="Guidance scale",
122
- minimum=0.0,
123
- maximum=10.0,
124
- step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
126
- )
127
-
128
- num_inference_steps = gr.Slider(
129
- label="Number of inference steps",
130
- minimum=1,
131
- maximum=50,
132
- step=1,
133
- value=2, # Replace with defaults that work for your model
134
- )
135
-
136
- gr.Examples(examples=examples, inputs=[prompt])
137
- gr.on(
138
- triggers=[run_button.click, prompt.submit],
139
- fn=infer,
140
- inputs=[
141
- prompt,
142
- negative_prompt,
143
- seed,
144
- randomize_seed,
145
- width,
146
- height,
147
- guidance_scale,
148
- num_inference_steps,
149
- ],
150
- outputs=[result, seed],
151
- )
152
-
153
- if __name__ == "__main__":
154
- demo.launch()
 
1
+ from flask import Flask
2
+ import gradio as gr
3
+ import torch
4
+ from torchvision import transforms
5
+ from PIL import Image
6
+ from config import MODEL_CONFIG
7
+ from model import CycleGAN
8
+
9
+ # Load the CycleGAN models
10
+ model_paths = {
11
+ "CycleGAN_Cezanne_Unet_300": "/checkpoints/checkpoints/cyclegan_cezanne_unet_300_epochs.ckpt",
12
+ "CycleGAN_Monet_Unet_250": "/checkpoints/checkpoints/cyclegan_monet_unet_250_epochs.ckpt",
13
+ "CycleGAN_Vangogh_Resnet_70": "/cyclegan_vangogh_resnet_70_epochs.ckpt",
14
+ "CycleGAN_Vangogh_Unet_70":"/cyclegan_vangogh_unet_70_epochs.ckpt"
15
+ }
16
+
17
+ models = {name: CycleGAN.load_from_checkpoint(path, **MODEL_CONFIG) for name, path in model_paths.items()}
18
+
19
+ # Define the image transformation
20
+ transform = transforms.Compose([
21
+ transforms.Resize((256, 256)),
22
+ transforms.ToTensor(),
23
+ ])
24
+
25
+ # Define the image translation function
26
+ def translate_image(input_image, style):
27
+ model = models[style]
28
+ image = transform(input_image).unsqueeze(0)
29
+ with torch.no_grad():
30
+ translated_image = model(image)
31
+ return transforms.ToPILImage()(translated_image.squeeze(0))
32
+
33
+ # Initialize the Gradio interface
34
+ iface = gr.Interface(
35
+ fn=translate_image,
36
+ inputs=[
37
+ gr.Image(type="pil"),
38
+ gr.Dropdown(choices=list(models.keys()), label="Select Style")
39
+ ],
40
+ outputs=gr.Image(type="pil"),
41
+ title="CycleGAN Image Translation",
42
+ description="Upload an image and select a style to translate it using CycleGAN."
43
+ )
44
+
45
+ if __name__ == "__main__":
46
+ iface.launch(debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ DEBUG = False
3
+
4
+ MODEL_CONFIG = {
5
+ # the type of generator, and the number of residual blocks if ResNet generator is used
6
+ "gen_name": "unet", # types: 'unet', 'resnet'
7
+ "num_resblocks": 6,
8
+ # the number of filters in the first layer for the generators and discriminators
9
+ "hid_channels": 64,
10
+ # using DeepSpeed's FusedAdam (currently GPU only) is slightly faster
11
+ "optimizer": torch.optim.Adam,
12
+ # the learning rate and beta parameters for the Adam optimizer
13
+ "lr": 3e-6,
14
+ "betas": (0.5, 0.999),
15
+ # the weights used in the identity loss and cycle loss
16
+ "lambda_idt": 0,
17
+ "lambda_cycle": (10, 10), # (MPM direction, PMP direction)
18
+ # the size of the buffer that stores previously generated images
19
+ "buffer_size": 100,
20
+ # the number of epochs for training
21
+ "num_epochs": 30 if not DEBUG else 70,
22
+ # the number of epochs before starting the learning rate decay
23
+ "decay_epochs": 10 if not DEBUG else 70,
24
+ }
model.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import pytorch_lightning as L
5
+ import numpy as np
6
+
7
+ class Downsampling(nn.Module):
8
+ def __init__(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1, norm=True, lrelu=True):
9
+ super().__init__()
10
+ self.block = nn.Sequential(
11
+ nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=not norm),
12
+ )
13
+ if norm:
14
+ self.block.append(nn.InstanceNorm2d(out_channels, affine=True))
15
+ if lrelu is not None:
16
+ self.block.append(nn.LeakyReLU(0.2, True) if lrelu else nn.ReLU(True))
17
+
18
+ def forward(self, x):
19
+ return self.block(x)
20
+
21
+ class Upsampling(nn.Module):
22
+ def __init__(self, in_channels, out_channels, kernel_size=4, stride=2, padding=1, output_padding=0, dropout=False):
23
+ super().__init__()
24
+ self.block = nn.Sequential(
25
+ nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, bias=False),
26
+ nn.InstanceNorm2d(out_channels, affine=True),
27
+ )
28
+ if dropout:
29
+ self.block.append(nn.Dropout(0.5))
30
+ self.block.append(nn.ReLU(True))
31
+
32
+ def forward(self, x):
33
+ return self.block(x)
34
+
35
+ class ResBlock(nn.Module):
36
+ def __init__(self, in_channels, kernel_size=3, padding=1):
37
+ super().__init__()
38
+ self.block = nn.Sequential(
39
+ nn.ReflectionPad2d(padding),
40
+ Downsampling(in_channels, in_channels, kernel_size=kernel_size, stride=1, padding=0, lrelu=False),
41
+ nn.ReflectionPad2d(padding),
42
+ Downsampling(in_channels, in_channels, kernel_size=kernel_size, stride=1, padding=0, lrelu=None),
43
+ )
44
+
45
+ def forward(self, x):
46
+ return x + self.block(x)
47
+
48
+ class UNetGenerator(nn.Module):
49
+ def __init__(self, hid_channels, in_channels, out_channels):
50
+ super().__init__()
51
+ self.downsampling_path = nn.Sequential(
52
+ Downsampling(in_channels, hid_channels, norm=False),
53
+ Downsampling(hid_channels, hid_channels*2),
54
+ Downsampling(hid_channels*2, hid_channels*4),
55
+ Downsampling(hid_channels*4, hid_channels*8),
56
+ Downsampling(hid_channels*8, hid_channels*8),
57
+ Downsampling(hid_channels*8, hid_channels*8),
58
+ Downsampling(hid_channels*8, hid_channels*8),
59
+ Downsampling(hid_channels*8, hid_channels*8, norm=False),
60
+ )
61
+ self.upsampling_path = nn.Sequential(
62
+ Upsampling(hid_channels*8, hid_channels*8, dropout=True),
63
+ Upsampling(hid_channels*16, hid_channels*8, dropout=True),
64
+ Upsampling(hid_channels*16, hid_channels*8, dropout=True),
65
+ Upsampling(hid_channels*16, hid_channels*8),
66
+ Upsampling(hid_channels*16, hid_channels*4),
67
+ Upsampling(hid_channels*8, hid_channels*2),
68
+ Upsampling(hid_channels*4, hid_channels),
69
+ )
70
+ self.feature_block = nn.Sequential(
71
+ nn.ConvTranspose2d(hid_channels*2, out_channels, kernel_size=4, stride=2, padding=1),
72
+ nn.Tanh(),
73
+ )
74
+
75
+ def forward(self, x):
76
+ skips = []
77
+ for down in self.downsampling_path:
78
+ x = down(x)
79
+ skips.append(x)
80
+ skips = reversed(skips[:-1])
81
+
82
+ for up, skip in zip(self.upsampling_path, skips):
83
+ x = up(x)
84
+ x = torch.cat([x, skip], dim=1)
85
+ return self.feature_block(x)
86
+
87
+ class ResNetGenerator(nn.Module):
88
+ def __init__(self, hid_channels, in_channels, out_channels, num_resblocks):
89
+ super().__init__()
90
+ self.model = nn.Sequential(
91
+ nn.ReflectionPad2d(3),
92
+ Downsampling(in_channels, hid_channels, kernel_size=7, stride=1, padding=0, lrelu=False),
93
+ Downsampling(hid_channels, hid_channels*2, kernel_size=3, lrelu=False),
94
+ Downsampling(hid_channels*2, hid_channels*4, kernel_size=3, lrelu=False),
95
+ *[ResBlock(hid_channels*4) for _ in range(num_resblocks)],
96
+ Upsampling(hid_channels*4, hid_channels*2, kernel_size=3, output_padding=1),
97
+ Upsampling(hid_channels*2, hid_channels, kernel_size=3, output_padding=1),
98
+ nn.ReflectionPad2d(3),
99
+ nn.Conv2d(hid_channels, out_channels, kernel_size=7, stride=1, padding=0),
100
+ nn.Tanh(),
101
+ )
102
+
103
+ def forward(self, x):
104
+ return self.model(x)
105
+
106
+ def get_gen(gen_name, hid_channels, num_resblocks, in_channels=3, out_channels=3):
107
+ if gen_name == "unet":
108
+ return UNetGenerator(hid_channels, in_channels, out_channels)
109
+ elif gen_name == "resnet":
110
+ return ResNetGenerator(hid_channels, in_channels, out_channels, num_resblocks)
111
+ else:
112
+ raise NotImplementedError(f"Generator name '{gen_name}' not recognized.")
113
+
114
+ class Discriminator(nn.Module):
115
+ def __init__(self, hid_channels, in_channels=3):
116
+ super().__init__()
117
+ self.block = nn.Sequential(
118
+ Downsampling(in_channels, hid_channels, norm=False),
119
+ Downsampling(hid_channels, hid_channels*2),
120
+ Downsampling(hid_channels*2, hid_channels*4),
121
+ Downsampling(hid_channels*4, hid_channels*8, stride=1),
122
+ nn.Conv2d(hid_channels*8, 1, kernel_size=4, padding=1),
123
+ )
124
+
125
+ def forward(self, x):
126
+ return self.block(x)
127
+
128
+ class ImageBuffer(object):
129
+ def __init__(self, buffer_size):
130
+ self.buffer_size = buffer_size
131
+ if self.buffer_size > 0:
132
+ self.curr_cap = 0
133
+ self.buffer = []
134
+
135
+ def __call__(self, imgs):
136
+ if self.buffer_size == 0:
137
+ return imgs
138
+
139
+ return_imgs = []
140
+ for img in imgs:
141
+ img = img.unsqueeze(dim=0)
142
+
143
+ if self.curr_cap < self.buffer_size:
144
+ self.curr_cap += 1
145
+ self.buffer.append(img)
146
+ return_imgs.append(img)
147
+ else:
148
+ p = np.random.uniform(low=0., high=1.)
149
+
150
+ if p > 0.5:
151
+ idx = np.random.randint(low=0, high=self.buffer_size)
152
+ tmp = self.buffer[idx].clone()
153
+ self.buffer[idx] = img
154
+ return_imgs.append(tmp)
155
+ else:
156
+ return_imgs.append(img)
157
+ return torch.cat(return_imgs, dim=0)
158
+
159
+ class CycleGAN(L.LightningModule):
160
+ def __init__(self, gen_name, num_resblocks, hid_channels, optimizer, lr, lambda_idt, lambda_cycle, buffer_size, num_epochs, decay_epochs, betas):
161
+ super().__init__()
162
+ self.save_hyperparameters()
163
+ self.optimizer = optimizer
164
+ self.automatic_optimization = False
165
+
166
+ self.gen_PM = get_gen(gen_name, hid_channels, num_resblocks)
167
+ self.gen_MP = get_gen(gen_name, hid_channels, num_resblocks)
168
+ self.disc_M = Discriminator(hid_channels)
169
+ self.disc_P = Discriminator(hid_channels)
170
+
171
+ self.buffer_fake_M = ImageBuffer(buffer_size)
172
+ self.buffer_fake_P = ImageBuffer(buffer_size)
173
+
174
+ def forward(self, img):
175
+ return self.gen_PM(img)
176
+
177
+ def init_weights(self):
178
+ def init_fn(m):
179
+ if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.InstanceNorm2d)):
180
+ nn.init.normal_(m.weight, 0.0, 0.02)
181
+ if m.bias is not None:
182
+ nn.init.constant_(m.bias, 0.0)
183
+
184
+ for net in [self.gen_PM, self.gen_MP, self.disc_M, self.disc_P]:
185
+ net.apply(init_fn)
186
+
187
+ def setup(self, stage):
188
+ if stage == "fit":
189
+ print("Model initialized.")
190
+
191
+ def get_lr_scheduler(self, optimizer):
192
+ def lr_lambda(epoch):
193
+ len_decay_phase = self.hparams.num_epochs - self.hparams.decay_epochs + 1.0
194
+ curr_decay_step = max(0, epoch - self.hparams.decay_epochs + 1.0)
195
+ val = 1.0 - curr_decay_step / len_decay_phase
196
+ return max(0.0, val)
197
+
198
+ return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_lambda)
199
+
200
+ def configure_optimizers(self):
201
+ opt_config = {
202
+ "lr": self.hparams.lr,
203
+ "betas": self.hparams.betas,
204
+ }
205
+ opt_gen = self.optimizer(
206
+ list(self.gen_PM.parameters()) + list(self.gen_MP.parameters()),
207
+ **opt_config,
208
+ )
209
+ opt_disc = self.optimizer(
210
+ list(self.disc_M.parameters()) + list(self.disc_P.parameters()),
211
+ **opt_config,
212
+ )
213
+ optimizers = [opt_gen, opt_disc]
214
+ schedulers = [self.get_lr_scheduler(opt) for opt in optimizers]
215
+ return optimizers, schedulers
216
+
217
+ def adv_criterion(self, y_hat, y):
218
+ return F.mse_loss(y_hat, y)
219
+
220
+ def recon_criterion(self, y_hat, y):
221
+ return F.l1_loss(y_hat, y)
222
+
223
+ def get_adv_loss(self, fake, disc):
224
+ fake_hat = disc(fake)
225
+ real_labels = torch.ones_like(fake_hat)
226
+ adv_loss = self.adv_criterion(fake_hat, real_labels)
227
+ return adv_loss
228
+
229
+ def get_idt_loss(self, real, idt, lambda_cycle):
230
+ idt_loss = self.recon_criterion(idt, real)
231
+ return 0
232
+
233
+ def get_cycle_loss(self, real, recon, lambda_cycle):
234
+ cycle_loss = self.recon_criterion(recon, real)
235
+ return lambda_cycle * cycle_loss
236
+
237
+ def get_gen_loss(self):
238
+ adv_loss_PM = self.get_adv_loss(self.fake_M, self.disc_M)
239
+ adv_loss_MP = self.get_adv_loss(self.fake_P, self.disc_P)
240
+ total_adv_loss = adv_loss_PM + adv_loss_MP
241
+
242
+ lambda_cycle = self.hparams.lambda_cycle
243
+ idt_loss_MM = self.get_idt_loss(self.real_M, self.idt_M, lambda_cycle[0])
244
+ idt_loss_PP = self.get_idt_loss(self.real_P, self.idt_P, lambda_cycle[1])
245
+ total_idt_loss = idt_loss_MM + idt_loss_PP
246
+
247
+ cycle_loss_MPM = self.get_cycle_loss(self.real_M, self.recon_M, lambda_cycle[0])
248
+ cycle_loss_PMP = self.get_cycle_loss(self.real_P, self.recon_P, lambda_cycle[1])
249
+ total_cycle_loss = cycle_loss_MPM + cycle_loss_PMP
250
+
251
+ gen_loss = total_adv_loss + total_idt_loss + total_cycle_loss
252
+ return gen_loss
253
+
254
+ def get_disc_loss(self, real, fake, disc):
255
+ real_hat = disc(real)
256
+ real_labels = torch.ones_like(real_hat)
257
+ real_loss = self.adv_criterion(real_hat, real_labels)
258
+
259
+ fake_hat = disc(fake.detach())
260
+ fake_labels = torch.zeros_like(fake_hat)
261
+ fake_loss = self.adv_criterion(fake_hat, fake_labels)
262
+
263
+ disc_loss = (fake_loss + real_loss) * 0.5
264
+ return disc_loss
265
+
266
+ def get_disc_loss_M(self):
267
+ fake_M = self.buffer_fake_M(self.fake_M)
268
+ return self.get_disc_loss(self.real_M, fake_M, self.disc_M)
269
+
270
+ def get_disc_loss_P(self):
271
+ fake_P = self.buffer_fake_P(self.fake_P)
272
+ return self.get_disc_loss(self.real_P, fake_P, self.disc_P)
273
+
274
+ def training_step(self, batch, batch_idx):
275
+ self.real_M = batch["monet"]
276
+ self.real_P = batch["photo"]
277
+ opt_gen, opt_disc = self.optimizers()
278
+
279
+ self.fake_M = self.gen_PM(self.real_P)
280
+ self.fake_P = self.gen_MP(self.real_M)
281
+
282
+ self.idt_M = self.gen_PM(self.real_M)
283
+ self.idt_P = self.gen_MP(self.real_P)
284
+
285
+ self.recon_M = self.gen_PM(self.fake_P)
286
+ self.recon_P = self.gen_MP(self.fake_M)
287
+
288
+ self.toggle_optimizer(opt_gen)
289
+ gen_loss = self.get_gen_loss()
290
+ opt_gen.zero_grad()
291
+ self.manual_backward(gen_loss)
292
+ opt_gen.step()
293
+ self.untoggle_optimizer(opt_gen)
294
+
295
+ self.toggle_optimizer(opt_disc)
296
+ disc_loss_M = self.get_disc_loss_M()
297
+ disc_loss_P = self.get_disc_loss_P()
298
+ opt_disc.zero_grad()
299
+ self.manual_backward(disc_loss_M)
300
+ self.manual_backward(disc_loss_P)
301
+ opt_disc.step()
302
+ self.untoggle_optimizer(opt_disc)
303
+
304
+ metrics = {
305
+ "gen_loss": gen_loss,
306
+ "disc_loss_M": disc_loss_M,
307
+ "disc_loss_P": disc_loss_P,
308
+ }
309
+ wandb.log(metrics)
310
+ self.log_dict(metrics, on_step=False, on_epoch=True, prog_bar=True)
311
+
312
+ def validation_step(self, batch, batch_idx):
313
+ self.display_results(batch, batch_idx, "validate")
314
+
315
+ def test_step(self, batch, batch_idx):
316
+ self.display_results(batch, batch_idx, "test")
317
+
318
+ def predict_step(self, batch, batch_idx):
319
+ return self(batch)
320
+
321
+ def display_results(self, batch, batch_idx, stage):
322
+ real_P = batch
323
+ fake_M = self(real_P)
324
+
325
+ if stage == "validate":
326
+ title = f"Epoch {self.current_epoch+1}: Photo-to-Monet Translation"
327
+ else:
328
+ title = f"Sample {batch_idx+1}: Photo-to-Monet Translation"
329
+
330
+ show_img(
331
+ torch.cat([real_P, fake_M], dim=0),
332
+ nrow=len(real_P),
333
+ title=title,
334
+ )
335
+
336
+ def on_train_epoch_start(self):
337
+ curr_lr = self.lr_schedulers()[0].get_last_lr()[0]
338
+ self.log("lr", curr_lr, on_step=False, on_epoch=True, prog_bar=True)
339
+
340
+ def on_train_epoch_end(self):
341
+ for sch in self.lr_schedulers():
342
+ sch.step()
343
+
344
+ logged_values = self.trainer.progress_bar_metrics
345
+ print(
346
+ f"Epoch {self.current_epoch+1}",
347
+ *[f"{k}: {v:.5f}" for k, v in logged_values.items()],
348
+ sep=" - ",
349
+ )
350
+
351
+ def on_train_end(self):
352
+ print("Training ended.")
353
+
354
+ def on_predict_epoch_end(self):
355
+ predictions = self.trainer.predict_loop.predictions
356
+ num_batches = len(predictions)
357
+ batch_size = predictions[0].shape[0]
358
+ last_batch_diff = batch_size - predictions[-1].shape[0]
359
+ print(f"Number of images generated: {num_batches*batch_size-last_batch_diff}.")
requirements.txt CHANGED
@@ -1,6 +1,9 @@
1
- accelerate
2
- diffusers
3
- invisible_watermark
4
- torch
5
- transformers
6
- xformers
 
 
 
 
1
+ flask
2
+ gradio
3
+ torch
4
+ torchvision
5
+ pytorch_lightning
6
+ numpy
7
+ Pillow
8
+ matplotlib
9
+ wandb
utils.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ import torchvision.transforms as T
4
+ import matplotlib.pyplot as plt
5
+ from model import CycleGAN
6
+
7
+ # Load and preprocess the input image
8
+ def load_image(image_path, device, image_size=(256, 256)):
9
+ transform = T.Compose([
10
+ T.Resize(image_size),
11
+ T.ToTensor(),
12
+ T.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) # Normalize to [-1, 1]
13
+ ])
14
+ image = Image.open(image_path).convert("RGB")
15
+ image = transform(image).unsqueeze(0).to(device)
16
+ return image
17
+
18
+ # Display the output image
19
+ def display_image(tensor_image):
20
+ tensor_image = tensor_image.squeeze(0).cpu() # Remove batch dimension
21
+ tensor_image = (tensor_image * 0.5 + 0.5).clamp(0, 1) # Denormalize
22
+ plt.imshow(tensor_image.permute(1, 2, 0)) # CHW to HWC
23
+ plt.axis("off")
24
+ plt.show()
25
+
26
+ # Load the input image
27
+ device = "cuda" if torch.cuda.is_available() else "cpu"
28
+ model = CycleGAN.load_from_checkpoint("/content/cyclegan_monet_unet_250_epochs.ckpt", **MODEL_CONFIG)
29
+