File size: 5,700 Bytes
1714b7c
d8a5a4d
 
 
 
 
 
1714b7c
d8a5a4d
 
 
 
 
1714b7c
d8a5a4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import gradio as gr
import io 
from PIL import Image 
import base64
import requests
import json 
from PIL import Image

def read_content(file_path: str) -> str:
    """read the content of target file
    """
    with open(file_path, 'r', encoding='utf-8') as f:
        content = f.read()

    return content

def base2picture(resbase64):
    res=resbase64.split(',')[1]
    img_b64decode = base64.b64decode(res)
    image = io.BytesIO(img_b64decode)
    img = Image.open(image)
    return img 

def filter_content(raw_style: str):
    if "(" in raw_style:
        i = raw_style.index("(")
    else :
        i = -1

    if i == -1:
        return raw_style
    else :
        return raw_style[:i]

def request_images(raw_text, class_draw, style_draw, batch_size):
    if filter_content(class_draw) != "国画":
        if filter_content(class_draw) != "通用":
            raw_text = raw_text + f",{filter_content(class_draw)}"

        for sty in style_draw:
            raw_text = raw_text + f",{filter_content(sty)}"
        print(f"raw text is {raw_text}")
        url = "http://flagart.baai.ac.cn/api/general/"
    elif filter_content(class_draw) == "国画":
        if raw_text.endswith("国画"):
            pass 
        else :
            raw_text = raw_text + ",国画"
        url = "http://flagart.baai.ac.cn/api/guohua/"

    d = {"data":[raw_text, batch_size]}
    r = requests.post(url, json=d, headers={"Content-Type": "application/json", "Accept": "*/*", "Accept-Encoding": "gzip, deflate, br", "Connection": "keep-alive"})
    result_text = r.text
    content = json.loads(result_text)["data"][0]
    images = []
    for i in range(batch_size):
        # print(content[i])
        images.append(base2picture(content[i]))        

    return images 

examples = [
    '水墨蝴蝶和牡丹花,国画',
    '苍劲有力的墨竹,国画',
    '暴风雨中的灯塔',
    '机械小松鼠,科学幻想',
    '中国水墨山水画,国画',
    "Lighthouse in the storm",
    "A dog",
    "Landscape by 张大千",
    "A tiger 长了兔子耳朵",
    "A baby bird 铅笔素描",

]

if __name__ == "__main__":
    block = gr.Blocks(css=read_content('style.css'))

    with block:
        gr.HTML(read_content("header.html"))

        with gr.Group():
            with gr.Box():
                with gr.Row().style(mobile_collapse=False, equal_height=True):
                    text = gr.Textbox(
                        label="Prompt",
                        show_label=False,
                        max_lines=1,
                        placeholder="Input text(输入文字)",
                        interactive=True,
                    ).style(
                        border=(True, False, True, True),
                        rounded=(True, False, False, True),
                        container=False,
                    )
                    
                    btn = gr.Button("Generate image").style(
                        margin=False,
                        rounded=(True, True, True, True),
                    )
                with gr.Row().style(mobile_collapse=False, equal_height=True):
                    class_draw = gr.Dropdown(["通用(general)", "国画(traditional Chinese painting)", 
                                                "照片,摄影(picture photography)", "油画(oil painting)", 
                                                "铅笔素描(pencil sketch)", "CG", 
                                                "水彩画(watercolor painting)", "水墨画(ink and wash)",
                                                "插画(illustrations)", "3D"],
                                            label="生成类型(type)",
                                            show_label=True,
                                            value="通用(general)")
                with gr.Row().style(mobile_collapse=False, equal_height=True):
                    style_draw = gr.CheckboxGroup(["蒸汽朋克(steampunk)", "电影摄影风格(film photography)", 
                                                "概念艺术(concept art)", "Warming lighting", 
                                                "Dramatic lighting", "Natural lighting", 
                                                "虚幻引擎(unreal engine)", "4k", "8k",
                                                "充满细节(full details)"],
                                            label="画面风格(style)",
                                            show_label=True,
                                            )
                with gr.Row().style(mobile_collapse=False, equal_height=True):
                     sample_size = gr.Slider(minimum=1,
                                            maximum=4,
                                            step=1,
                                            label="生成数量(number)",
                                            show_label=True,
                                            interactive=True,
                                            )

            gallery = gr.Gallery(
                label="Generated images", show_label=False, elem_id="gallery"
            ).style(grid=[2], height="auto")

            gr.Examples(examples=examples, fn=request_images, inputs=text, outputs=gallery, examples_per_page=100)
            text.submit(request_images, inputs=[text, class_draw, style_draw, sample_size], outputs=gallery)
            btn.click(request_images, inputs=[text, class_draw, style_draw, sample_size], outputs=gallery)

        gr.HTML(read_content("footer.html"))
        # gr.Image('./contributors.png')

    block.queue(max_size=50, concurrency_count=20).launch()