|
|
|
|
|
import argparse |
|
import os, shutil, sys |
|
import time |
|
import warnings |
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
root_path = os.path.abspath('.') |
|
sys.path.append(root_path) |
|
from opt import opt |
|
|
|
|
|
def storage_manage(): |
|
if not os.path.exists("runs_last/"): |
|
os.makedirs("runs_last/") |
|
|
|
|
|
new_address = "runs_last/"+str(int(time.time()))+"/" |
|
shutil.copytree("runs/", new_address) |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--auto_resume_closest', action='store_true') |
|
parser.add_argument('--auto_resume_best', action='store_true') |
|
parser.add_argument('--pretrained_path', type = str, default="") |
|
|
|
global args |
|
args = parser.parse_args() |
|
|
|
|
|
if args.auto_resume_closest and args.auto_resume_best: |
|
print("you could only resume either nearest or best, not both") |
|
os._exit(0) |
|
|
|
|
|
|
|
if not args.auto_resume_closest and not args.auto_resume_best: |
|
|
|
if os.path.exists("./runs"): |
|
storage_manage() |
|
shutil.rmtree("./runs") |
|
|
|
|
|
def folder_prepare(): |
|
def _make_folder(folder_name): |
|
if not os.path.exists(folder_name): |
|
os.makedirs(folder_name) |
|
|
|
def _delete_and_make_folder(folder_name): |
|
if os.path.exists(folder_name): |
|
shutil.rmtree(folder_name) |
|
os.makedirs(folder_name) |
|
|
|
|
|
make_folder_name_lists = ["saved_models/", "saved_models/checkpoints/", "datasets/"] |
|
delete_and_make_folder_name_lists = [] |
|
|
|
for folder_name in make_folder_name_lists: |
|
_make_folder(folder_name) |
|
|
|
for folder_name in delete_and_make_folder_name_lists: |
|
_delete_and_make_folder(folder_name) |
|
|
|
|
|
|
|
def process(options): |
|
print(args) |
|
start = time.time() |
|
|
|
|
|
if options['architecture'] == "ESRNET": |
|
from train_esrnet import train_esrnet |
|
obj = train_esrnet(options, args) |
|
elif options['architecture'] == "ESRGAN": |
|
from train_esrgan import train_esrgan |
|
obj = train_esrgan(options, args) |
|
elif options['architecture'] == "GRL": |
|
from train_grl import train_grl |
|
obj = train_grl(options, args) |
|
elif options['architecture'] == "GRLGAN": |
|
from train_grlgan import train_grlgan |
|
obj = train_grlgan(options, args) |
|
elif options['architecture'] == "CUNET": |
|
from train_cunet import train_cunet |
|
obj = train_cunet(options, args) |
|
elif options['architecture'] == "CUGAN": |
|
from train_cugan import train_cugan |
|
obj = train_cugan(options, args) |
|
else: |
|
raise NotImplementedError("This is not a supported model architecture") |
|
|
|
|
|
obj.run() |
|
|
|
total_time = time.time() - start |
|
print("All programs spent {} hour {} min {} s".format(str(total_time//3600), str((total_time%3600)//60), str(total_time%3600))) |
|
|
|
|
|
def main(): |
|
parse_args() |
|
|
|
folder_prepare() |
|
process(opt) |
|
|
|
if __name__ == "__main__": |
|
main() |