zhuguangbin commited on
Commit
9c7472b
1 Parent(s): 94c4287
Files changed (3) hide show
  1. .gitignore +1 -0
  2. app.py +77 -0
  3. requirements.txt +3 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__/
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import requests
3
+ import time
4
+ from PIL import Image
5
+ import os, io, json
6
+ import base64
7
+
8
+ sd_api_base = os.environ["SD_API_BASE"]
9
+ sd_api_key = os.environ["SD_API_KEY"]
10
+
11
+ # 发送POST请求的函数
12
+ def send_post_request(input_json_string):
13
+
14
+ try:
15
+ # 尝试将输入的字符串转换为JSON对象
16
+ data = json.loads(input_json_string)
17
+ except json.JSONDecodeError as e:
18
+ return f"输入的字符串不是有效的JSON格式: {e}"
19
+
20
+ url = f"{sd_api_base}/txt2img/run/"
21
+ headers = {
22
+ 'Content-Type': 'application/json',
23
+ 'Authorization': f'Bearer {sd_api_key}',
24
+ }
25
+ response = requests.post(url, headers=headers, json=data)
26
+ if response.status_code == 200:
27
+ return response.json()
28
+ else:
29
+ raise Exception(f"Error in POST request: {response.text}")
30
+
31
+ # 轮询GET请求,直到异步操作完成
32
+ def poll_status(id):
33
+ url = f"{sd_api_base}/txt2img/status/{id}"
34
+ headers = {
35
+ 'Content-Type': 'application/json',
36
+ 'Authorization': f'Bearer {sd_api_key}',
37
+ }
38
+ while True:
39
+ response = requests.get(url, headers=headers)
40
+ if response.status_code == 200:
41
+ result = response.json()
42
+ if result['status'] == 'COMPLETED':
43
+ return result
44
+ else:
45
+ time.sleep(1) # 等待1秒后再次尝试
46
+ else:
47
+ raise Exception(f"Error in GET request: {response.text}")
48
+
49
+ # 将Base64编码的图片数据转换为可显示的图片
50
+ def display_images(output_json):
51
+ images_data = output_json['output']['images']
52
+ images = []
53
+ for base64_data in images_data:
54
+ image_data = base64.b64decode(base64_data)
55
+ image = Image.open(io.BytesIO(image_data))
56
+ images.append(image)
57
+ return images
58
+
59
+ # Gradio界面的函数
60
+ def gradio_interface(input_json):
61
+ post_response = send_post_request(input_json)
62
+ print(post_response)
63
+ status_response = poll_status(post_response['id'])
64
+ images = display_images(status_response)
65
+
66
+ return images
67
+
68
+ # 设置Gradio界面
69
+ iface = gr.Interface(
70
+ fn=gradio_interface,
71
+ inputs=gr.Textbox(lines=2, placeholder="Type something here..."),
72
+ outputs="gallery"
73
+ # examples=[{"prompt": "a dog"}]
74
+ )
75
+
76
+ # 启动Gradio应用程序
77
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ gradio==4.3.0
2
+ requests==2.31.0
3
+ Pillow==10.0.1