Spaces:
Running
Running
import os | |
import glob | |
import numpy as np | |
import wandb | |
import copy | |
import argparse | |
import matplotlib.pyplot as plt | |
import torch | |
import torchvision.transforms as transforms | |
from torchinfo import summary | |
from utils import StyleContentDataset, DataStore, denorm_img | |
from loss import Loss | |
from model import Model | |
config = { | |
"lr": 1e-4, | |
"max_iter": 80000, | |
"logging_interval": 100, | |
"preview_interval": 1000, | |
"batch_size": 4, | |
"activations": "ReLU", | |
"optimizer": "Adam", | |
"lambda": 7 | |
} | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
print(f"Using {device} device") | |
def prepare_data(style_dir, content_dir, preview_dir): | |
norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) | |
# Training images | |
transform = transforms.Compose([transforms.Resize(512), transforms.RandomCrop(256)]) | |
style_imgs = glob.glob(os.path.join(style_dir, '*.jpg')) | |
content_imgs = glob.glob(os.path.join(content_dir, '*.jpg')) | |
train_dataset = StyleContentDataset(style_imgs, content_imgs, transform=transform, normalize=norm) | |
datastore = DataStore(train_dataset, batch_size=config['batch_size'], shuffle=True) | |
# Preview images | |
transform = transforms.Compose([transforms.Resize(256), transforms.CenterCrop(256)]) | |
preview_style_imgs = glob.glob(os.path.join(preview_dir, 'style/*.jpg')) | |
preview_content_imgs = glob.glob(os.path.join(preview_dir, 'content/*.jpg')) | |
# preview_dataset = StyleContentDataset(preview_style_imgs, preview_content_imgs, transform=transform, normalize=norm) | |
preview_dataset = StyleContentDataset(preview_style_imgs, [preview_content_imgs[8]] * len(preview_style_imgs), transform=transform, normalize=norm) | |
preview_datastore = DataStore(preview_dataset, batch_size=len(preview_dataset), shuffle=False) | |
return datastore, preview_datastore | |
def preview(model: Model, datastore: DataStore, iteration, save=False, use_wandb=False): | |
model.eval() | |
with torch.no_grad(): | |
# np.random.shuffle(datastore.dataset.style_imgs) | |
# np.random.shuffle(datastore.dataset.content_imgs) | |
style, content = datastore.get() | |
style, content = style.to(device), content.to(device) | |
out = model(content, style) | |
fig, axs = plt.subplots(8, 6, figsize=(20, 26)) | |
axs = axs.flatten() | |
i = 0 | |
for (s, c, o) in zip(style, content, out): # style, content, out | |
axs[i].imshow(denorm_img(s.cpu()).permute(1, 2, 0)) | |
axs[i].axis('off') | |
axs[i].set_title('style') | |
axs[i+1].imshow(denorm_img(c.cpu()).permute(1, 2, 0)) | |
axs[i+1].axis('off') | |
axs[i+1].set_title('content') | |
axs[i+2].imshow(denorm_img(o.cpu()).permute(1, 2, 0)) | |
axs[i+2].axis('off') | |
axs[i+2].set_title('output') | |
i += 3 | |
if save: | |
fig.savefig(f'outputs/{iteration}_preview.png') | |
plt.close(fig) | |
if use_wandb: | |
wandb.log({'preview': wandb.Image(f'outputs/{iteration}_preview.png')}, step=iteration) | |
def train_one_iter(datastore: DataStore, model: Model, optimizer: torch.optim.Adam, loss_fn: Loss): | |
model.train() | |
style, content = datastore.get() | |
style, content = style.to(device), content.to(device) | |
optimizer.zero_grad() | |
# Forward | |
out = model(content, style) | |
# Save activations | |
style_activations = copy.deepcopy(model.activations) | |
enc_out = model.encoder(out) | |
out_activations = model.activations | |
# Compute loss | |
loss = loss_fn(enc_out, model.t, out_activations, style_activations) | |
# Update parameters | |
loss.backward() | |
optimizer.step() | |
return loss.item(), loss_fn.loss_c.item(), loss_fn.loss_s.item() | |
def train(datastore, preview_datastore, model: Model, optimizer: torch.optim.Adam, use_wandb=False): | |
train_history = {'style_loss': [], 'content_loss': [], 'loss': []} | |
# optimizer = torch.optim.Adam(model.decoder.parameters(), lr=config['lr']) | |
loss_fn = Loss(lamb=config['lambda']) | |
for i in range(config['max_iter']): | |
loss, content_loss, style_loss = train_one_iter(datastore, model, optimizer, loss_fn) | |
train_history['loss'].append(loss) | |
train_history['style_loss'].append(style_loss) | |
train_history['content_loss'].append(content_loss) | |
if i%config['logging_interval'] == 0: | |
print(f'iter: {i}') | |
print(f'loss: {loss:>5f}, style loss: {style_loss:>5f}, content loss: {content_loss:>5f}') | |
print('-------------------------------') | |
if use_wandb: | |
wandb.log({ | |
'iter': i, 'loss': loss, 'style_loss': style_loss, 'content_loss': content_loss | |
}) | |
if i%config['preview_interval'] == 0: | |
torch.save({ | |
'iter': i, 'model_state': model.state_dict(), 'optimizer_state': optimizer.state_dict() | |
}, 'outputs/checkpoint.pt') | |
preview(model, preview_datastore, i, save=True, use_wandb=use_wandb) | |
return train_history | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--content_path', type=str, help='path to content dataset') | |
parser.add_argument('--style_path', type=str, help='path to content dataset') | |
parser.add_argument('--preview_path', type=str, help='path to preview dataset') | |
parser.add_argument('--wandb', type=str, help='wandb id') | |
parser.add_argument('--model_path', type=str, help='path to model') | |
args = parser.parse_args() | |
use_wandb = False | |
wandb_key = args.wandb | |
if wandb_key: | |
wandb.login(key=wandb_key) | |
wandb.init(project="assignment-3", name="", reinit=True, config=config) | |
use_wandb = True | |
if args.content_path and args.style_path and args.preview_path: | |
content_dir = args.content_path | |
style_dir = args.style_path | |
preview_dir = args.preview_path | |
else: | |
print('You didnt specify the data path >:(') | |
return | |
if not os.path.isdir('outputs'): | |
os.mkdir('outputs') | |
datastore, preview_datastore = prepare_data(style_dir, content_dir, preview_dir) | |
model = Model() | |
optimizer = torch.optim.Adam(model.decoder.parameters(), lr=config['lr']) | |
if args.model_path: | |
# From checkpoint | |
checkpoint = torch.load('outputs/checkpoint.pt') | |
model.load_state_dict(checkpoint['model_state']) | |
optimizer.load_state_dict(checkpoint['optimizer_state']) | |
config['max_iter'] -= checkpoint['iter'] | |
# From final model | |
# model.load_state_dict(torch.load(args.model_path, map_location=torch.device(device))) | |
# print(summary(model)) | |
model.to(device) | |
train(datastore, preview_datastore, model, optimizer, use_wandb) | |
torch.save(model.state_dict(), 'outputs/model.pt') | |
if use_wandb: | |
artifact = wandb.Artifact('model', type='model') | |
artifact.add_file('outputs/model.pt') | |
wandb.log_artifact(artifact) | |
wandb.finish() | |
if __name__ == '__main__': | |
main() |