|
import argparse |
|
import os |
|
import sys |
|
import time |
|
from typing import List, Optional |
|
|
|
import prettytable as pt |
|
import torch |
|
import yaml |
|
from termcolor import cprint |
|
|
|
|
|
def load_dataset_arguments(cfg_path, opt): |
|
if opt.load is None and cfg_path is None: |
|
return |
|
|
|
|
|
if len(sys.argv) > 1: |
|
arguments = sys.argv[1:] |
|
arguments = list( |
|
map(lambda x: x.replace("--", ""), filter(lambda x: "--" in x, arguments)) |
|
) |
|
else: |
|
arguments = [] |
|
|
|
|
|
if cfg_path is not None: |
|
opt.load = cfg_path |
|
else: |
|
assert os.path.exists(opt.load) |
|
with open(opt.load, "r") as f: |
|
yaml_arguments = yaml.safe_load(f) |
|
|
|
for k, v in yaml_arguments.items(): |
|
if not k in arguments: |
|
setattr(opt, k, v) |
|
|
|
|
|
def get_opt(cfg_path: Optional[str] = None, additional_parsers: Optional[List] = None): |
|
parents = [get_arguments_parser()] |
|
if additional_parsers: |
|
parents.extend(additional_parsers) |
|
parser = argparse.ArgumentParser( |
|
"Options for training and evaluation", parents=parents, allow_abbrev=False |
|
) |
|
opt = parser.parse_known_args()[0] |
|
|
|
|
|
load_dataset_arguments(cfg_path, opt) |
|
|
|
|
|
if opt.decoder.lower() not in ["c1"]: |
|
cprint("Not supported yet! Check if the output use log_softmax!", "red") |
|
time.sleep(3) |
|
|
|
if opt.map_mask_weight > 0.0 or opt.volume_mask_weight > 0.0: |
|
cprint("Mask loss is not 0!", "red") |
|
time.sleep(3) |
|
|
|
if opt.val_set != "val": |
|
cprint(f"Evaluating on {opt.val_set} set!", "red") |
|
time.sleep(3) |
|
|
|
if opt.mvc_spixel: |
|
assert ( |
|
not opt.loss_on_mid_map |
|
), "Middle map supervision is not supported with spixel!" |
|
|
|
if "early" in opt.modality: |
|
assert ( |
|
len(opt.modality) == 1 |
|
), "Early fusion is not supported for multi-modality!" |
|
for modal in opt.modality: |
|
assert modal in [ |
|
"rgb", |
|
"srm", |
|
"bayar", |
|
"early", |
|
], f"Unsupported modality {modal}!" |
|
|
|
if opt.resume: |
|
assert os.path.exists(opt.resume) |
|
|
|
|
|
|
|
|
|
|
|
if len(sys.argv) > 1: |
|
arguments = sys.argv[1:] |
|
arguments = list( |
|
map(lambda x: x.replace("--", ""), filter(lambda x: "--" in x, arguments)) |
|
) |
|
params = [] |
|
for argument in arguments: |
|
if not argument in [ |
|
"suffix", |
|
"save_root_path", |
|
"dataset", |
|
"source", |
|
"resume", |
|
"num_workers", |
|
"eval_freq", |
|
"print_freq", |
|
"lr_steps", |
|
"rgb_resume", |
|
"srm_resume", |
|
"bayar_resume", |
|
"teacher_resume", |
|
"occ", |
|
"load", |
|
"amp_opt_level", |
|
"val_shuffle", |
|
"tile_size", |
|
"modality", |
|
]: |
|
try: |
|
value = ( |
|
str(eval("opt.{}".format(argument.split("=")[0]))) |
|
.replace("[", "") |
|
.replace("]", "") |
|
.replace(" ", "-") |
|
.replace(",", "") |
|
) |
|
params.append( |
|
argument.split("=")[0].replace("_", "").replace(" ", "") |
|
+ "=" |
|
+ value |
|
) |
|
except: |
|
cprint("Unknown argument: {}".format(argument), "red") |
|
if "early" in opt.modality: |
|
params.append("modality=early") |
|
test_name = "_".join(params) |
|
|
|
else: |
|
test_name = "" |
|
|
|
time_stamp = time.strftime("%b-%d-%H-%M-%S", time.localtime()) |
|
dir_name = "{}_{}{}_{}".format( |
|
"-".join(list(opt.train_datalist.keys())).upper(), |
|
test_name, |
|
opt.suffix, |
|
time_stamp, |
|
).replace("__", "_") |
|
|
|
opt.time_stamp = time_stamp |
|
opt.dir_name = dir_name |
|
opt.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
if opt.debug or opt.wholetest: |
|
opt.val_shuffle = True |
|
cprint("Setting val_shuffle to True in debug and wholetest mode!", "red") |
|
time.sleep(3) |
|
|
|
if len(opt.modality) < 2 and opt.mvc_weight != 0.0: |
|
opt.mvc_weight = 0.0 |
|
cprint( |
|
"Setting multi-view consistency weight to 0. for single modality training", |
|
"red", |
|
) |
|
time.sleep(3) |
|
|
|
if "early" in opt.modality: |
|
opt.mvc_single_weight = {"early": 1.0} |
|
else: |
|
if "rgb" not in opt.modality: |
|
opt.mvc_single_weight[0] = 0.0 |
|
if "srm" not in opt.modality: |
|
opt.mvc_single_weight[1] = 0.0 |
|
if "bayar" not in opt.modality: |
|
opt.mvc_single_weight[2] = 0.0 |
|
weight_sum = sum(opt.mvc_single_weight) |
|
single_weight = list(map(lambda x: x / weight_sum, opt.mvc_single_weight)) |
|
opt.mvc_single_weight = { |
|
"rgb": single_weight[0], |
|
"srm": single_weight[1], |
|
"bayar": single_weight[2], |
|
} |
|
cprint( |
|
"Change mvc single modality weight to {}".format(opt.mvc_single_weight), "blue" |
|
) |
|
time.sleep(3) |
|
|
|
|
|
tb = pt.PrettyTable(field_names=["Arguments", "Values"]) |
|
for k, v in vars(opt).items(): |
|
|
|
if k not in ["dir_name", "resume", "rgb_resume", "srm_resume", "bayar_resume"]: |
|
tb.add_row([k, v]) |
|
print(tb) |
|
|
|
return opt |
|
|
|
|
|
def get_arguments_parser(): |
|
parser = argparse.ArgumentParser( |
|
"CVPR2022 image manipulation detection model", add_help=False |
|
) |
|
parser.add_argument("--debug", action="store_true", default=False) |
|
parser.add_argument("--wholetest", action="store_true", default=False) |
|
|
|
parser.add_argument( |
|
"--load", default="configs/final.yaml", help="Load configuration YAML file." |
|
) |
|
parser.add_argument("--num_class", type=int, default=1, help="Use sigmoid.") |
|
|
|
|
|
parser.add_argument("--map_label_weight", type=float, default=1.0) |
|
parser.add_argument("--volume_label_weight", type=float, default=1.0) |
|
parser.add_argument( |
|
"--map_mask_weight", |
|
type=float, |
|
default=0.0, |
|
help="Only use this for debug purpose.", |
|
) |
|
parser.add_argument( |
|
"--volume_mask_weight", |
|
type=float, |
|
default=0.0, |
|
help="Only use this for debug purpose.", |
|
) |
|
parser.add_argument( |
|
"--consistency_weight", |
|
type=float, |
|
default=0.0, |
|
help="Consitency between output map and volume within a single view.", |
|
) |
|
parser.add_argument( |
|
"--consistency_type", type=str, default="l2", choices=["l1", "l2"] |
|
) |
|
parser.add_argument( |
|
"--consistency_kmeans", |
|
action="store_true", |
|
default=False, |
|
help="Perform k-means on the volume to determine pristine and modified areas.", |
|
) |
|
parser.add_argument( |
|
"--consistency_stop_map_grad", |
|
action="store_true", |
|
default=False, |
|
help="Stop gradient for the map.", |
|
) |
|
parser.add_argument( |
|
"--consistency_source", type=str, default="self", choices=["self", "ensemble"] |
|
) |
|
parser.add_argument("--map_entropy_weight", type=float, default=0.0) |
|
parser.add_argument("--volume_entropy_weight", type=float, default=0.0) |
|
parser.add_argument("--mvc_weight", type=float, default=0.0) |
|
parser.add_argument( |
|
"--mvc_time_dependent", |
|
action="store_true", |
|
default=False, |
|
help="Use Gaussian smooth on the MVCW weight.", |
|
) |
|
parser.add_argument("--mvc_soft", action="store_true", default=False) |
|
parser.add_argument("--mvc_zeros_on_au", action="store_true", default=False) |
|
parser.add_argument( |
|
"--mvc_single_weight", |
|
type=float, |
|
nargs="+", |
|
default=[1.0, 1.0, 1.0], |
|
help="Weight for the RGB, SRM and Bayar modality for MVC training.", |
|
) |
|
parser.add_argument( |
|
"--mvc_steepness", type=float, default=5.0, help="The large the slower." |
|
) |
|
parser.add_argument("--mvc_spixel", action="store_true", default=False) |
|
parser.add_argument("--mvc_num_spixel", type=int, default=100) |
|
parser.add_argument( |
|
"--loss_on_mid_map", |
|
action="store_true", |
|
default=False, |
|
help="This only applies for the output map, but not for the consistency volume.", |
|
) |
|
parser.add_argument( |
|
"--label_loss_on_whole_map", |
|
action="store_true", |
|
default=False, |
|
help="Apply cls loss on the avg(map) for pristine images, instead of max(map).", |
|
) |
|
|
|
|
|
parser.add_argument("--modality", type=str, default=["rgb"], nargs="+") |
|
parser.add_argument("--srm_clip", type=float, default=5.0) |
|
parser.add_argument("--bayar_magnitude", type=float, default=1.0) |
|
parser.add_argument("--encoder", type=str, default="ResNet50") |
|
parser.add_argument("--encoder_weight", type=str, default="") |
|
parser.add_argument("--decoder", type=str, default="C1") |
|
parser.add_argument("--decoder_weight", type=str, default="") |
|
parser.add_argument( |
|
"--fc_dim", |
|
type=int, |
|
default=2048, |
|
help="Changing this might leads to error in the conjunction between encoder and decoder.", |
|
) |
|
parser.add_argument( |
|
"--volume_block_idx", |
|
type=int, |
|
default=1, |
|
choices=[0, 1, 2, 3], |
|
help="Compute the consistency volume at certain block.", |
|
) |
|
parser.add_argument("--share_embed_head", action="store_true", default=False) |
|
parser.add_argument( |
|
"--fcn_up", |
|
type=int, |
|
default=32, |
|
choices=[8, 16, 32], |
|
help="FCN architecture, 32s, 16s, or 8s.", |
|
) |
|
parser.add_argument("--gem", action="store_true", default=False) |
|
parser.add_argument("--gem_coef", type=float, default=100) |
|
parser.add_argument("--gsm", action="store_true", default=False) |
|
parser.add_argument( |
|
"--map_portion", |
|
type=float, |
|
default=0, |
|
help="Select topk portion of the output map for the image-level classification. 0 for use max.", |
|
) |
|
parser.add_argument("--otsu_sel", action="store_true", default=False) |
|
parser.add_argument("--otsu_portion", type=float, default=1.0) |
|
|
|
|
|
parser.add_argument("--no_gaussian_blur", action="store_true", default=False) |
|
parser.add_argument("--no_color_jitter", action="store_true", default=False) |
|
parser.add_argument("--no_jpeg_compression", action="store_true", default=False) |
|
parser.add_argument("--resize_aug", action="store_true", default=False) |
|
parser.add_argument( |
|
"--uncorrect_label", |
|
action="store_true", |
|
default=False, |
|
help="This will not correct image-level labels caused by image cropping.", |
|
) |
|
parser.add_argument("--input_size", type=int, default=224) |
|
parser.add_argument("--dropout", type=float, default=0.0) |
|
parser.add_argument( |
|
"--optimizer", type=str, default="adamw", choices=["sgd", "adamw"] |
|
) |
|
parser.add_argument("--resume", type=str, default="") |
|
parser.add_argument("--eval", action="store_true", default=False) |
|
parser.add_argument( |
|
"--val_set", |
|
type=str, |
|
default="val", |
|
choices=["train", "val"], |
|
help="Change to train for debug purpose.", |
|
) |
|
parser.add_argument( |
|
"--val_shuffle", action="store_true", default=False, help="Shuffle val set." |
|
) |
|
parser.add_argument("--save_figure", action="store_true", default=False) |
|
parser.add_argument("--figure_path", type=str, default="figures") |
|
parser.add_argument("--batch_size", type=int, default=36) |
|
parser.add_argument("--epochs", type=int, default=60) |
|
parser.add_argument("--eval_freq", type=int, default=3) |
|
parser.add_argument("--weight_decay", type=float, default=5e-4) |
|
parser.add_argument("--num_workers", type=int, default=36) |
|
parser.add_argument("--grad_clip", type=float, default=0.0) |
|
|
|
parser.add_argument( |
|
"--sched", |
|
default="cosine", |
|
type=str, |
|
metavar="SCHEDULER", |
|
help='LR scheduler (default: "cosine"', |
|
) |
|
parser.add_argument( |
|
"--lr", |
|
type=float, |
|
default=1e-4, |
|
metavar="LR", |
|
help="learning rate (default: 5e-4)", |
|
) |
|
parser.add_argument( |
|
"--lr-noise", |
|
type=float, |
|
nargs="+", |
|
default=None, |
|
metavar="pct, pct", |
|
help="learning rate noise on/off epoch percentages", |
|
) |
|
parser.add_argument( |
|
"--lr-noise-pct", |
|
type=float, |
|
default=0.67, |
|
metavar="PERCENT", |
|
help="learning rate noise limit percent (default: 0.67)", |
|
) |
|
parser.add_argument( |
|
"--lr-noise-std", |
|
type=float, |
|
default=1.0, |
|
metavar="STDDEV", |
|
help="learning rate noise std-dev (default: 1.0)", |
|
) |
|
parser.add_argument( |
|
"--warmup-lr", |
|
type=float, |
|
default=2e-7, |
|
metavar="LR", |
|
help="warmup learning rate (default: 1e-6)", |
|
) |
|
parser.add_argument( |
|
"--min-lr", |
|
type=float, |
|
default=2e-6, |
|
metavar="LR", |
|
help="lower lr bound for cyclic schedulers that hit 0 (1e-5)", |
|
) |
|
parser.add_argument( |
|
"--decay-epochs", |
|
type=float, |
|
default=20, |
|
metavar="N", |
|
help="epoch interval to decay LR", |
|
) |
|
parser.add_argument( |
|
"--warmup-epochs", |
|
type=int, |
|
default=5, |
|
metavar="N", |
|
help="epochs to warmup LR, if scheduler supports", |
|
) |
|
parser.add_argument( |
|
"--cooldown-epochs", |
|
type=int, |
|
default=5, |
|
metavar="N", |
|
help="epochs to cooldown LR at min_lr, after cyclic schedule ends", |
|
) |
|
parser.add_argument( |
|
"--patience-epochs", |
|
type=int, |
|
default=5, |
|
metavar="N", |
|
help="patience epochs for Plateau LR scheduler (default: 10", |
|
) |
|
parser.add_argument( |
|
"--decay-rate", |
|
"-dr", |
|
type=float, |
|
default=0.5, |
|
metavar="RATE", |
|
help="LR decay rate (default: 0.1)", |
|
) |
|
parser.add_argument("--lr_cycle_limit", "-lcl", type=int, default=1) |
|
parser.add_argument("--lr_cycle_mul", "-lcm", type=float, default=1) |
|
|
|
|
|
parser.add_argument("--mask_threshold", type=float, default=0.5) |
|
parser.add_argument( |
|
"-lis", |
|
"--large_image_strategy", |
|
choices=["rescale", "slide", "none"], |
|
default="slide", |
|
help="Slide will get better performance than rescale.", |
|
) |
|
parser.add_argument( |
|
"--tile_size", |
|
type=int, |
|
default=768, |
|
help="If the testing image is larger than tile_size, I will use sliding window to do the inference.", |
|
) |
|
parser.add_argument("--tile_overlap", type=float, default=0.1) |
|
parser.add_argument("--spixel_postproc", action="store_true", default=False) |
|
parser.add_argument("--convcrf_postproc", action="store_true", default=False) |
|
parser.add_argument("--convcrf_shape", type=int, default=512) |
|
parser.add_argument("--crf_postproc", action="store_true", default=False) |
|
parser.add_argument("--max_pool_postproc", type=int, default=1) |
|
parser.add_argument("--crf_downsample", type=int, default=1) |
|
parser.add_argument("--crf_iter_max", type=int, default=5) |
|
parser.add_argument("--crf_pos_w", type=int, default=3) |
|
parser.add_argument("--crf_pos_xy_std", type=int, default=1) |
|
parser.add_argument("--crf_bi_w", type=int, default=4) |
|
parser.add_argument("--crf_bi_xy_std", type=int, default=67) |
|
parser.add_argument("--crf_bi_rgb_std", type=int, default=3) |
|
|
|
|
|
parser.add_argument("--save_root_path", type=str, default="tmp") |
|
parser.add_argument("--suffix", type=str, default="") |
|
parser.add_argument("--print_freq", type=int, default=100) |
|
|
|
|
|
parser.add_argument("--seed", type=int, default=1) |
|
|
|
return parser |
|
|