Spaces:
Runtime error
Runtime error
tastelikefeet
commited on
Commit
·
4112985
1
Parent(s):
e834a35
up4
Browse files- app.py +10 -10
- requirements.txt +1 -0
- src/__init__.py +0 -0
- src/background_generation.py +76 -0
- src/log.py +18 -0
- src/person_detect.py +39 -0
- src/util.py +177 -0
- src/virtualmodel.py +81 -0
app.py
CHANGED
@@ -116,36 +116,36 @@ with block:
|
|
116 |
<a href='https://github.com/AIGCDesignGroup/ReplaceAnything'><img src='https://img.shields.io/badge/Github-Repo-blue'></a>
|
117 |
</div>
|
118 |
</br>
|
119 |
-
<h3>
|
120 |
-
<h5 style="margin: 0; color: red"
|
121 |
</br>
|
122 |
</div>
|
123 |
""")
|
124 |
|
125 |
with gr.Tabs(elem_classes=["Tab"]):
|
126 |
-
with gr.TabItem("作品广场"):
|
127 |
gr.Gallery(value=showcases,
|
128 |
height=800,
|
129 |
columns=4,
|
130 |
object_fit="scale-down"
|
131 |
)
|
132 |
-
with gr.TabItem("创作图像"):
|
133 |
-
with gr.Accordion(label="🧭
|
134 |
with gr.Row(equal_height=True):
|
135 |
with gr.Row(elem_id="ShowCase"):
|
136 |
gr.Image(value="showcase/ra.gif")
|
137 |
gr.Markdown("""
|
138 |
-
- ⭐️ <b>step1:</b>在“输入图像”中上传or选择Example里面的一张图片
|
139 |
-
- ⭐️ <b>step2:</b>通过点击鼠标选择图像中希望保留的物体
|
140 |
-
- ⭐️ <b>step3:</b>输入对应的参数,例如prompt等,点击Run进行生成
|
141 |
-
- ⭐️ <b>step4 (可选):</b>此外支持换背景操作,上传目标风格背景,执行完step3后点击Run进行生成
|
142 |
""")
|
143 |
with gr.Row():
|
144 |
with gr.Column():
|
145 |
with gr.Column(elem_id="Input"):
|
146 |
with gr.Row():
|
147 |
with gr.Tabs(elem_classes=["feedback"]):
|
148 |
-
with gr.TabItem("输入图像"):
|
149 |
input_image = gr.Image(type="numpy", label="输入图",scale=2)
|
150 |
original_image = gr.State(value=None,label="索引")
|
151 |
original_mask = gr.State(value=None)
|
|
|
116 |
<a href='https://github.com/AIGCDesignGroup/ReplaceAnything'><img src='https://img.shields.io/badge/Github-Repo-blue'></a>
|
117 |
</div>
|
118 |
</br>
|
119 |
+
<h3>OffendingAIGC techniques have attracted lots of attention recently. They have demonstrated strong capabilities in the areas of image editing, image generation and so on. We find that generating new contents while strictly keeping the identity of use-specified object unchanged is of great demand, yet challenging. To this end, we propose ReplaceAnything framework. It can be used in many scenes, such as human replacement, clothing replacement, background replacement, and so on.</h3>
|
120 |
+
<h5 style="margin: 0; color: red">If you found the project helpful, you can click a Star on Github to get the latest updates on the project.</h5>
|
121 |
</br>
|
122 |
</div>
|
123 |
""")
|
124 |
|
125 |
with gr.Tabs(elem_classes=["Tab"]):
|
126 |
+
with gr.TabItem("作品广场(Image Gallery)"):
|
127 |
gr.Gallery(value=showcases,
|
128 |
height=800,
|
129 |
columns=4,
|
130 |
object_fit="scale-down"
|
131 |
)
|
132 |
+
with gr.TabItem("创作图像(Image Create)"):
|
133 |
+
with gr.Accordion(label="🧭 操作指南(Instructions):", open=True, elem_id="accordion"):
|
134 |
with gr.Row(equal_height=True):
|
135 |
with gr.Row(elem_id="ShowCase"):
|
136 |
gr.Image(value="showcase/ra.gif")
|
137 |
gr.Markdown("""
|
138 |
+
- ⭐️ <b>step1:</b>在“输入图像”中上传or选择Example里面的一张图片(Upload or select one image from Example)
|
139 |
+
- ⭐️ <b>step2:</b>通过点击鼠标选择图像中希望保留的物体(Click to select the object)
|
140 |
+
- ⭐️ <b>step3:</b>输入对应的参数,例如prompt等,点击Run进行生成(Input prompt or reference image)
|
141 |
+
- ⭐️ <b>step4 (可选):</b>此外支持换背景操作,上传目标风格背景,执行完step3后点击Run进行生成(Click Run button)
|
142 |
""")
|
143 |
with gr.Row():
|
144 |
with gr.Column():
|
145 |
with gr.Column(elem_id="Input"):
|
146 |
with gr.Row():
|
147 |
with gr.Tabs(elem_classes=["feedback"]):
|
148 |
+
with gr.TabItem("输入图像(Input Image)"):
|
149 |
input_image = gr.Image(type="numpy", label="输入图",scale=2)
|
150 |
original_image = gr.State(value=None,label="索引")
|
151 |
original_mask = gr.State(value=None)
|
requirements.txt
CHANGED
@@ -9,3 +9,4 @@ easydict
|
|
9 |
scikit-image
|
10 |
git+https://github.com/facebookresearch/segment-anything.git
|
11 |
torch
|
|
|
|
9 |
scikit-image
|
10 |
git+https://github.com/facebookresearch/segment-anything.git
|
11 |
torch
|
12 |
+
oss2==2.17.0
|
src/__init__.py
ADDED
File without changes
|
src/background_generation.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy
|
3 |
+
from PIL import Image
|
4 |
+
import requests
|
5 |
+
import urllib.request
|
6 |
+
from http import HTTPStatus
|
7 |
+
from datetime import datetime
|
8 |
+
import json
|
9 |
+
from .log import logger
|
10 |
+
import time
|
11 |
+
import gradio as gr
|
12 |
+
from .util import download_images
|
13 |
+
|
14 |
+
def call_bg_genration(base_image, ref_img, prompt,ref_prompt_weight=0.5):
|
15 |
+
API_KEY = os.getenv("API_KEY_BG_GENERATION")
|
16 |
+
BATCH_SIZE=4
|
17 |
+
headers = {
|
18 |
+
"Content-Type": "application/json",
|
19 |
+
"Accept": "application/json",
|
20 |
+
"Authorization": f"Bearer {API_KEY}",
|
21 |
+
"X-DashScope-Async": "enable",
|
22 |
+
}
|
23 |
+
data = {
|
24 |
+
"model": "wanx-background-generation-v2",
|
25 |
+
"input":{
|
26 |
+
"base_image_url": base_image,
|
27 |
+
'ref_image_url':ref_img,
|
28 |
+
"ref_prompt": prompt,
|
29 |
+
},
|
30 |
+
"parameters": {
|
31 |
+
"ref_prompt_weight": ref_prompt_weight,
|
32 |
+
"n": BATCH_SIZE
|
33 |
+
}
|
34 |
+
}
|
35 |
+
url_create_task = 'https://dashscope.aliyuncs.com/api/v1/services/aigc/background-generation/generation'
|
36 |
+
res_ = requests.post(url_create_task, data=json.dumps(data), headers=headers)
|
37 |
+
|
38 |
+
respose_code = res_.status_code
|
39 |
+
if 200 == respose_code:
|
40 |
+
res = json.loads(res_.content.decode())
|
41 |
+
request_id = res['request_id']
|
42 |
+
task_id = res['output']['task_id']
|
43 |
+
logger.info(f"task_id: {task_id}: Create Background Generation request success. Params: {data}")
|
44 |
+
|
45 |
+
# 异步查询
|
46 |
+
is_running = True
|
47 |
+
while is_running:
|
48 |
+
url_query = f'https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}'
|
49 |
+
res_ = requests.post(url_query, headers=headers)
|
50 |
+
respose_code = res_.status_code
|
51 |
+
if 200 == respose_code:
|
52 |
+
res = json.loads(res_.content.decode())
|
53 |
+
if "SUCCEEDED" == res['output']['task_status']:
|
54 |
+
logger.info(f"task_id: {task_id}: Background generation task query success.")
|
55 |
+
results = res['output']['results']
|
56 |
+
img_urls = [x['url'] for x in results]
|
57 |
+
logger.info(f"task_id: {task_id}: {res}")
|
58 |
+
break
|
59 |
+
elif "FAILED" != res['output']['task_status']:
|
60 |
+
logger.debug(f"task_id: {task_id}: query result...")
|
61 |
+
time.sleep(1)
|
62 |
+
else:
|
63 |
+
raise gr.Error('Fail to get results from Background Generation task.')
|
64 |
+
|
65 |
+
else:
|
66 |
+
logger.error(f'task_id: {task_id}: Fail to query task result: {res_.content}')
|
67 |
+
raise gr.Error("Fail to query task result.")
|
68 |
+
|
69 |
+
logger.info(f"task_id: {task_id}: download generated images.")
|
70 |
+
img_data = download_images(img_urls, BATCH_SIZE)
|
71 |
+
logger.info(f"task_id: {task_id}: Generate done.")
|
72 |
+
return img_data
|
73 |
+
else:
|
74 |
+
logger.error(f'Fail to create Background Generation task: {res_.content}')
|
75 |
+
raise gr.Error("Fail to create Background Generation task.")
|
76 |
+
|
src/log.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
from logging.handlers import RotatingFileHandler
|
3 |
+
import os
|
4 |
+
|
5 |
+
log_file_name = "workdir/log_replaceAnything.log"
|
6 |
+
os.makedirs(os.path.dirname(log_file_name), exist_ok=True)
|
7 |
+
|
8 |
+
format = '[%(levelname)s] %(asctime)s "%(filename)s", line %(lineno)d, %(message)s'
|
9 |
+
logging.basicConfig(
|
10 |
+
format=format,
|
11 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
12 |
+
level=logging.INFO)
|
13 |
+
logger = logging.getLogger(name="WordArt_Studio")
|
14 |
+
|
15 |
+
fh = RotatingFileHandler(log_file_name, maxBytes=20000000, backupCount=3)
|
16 |
+
formatter = logging.Formatter(format, datefmt="%Y-%m-%d %H:%M:%S")
|
17 |
+
fh.setFormatter(formatter)
|
18 |
+
logger.addHandler(fh)
|
src/person_detect.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy
|
3 |
+
from PIL import Image
|
4 |
+
import requests
|
5 |
+
import urllib.request
|
6 |
+
from http import HTTPStatus
|
7 |
+
from datetime import datetime
|
8 |
+
import json
|
9 |
+
from .log import logger
|
10 |
+
import time
|
11 |
+
import gradio as gr
|
12 |
+
from .util import download_images
|
13 |
+
|
14 |
+
API_KEY = os.getenv("API_KEY_VIRTUALMODEL")
|
15 |
+
|
16 |
+
def call_person_detect(input_image_url):
|
17 |
+
headers = {
|
18 |
+
"Content-Type": "application/json",
|
19 |
+
"Accept": "application/json",
|
20 |
+
"Authorization": f"Bearer {API_KEY}",
|
21 |
+
"X-DashScope-DataInspection": "enable",
|
22 |
+
}
|
23 |
+
data = {
|
24 |
+
"model": "body-detection",
|
25 |
+
"input":{
|
26 |
+
"image_url": input_image_url,
|
27 |
+
},
|
28 |
+
"parameters": {
|
29 |
+
"score_threshold": 0.6,
|
30 |
+
}
|
31 |
+
}
|
32 |
+
url_create_task = 'https://dashscope.aliyuncs.com/api/v1/services/vision/bodydetection/detect'
|
33 |
+
res_ = requests.post(url_create_task, data=json.dumps(data), headers=headers)
|
34 |
+
|
35 |
+
|
36 |
+
res = json.loads(res_.content.decode())
|
37 |
+
request_id = res['request_id']
|
38 |
+
results = res['output']['results']
|
39 |
+
return results
|
src/util.py
ADDED
@@ -0,0 +1,177 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import cv2
|
5 |
+
import os
|
6 |
+
import io
|
7 |
+
import oss2
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
import dashscope
|
11 |
+
from dashscope import MultiModalConversation
|
12 |
+
|
13 |
+
from http import HTTPStatus
|
14 |
+
import re
|
15 |
+
import requests
|
16 |
+
from .log import logger
|
17 |
+
import concurrent.futures
|
18 |
+
|
19 |
+
dashscope.api_key = os.getenv("API_KEY_QW")
|
20 |
+
# oss
|
21 |
+
access_key_id = os.getenv("ACCESS_KEY_ID")
|
22 |
+
access_key_secret = os.getenv("ACCESS_KEY_SECRET")
|
23 |
+
bucket_name = os.getenv("BUCKET_NAME")
|
24 |
+
endpoint = os.getenv("ENDPOINT")
|
25 |
+
|
26 |
+
bucket = oss2.Bucket(oss2.Auth(access_key_id, access_key_secret), endpoint, bucket_name)
|
27 |
+
oss_path = "ashui"
|
28 |
+
oss_path_img_gallery = "ashui_img_gallery"
|
29 |
+
|
30 |
+
def download_img_pil(index, img_url):
|
31 |
+
# print(img_url)
|
32 |
+
r = requests.get(img_url, stream=True)
|
33 |
+
if r.status_code == 200:
|
34 |
+
img = Image.open(io.BytesIO(r.content))
|
35 |
+
return (index, img)
|
36 |
+
else:
|
37 |
+
logger.error(f"Fail to download: {img_url}")
|
38 |
+
|
39 |
+
|
40 |
+
def download_images(img_urls, batch_size):
|
41 |
+
imgs_pil = [None] * batch_size
|
42 |
+
# worker_results = []
|
43 |
+
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
|
44 |
+
to_do = []
|
45 |
+
for i, url in enumerate(img_urls):
|
46 |
+
future = executor.submit(download_img_pil, i, url)
|
47 |
+
to_do.append(future)
|
48 |
+
|
49 |
+
for future in concurrent.futures.as_completed(to_do):
|
50 |
+
ret = future.result()
|
51 |
+
# worker_results.append(ret)
|
52 |
+
index, img_pil = ret
|
53 |
+
imgs_pil[index] = img_pil # 按顺序排列url,后续下载关联的图片或者svg需要使用
|
54 |
+
|
55 |
+
return imgs_pil
|
56 |
+
|
57 |
+
def upload_np_2_oss(input_image, name="cache.png", gallery=False):
|
58 |
+
imgByteArr = io.BytesIO()
|
59 |
+
Image.fromarray(input_image).save(imgByteArr, format="PNG")
|
60 |
+
imgByteArr = imgByteArr.getvalue()
|
61 |
+
|
62 |
+
if gallery:
|
63 |
+
path = oss_path_img_gallery
|
64 |
+
else:
|
65 |
+
path = oss_path
|
66 |
+
|
67 |
+
bucket.put_object(path+"/"+name, imgByteArr) # data为数据,可以是图片
|
68 |
+
ret = bucket.sign_url('GET', path+"/"+name, 60*60*24) # 返回值为链接,参数依次为,方法/oss上文件路径/过期时间(s)
|
69 |
+
del imgByteArr
|
70 |
+
return ret
|
71 |
+
|
72 |
+
|
73 |
+
def call_with_messages(prompt):
|
74 |
+
messages = [
|
75 |
+
{'role': 'user', 'content': prompt}]
|
76 |
+
response = dashscope.Generation.call(
|
77 |
+
'qwen-14b-chat',
|
78 |
+
messages=messages,
|
79 |
+
result_format='message', # set the result is message format.
|
80 |
+
)
|
81 |
+
if response.status_code == HTTPStatus.OK:
|
82 |
+
return response['output']["choices"][0]["message"]['content']
|
83 |
+
else:
|
84 |
+
print('Request id: %s, Status code: %s, error code: %s, error message: %s' % (
|
85 |
+
response.request_id, response.status_code,
|
86 |
+
response.code, response.message
|
87 |
+
))
|
88 |
+
return None
|
89 |
+
|
90 |
+
def HWC3(x):
|
91 |
+
assert x.dtype == np.uint8
|
92 |
+
if x.ndim == 2:
|
93 |
+
x = x[:, :, None]
|
94 |
+
assert x.ndim == 3
|
95 |
+
H, W, C = x.shape
|
96 |
+
assert C == 1 or C == 3 or C == 4
|
97 |
+
if C == 3:
|
98 |
+
return x
|
99 |
+
if C == 1:
|
100 |
+
return np.concatenate([x, x, x], axis=2)
|
101 |
+
if C == 4:
|
102 |
+
color = x[:, :, 0:3].astype(np.float32)
|
103 |
+
alpha = x[:, :, 3:4].astype(np.float32) / 255.0
|
104 |
+
y = color * alpha + 255.0 * (1.0 - alpha)
|
105 |
+
y = y.clip(0, 255).astype(np.uint8)
|
106 |
+
return y
|
107 |
+
|
108 |
+
|
109 |
+
def resize_image(input_image, resolution):
|
110 |
+
H, W, C = input_image.shape
|
111 |
+
H = float(H)
|
112 |
+
W = float(W)
|
113 |
+
k = float(resolution) / min(H, W)
|
114 |
+
H *= k
|
115 |
+
W *= k
|
116 |
+
H = int(np.round(H / 64.0)) * 64
|
117 |
+
W = int(np.round(W / 64.0)) * 64
|
118 |
+
img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
|
119 |
+
return img
|
120 |
+
|
121 |
+
|
122 |
+
def nms(x, t, s):
|
123 |
+
x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
|
124 |
+
|
125 |
+
f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
|
126 |
+
f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
|
127 |
+
f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
|
128 |
+
f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
|
129 |
+
|
130 |
+
y = np.zeros_like(x)
|
131 |
+
|
132 |
+
for f in [f1, f2, f3, f4]:
|
133 |
+
np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
|
134 |
+
|
135 |
+
z = np.zeros_like(y, dtype=np.uint8)
|
136 |
+
z[y > t] = 255
|
137 |
+
return z
|
138 |
+
|
139 |
+
|
140 |
+
def make_noise_disk(H, W, C, F):
|
141 |
+
noise = np.random.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C))
|
142 |
+
noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC)
|
143 |
+
noise = noise[F: F + H, F: F + W]
|
144 |
+
noise -= np.min(noise)
|
145 |
+
noise /= np.max(noise)
|
146 |
+
if C == 1:
|
147 |
+
noise = noise[:, :, None]
|
148 |
+
return noise
|
149 |
+
|
150 |
+
|
151 |
+
def min_max_norm(x):
|
152 |
+
x -= np.min(x)
|
153 |
+
x /= np.maximum(np.max(x), 1e-5)
|
154 |
+
return x
|
155 |
+
|
156 |
+
|
157 |
+
def safe_step(x, step=2):
|
158 |
+
y = x.astype(np.float32) * float(step + 1)
|
159 |
+
y = y.astype(np.int32).astype(np.float32) / float(step)
|
160 |
+
return y
|
161 |
+
|
162 |
+
|
163 |
+
def img2mask(img, H, W, low=10, high=90):
|
164 |
+
assert img.ndim == 3 or img.ndim == 2
|
165 |
+
assert img.dtype == np.uint8
|
166 |
+
|
167 |
+
if img.ndim == 3:
|
168 |
+
y = img[:, :, random.randrange(0, img.shape[2])]
|
169 |
+
else:
|
170 |
+
y = img
|
171 |
+
|
172 |
+
y = cv2.resize(y, (W, H), interpolation=cv2.INTER_CUBIC)
|
173 |
+
|
174 |
+
if random.uniform(0, 1) < 0.5:
|
175 |
+
y = 255 - y
|
176 |
+
|
177 |
+
return y < np.percentile(y, random.randrange(low, high))
|
src/virtualmodel.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy
|
3 |
+
from PIL import Image
|
4 |
+
import requests
|
5 |
+
import urllib.request
|
6 |
+
from http import HTTPStatus
|
7 |
+
from datetime import datetime
|
8 |
+
import json
|
9 |
+
from .log import logger
|
10 |
+
import time
|
11 |
+
import gradio as gr
|
12 |
+
from .util import download_images
|
13 |
+
|
14 |
+
API_KEY = os.getenv("API_KEY_VIRTUALMODEL")
|
15 |
+
|
16 |
+
def call_virtualmodel(input_image_url, input_mask_url, source_background_url, prompt, face_prompt):
|
17 |
+
BATCH_SIZE=4
|
18 |
+
headers = {
|
19 |
+
"Content-Type": "application/json",
|
20 |
+
"Accept": "application/json",
|
21 |
+
"Authorization": f"Bearer {API_KEY}",
|
22 |
+
"X-DashScope-Async": "enable",
|
23 |
+
}
|
24 |
+
data = {
|
25 |
+
"model": "wanx-virtualmodel",
|
26 |
+
"input":{
|
27 |
+
"base_image_url": input_image_url,
|
28 |
+
"mask_image_url": input_mask_url,
|
29 |
+
"prompt": prompt,
|
30 |
+
"face_prompt": face_prompt,
|
31 |
+
"background_image_url": source_background_url,
|
32 |
+
},
|
33 |
+
"parameters": {
|
34 |
+
"short_side_size": "512",
|
35 |
+
"n": BATCH_SIZE
|
36 |
+
}
|
37 |
+
}
|
38 |
+
url_create_task = 'https://dashscope.aliyuncs.com/api/v1/services/aigc/virtualmodel/generation'
|
39 |
+
res_ = requests.post(url_create_task, data=json.dumps(data), headers=headers)
|
40 |
+
|
41 |
+
respose_code = res_.status_code
|
42 |
+
if 200 == respose_code:
|
43 |
+
res = json.loads(res_.content.decode())
|
44 |
+
request_id = res['request_id']
|
45 |
+
task_id = res['output']['task_id']
|
46 |
+
logger.info(f"task_id: {task_id}: Create VirtualModel request success. Params: {data}")
|
47 |
+
|
48 |
+
# 异步查询
|
49 |
+
is_running = True
|
50 |
+
while is_running:
|
51 |
+
url_query = f'https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}'
|
52 |
+
res_ = requests.post(url_query, headers=headers)
|
53 |
+
respose_code = res_.status_code
|
54 |
+
if 200 == respose_code:
|
55 |
+
res = json.loads(res_.content.decode())
|
56 |
+
if "SUCCEEDED" == res['output']['task_status']:
|
57 |
+
logger.info(f"task_id: {task_id}: VirtualModel generation task query success.")
|
58 |
+
results = res['output']['results']
|
59 |
+
img_urls = []
|
60 |
+
for x in results:
|
61 |
+
if "url" in x:
|
62 |
+
img_urls.append(x['url'])
|
63 |
+
logger.info(f"task_id: {task_id}: {res}")
|
64 |
+
break
|
65 |
+
elif "FAILED" != res['output']['task_status']:
|
66 |
+
logger.debug(f"task_id: {task_id}: query result...")
|
67 |
+
time.sleep(1)
|
68 |
+
else:
|
69 |
+
raise gr.Error('Fail to get results from VirtualModel task.')
|
70 |
+
|
71 |
+
else:
|
72 |
+
logger.error(f'task_id: {task_id}: Fail to query task result: {res_.content}')
|
73 |
+
raise gr.Error("Fail to query task result.")
|
74 |
+
|
75 |
+
logger.info(f"task_id: {task_id}: download generated images.")
|
76 |
+
img_data = download_images(img_urls, len(img_urls)) if len(img_urls) > 0 else []
|
77 |
+
logger.info(f"task_id: {task_id}: Generate done.")
|
78 |
+
return img_data
|
79 |
+
else:
|
80 |
+
logger.error(f'Fail to create VirtualModel task: {res_.content}')
|
81 |
+
raise gr.Error("Fail to create VirtualModel task.")
|