Ricercar commited on
Commit
4466817
1 Parent(s): 781d7f4

try gradio state

Browse files
Files changed (1) hide show
  1. app.py +38 -5
app.py CHANGED
@@ -218,8 +218,37 @@ class WebApp():
218
  print(f"Error catched: {e}")
219
  gr.Markdown(f"**Error catched: {e}**")
220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  def run_example(self, img, prompt, inv_model, spl_model, lora):
222
- return self.run_ditail(img, prompt, spl_model, gr.State(lora), inv_model)
223
 
224
  def show_credits(self):
225
  # gr.Markdown(
@@ -240,6 +269,7 @@ class WebApp():
240
 
241
  def ui(self):
242
  with gr.Blocks(css='.input_image img {object-fit: contain;}', head=self.ga_script) as demo:
 
243
  self.title()
244
  with gr.Row():
245
  # with gr.Column():
@@ -259,9 +289,12 @@ class WebApp():
259
  # expected_output_image = gr.Image(label="expected output image", visible=False)
260
  metadata = gr.JSON(label='metadata')
261
 
262
- submit_btn.click(self.run_ditail,
263
- inputs=list(self.args_input.values()),
264
- outputs=[output_image, metadata],
 
 
 
265
  scroll_to_output=True,
266
  )
267
 
@@ -271,7 +304,7 @@ class WebApp():
271
  examples=[[os.path.join(os.path.dirname(__file__), "example", "Lenna.png"), 'a woman called Lenna wearing a feathered hat', list(BASE_MODEL.keys())[1], list(BASE_MODEL.keys())[2], 'none']],
272
  inputs=[self.args_input['img'], self.args_input['pos_prompt'], self.args_input['inv_model'], self.args_input['spl_model'], gr.Textbox(label='LoRA', visible=False), ],
273
  fn = self.run_example,
274
- outputs=[output_image, metadata],
275
  run_on_click=True,
276
  cache_examples=cache_examples,
277
  )
 
218
  print(f"Error catched: {e}")
219
  gr.Markdown(f"**Error catched: {e}**")
220
 
221
+ def run_ditail_alt(self, gr_args, *values):
222
+ gr_args = self.args_base.copy()
223
+ print(self.args_input.keys())
224
+ for k, v in zip(list(self.args_input.keys()), values):
225
+ gr_args[k] = v
226
+ # quick fix for example
227
+ gr_args['lora'] = 'none' if not isinstance(gr_args['lora'], str) else gr_state['lora']
228
+ print('selected lora: ', gr_args['lora'])
229
+ # map inversion model to url
230
+ gr_args['pos_prompt'] = ', '.join(LORA_TRIGGER_WORD.get(gr_args['lora'], [])+[gr_args['pos_prompt']])
231
+ gr_args['inv_model'] = BASE_MODEL[gr_args['inv_model']]
232
+ gr_args['spl_model'] = BASE_MODEL[gr_args['spl_model']]
233
+ print('selected model: ', gr_args['inv_model'], gr_args['spl_model'])
234
+
235
+ seed_everything(gr_args['seed'])
236
+ ditail = DitailDemo(gr_args)
237
+
238
+ metadata_to_show = ['inv_model', 'spl_model', 'lora', 'lora_scale', 'inv_steps', 'spl_steps', 'pos_prompt', 'alpha', 'neg_prompt', 'beta', 'omega']
239
+ self.args_to_show = {}
240
+ for key in metadata_to_show:
241
+ self.args_to_show[key] = gr_args[key]
242
+
243
+ img = ditail.run_ditail()
244
+
245
+ # reset ditail
246
+ ditail = None
247
+
248
+ return gr_args, img, self.args_to_show
249
+
250
  def run_example(self, img, prompt, inv_model, spl_model, lora):
251
+ return self.run_ditail_alt(self.gr_state, img, prompt, spl_model, gr.State(lora), inv_model)
252
 
253
  def show_credits(self):
254
  # gr.Markdown(
 
269
 
270
  def ui(self):
271
  with gr.Blocks(css='.input_image img {object-fit: contain;}', head=self.ga_script) as demo:
272
+
273
  self.title()
274
  with gr.Row():
275
  # with gr.Column():
 
289
  # expected_output_image = gr.Image(label="expected output image", visible=False)
290
  metadata = gr.JSON(label='metadata')
291
 
292
+ # init a gradio state
293
+ self.gr_state = gr.State()
294
+
295
+ submit_btn.click(self.run_ditail_alt,
296
+ inputs=[self.gr_state] + list(self.args_input.values()),
297
+ outputs=[self.gr_state, output_image, metadata],
298
  scroll_to_output=True,
299
  )
300
 
 
304
  examples=[[os.path.join(os.path.dirname(__file__), "example", "Lenna.png"), 'a woman called Lenna wearing a feathered hat', list(BASE_MODEL.keys())[1], list(BASE_MODEL.keys())[2], 'none']],
305
  inputs=[self.args_input['img'], self.args_input['pos_prompt'], self.args_input['inv_model'], self.args_input['spl_model'], gr.Textbox(label='LoRA', visible=False), ],
306
  fn = self.run_example,
307
+ outputs=[self.gr_state, output_image, metadata],
308
  run_on_click=True,
309
  cache_examples=cache_examples,
310
  )