Spaces:
Sleeping
Sleeping
File size: 4,278 Bytes
ca25718 dd8f929 b66671b dd8f929 ca25718 dd8f929 b66671b dd8f929 ca25718 dd8f929 b66671b dd8f929 ca25718 dd8f929 b66671b dd8f929 ca25718 dd8f929 ca25718 dd8f929 ca25718 dd8f929 ca25718 dd8f929 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import argparse
def parse_args():
parser = argparse.ArgumentParser(description="Process Reward Optimization.")
# update paths here!
parser.add_argument(
"--cache_dir",
type=str,
help="HF cache directory",
default="/shared-local/aoq951/HF_CACHE/",
)
parser.add_argument(
"--save_dir",
type=str,
help="Directory to save images",
default="/shared-local/aoq951/ReNO/outputs",
)
# model and optim
parser.add_argument("--model", type=str, help="Model to use", default="sdxl-turbo")
parser.add_argument("--lr", type=float, help="Learning rate", default=5.0)
parser.add_argument("--n_iters", type=int, help="Number of iterations", default=50)
parser.add_argument(
"--n_inference_steps", type=int, help="Number of iterations", default=1
)
parser.add_argument(
"--optim",
choices=["sgd", "adam", "lbfgs"],
default="sgd",
help="Optimizer to be used",
)
parser.add_argument("--nesterov", default=True, action="store_false")
parser.add_argument(
"--grad_clip", type=float, help="Gradient clipping", default=0.1
)
parser.add_argument("--seed", type=int, help="Seed to use", default=0)
# reward losses
parser.add_argument(
"--enable_hps", default=False, action="store_true",
)
parser.add_argument(
"--hps_weighting", type=float, help="Weighting for HPS", default=5.0
)
parser.add_argument(
"--enable_imagereward",
default=False,
action="store_true",
)
parser.add_argument(
"--imagereward_weighting",
type=float,
help="Weighting for ImageReward",
default=1.0,
)
parser.add_argument(
"--enable_clip", default=False, action="store_true"
)
parser.add_argument(
"--clip_weighting", type=float, help="Weighting for CLIP", default=0.01
)
parser.add_argument(
"--enable_pickscore",
default=False,
action="store_true",
)
parser.add_argument(
"--pickscore_weighting",
type=float,
help="Weighting for PickScore",
default=0.05,
)
parser.add_argument(
"--disable_aesthetic",
default=False,
action="store_false",
dest="enable_aesthetic",
)
parser.add_argument(
"--aesthetic_weighting",
type=float,
help="Weighting for Aesthetic",
default=0.0,
)
parser.add_argument(
"--disable_reg", default=True, action="store_false", dest="enable_reg"
)
parser.add_argument(
"--reg_weight", type=float, help="Regularization weight", default=0.01
)
# task specific
parser.add_argument(
"--task",
type=str,
help="Task to run",
default="single",
choices=[
"t2i-compbench",
"single",
"parti-prompts",
"geneval",
"example-prompts",
],
)
parser.add_argument(
"--prompt",
type=str,
help="Prompt to run",
default="A red dog and a green cat",
)
parser.add_argument(
"--benchmark_reward",
help="Reward to benchmark on",
default="total",
choices=["ImageReward", "PickScore", "HPS", "CLIP", "total"],
)
# general
parser.add_argument("--save_all_images", default=False, action="store_true")
parser.add_argument("--no_optim", default=False, action="store_true")
parser.add_argument("--imageselect", default=False, action="store_true")
parser.add_argument("--memsave", default=False, action="store_true")
parser.add_argument("--dtype", type=str, help="Data type to use", default="float16")
parser.add_argument("--device_id", type=str, help="Device ID to use", default=None)
parser.add_argument(
"--cpu_offloading",
help="Enable CPU offloading",
default=False,
action="store_true",
)
# optional multi-step model
parser.add_argument("--enable_multi_apply", default=False, action="store_true")
parser.add_argument(
"--multi_step_model", type=str, help="Model to use", default="flux"
)
args = parser.parse_args()
return args |