Spaces:
Sleeping
Sleeping
import numpy as np | |
import os, time, random | |
import argparse | |
import json | |
import torch.nn.functional as F | |
import torch | |
from torch.utils.data import Dataset, DataLoader | |
from torch.optim import lr_scheduler | |
from model.model import InvISPNet | |
from dataset.FiveK_dataset import FiveKDatasetTrain | |
from config.config import get_arguments | |
from utils.JPEG import DiffJPEG | |
os.system("nvidia-smi -q -d Memory |grep -A4 GPU|grep Free >tmp") | |
os.environ["CUDA_VISIBLE_DEVICES"] = str( | |
np.argmax([int(x.split()[2]) for x in open("tmp", "r").readlines()]) | |
) | |
# os.environ['CUDA_VISIBLE_DEVICES'] = "1" | |
os.system("rm tmp") | |
DiffJPEG = DiffJPEG(differentiable=True, quality=90).cuda() | |
parser = get_arguments() | |
parser.add_argument( | |
"--out_path", type=str, default="./exps/", help="Path to save checkpoint. " | |
) | |
parser.add_argument( | |
"--resume", dest="resume", action="store_true", help="Resume training. " | |
) | |
parser.add_argument( | |
"--loss", | |
type=str, | |
default="L1", | |
choices=["L1", "L2"], | |
help="Choose which loss function to use. ", | |
) | |
parser.add_argument("--lr", type=float, default=0.0001, help="Learning rate") | |
parser.add_argument( | |
"--aug", dest="aug", action="store_true", help="Use data augmentation." | |
) | |
args = parser.parse_args() | |
print("Parsed arguments: {}".format(args)) | |
os.makedirs(args.out_path, exist_ok=True) | |
os.makedirs(args.out_path + "%s" % args.task, exist_ok=True) | |
os.makedirs(args.out_path + "%s/checkpoint" % args.task, exist_ok=True) | |
with open(args.out_path + "%s/commandline_args.yaml" % args.task, "w") as f: | |
json.dump(args.__dict__, f, indent=2) | |
def main(args): | |
# ======================================define the model====================================== | |
net = InvISPNet(channel_in=3, channel_out=3, block_num=8) | |
net.cuda() | |
# load the pretrained weight if there exists one | |
if args.resume: | |
net.load_state_dict( | |
torch.load(args.out_path + "%s/checkpoint/latest.pth" % args.task) | |
) | |
print("[INFO] loaded " + args.out_path + "%s/checkpoint/latest.pth" % args.task) | |
optimizer = torch.optim.Adam(net.parameters(), lr=args.lr) | |
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[50, 80], gamma=0.5) | |
print("[INFO] Start data loading and preprocessing") | |
RAWDataset = FiveKDatasetTrain(opt=args) | |
dataloader = DataLoader( | |
RAWDataset, | |
batch_size=args.batch_size, | |
shuffle=True, | |
num_workers=0, | |
drop_last=True, | |
) | |
print("[INFO] Start to train") | |
step = 0 | |
for epoch in range(0, 300): | |
epoch_time = time.time() | |
for i_batch, sample_batched in enumerate(dataloader): | |
step_time = time.time() | |
input, target_rgb, target_raw = ( | |
sample_batched["input_raw"].cuda(), | |
sample_batched["target_rgb"].cuda(), | |
sample_batched["target_raw"].cuda(), | |
) | |
reconstruct_rgb = net(input) | |
reconstruct_rgb = torch.clamp(reconstruct_rgb, 0, 1) | |
rgb_loss = F.l1_loss(reconstruct_rgb, target_rgb) | |
reconstruct_rgb = DiffJPEG(reconstruct_rgb) | |
reconstruct_raw = net(reconstruct_rgb, rev=True) | |
raw_loss = F.l1_loss(reconstruct_raw, target_raw) | |
loss = args.rgb_weight * rgb_loss + raw_loss | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
print( | |
"task: %s Epoch: %d Step: %d || loss: %.5f raw_loss: %.5f rgb_loss: %.5f || lr: %f time: %f" | |
% ( | |
args.task, | |
epoch, | |
step, | |
loss.detach().cpu().numpy(), | |
raw_loss.detach().cpu().numpy(), | |
rgb_loss.detach().cpu().numpy(), | |
optimizer.param_groups[0]["lr"], | |
time.time() - step_time, | |
) | |
) | |
step += 1 | |
torch.save( | |
net.state_dict(), args.out_path + "%s/checkpoint/latest.pth" % args.task | |
) | |
if (epoch + 1) % 10 == 0: | |
# os.makedirs(args.out_path+"%s/checkpoint/%04d"%(args.task,epoch), exist_ok=True) | |
torch.save( | |
net.state_dict(), | |
args.out_path + "%s/checkpoint/%04d.pth" % (args.task, epoch), | |
) | |
print( | |
"[INFO] Successfully saved " | |
+ args.out_path | |
+ "%s/checkpoint/%04d.pth" % (args.task, epoch) | |
) | |
scheduler.step() | |
print("[INFO] Epoch time: ", time.time() - epoch_time, "task: ", args.task) | |
if __name__ == "__main__": | |
torch.set_num_threads(4) | |
main(args) | |