import gradio as gr import io from PIL import Image import base64 import requests import json from PIL import Image def read_content(file_path: str) -> str: """read the content of target file """ with open(file_path, 'r', encoding='utf-8') as f: content = f.read() return content def base2picture(resbase64): res=resbase64.split(',')[1] img_b64decode = base64.b64decode(res) image = io.BytesIO(img_b64decode) img = Image.open(image) return img def filter_content(raw_style: str): if "(" in raw_style: i = raw_style.index("(") else : i = -1 if i == -1: return raw_style else : return raw_style[:i] def request_images(raw_text, class_draw, style_draw, batch_size): if filter_content(class_draw) != "国画": if filter_content(class_draw) != "通用": raw_text = raw_text + f",{filter_content(class_draw)}" for sty in style_draw: raw_text = raw_text + f",{filter_content(sty)}" print(f"raw text is {raw_text}") url = "http://flagart.baai.ac.cn/api/general/" elif filter_content(class_draw) == "国画": if raw_text.endswith("国画"): pass else : raw_text = raw_text + ",国画" url = "http://flagart.baai.ac.cn/api/guohua/" d = {"data":[raw_text, batch_size]} r = requests.post(url, json=d, headers={"Content-Type": "application/json", "Accept": "*/*", "Accept-Encoding": "gzip, deflate, br", "Connection": "keep-alive"}) result_text = r.text content = json.loads(result_text)["data"][0] images = [] for i in range(batch_size): # print(content[i]) images.append(base2picture(content[i])) return images examples = [ '水墨蝴蝶和牡丹花,国画', '苍劲有力的墨竹,国画', '暴风雨中的灯塔', '机械小松鼠,科学幻想', '中国水墨山水画,国画', "Lighthouse in the storm", "A dog", "Landscape by 张大千", "A tiger 长了兔子耳朵", "A baby bird 铅笔素描", ] if __name__ == "__main__": block = gr.Blocks(css=read_content('style.css')) with block: gr.HTML(read_content("header.html")) with gr.Group(): with gr.Box(): with gr.Row().style(mobile_collapse=False, equal_height=True): text = gr.Textbox( label="Prompt", show_label=False, max_lines=1, placeholder="Input text(输入文字)", interactive=True, ).style( border=(True, False, True, True), rounded=(True, False, False, True), container=False, ) btn = gr.Button("Generate image").style( margin=False, rounded=(True, True, True, True), ) with gr.Row().style(mobile_collapse=False, equal_height=True): class_draw = gr.Dropdown(["通用(general)", "国画(traditional Chinese painting)", "照片,摄影(picture photography)", "油画(oil painting)", "铅笔素描(pencil sketch)", "CG", "水彩画(watercolor painting)", "水墨画(ink and wash)", "插画(illustrations)", "3D"], label="生成类型(type)", show_label=True, value="通用(general)") with gr.Row().style(mobile_collapse=False, equal_height=True): style_draw = gr.CheckboxGroup(["蒸汽朋克(steampunk)", "电影摄影风格(film photography)", "概念艺术(concept art)", "Warming lighting", "Dramatic lighting", "Natural lighting", "虚幻引擎(unreal engine)", "4k", "8k", "充满细节(full details)"], label="画面风格(style)", show_label=True, ) with gr.Row().style(mobile_collapse=False, equal_height=True): sample_size = gr.Slider(minimum=1, maximum=4, step=1, label="生成数量(number)", show_label=True, interactive=True, ) gallery = gr.Gallery( label="Generated images", show_label=False, elem_id="gallery" ).style(grid=[2], height="auto") gr.Examples(examples=examples, fn=request_images, inputs=text, outputs=gallery, examples_per_page=100) text.submit(request_images, inputs=[text, class_draw, style_draw, sample_size], outputs=gallery) btn.click(request_images, inputs=[text, class_draw, style_draw, sample_size], outputs=gallery) gr.HTML(read_content("footer.html")) # gr.Image('./contributors.png') block.queue(max_size=50, concurrency_count=20).launch()