tastelikefeet commited on
Commit
4112985
·
1 Parent(s): e834a35
app.py CHANGED
@@ -116,36 +116,36 @@ with block:
116
  <a href='https://github.com/AIGCDesignGroup/ReplaceAnything'><img src='https://img.shields.io/badge/Github-Repo-blue'></a>
117
  </div>
118
  </br>
119
- <h3> 我们发现,在严格保持某个“物体ID”不变的情况下生成新的内容有着很大的市场需求,同时也是具有挑战性的。为此,我们提出了ReplaceAnything框架。它可以用于很多场景,比如<b>人体替换、服装替换、物体替换以及背景替换</b>等等。</h3>
120
- <h5 style="margin: 0; color: red">如果你认为该项目有所帮助的话,不妨给我们Github点个Star以便获取最新的项目进展.</h5>
121
  </br>
122
  </div>
123
  """)
124
 
125
  with gr.Tabs(elem_classes=["Tab"]):
126
- with gr.TabItem("作品广场"):
127
  gr.Gallery(value=showcases,
128
  height=800,
129
  columns=4,
130
  object_fit="scale-down"
131
  )
132
- with gr.TabItem("创作图像"):
133
- with gr.Accordion(label="🧭 操作指南:", open=True, elem_id="accordion"):
134
  with gr.Row(equal_height=True):
135
  with gr.Row(elem_id="ShowCase"):
136
  gr.Image(value="showcase/ra.gif")
137
  gr.Markdown("""
138
- - ⭐️ <b>step1:</b>在“输入图像”中上传or选择Example里面的一张图片
139
- - ⭐️ <b>step2:</b>通过点击鼠标选择图像中希望保留的物体
140
- - ⭐️ <b>step3:</b>输入对应的参数,例如prompt等,点击Run进行生成
141
- - ⭐️ <b>step4 (可选):</b>此外支持换背景操作,上传目标风格背景,执行完step3后点击Run进行生成
142
  """)
143
  with gr.Row():
144
  with gr.Column():
145
  with gr.Column(elem_id="Input"):
146
  with gr.Row():
147
  with gr.Tabs(elem_classes=["feedback"]):
148
- with gr.TabItem("输入图像"):
149
  input_image = gr.Image(type="numpy", label="输入图",scale=2)
150
  original_image = gr.State(value=None,label="索引")
151
  original_mask = gr.State(value=None)
 
116
  <a href='https://github.com/AIGCDesignGroup/ReplaceAnything'><img src='https://img.shields.io/badge/Github-Repo-blue'></a>
117
  </div>
118
  </br>
119
+ <h3>OffendingAIGC techniques have attracted lots of attention recently. They have demonstrated strong capabilities in the areas of image editing, image generation and so on. We find that generating new contents while strictly keeping the identity of use-specified object unchanged is of great demand, yet challenging. To this end, we propose ReplaceAnything framework. It can be used in many scenes, such as human replacement, clothing replacement, background replacement, and so on.</h3>
120
+ <h5 style="margin: 0; color: red">If you found the project helpful, you can click a Star on Github to get the latest updates on the project.</h5>
121
  </br>
122
  </div>
123
  """)
124
 
125
  with gr.Tabs(elem_classes=["Tab"]):
126
+ with gr.TabItem("作品广场(Image Gallery)"):
127
  gr.Gallery(value=showcases,
128
  height=800,
129
  columns=4,
130
  object_fit="scale-down"
131
  )
132
+ with gr.TabItem("创作图像(Image Create)"):
133
+ with gr.Accordion(label="🧭 操作指南(Instructions):", open=True, elem_id="accordion"):
134
  with gr.Row(equal_height=True):
135
  with gr.Row(elem_id="ShowCase"):
136
  gr.Image(value="showcase/ra.gif")
