import os
import io
import base64
import json
from typing import Callable, Any
from PIL import Image
import gradio as gr
from common import utils as posex
project_dir = os.path.dirname(os.path.abspath(__file__))
print(project_dir)
if '__file__' in globals():
posex.set_save_dir(os.path.join(os.path.dirname(__file__), '', 'saved_poses'))
else:
# cf. https://stackoverflow.com/a/53293924
import inspect
posex.set_save_dir(os.path.join(os.path.dirname(inspect.getfile(lambda: None)), '', 'saved_poses'))
def js2py(
name: str,
id: Callable[[str], str],
js: Callable[[str], str],
sink: gr.components.IOComponent,
) -> gr.Textbox:
v_set = gr.Button(elem_id=id(f'{name}_set'),visible=False)
v = gr.Textbox(elem_id=id(name),visible=False)
v_sink = gr.Textbox(visible=False)
v_set.click(fn=None, _js=js(name), outputs=[v, v_sink])
v_sink.change(fn=None, _js=js(f'{name}_after'), outputs=[sink])
return v
def py2js(
name: str,
fn: Callable[[], str],
id: Callable[[str], str],
js: Callable[[str], str],
sink: gr.components.IOComponent,
) -> None:
v_fire = gr.Button(elem_id=id(f'{name}_get'),visible=False)
v_sink = gr.Textbox(visible=False)
v_sink2 = gr.Textbox(visible=False)
v_fire.click(fn=wrap_api(fn), outputs=[v_sink, v_sink2])
v_sink2.change(fn=None, _js=js(name), inputs=[v_sink], outputs=[sink])
def jscall(
name: str,
fn: Callable[[str], str],
id: Callable[[str], str],
js: Callable[[str], str],
sink: gr.components.IOComponent,
) -> None:
v_args_set = gr.Button(elem_id=id(f'{name}_args_set'), visible=False)
v_args = gr.JSON(elem_id=id(f'{name}_args'), visible=False)
v_args_sink = gr.JSON(visible=False)
v_args_set.click(fn=None, _js=js(f'{name}_args'), outputs=[v_args, v_args_sink])
v_args_sink.change(fn=None, _js=js(f'{name}_args_after'), outputs=[sink])
v_fire = gr.Button(elem_id=id(f'{name}_get'),visible=False)
v_sink = gr.Textbox(visible=False)
v_sink2 = gr.Textbox(visible=False)
v_fire.click(fn=wrap_api(fn), inputs=[v_args], outputs=[v_sink, v_sink2])
v_sink2.change(fn=None, _js=js(name), inputs=[v_sink], outputs=[sink])
def generatecall(
name: str,
fn: Callable[[str], str],
id: Callable[[str], str],
js: Callable[[str], str],
sink: gr.components.IOComponent,
prompt,
prompt_n,
output_img,
) -> None:
v_args_set = gr.Button(elem_id=id(f'{name}_args_set'), visible=False)
v_args = gr.JSON(elem_id=id(f'{name}_args'), visible=False)
v_args_sink = gr.JSON(visible=False)
v_args_set.click(fn=None, _js=js(f'{name}_args'), outputs=[v_args, v_args_sink])
v_args_sink.change(fn=None, _js=js(f'{name}_args_after'), outputs=[sink])
v_fire = gr.Button(elem_id=id(f'{name}_get'),visible=False)
v_sink = gr.Textbox(visible=False)
v_sink2 = gr.Textbox(visible=False)
v_fire.click(fn=fn, inputs=[v_args,prompt,prompt_n], outputs=[output_img])
v_sink2.change(fn=None, _js=js(name), inputs=[v_sink], outputs=[sink])
def get_self_extension():
if '__file__' in globals():
filepath = __file__
else:
import inspect
filepath = inspect.getfile(lambda: None)
# APIs
def wrap_api(fn):
_r = 0
def f(*args, **kwargs):
nonlocal _r
_r += 1
v = fn(*args, **kwargs)
return v, str(_r)
return f
def all_pose():
return json.dumps(list(posex.all_poses()))
def delete_pose(args):
posex.delete_pose(json.loads(args)[0])
return ''
def save_pose(args):
posex.save_pose(json.loads(args)[0])
return ''
def load_pose(args):
return json.dumps(posex.load_pose(json.loads(args)[0]))
# def get_imgs(args):
# return posex.get_img(args)
def generate_imgs(data, image_prompt, image_n_prompt):
return posex.generate_img(data, image_prompt, image_n_prompt)
def get_image_sketch(image_prompt, image_n_prompt, image):
return posex.get_image_sketch(image, image_prompt, image_n_prompt)
def javascript_html():
script_js = f'script.js?{os.path.getmtime(os.path.join(project_dir,"script.js"))}'
path7 = f'javascript/posex-webui.js?{os.path.getmtime(os.path.join(project_dir,"javascript/posex-webui.js"))}'
head = f'\n'
head += f'\n'
return head
def css_html():
head = f''
return head
def reload_javascript():
js = javascript_html()
css = css_html()
def template_response(*args, **kwargs):
res = GradioTemplateResponseOriginal(*args, **kwargs)
res.body = res.body.replace(b'', f'{js}'.encode("utf8"))
res.body = res.body.replace(b'