Spaces:
Runtime error
Runtime error
import imp | |
import gradio as gr | |
import io | |
from PIL import Image, PngImagePlugin | |
import base64 | |
import requests | |
import json | |
import ui_functions as uifn | |
from css_and_js import js, call_JS | |
txt2img_defaults = { | |
'prompt': '', | |
'ddim_steps': 50, | |
'toggles': [1, 2, 3], | |
'sampler_name': 'k_lms', | |
'ddim_eta': 0.0, | |
'n_iter': 1, | |
'batch_size': 1, | |
'cfg_scale': 7.5, | |
'seed': '', | |
'height': 512, | |
'width': 512, | |
'fp': None, | |
'variant_amount': 0.0, | |
'variant_seed': '', | |
'submit_on_enter': 'Yes', | |
} | |
img2img_defaults = { | |
'prompt': '', | |
'ddim_steps': 50, | |
'toggles': [1, 4, 5], | |
'sampler_name': 'k_lms', | |
'ddim_eta': 0.0, | |
'n_iter': 1, | |
'batch_size': 1, | |
'cfg_scale': 5.0, | |
'denoising_strength': 0.75, | |
'mask_mode': 1, | |
'resize_mode': 0, | |
'seed': '', | |
'height': 512, | |
'width': 512, | |
'fp': None, | |
} | |
sample_img2img = None | |
job_manager = None | |
RealESRGAN = True | |
show_embeddings = False | |
img2img_resize_modes = [ | |
"Just resize", | |
"Crop and resize", | |
"Resize and fill", | |
] | |
img2img_toggles = [ | |
'Create prompt matrix (separate multiple prompts using |, and get all combinations of them)', | |
'Normalize Prompt Weights (ensure sum of weights add up to 1.0)', | |
'Loopback (use images from previous batch when creating next batch)', | |
'Random loopback seed', | |
'Save individual images', | |
'Save grid', | |
'Sort samples by prompt', | |
'Write sample info files', | |
'Write sample info to one file', | |
'jpg samples', | |
] | |
img2img_toggle_defaults = [img2img_toggles[i] for i in img2img_defaults['toggles']] | |
def read_content(file_path: str) -> str: | |
"""read the content of target file | |
""" | |
with open(file_path, 'r', encoding='utf-8') as f: | |
content = f.read() | |
return content | |
def base2picture(resbase64): | |
res=resbase64.split(',')[1] | |
img_b64decode = base64.b64decode(res) | |
image = io.BytesIO(img_b64decode) | |
img = Image.open(image) | |
return img | |
def filter_content(raw_style: str): | |
if "(" in raw_style: | |
i = raw_style.index("(") | |
else : | |
i = -1 | |
if i == -1: | |
return raw_style | |
else : | |
return raw_style[:i] | |
def request_images(raw_text, class_draw, style_draw, batch_size, sr_option): | |
if filter_content(class_draw) != "国画": | |
if filter_content(class_draw) != "通用": | |
raw_text = raw_text + f",{filter_content(class_draw)}" | |
for sty in style_draw: | |
raw_text = raw_text + f",{filter_content(sty)}" | |
print(f"raw text is {raw_text}") | |
url = "http://flagart.baai.ac.cn/api/general/" | |
elif filter_content(class_draw) == "国画": | |
if raw_text.endswith("国画"): | |
pass | |
else : | |
raw_text = raw_text + ",国画" | |
url = "http://flagart.baai.ac.cn/api/guohua/" | |
d = {"data":[raw_text, batch_size, sr_option]} | |
r = requests.post(url, json=d, headers={"Content-Type": "application/json", "Accept": "*/*", "Accept-Encoding": "gzip, deflate, br", "Connection": "keep-alive"}) | |
result_text = r.text | |
content = json.loads(result_text)["data"][0] | |
images = [] | |
for i in range(batch_size): | |
# print(content[i]) | |
images.append(base2picture(content[i])) | |
return images | |
def encode_pil_to_base64(pil_image): | |
with io.BytesIO() as output_bytes: | |
# Copy any text-only metadata | |
use_metadata = False | |
metadata = PngImagePlugin.PngInfo() | |
for key, value in pil_image.info.items(): | |
if isinstance(key, str) and isinstance(value, str): | |
metadata.add_text(key, value) | |
use_metadata = True | |
pil_image.save( | |
output_bytes, "PNG", pnginfo=(metadata if use_metadata else None) | |
) | |
bytes_data = output_bytes.getvalue() | |
base64_str = str(base64.b64encode(bytes_data), "utf-8") | |
return "data:image/png;base64," + base64_str | |
def img2img(*args): | |
# 处理image | |
for i, item in enumerate(args): | |
# print(type(item)) | |
if type(item) == dict: | |
args[i]['image'] = encode_pil_to_base64(item['image']) | |
args[i]['mask'] = encode_pil_to_base64(item['mask']) | |
# else: | |
# print(i,type(item)) | |
# print(item) | |
batch_size = args[8] | |
url = "http://flagart.baai.ac.cn/api/img2img/" | |
d = {"data":args} | |
r = requests.post(url, json=d, headers={"Content-Type": "application/json", "Accept": "*/*", "Accept-Encoding": "gzip, deflate, br", "Connection": "keep-alive"}) | |
# print(r) | |
result_text = r.text | |
content = json.loads(result_text)["data"][0] | |
images = [] | |
for i in range(batch_size): | |
# print(content[i]) | |
images.append(base2picture(content[i])) | |
# content = json.loads(result_text) | |
# print(result_text) | |
# print("服务器已经把东西返回来啦!!!!!!!乌拉乌拉!!!!!") | |
return images | |
examples = [ | |
'水墨蝴蝶和牡丹花,国画', | |
'苍劲有力的墨竹,国画', | |
'暴风雨中的灯塔', | |
'机械小松鼠,科学幻想', | |
'中国水墨山水画,国画', | |
"Lighthouse in the storm", | |
"A dog", | |
"Landscape by 张大千", | |
"A tiger 长了兔子耳朵", | |
"A baby bird 铅笔素描", | |
] | |
if __name__ == "__main__": | |
block = gr.Blocks(css=read_content('style.css')) | |
with block: | |
gr.HTML(read_content("header.html")) | |
with gr.Tabs(elem_id='tabss') as tabs: | |
with gr.TabItem("文生图(Text-to-img)", id='txt2img_tab'): | |
with gr.Group(): | |
with gr.Box(): | |
with gr.Row().style(mobile_collapse=False, equal_height=True): | |
text = gr.Textbox( | |
label="Prompt", | |
show_label=False, | |
max_lines=1, | |
placeholder="Input text(输入文字)", | |
interactive=True, | |
).style( | |
border=(True, False, True, True), | |
rounded=(True, False, False, True), | |
container=False, | |
) | |
btn = gr.Button("Generate image").style( | |
margin=False, | |
rounded=(True, True, True, True), | |
) | |
with gr.Row().style(mobile_collapse=False, equal_height=True): | |
class_draw = gr.Dropdown(["通用(general)", "国画(traditional Chinese painting)", | |
"照片,摄影(picture photography)", "油画(oil painting)", | |
"铅笔素描(pencil sketch)", "CG", | |
"水彩画(watercolor painting)", "水墨画(ink and wash)", | |
"插画(illustrations)", "3D", "图生图(img2img)"], | |
label="生成类型(type)", | |
show_label=True, | |
value="通用(general)") | |
with gr.Row().style(mobile_collapse=False, equal_height=True): | |
style_draw = gr.CheckboxGroup(["蒸汽朋克(steampunk)", "电影摄影风格(film photography)", | |
"概念艺术(concept art)", "Warming lighting", | |
"Dramatic lighting", "Natural lighting", | |
"虚幻引擎(unreal engine)", "4k", "8k", | |
"充满细节(full details)"], | |
label="画面风格(style)", | |
show_label=True, | |
) | |
with gr.Row().style(mobile_collapse=False, equal_height=True): | |
sample_size = gr.Slider(minimum=1, | |
maximum=4, | |
step=1, | |
label="生成数量(number)", | |
show_label=True, | |
interactive=True, | |
) | |
with gr.Row().style(mobile_collapse=False, equal_height=True): | |
sr_option = gr.Checkbox(value=False, label="是否使用超分(Whether to use super-resolution)") | |
gallery = gr.Gallery( | |
label="Generated images", show_label=False, elem_id="gallery" | |
).style(grid=[2], height="auto") | |
output_txt2img_copy_to_input_btn = gr.Button("Go to img2img") | |
gr.Examples(examples=examples, fn=request_images, inputs=text, outputs=gallery, examples_per_page=100) | |
text.submit(request_images, inputs=[text, class_draw, style_draw, sample_size, sr_option], outputs=gallery) | |
btn.click(request_images, inputs=[text, class_draw, style_draw, sample_size, sr_option], outputs=gallery) | |
with gr.TabItem("图生图(Img-to-Img)", id="img2img_tab"): | |
with gr.Row(elem_id="prompt_row"): | |
img2img_prompt = gr.Textbox(label="Prompt", | |
elem_id='img2img_prompt_input', | |
placeholder="神奇的森林,流淌的河流.", | |
lines=1, | |
max_lines=1 if txt2img_defaults['submit_on_enter'] == 'Yes' else 25, | |
value=img2img_defaults['prompt'], | |
show_label=False).style() | |
img2img_btn_mask = gr.Button("Generate", variant="primary", visible=False, | |
elem_id="img2img_mask_btn") | |
img2img_btn_editor = gr.Button("Generate", variant="primary", elem_id="img2img_edit_btn") | |
with gr.Row().style(equal_height=False): | |
with gr.Column(): | |
gr.Markdown('#### 输入图像') | |
img2img_image_mask = gr.Image( | |
value=sample_img2img, | |
source="upload", | |
interactive=True, | |
tool="sketch", | |
type='pil', | |
elem_id="img2img_mask", | |
image_mode="RGBA" | |
) | |
img2img_image_editor = gr.Image( | |
value=sample_img2img, | |
source="upload", | |
interactive=True, | |
tool="select", | |
type='pil', | |
visible=False, | |
image_mode="RGBA", | |
elem_id="img2img_editor" | |
) | |
with gr.Tabs(): | |
with gr.TabItem("编辑设置"): | |
with gr.Row(): | |
# disable Uncrop for now | |
choices=["Mask", "Crop", "Uncrop"] | |
img2img_image_editor_mode = gr.Radio(choices=["Mask"], | |
label="编辑模式", | |
value="Mask", elem_id='edit_mode_select', | |
visible=True) | |
img2img_mask = gr.Radio(choices=["保留mask区域", "生成mask区域"], | |
label="Mask 方式", | |
#value=img2img_mask_modes[img2img_defaults['mask_mode']], | |
value = "生成mask区域", | |
visible=True) | |
img2img_mask_blur_strength = gr.Slider(minimum=1, maximum=10, step=1, | |
label="How much blurry should the mask be? (to avoid hard edges)", | |
value=3, visible=False) | |
img2img_resize = gr.Radio(label="Resize mode", | |
choices=["Just resize", "Crop and resize", | |
"Resize and fill"], | |
value=img2img_resize_modes[ | |
img2img_defaults['resize_mode']], visible=False) | |
img2img_painterro_btn = gr.Button("Advanced Editor",visible=False) | |
# with gr.TabItem("Hints",visible=False): | |
# img2img_help = gr.Markdown(visible=False, value=uifn.help_text) | |
with gr.Column(): | |
gr.Markdown('#### 编辑后的图片') | |
output_img2img_gallery = gr.Gallery(label="Images", elem_id="img2img_gallery_output").style( | |
grid=[4, 4, 4]) | |
img2img_job_ui = job_manager.draw_gradio_ui() if job_manager else None | |
with gr.Column(visible=False): | |
with gr.Tabs(visible=False): | |
with gr.TabItem("", id="img2img_actions_tab",visible=False): | |
gr.Markdown("Select an image, then press one of the buttons below") | |
with gr.Row(): | |
output_img2img_copy_to_clipboard_btn = gr.Button("Copy to clipboard") | |
output_img2img_copy_to_input_btn = gr.Button("Push to img2img input") | |
output_img2img_copy_to_mask_btn = gr.Button("Push to img2img input mask") | |
gr.Markdown("Warning: This will clear your current image and mask settings!") | |
with gr.TabItem("", id="img2img_output_info_tab",visible=False): | |
output_img2img_params = gr.Textbox(label="Generation parameters") | |
with gr.Row(): | |
output_img2img_copy_params = gr.Button("Copy full parameters").click( | |
inputs=output_img2img_params, outputs=[], | |
_js='(x) => {navigator.clipboard.writeText(x.replace(": ",":"))}', fn=None, | |
show_progress=False) | |
output_img2img_seed = gr.Number(label='Seed', interactive=False, visible=False) | |
output_img2img_copy_seed = gr.Button("Copy only seed").click( | |
inputs=output_img2img_seed, outputs=[], | |
_js=call_JS("gradioInputToClipboard"), fn=None, show_progress=False) | |
output_img2img_stats = gr.HTML(label='Stats') | |
gr.Markdown('# 编辑设置',visible=False) | |
with gr.Row(visible=False): | |
with gr.Column(): | |
img2img_width = gr.Slider(minimum=64, maximum=2048, step=64, label="图片宽度", | |
value=img2img_defaults["width"]) | |
img2img_height = gr.Slider(minimum=64, maximum=2048, step=64, label="图片高度", | |
value=img2img_defaults["height"]) | |
img2img_cfg = gr.Slider(minimum=-40.0, maximum=30.0, step=0.5, | |
label='文本引导强度', | |
value=img2img_defaults['cfg_scale'], elem_id='cfg_slider') | |
img2img_seed = gr.Textbox(label="随机种子", lines=1, max_lines=1, | |
value=img2img_defaults["seed"]) | |
img2img_batch_count = gr.Slider(minimum=1, maximum=50, step=1, | |
label='生成数量', | |
value=img2img_defaults['n_iter']) | |
img2img_dimensions_info_text_box = gr.Textbox( | |
label="长宽比设置") | |
with gr.Column(): | |
img2img_steps = gr.Slider(minimum=1, maximum=250, step=1, label="采样步数", | |
value=img2img_defaults['ddim_steps']) | |
img2img_sampling = gr.Dropdown(label='采样方式', | |
choices=["DDIM", 'k_dpm_2_a', 'k_dpm_2', 'k_euler_a', 'k_euler', | |
'k_heun', 'k_lms'], | |
value=img2img_defaults['sampler_name']) | |
img2img_denoising = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising Strength', | |
value=img2img_defaults['denoising_strength'],visible=False) | |
img2img_toggles = gr.CheckboxGroup(label='', choices=img2img_toggles, | |
value=img2img_toggle_defaults,visible=False) | |
img2img_realesrgan_model_name = gr.Dropdown(label='RealESRGAN model', | |
choices=['RealESRGAN_x4plus', | |
'RealESRGAN_x4plus_anime_6B'], | |
value='RealESRGAN_x4plus', | |
visible=RealESRGAN is not None) # TODO: Feels like I shouldnt slot it in here. | |
img2img_embeddings = gr.File(label="Embeddings file for textual inversion", | |
visible=show_embeddings) | |
img2img_image_editor_mode.change( | |
uifn.change_image_editor_mode, | |
[img2img_image_editor_mode, | |
img2img_image_editor, | |
img2img_image_mask, | |
img2img_resize, | |
img2img_width, | |
img2img_height | |
], | |
[img2img_image_editor, img2img_image_mask, img2img_btn_editor, img2img_btn_mask, | |
img2img_painterro_btn, img2img_mask, img2img_mask_blur_strength] | |
) | |
# 这个函数之前注释掉了,但是没有出现bug,看一下这个的作用; | |
img2img_image_editor_mode.change( | |
uifn.update_image_mask, | |
[img2img_image_editor, img2img_resize, img2img_width, img2img_height], | |
img2img_image_mask | |
) | |
# 把上面这个注释掉下面就不管用了,很神奇无法理解... | |
# output_txt2img_copy_to_input_btn.click( | |
# uifn.copy_img_to_input, | |
# [gallery], | |
# [tabs, img2img_image_editor, img2img_image_mask], | |
# _js=call_JS("moveImageFromGallery", | |
# fromId="txt2img_gallery_output", | |
# toId="img2img_mask") | |
# ) | |
output_txt2img_copy_to_input_btn.click( | |
uifn.copy_img_to_input, | |
[gallery], | |
[tabs, img2img_image_editor, img2img_image_mask], | |
) | |
# 下面这几个函数现在都没什么用 | |
output_img2img_copy_to_input_btn.click( | |
uifn.copy_img_to_edit, | |
[output_img2img_gallery], | |
[img2img_image_editor, tabs, img2img_image_editor_mode], | |
_js=call_JS("moveImageFromGallery", | |
fromId="gallery", | |
toId="img2img_editor") | |
) | |
output_img2img_copy_to_mask_btn.click( | |
uifn.copy_img_to_mask, | |
[output_img2img_gallery], | |
[img2img_image_mask, tabs, img2img_image_editor_mode], | |
_js=call_JS("moveImageFromGallery", | |
fromId="img2img_gallery_output", | |
toId="img2img_editor") | |
) | |
output_img2img_copy_to_clipboard_btn.click(fn=None, inputs=output_img2img_gallery, outputs=[], | |
_js=call_JS("copyImageFromGalleryToClipboard", | |
fromId="img2img_gallery_output") | |
) | |
img2img_func = img2img | |
img2img_inputs = [img2img_prompt, img2img_image_editor_mode, img2img_mask, | |
img2img_mask_blur_strength, img2img_steps, img2img_sampling, img2img_toggles, | |
img2img_realesrgan_model_name, img2img_batch_count, img2img_cfg, | |
img2img_denoising, img2img_seed, img2img_height, img2img_width, img2img_resize, | |
img2img_image_mask] | |
# img2img_outputs = [output_img2img_gallery, output_img2img_seed, output_img2img_params, | |
# output_img2img_stats] | |
img2img_outputs = [output_img2img_gallery] | |
# If a JobManager was passed in then wrap the Generate functions | |
if img2img_job_ui: | |
img2img_func, img2img_inputs, img2img_outputs = img2img_job_ui.wrap_func( | |
func=img2img_func, | |
inputs=img2img_inputs, | |
outputs=img2img_outputs, | |
) | |
img2img_btn_mask.click( | |
img2img_func, | |
img2img_inputs, | |
img2img_outputs | |
) | |
def img2img_submit_params(): | |
# print([img2img_prompt, img2img_image_editor_mode, img2img_mask, | |
# img2img_mask_blur_strength, img2img_steps, img2img_sampling, img2img_toggles, | |
# img2img_realesrgan_model_name, img2img_batch_count, img2img_cfg, | |
# img2img_denoising, img2img_seed, img2img_height, img2img_width, img2img_resize, | |
# img2img_image_editor, img2img_image_mask, img2img_embeddings]) | |
return (img2img_func, | |
img2img_inputs, | |
img2img_outputs) | |
img2img_btn_editor.click(*img2img_submit_params()) | |
# GENERATE ON ENTER | |
img2img_prompt.submit(None, None, None, | |
_js=call_JS("clickFirstVisibleButton", | |
rowId="prompt_row")) | |
img2img_painterro_btn.click(None, | |
[img2img_image_editor, img2img_image_mask, img2img_image_editor_mode], | |
[img2img_image_editor, img2img_image_mask], | |
_js=call_JS("Painterro.init", toId="img2img_editor") | |
) | |
img2img_width.change(fn=uifn.update_dimensions_info, inputs=[img2img_width, img2img_height], | |
outputs=img2img_dimensions_info_text_box) | |
img2img_height.change(fn=uifn.update_dimensions_info, inputs=[img2img_width, img2img_height], | |
outputs=img2img_dimensions_info_text_box) | |
gr.HTML(read_content("footer.html")) | |
# gr.Image('./contributors.png') | |
block.queue(max_size=50, concurrency_count=20).launch() |