137
  gr.Markdown("""
138
+ - ⭐️ <b>step1:</b>在“输入图像”中上传or选择Example里面的一张图片(Upload or select one image from Example)
139
+ - ⭐️ <b>step2:</b>通过点击鼠标选择图像中希望保留的物体(Click to select the object)
140
+ - ⭐️ <b>step3:</b>输入对应的参数,例如prompt等,点击Run进行生成(Input prompt or reference image)
141
+ - ⭐️ <b>step4 (可选):</b>此外支持换背景操作,上传目标风格背景,执行完step3后点击Run进行生成(Click Run button)
142
  """)
143
  with gr.Row():
144
  with gr.Column():
145
  with gr.Column(elem_id="Input"):
146
  with gr.Row():
147
  with gr.Tabs(elem_classes=["feedback"]):
148
+ with gr.TabItem("输入图像(Input Image)"):
149
  input_image = gr.Image(type="numpy", label="输入图",scale=2)
150
  original_image = gr.State(value=None,label="索引")
151
  original_mask = gr.State(value=None)
requirements.txt CHANGED
@@ -9,3 +9,4 @@ easydict
9
  scikit-image
10
  git+https://github.com/facebookresearch/segment-anything.git
11
  torch
 
 
9
  scikit-image
10
  git+https://github.com/facebookresearch/segment-anything.git
11
  torch
