Files changed (2) hide show
  1. app.py +33 -17
  2. arguments.py +8 -10
app.py CHANGED
@@ -78,25 +78,36 @@ def setup_model(loaded_model_setup, prompt, model, seed, num_iterations, enable_
78
  args.save_all_images = True
79
 
80
  if enable_hps is True:
81
- args.disable_hps = False
82
  args.hps_weighting = hps_w
 
 
83
 
84
  if enable_imagereward is True:
85
- args.disable_imagereward = False
86
  args.imagereward_weighting = imgrw_w
 
 
87
 
88
  if enable_pickscore is True:
89
- args.disable_pickscore = False
90
  args.pickscore_weighting = pcks_w
 
 
91
 
92
  if enable_clip is True:
93
- args.disable_clip = False
94
  args.clip_weighting = clip_w
 
 
95
 
96
  if model == "flux":
97
  args.cpu_offloading = True
98
  args.enable_multi_apply = True
99
  args.multi_step_model = "flux"
 
 
 
100
 
101
  # Check if args are the same as the loaded_model_setup except for the prompt
102
  if loaded_model_setup and hasattr(loaded_model_setup[0], '__dict__'):
@@ -264,7 +275,12 @@ def combined_function(gallery_state, loaded_model_setup, prompt, chosen_model, s
264
 
265
  # Create Gradio interface
266
  title="# ReNO: Enhancing One-step Text-to-Image Models through Reward-based Noise Optimization"
267
- description="Enter a prompt to generate an image using ReNO. Adjust the model and parameters as needed."
 
 
 
 
 
268
 
269
  css="""
270
  #model-status-id{
@@ -299,28 +315,28 @@ with gr.Blocks(css=css, analytics_enabled=False) as demo:
299
  with gr.Column():
300
  prompt = gr.Textbox(label="Prompt")
301
  with gr.Row():
302
- chosen_model = gr.Dropdown(["sd-turbo", "sdxl-turbo", "pixart", "hyper-sd", "flux"], label="Model", value="sd-turbo")
303
  seed = gr.Number(label="seed", value=0)
304
 
305
  model_status = gr.Textbox(label="model status", visible=True, elem_id="model-status-id")
306
 
307
  with gr.Row():
308
- n_iter = gr.Slider(minimum=10, maximum=100, step=10, value=10, label="Number of Iterations")
309
- learning_rate = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, value=5.0, label="Learning Rate")
310
 
311
  with gr.Accordion("Advanced Settings", open=True):
312
  with gr.Column():
313
  with gr.Row():
314
- enable_hps = gr.Checkbox(label="HPS ON", value=False, scale=1)
315
  hps_w = gr.Slider(label="HPS weight", step=0.1, minimum=0.0, maximum=10.0, value=5.0, interactive=False, scale=3)
316
  with gr.Row():
317
- enable_imagereward = gr.Checkbox(label="ImageReward ON", value=False, scale=1)
318
  imgrw_w = gr.Slider(label="ImageReward weight", step=0.1, minimum=0, maximum=5.0, value=1.0, interactive=False, scale=3)
319
  with gr.Row():
320
- enable_pickscore = gr.Checkbox(label="PickScore ON", value=False, scale=1)
321
- pcks_w = gr.Slider(label="PickScore weight", step=0.01, minimum=0, maximum=5.0, value=0.05, interactive=False, scale=3)
322
  with gr.Row():
323
- enable_clip = gr.Checkbox(label="CLIP ON", value=False, scale=1)
324
  clip_w = gr.Slider(label="CLIP weight", step=0.01, minimum=0, maximum=0.1, value=0.01, interactive=False, scale=3)
325
 
326
  submit_btn = gr.Button("Submit")
@@ -328,11 +344,11 @@ with gr.Blocks(css=css, analytics_enabled=False) as demo:
328
  gr.Examples(
329
  examples = [
330
  "A red dog and a green cat",
331
- "A pink elephant and a grey cow",
332
- "A toaster riding a bike",
333
- "Dwayne Johnson depicted as a philosopher king in an academic painting by Greg Rutkowski",
334
  "A curious, orange fox and a fluffy, white rabbit, playing together in a lush, green meadow filled with yellow dandelions",
335
- "An epic oil painting: a red portal infront of a cityscape, a solitary figure, and a colorful sky over snowy mountains"
 
 
336
  ],
337
  inputs = [prompt]
338
  )
 
78
  args.save_all_images = True
79
 
80
  if enable_hps is True:
81
+ args.enable_hps = True
82
  args.hps_weighting = hps_w
83
+ else:
84
+ args.enable_hps = False
85
 
86
  if enable_imagereward is True:
87
+ args.enable_imagereward = True
88
  args.imagereward_weighting = imgrw_w
89
+ else:
90
+ args.enable_imagereward = False
91
 
92
  if enable_pickscore is True:
93
+ args.enable_pickscore = True
94
  args.pickscore_weighting = pcks_w
95
+ else:
96
+ args.enable_pickscore = False
97
 
98
  if enable_clip is True:
99
+ args.enable_clip = True
100
  args.clip_weighting = clip_w
101
+ else:
102
+ args.enable_clip = False
103
 
104
  if model == "flux":
105
  args.cpu_offloading = True
106
  args.enable_multi_apply = True
107
  args.multi_step_model = "flux"
108
+
109
+ if model == "hyper-sd":
110
+ args.cpu_offloading = True
111
 
112
  # Check if args are the same as the loaded_model_setup except for the prompt
113
  if loaded_model_setup and hasattr(loaded_model_setup[0], '__dict__'):
 
275
 
276
  # Create Gradio interface
277
  title="# ReNO: Enhancing One-step Text-to-Image Models through Reward-based Noise Optimization"
278
+ description = "Enter a prompt to generate an image using ReNO. The method enhances text-to-image generation by optimizing \
279
+ the initial noise using reward models as detailed in the paper. The demo uses a lower learning rate (2.5) compared to the paper's default (5.0) \
280
+ for smoother trajectories - if you are looking for more dramatic changes, you can increase this value. You can also \
281
+ adjust the reward weights to e.g. prioritize either prompt following (increase ImageReward) or aesthetic quality \
282
+ (increase HPS/PickScore) based on your preferences.\n\nThe first time you load this demo, it will take a bit \
283
+ to download and initialize the required model. Once loaded, each optimization run takes about 25-60 seconds."
284
 
285
  css="""
286
  #model-status-id{
 
315
  with gr.Column():
316
  prompt = gr.Textbox(label="Prompt")
317
  with gr.Row():
318
+ chosen_model = gr.Dropdown(["sd-turbo", "sdxl-turbo", "pixart", "hyper-sd", "flux"], label="Model", value="sdxl-turbo")
319
  seed = gr.Number(label="seed", value=0)
320
 
321
  model_status = gr.Textbox(label="model status", visible=True, elem_id="model-status-id")
322
 
323
  with gr.Row():
324
+ n_iter = gr.Slider(minimum=10, maximum=100, step=10, value=50, label="Number of Iterations")
325
+ learning_rate = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, value=2.5, label="Learning Rate")
326
 
327
  with gr.Accordion("Advanced Settings", open=True):
328
  with gr.Column():
329
  with gr.Row():
330
+ enable_hps = gr.Checkbox(label="HPS ON", value=True, scale=1)
331
  hps_w = gr.Slider(label="HPS weight", step=0.1, minimum=0.0, maximum=10.0, value=5.0, interactive=False, scale=3)
332
  with gr.Row():
333
+ enable_imagereward = gr.Checkbox(label="ImageReward ON", value=True, scale=1)
334
  imgrw_w = gr.Slider(label="ImageReward weight", step=0.1, minimum=0, maximum=5.0, value=1.0, interactive=False, scale=3)
335
  with gr.Row():
336
+ enable_pickscore = gr.Checkbox(label="PickScore ON", value=True, scale=1)
337
+ pcks_w = gr.Slider(label="PickScore weight", step=0.01, minimum=0, maximum=0.5, value=0.05, interactive=False, scale=3)
338
  with gr.Row():
339
+ enable_clip = gr.Checkbox(label="CLIP ON", value=True, scale=1)
340
  clip_w = gr.Slider(label="CLIP weight", step=0.01, minimum=0, maximum=0.1, value=0.01, interactive=False, scale=3)
341
 
342
  submit_btn = gr.Button("Submit")
 
344
  gr.Examples(
345
  examples = [
346
  "A red dog and a green cat",
347
+ "A blue scooter is parked near a curb in front of a green vintage car",
 
 
348
  "A curious, orange fox and a fluffy, white rabbit, playing together in a lush, green meadow filled with yellow dandelions",
349
+ "An orange chair to the right of a black airplane"
350
+ "A toaster riding a bike",
351
+ "A brain riding a rocketship towards the moon",
352
  ],
353
  inputs = [prompt]
354
  )
arguments.py CHANGED
@@ -39,16 +39,15 @@ def parse_args():
39
 
40
  # reward losses
41
  parser.add_argument(
42
- "--disable_hps", default=True, action="store_false", dest="enable_hps"
43
  )
44
  parser.add_argument(
45
  "--hps_weighting", type=float, help="Weighting for HPS", default=5.0
46
  )
47
  parser.add_argument(
48
- "--disable_imagereward",
49
- default=True,
50
- action="store_false",
51
- dest="enable_imagereward",
52
  )
53
  parser.add_argument(
54
  "--imagereward_weighting",
@@ -57,16 +56,15 @@ def parse_args():
57
  default=1.0,
58
  )
59
  parser.add_argument(
60
- "--disable_clip", default=True, action="store_false", dest="enable_clip"
61
  )
62
  parser.add_argument(
63
  "--clip_weighting", type=float, help="Weighting for CLIP", default=0.01
64
  )
65
  parser.add_argument(
66
- "--disable_pickscore",
67
- default=True,
68
- action="store_false",
69
- dest="enable_pickscore",
70
  )
71
  parser.add_argument(
72
  "--pickscore_weighting",
 
39
 
40
  # reward losses
41
  parser.add_argument(
42
+ "--enable_hps", default=False, action="store_true",
43
  )
44
  parser.add_argument(
45
  "--hps_weighting", type=float, help="Weighting for HPS", default=5.0
46
  )
47
  parser.add_argument(
48
+ "--enable_imagereward",
49
+ default=False,
50
+ action="store_true",
 
51
  )
52
  parser.add_argument(
53
  "--imagereward_weighting",
 
56
  default=1.0,
57
  )
58
  parser.add_argument(
59
+ "--enable_clip", default=False, action="store_true"
60
  )
61
  parser.add_argument(
62
  "--clip_weighting", type=float, help="Weighting for CLIP", default=0.01
63
  )
64
  parser.add_argument(
65
+ "--enable_pickscore",
66
+ default=False,
67
+ action="store_true",
 
68
  )
69
  parser.add_argument(
70
  "--pickscore_weighting",