try gradio state
Browse files
app.py
CHANGED
@@ -218,8 +218,37 @@ class WebApp():
|
|
218 |
print(f"Error catched: {e}")
|
219 |
gr.Markdown(f"**Error catched: {e}**")
|
220 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
221 |
def run_example(self, img, prompt, inv_model, spl_model, lora):
|
222 |
-
return self.
|
223 |
|
224 |
def show_credits(self):
|
225 |
# gr.Markdown(
|
@@ -240,6 +269,7 @@ class WebApp():
|
|
240 |
|
241 |
def ui(self):
|
242 |
with gr.Blocks(css='.input_image img {object-fit: contain;}', head=self.ga_script) as demo:
|
|
|
243 |
self.title()
|
244 |
with gr.Row():
|
245 |
# with gr.Column():
|
@@ -259,9 +289,12 @@ class WebApp():
|
|
259 |
# expected_output_image = gr.Image(label="expected output image", visible=False)
|
260 |
metadata = gr.JSON(label='metadata')
|
261 |
|
262 |
-
|
263 |
-
|
264 |
-
|
|
|
|
|
|
|
265 |
scroll_to_output=True,
|
266 |
)
|
267 |
|
@@ -271,7 +304,7 @@ class WebApp():
|
|
271 |
examples=[[os.path.join(os.path.dirname(__file__), "example", "Lenna.png"), 'a woman called Lenna wearing a feathered hat', list(BASE_MODEL.keys())[1], list(BASE_MODEL.keys())[2], 'none']],
|
272 |
inputs=[self.args_input['img'], self.args_input['pos_prompt'], self.args_input['inv_model'], self.args_input['spl_model'], gr.Textbox(label='LoRA', visible=False), ],
|
273 |
fn = self.run_example,
|
274 |
-
outputs=[output_image, metadata],
|
275 |
run_on_click=True,
|
276 |
cache_examples=cache_examples,
|
277 |
)
|
|
|
218 |
print(f"Error catched: {e}")
|
219 |
gr.Markdown(f"**Error catched: {e}**")
|
220 |
|
221 |
+
def run_ditail_alt(self, gr_args, *values):
|
222 |
+
gr_args = self.args_base.copy()
|
223 |
+
print(self.args_input.keys())
|
224 |
+
for k, v in zip(list(self.args_input.keys()), values):
|
225 |
+
gr_args[k] = v
|
226 |
+
# quick fix for example
|
227 |
+
gr_args['lora'] = 'none' if not isinstance(gr_args['lora'], str) else gr_state['lora']
|
228 |
+
print('selected lora: ', gr_args['lora'])
|
229 |
+
# map inversion model to url
|
230 |
+
gr_args['pos_prompt'] = ', '.join(LORA_TRIGGER_WORD.get(gr_args['lora'], [])+[gr_args['pos_prompt']])
|
231 |
+
gr_args['inv_model'] = BASE_MODEL[gr_args['inv_model']]
|
232 |
+
gr_args['spl_model'] = BASE_MODEL[gr_args['spl_model']]
|
233 |
+
print('selected model: ', gr_args['inv_model'], gr_args['spl_model'])
|
234 |
+
|
235 |
+
seed_everything(gr_args['seed'])
|
236 |
+
ditail = DitailDemo(gr_args)
|
237 |
+
|
238 |
+
metadata_to_show = ['inv_model', 'spl_model', 'lora', 'lora_scale', 'inv_steps', 'spl_steps', 'pos_prompt', 'alpha', 'neg_prompt', 'beta', 'omega']
|
239 |
+
self.args_to_show = {}
|
240 |
+
for key in metadata_to_show:
|
241 |
+
self.args_to_show[key] = gr_args[key]
|
242 |
+
|
243 |
+
img = ditail.run_ditail()
|
244 |
+
|
245 |
+
# reset ditail
|
246 |
+
ditail = None
|
247 |
+
|
248 |
+
return gr_args, img, self.args_to_show
|
249 |
+
|
250 |
def run_example(self, img, prompt, inv_model, spl_model, lora):
|
251 |
+
return self.run_ditail_alt(self.gr_state, img, prompt, spl_model, gr.State(lora), inv_model)
|
252 |
|
253 |
def show_credits(self):
|
254 |
# gr.Markdown(
|
|
|
269 |
|
270 |
def ui(self):
|
271 |
with gr.Blocks(css='.input_image img {object-fit: contain;}', head=self.ga_script) as demo:
|
272 |
+
|
273 |
self.title()
|
274 |
with gr.Row():
|
275 |
# with gr.Column():
|
|
|
289 |
# expected_output_image = gr.Image(label="expected output image", visible=False)
|
290 |
metadata = gr.JSON(label='metadata')
|
291 |
|
292 |
+
# init a gradio state
|
293 |
+
self.gr_state = gr.State()
|
294 |
+
|
295 |
+
submit_btn.click(self.run_ditail_alt,
|
296 |
+
inputs=[self.gr_state] + list(self.args_input.values()),
|
297 |
+
outputs=[self.gr_state, output_image, metadata],
|
298 |
scroll_to_output=True,
|
299 |
)
|
300 |
|
|
|
304 |
examples=[[os.path.join(os.path.dirname(__file__), "example", "Lenna.png"), 'a woman called Lenna wearing a feathered hat', list(BASE_MODEL.keys())[1], list(BASE_MODEL.keys())[2], 'none']],
|
305 |
inputs=[self.args_input['img'], self.args_input['pos_prompt'], self.args_input['inv_model'], self.args_input['spl_model'], gr.Textbox(label='LoRA', visible=False), ],
|
306 |
fn = self.run_example,
|
307 |
+
outputs=[self.gr_state, output_image, metadata],
|
308 |
run_on_click=True,
|
309 |
cache_examples=cache_examples,
|
310 |
)
|