12
+ oss2==2.17.0
src/__init__.py ADDED
File without changes
src/background_generation.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy
3
+ from PIL import Image
4
+ import requests
5
+ import urllib.request
6
+ from http import HTTPStatus
7
+ from datetime import datetime
8
+ import json
9
+ from .log import logger
10
+ import time
11
+ import gradio as gr
12
+ from .util import download_images
13
+
14
+ def call_bg_genration(base_image, ref_img, prompt,ref_prompt_weight=0.5):
15
+ API_KEY = os.getenv("API_KEY_BG_GENERATION")
16
+ BATCH_SIZE=4
17
+ headers = {
18
+ "Content-Type": "application/json",
19
+ "Accept": "application/json",
20
+ "Authorization": f"Bearer {API_KEY}",
21
+ "X-DashScope-Async": "enable",
22
+ }
23
+ data = {
24
+ "model": "wanx-background-generation-v2",
25
+ "input":{
26
+ "base_image_url": base_image,
27
+ 'ref_image_url':ref_img,
28
+ "ref_prompt": prompt,
29
+ },
30
+ "parameters": {
31
+ "ref_prompt_weight": ref_prompt_weight,
32
+ "n": BATCH_SIZE
33
+ }
34
+ }
35
+ url_create_task = 'https://dashscope.aliyuncs.com/api/v1/services/aigc/background-generation/generation'
36
+ res_ = requests.post(url_create_task, data=json.dumps(data), headers=headers)
37
+
38
+ respose_code = res_.status_code
39
+ if 200 == respose_code:
40
+ res = json.loads(res_.content.decode())
41
+ request_id = res['request_id']
42
+ task_id = res['output']['task_id']
43
+ logger.info(f"task_id: {task_id}: Create Background Generation request success. Params: {data}")
44
+
45
+ # 异步查询
46
+ is_running = True
47
+ while is_running:
48
+ url_query = f'https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}'
49
+ res_ = requests.post(url_query, headers=headers)
50
+ respose_code = res_.status_code
51
+ if 200 == respose_code:
52
+ res = json.loads(res_.content.decode())
53
+ if "SUCCEEDED" == res['output']['task_status']:
54
+ logger.info(f"task_id: {task_id}: Background generation task query success.")
55
+ results = res['output']['results']
56
+ img_urls = [x['url'] for x in results]
57
+ logger.info(f"task_id: {task_id}: {res}")
58
+ break
59
+ elif "FAILED" != res['output']['task_status']:
60
+ logger.debug(f"task_id: {task_id}: query result...")
61
+ time.sleep(1)
62
+ else:
63
+ raise gr.Error('Fail to get results from Background Generation task.')
64
+
65
+ else:
66
+ logger.error(f'task_id: {task_id}: Fail to query task result: {res_.content}')
67
+ raise gr.Error("Fail to query task result.")
68
+
69
+ logger.info(f"task_id: {task_id}: download generated images.")
70
+ img_data = download_images(img_urls, BATCH_SIZE)
71
+ logger.info(f"task_id: {task_id}: Generate done.")
72
+ return img_data
73
+ else:
74
+ logger.error(f'Fail to create Background Generation task: {res_.content}')
75
+ raise gr.Error("Fail to create Background Generation task.")
76
+
src/log.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from logging.handlers import RotatingFileHandler
3
+ import os
4
+
5
+ log_file_name = "workdir/log_replaceAnything.log"
6
+ os.makedirs(os.path.dirname(log_file_name), exist_ok=True)
7
+
8
+ format = '[%(levelname)s] %(asctime)s "%(filename)s", line %(lineno)d, %(message)s'
9
+ logging.basicConfig(
10
+ format=format,
11
+ datefmt="%Y-%m-%d %H:%M:%S",
12
+ level=logging.INFO)
13
+ logger = logging.getLogger(name="WordArt_Studio")
14
+
15
+ fh = RotatingFileHandler(log_file_name, maxBytes=20000000, backupCount=3)
16
+ formatter = logging.Formatter(format, datefmt="%Y-%m-%d %H:%M:%S")
17
+ fh.setFormatter(formatter)
18
+ logger.addHandler(fh)
src/person_detect.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy
3
+ from PIL import Image
4
+ import requests
5
+ import urllib.request
6
+ from http import HTTPStatus
7
+ from datetime import datetime
8
+ import json
9
+ from .log import logger
10
+ import time
11
+ import gradio as gr
12
+ from .util import download_images
13
+
14
+ API_KEY = os.getenv("API_KEY_VIRTUALMODEL")
15
+
16
+ def call_person_detect(input_image_url):
17
+ headers = {
18
+ "Content-Type": "application/json",
19
+ "Accept": "application/json",
20
+ "Authorization": f"Bearer {API_KEY}",
21
+ "X-DashScope-DataInspection": "enable",
22
+ }
23
+ data = {
24
+ "model": "body-detection",
25
+ "input":{
26
+ "image_url": input_image_url,
27
+ },
28
+ "parameters": {
29
+ "score_threshold": 0.6,
30
+ }
31
+ }
32
+ url_create_task = 'https://dashscope.aliyuncs.com/api/v1/services/vision/bodydetection/detect'
33
+ res_ = requests.post(url_create_task, data=json.dumps(data), headers=headers)
34
+
35
+
36
+ res = json.loads(res_.content.decode())
37
+ request_id = res['request_id']
38
+ results = res['output']['results']
39
+ return results
src/util.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import numpy as np
4
+ import cv2
5
+ import os
6
+ import io
7
+ import oss2
8
+ from PIL import Image
9
+
10
+ import dashscope
11
+ from dashscope import MultiModalConversation
12
+
13
+ from http import HTTPStatus
14
+ import re
15
+ import requests
16
+ from .log import logger
17
+ import concurrent.futures
18
+
19
+ dashscope.api_key = os.getenv("API_KEY_QW")
20
+ # oss
21
+ access_key_id = os.getenv("ACCESS_KEY_ID")
22
+ access_key_secret = os.getenv("ACCESS_KEY_SECRET")
23
+ bucket_name = os.getenv("BUCKET_NAME")
24
+ endpoint = os.getenv("ENDPOINT")
25
+
26
+ bucket = oss2.Bucket(oss2.Auth(access_key_id, access_key_secret), endpoint, bucket_name)
27
+ oss_path = "ashui"
28
+ oss_path_img_gallery = "ashui_img_gallery"
29
+
30
+ def download_img_pil(index, img_url):
31
+ # print(img_url)
32
+ r = requests.get(img_url, stream=True)
33
+ if r.status_code == 200:
34
+ img = Image.open(io.BytesIO(r.content))
35
+ return (index, img)
36
+ else:
37
+ logger.error(f"Fail to download: {img_url}")
38
+
39
+
40
+ def download_images(img_urls, batch_size):
41
+ imgs_pil = [None] * batch_size
42
+ # worker_results = []
43
+ with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
44
+ to_do = []
45
+ for i, url in enumerate(img_urls):
46
+ future = executor.submit(download_img_pil, i, url)
47
+ to_do.append(future)
48
+
49
+ for future in concurrent.futures.as_completed(to_do):
50
+ ret = future.result()
51
+ # worker_results.append(ret)
52
+ index, img_pil = ret
53
+ imgs_pil[index] = img_pil # 按顺序排列url,后续下载关联的图片或者svg需要使用
54
+
55
+ return imgs_pil
56
+
57
+ def upload_np_2_oss(input_image, name="cache.png", gallery=False):
58
+ imgByteArr = io.BytesIO()
59
+ Image.fromarray(input_image).save(imgByteArr, format="PNG")
60
+ imgByteArr = imgByteArr.getvalue()
61
+
62
+ if gallery:
63
+ path = oss_path_img_gallery
64
+ else:
65
+ path = oss_path
66
+
67
+ bucket.put_object(path+"/"+name, imgByteArr) # data为数据,可以是图片
68
+ ret = bucket.sign_url('GET', path+"/"+name, 60*60*24) # 返回值为链接,参数依次为,方法/oss上文件路径/过期时间(s)
69
+ del imgByteArr
70
+ return ret
71
+
72
+
73
+ def call_with_messages(prompt):
74
+ messages = [
75
+ {'role': 'user', 'content': prompt}]
76
+ response = dashscope.Generation.call(
77
+ 'qwen-14b-chat',
78
+ messages=messages,
79
+ result_format='message', # set the result is message format.
80
+ )
81
+ if response.status_code == HTTPStatus.OK:
82
+ return response['output']["choices"][0]["message"]['content']
83
+ else:
84
+ print('Request id: %s, Status code: %s, error code: %s, error message: %s' % (
85
+ response.request_id, response.status_code,
86
+ response.code, response.message
87
+ ))
88
+ return None
89
+
90
+ def HWC3(x):
91
+ assert x.dtype == np.uint8
92
+ if x.ndim == 2:
93
+ x = x[:, :, None]
94
+ assert x.ndim == 3
95
+ H, W, C = x.shape
96
+ assert C == 1 or C == 3 or C == 4
97
+ if C == 3:
98
+ return x
99
+ if C == 1:
100
+ return np.concatenate([x, x, x], axis=2)
101
+ if C == 4:
102
+ color = x[:, :, 0:3].astype(np.float32)
103
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
104
+ y = color * alpha + 255.0 * (1.0 - alpha)
105
+ y = y.clip(0, 255).astype(np.uint8)
106
+ return y
107
+
108
+
109
+ def resize_image(input_image, resolution):
110
+ H, W, C = input_image.shape
111
+ H = float(H)
112
+ W = float(W)
113
+ k = float(resolution) / min(H, W)
114
+ H *= k
115
+ W *= k
116
+ H = int(np.round(H / 64.0)) * 64
117
+ W = int(np.round(W / 64.0)) * 64
118
+ img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
119
+ return img
120
+
121
+
122
+ def nms(x, t, s):
123
+ x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
124
+
125
+ f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
126
+ f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
127
+ f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
128
+ f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
129
+
130
+ y = np.zeros_like(x)
131
+
132
+ for f in [f1, f2, f3, f4]:
133
+ np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
134
+
135
+ z = np.zeros_like(y, dtype=np.uint8)
136
+ z[y > t] = 255
137
+ return z
138
+
139
+
140
+ def make_noise_disk(H, W, C, F):
141
+ noise = np.random.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C))
142
+ noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC)
143
+ noise = noise[F: F + H, F: F + W]
144
+ noise -= np.min(noise)
145
+ noise /= np.max(noise)
146
+ if C == 1:
147
+ noise = noise[:, :, None]
148
+ return noise
149
+
150
+
151
+ def min_max_norm(x):
152
+ x -= np.min(x)
153
+ x /= np.maximum(np.max(x), 1e-5)
154
+ return x
155
+
156
+
157
+ def safe_step(x, step=2):
158
+ y = x.astype(np.float32) * float(step + 1)
159
+ y = y.astype(np.int32).astype(np.float32) / float(step)
160
+ return y
161
+
162
+
163
+ def img2mask(img, H, W, low=10, high=90):
164
+ assert img.ndim == 3 or img.ndim == 2
165
+ assert img.dtype == np.uint8
166
+
167
+ if img.ndim == 3:
168
+ y = img[:, :, random.randrange(0, img.shape[2])]
169
+ else:
170
+ y = img
171
+
172
+ y = cv2.resize(y, (W, H), interpolation=cv2.INTER_CUBIC)
173
+
174
+ if random.uniform(0, 1) < 0.5:
175
+ y = 255 - y
176
+
177
+ return y < np.percentile(y, random.randrange(low, high))
src/virtualmodel.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy
3
+ from PIL import Image
4
+ import requests
5
+ import urllib.request
6
+ from http import HTTPStatus
7
+ from datetime import datetime
8
+ import json
9
+ from .log import logger
10
+ import time
11
+ import gradio as gr
12
+ from .util import download_images
13
+
14
+ API_KEY = os.getenv("API_KEY_VIRTUALMODEL")
15
+
16
+ def call_virtualmodel(input_image_url, input_mask_url, source_background_url, prompt, face_prompt):
17
+ BATCH_SIZE=4
18
+ headers = {
19
+ "Content-Type": "application/json",
20
+ "Accept": "application/json",
21
+ "Authorization": f"Bearer {API_KEY}",
22
+ "X-DashScope-Async": "enable",
23
+ }
24
+ data = {
25
+ "model": "wanx-virtualmodel",
26
+ "input":{
27
+ "base_image_url": input_image_url,
28
+ "mask_image_url": input_mask_url,
29
+ "prompt": prompt,
30
+ "face_prompt": face_prompt,
31
+ "background_image_url": source_background_url,
32
+ },
33
+ "parameters": {
34
+ "short_side_size": "512",
35
+ "n": BATCH_SIZE
36
+ }
37
+ }
38
+ url_create_task = 'https://dashscope.aliyuncs.com/api/v1/services/aigc/virtualmodel/generation'
39
+ res_ = requests.post(url_create_task, data=json.dumps(data), headers=headers)
40
+
41
+ respose_code = res_.status_code
42
+ if 200 == respose_code:
43
+ res = json.loads(res_.content.decode())
44
+ request_id = res['request_id']
45
+ task_id = res['output']['task_id']
46
+ logger.info(f"task_id: {task_id}: Create VirtualModel request success. Params: {data}")
47
+
48
+ # 异步查询
49
+ is_running = True
50
+ while is_running:
51
+ url_query = f'https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}'
52
+ res_ = requests.post(url_query, headers=headers)
53
+ respose_code = res_.status_code
54
+ if 200 == respose_code:
55
+ res = json.loads(res_.content.decode())
56
+ if "SUCCEEDED" == res['output']['task_status']:
57
+ logger.info(f"task_id: {task_id}: VirtualModel generation task query success.")
58
+ results = res['output']['results']
59
+ img_urls = []
60
+ for x in results:
61
+ if "url" in x:
62
+ img_urls.append(x['url'])
63
+ logger.info(f"task_id: {task_id}: {res}")
64
+ break
65
+ elif "FAILED" != res['output']['task_status']:
66
+ logger.debug(f"task_id: {task_id}: query result...")
67
+ time.sleep(1)
68
+ else:
69
+ raise gr.Error('Fail to get results from VirtualModel task.')
70
+
71
+ else:
72
+ logger.error(f'task_id: {task_id}: Fail to query task result: {res_.content}')
73
+ raise gr.Error("Fail to query task result.")
74
+
75
+ logger.info(f"task_id: {task_id}: download generated images.")
76
+ img_data = download_images(img_urls, len(img_urls)) if len(img_urls) > 0 else []
77
+ logger.info(f"task_id: {task_id}: Generate done.")
78
+ return img_data
79
+ else:
80
+ logger.error(f'Fail to create VirtualModel task: {res_.content}')
81
+ raise gr.Error("Fail to create VirtualModel task.")