root commited on
Commit
d8a5a4d
1 Parent(s): 8141c9f

modified app.py

Browse files
Files changed (1) hide show
  1. app.py +136 -4
app.py CHANGED
@@ -1,7 +1,139 @@
1
  import gradio as gr
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import io
3
+ from PIL import Image
4
+ import base64
5
+ import requests
6
+ import json
7
+ from PIL import Image
8
 
9
+ def read_content(file_path: str) -> str:
10
+ """read the content of target file
11
+ """
12
+ with open(file_path, 'r', encoding='utf-8') as f:
13
+ content = f.read()
14
 
15
+ return content
16
+
17
+ def base2picture(resbase64):
18
+ res=resbase64.split(',')[1]
19
+ img_b64decode = base64.b64decode(res)
20
+ image = io.BytesIO(img_b64decode)
21
+ img = Image.open(image)
22
+ return img
23
+
24
+ def filter_content(raw_style: str):
25
+ if "(" in raw_style:
26
+ i = raw_style.index("(")
27
+ else :
28
+ i = -1
29
+
30
+ if i == -1:
31
+ return raw_style
32
+ else :
33
+ return raw_style[:i]
34
+
35
+ def request_images(raw_text, class_draw, style_draw, batch_size):
36
+ if filter_content(class_draw) != "国画":
37
+ if filter_content(class_draw) != "通用":
38
+ raw_text = raw_text + f",{filter_content(class_draw)}"
39
+
40
+ for sty in style_draw:
41
+ raw_text = raw_text + f",{filter_content(sty)}"
42
+ print(f"raw text is {raw_text}")
43
+ url = "http://flagart.baai.ac.cn/api/general/"
44
+ elif filter_content(class_draw) == "国画":
45
+ if raw_text.endswith("国画"):
46
+ pass
47
+ else :
48
+ raw_text = raw_text + ",国画"
49
+ url = "http://flagart.baai.ac.cn/api/guohua/"
50
+
51
+ d = {"data":[raw_text, batch_size]}
52
+ r = requests.post(url, json=d, headers={"Content-Type": "application/json", "Accept": "*/*", "Accept-Encoding": "gzip, deflate, br", "Connection": "keep-alive"})
53
+ result_text = r.text
54
+ content = json.loads(result_text)["data"][0]
55
+ images = []
56
+ for i in range(batch_size):
57
+ # print(content[i])
58
+ images.append(base2picture(content[i]))
59
+
60
+ return images
61
+
62
+ examples = [
63
+ '水墨蝴蝶和牡丹花,国画',
64
+ '苍劲有力的墨竹,国画',
65
+ '暴风雨中的灯塔',
66
+ '机械小松鼠,科学幻想',
67
+ '中国水墨山水画,国画',
68
+ "Lighthouse in the storm",
69
+ "A dog",
70
+ "Landscape by 张大千",
71
+ "A tiger 长了兔子耳朵",
72
+ "A baby bird 铅笔素描",
73
+
74
+ ]
75
+
76
+ if __name__ == "__main__":
77
+ block = gr.Blocks(css=read_content('style.css'))
78
+
79
+ with block:
80
+ gr.HTML(read_content("header.html"))
81
+
82
+ with gr.Group():
83
+ with gr.Box():
84
+ with gr.Row().style(mobile_collapse=False, equal_height=True):
85
+ text = gr.Textbox(
86
+ label="Prompt",
87
+ show_label=False,
88
+ max_lines=1,
89
+ placeholder="Input text(输入文字)",
90
+ interactive=True,
91
+ ).style(
92
+ border=(True, False, True, True),
93
+ rounded=(True, False, False, True),
94
+ container=False,
95
+ )
96
+
97
+ btn = gr.Button("Generate image").style(
98
+ margin=False,
99
+ rounded=(True, True, True, True),
100
+ )
101
+ with gr.Row().style(mobile_collapse=False, equal_height=True):
102
+ class_draw = gr.Dropdown(["通用(general)", "国画(traditional Chinese painting)",
103
+ "照片,摄影(picture photography)", "油画(oil painting)",
104
+ "铅笔素描(pencil sketch)", "CG",
105
+ "水彩画(watercolor painting)", "水墨画(ink and wash)",
106
+ "插画(illustrations)", "3D"],
107
+ label="生成类型(type)",
108
+ show_label=True,
109
+ value="通用(general)")
110
+ with gr.Row().style(mobile_collapse=False, equal_height=True):
111
+ style_draw = gr.CheckboxGroup(["蒸汽朋克(steampunk)", "电影摄影风格(film photography)",
112
+ "概念艺术(concept art)", "Warming lighting",
113
+ "Dramatic lighting", "Natural lighting",
114
+ "虚幻引擎(unreal engine)", "4k", "8k",
115
+ "充满细节(full details)"],
116
+ label="画面风格(style)",
117
+ show_label=True,
118
+ )
119
+ with gr.Row().style(mobile_collapse=False, equal_height=True):
120
+ sample_size = gr.Slider(minimum=1,
121
+ maximum=4,
122
+ step=1,
123
+ label="生成数量(number)",
124
+ show_label=True,
125
+ interactive=True,
126
+ )
127
+
128
+ gallery = gr.Gallery(
129
+ label="Generated images", show_label=False, elem_id="gallery"
130
+ ).style(grid=[2], height="auto")
131
+
132
+ gr.Examples(examples=examples, fn=request_images, inputs=text, outputs=gallery, examples_per_page=100)
133
+ text.submit(request_images, inputs=[text, class_draw, style_draw, sample_size], outputs=gallery)
134
+ btn.click(request_images, inputs=[text, class_draw, style_draw, sample_size], outputs=gallery)
135
+
136
+ gr.HTML(read_content("footer.html"))
137
+ # gr.Image('./contributors.png')
138
+
139
+ block.queue(max_size=50, concurrency_count=20).launch()