diff --git a/README.md b/README.md
index 5c0c69382c0d85e57152b7553610e0635d999ca5..f117aed0ee105bd3675f5891962582744b302b3e 100644
--- a/README.md
+++ b/README.md
@@ -1,12 +1,13 @@
---
-title: SAM SDXL Inpainting
-emoji: 🦀
-colorFrom: red
-colorTo: gray
+title: ReplaceAnything Using SAM + SDXL Inpainting
+emoji: 📚
+colorFrom: yellow
+colorTo: blue
sdk: gradio
-sdk_version: 4.14.0
+sdk_version: 3.50.2
app_file: app.py
pinned: false
+license: apache-2.0
---
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a04fad64b1dcbccce7371c3d257a2187274813c
--- /dev/null
+++ b/app.py
@@ -0,0 +1,311 @@
+##!/usr/bin/python3
+# -*- coding: utf-8 -*-
+# @Time : 2023-06-01
+# @Author : ashui(Binghui Chen)
+from sympy import im
+import time
+import cv2
+import gradio as gr
+import numpy as np
+import random
+import math
+import uuid
+import torch
+from torch import autocast
+
+from src.util import resize_image, upload_np_2_oss
+from diffusers import AutoPipelineForInpainting, UNet2DConditionModel
+import diffusers
+import sys, os
+
+from PIL import Image, ImageFilter, ImageOps, ImageDraw
+
+from segment_anything import SamPredictor, sam_model_registry
+
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+pipe = AutoPipelineForInpainting.from_pretrained("diffusers/stable-diffusion-xl-1.0-inpainting-0.1", torch_dtype=torch.float16, variant="fp16").to(device)
+
+mobile_sam = sam_model_registry['vit_h'](checkpoint='models/sam_vit_h_4b8939.pth').to("cuda")
+mobile_sam.eval()
+mobile_predictor = SamPredictor(mobile_sam)
+colors = [(255, 0, 0), (0, 255, 0)]
+markers = [1, 5]
+
+# - - - - - examples - - - - - #
+# 输入图地址, 文本, 背景图地址, index, []
+image_examples = [
+ ["imgs/000.jpg", "A young woman in short sleeves shows off a mobile phone", None, 0, []],
+ ["imgs/001.jpg", "A young woman wears short sleeves, her hand is holding a bottle.", None, 1, []],
+ ["imgs/003.png", "A woman is wearing a black suit against a blue background", "imgs/003_bg.jpg", 2, []],
+ ["imgs/002.png", "A young woman poses in a dress, she stands in front of a blue background", "imgs/002_bg.png", 3, []],
+ ["imgs/bg_gen/base_imgs/1cdb9b1e6daea6a1b85236595d3e43d6.png", "water splash", None, 4, []],
+ ["imgs/bg_gen/base_imgs/1cdb9b1e6daea6a1b85236595d3e43d6.png", "", "imgs/bg_gen/ref_imgs/df9a93ac2bca12696a9166182c4bf02ad9679aa5.jpg", 5, []],
+ ["imgs/bg_gen/base_imgs/IMG_2941.png", "On the desert floor", None, 6, []],
+ ["imgs/bg_gen/base_imgs/b2b1ed243364473e49d2e478e4f24413.png","White ground, white background, light coming in, Canon",None,7,[]],
+ ]
+
+img = "image_gallery/"
+files = os.listdir(img)
+files = sorted(files)
+showcases = []
+for idx, name in enumerate(files):
+ temp = os.path.join(os.path.dirname(__file__), img, name)
+ showcases.append(temp)
+
+def process(original_image, original_mask, input_mask, selected_points, prompt,negative_prompt,guidance_scale,steps,strength,scheduler):
+ if original_image.shape[0]>original_image.shape[1]:
+ original_image=cv2.resize(original_image,(int(original_image.shape[1]*1000/original_image.shape[0]),1000))
+ if original_mask.shape[0]>original_mask.shape[1]:
+ original_mask=cv2.resize(original_mask,(int(original_mask.shape[1]*1000/original_mask.shape[0]),1000))
+ if original_image is None:
+ raise gr.Error('Please upload the input image')
+ if (original_mask is None or len(selected_points)==0) and input_mask is None:
+ raise gr.Error("Please click the region where you want to keep unchanged, or upload a white-black Mask image where white color indicates region to be retained.")
+
+ # load example image
+ if isinstance(original_image, int):
+ image_name = image_examples[original_image][0]
+ original_image = cv2.imread(image_name)
+ original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB)
+
+ if input_mask is not None:
+ H,W=original_image.shape[:2]
+ original_mask = cv2.resize(input_mask, (W, H))
+ else:
+ original_mask = np.clip(255 - original_mask, 0, 255).astype(np.uint8)
+
+ request_id = str(uuid.uuid4())
+ # input_image_url = upload_np_2_oss(original_image, request_id+".png")
+ # input_mask_url = upload_np_2_oss(original_mask, request_id+"_mask.png")
+ # source_background_url = "" if source_background is None else upload_np_2_oss(source_background, request_id+"_bg.png")
+ if negative_prompt == "":
+ negative_prompt = None
+ scheduler_class_name = scheduler.split("-")[0]
+
+ add_kwargs = {}
+ if len(scheduler.split("-")) > 1:
+ add_kwargs["use_karras"] = True
+ if len(scheduler.split("-")) > 2:
+ add_kwargs["algorithm_type"] = "sde-dpmsolver++"
+
+ scheduler = getattr(diffusers, scheduler_class_name)
+ pipe.scheduler = scheduler.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="scheduler", **add_kwargs)
+
+ # Image.fromarray(original_mask).save("original_mask.png")
+ init_image = Image.fromarray(original_image).convert("RGB")
+ mask = Image.fromarray(original_mask).convert("RGB")
+ output = pipe(prompt = prompt, negative_prompt=negative_prompt, image=init_image, mask_image=mask, guidance_scale=guidance_scale, num_inference_steps=int(steps), strength=strength)
+ # person detect: [[x1,y1,x2,y2,score],]
+ # det_res = call_person_detect(input_image_url)
+
+ res = []
+ # if len(det_res)>0:
+ # if len(prompt)==0:
+ # raise gr.Error('Please input the prompt')
+ # # res = call_virtualmodel(input_image_url, input_mask_url, source_background_url, prompt, face_prompt)
+ # else:
+ # ###
+ # if len(prompt)==0:
+ # prompt=None
+ # ref_image_url=None if source_background_url =='' else source_background_url
+ # original_mask=original_mask[:,:,:1]
+ # base_image=np.concatenate([original_image, original_mask],axis=2)
+ # base_image_url=upload_np_2_oss(base_image, request_id+"_base.png")
+ # res=call_bg_genration(base_image_url,ref_image_url,prompt,ref_prompt_weight=0.5)
+ # Image.fromarray(input_mask).save("input_mask.png")
+ res= output.images[0]
+ res = res.convert("RGB")
+ #resize the output image to original image size
+ res = res.resize((original_image.shape[1],original_image.shape[0]), Image.LANCZOS)
+ return [res], request_id, True
+
+block = gr.Blocks(
+ css="css/style.css",
+ theme=gr.themes.Soft(
+ radius_size=gr.themes.sizes.radius_none,
+ text_size=gr.themes.sizes.text_md
+ )
+ ).queue(concurrency_count=2)
+with block:
+ with gr.Row():
+ with gr.Column():
+ gr.HTML(f"""
+
+
+
SAM + SDXL Inpainting
+
+
+
+
+
ReplaceAnything using SAM + SDXL Inpainting as you want: Ultra-high quality content replacement
+
+ """)
+
+ with gr.Tabs(elem_classes=["Tab"]):
+ with gr.TabItem("Image Create"):
+ with gr.Accordion(label="🧭 Instructions:", open=True, elem_id="accordion"):
+ with gr.Row(equal_height=True):
+ gr.Markdown("""
+ - ⭐️ step1:Upload or select one image from Example
+ - ⭐️ step2:Click on Input-image to select the object to be retained (or upload a white-black Mask image, in which white color indicates the region you want to keep unchanged)
+ - ⭐️ step3:Input prompt or reference image (highly-recommended) for generating new contents
+ - ⭐️ step4:Click Run button
+ """)
+ with gr.Row():
+ with gr.Column():
+ with gr.Column(elem_id="Input"):
+ with gr.Row():
+ with gr.Tabs(elem_classes=["feedback"]):
+ with gr.TabItem("Input Image"):
+ input_image = gr.Image(type="numpy", label="input",scale=2)
+ original_image = gr.State(value=None,label="index")
+ original_mask = gr.State(value=None)
+ selected_points = gr.State([],label="click points")
+ with gr.Row(elem_id="Seg"):
+ radio = gr.Radio(['foreground', 'background'], label='Click to seg: ', value='foreground',scale=2)
+ undo_button = gr.Button('Undo seg', elem_id="btnSEG",scale=1)
+ input_mask = gr.Image(type="numpy", label="Mask Image")
+ prompt = gr.Textbox(label="Prompt", placeholder="Please input your prompt",value='',lines=1)
+ negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="Please input your prompt",value='hand,blur,face,bad',lines=1)
+ guidance_scale = gr.Number(value=7.5, minimum=1.0, maximum=20.0, step=0.1, label="guidance_scale")
+ steps = gr.Number(value=20, minimum=10, maximum=30, step=1, label="steps")
+ strength = gr.Number(value=0.99, minimum=0.01, maximum=1.0, step=0.01, label="strength")
+ with gr.Row(mobile_collapse=False, equal_height=True):
+ schedulers = ["DEISMultistepScheduler", "HeunDiscreteScheduler", "EulerDiscreteScheduler", "DPMSolverMultistepScheduler", "DPMSolverMultistepScheduler-Karras", "DPMSolverMultistepScheduler-Karras-SDE"]
+ scheduler = gr.Dropdown(label="Schedulers", choices=schedulers, value="EulerDiscreteScheduler")
+
+ run_button = gr.Button("Run",elem_id="btn")
+
+ with gr.Column():
+ with gr.Tabs(elem_classes=["feedback"]):
+ with gr.TabItem("Outputs"):
+ result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", preview=True)
+ # recommend=gr.Button("Recommend results to Image Gallery",elem_id="recBut")
+ request_id=gr.State(value="")
+ gallery_flag=gr.State(value=False)
+
+ # once user upload an image, the original image is stored in `original_image`
+ def store_img(img):
+ # image upload is too slow
+ # if min(img.shape[0], img.shape[1]) > 896:
+ # img = resize_image(img, 896)
+ # if max(img.shape[0], img.shape[1])*1.0/min(img.shape[0], img.shape[1])>2.0:
+ # raise gr.Error('image aspect ratio cannot be larger than 2.0')
+ return img, img, [], None # when new image is uploaded, `selected_points` should be empty
+
+ input_image.upload(
+ store_img,
+ [input_image],
+ [input_image, original_image, selected_points]
+ )
+
+ # user click the image to get points, and show the points on the image
+ def segmentation(img, sel_pix):
+ print("segmentation")
+ # online show seg mask
+ points = []
+ labels = []
+ for p, l in sel_pix:
+ points.append(p)
+ labels.append(l)
+ mobile_predictor.set_image(img if isinstance(img, np.ndarray) else np.array(img))
+ with torch.no_grad():
+ with autocast("cuda"):
+ masks, _, _ = mobile_predictor.predict(point_coords=np.array(points), point_labels=np.array(labels), multimask_output=False)
+
+ output_mask = np.ones((masks.shape[1], masks.shape[2], 3))*255
+ for i in range(3):
+ output_mask[masks[0] == True, i] = 0.0
+
+ mask_all = np.ones((masks.shape[1], masks.shape[2], 3))
+ color_mask = np.random.random((1, 3)).tolist()[0]
+ for i in range(3):
+ mask_all[masks[0] == True, i] = color_mask[i]
+ masked_img = img / 255 * 0.3 + mask_all * 0.7
+ masked_img = masked_img*255
+ ## draw points
+ for point, label in sel_pix:
+ cv2.drawMarker(masked_img, point, colors[label], markerType=markers[label], markerSize=20, thickness=5)
+ return masked_img, output_mask
+
+ def get_point(img, sel_pix, point_type, evt: gr.SelectData):
+
+ if point_type == 'foreground':
+ sel_pix.append((evt.index, 1)) # append the foreground_point
+ elif point_type == 'background':
+ sel_pix.append((evt.index, 0)) # append the background_point
+ else:
+ sel_pix.append((evt.index, 1)) # default foreground_point
+
+ if isinstance(img, int):
+ image_name = image_examples[img][0]
+ img = cv2.imread(image_name)
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+
+ # online show seg mask
+ if img.shape[0]>img.shape[1]:
+ img=cv2.resize(img,(int(img.shape[1]*1000/img.shape[0]),1000))
+ masked_img, output_mask = segmentation(img, sel_pix)
+
+ return masked_img.astype(np.uint8), output_mask
+
+ input_image.select(
+ get_point,
+ [original_image, selected_points, radio],
+ [input_image, original_mask],
+ )
+
+ # undo the selected point
+ def undo_points(orig_img, sel_pix):
+ # draw points
+ output_mask = None
+ if len(sel_pix) != 0:
+ if isinstance(orig_img, int): # if orig_img is int, the image if select from examples
+ temp = cv2.imread(image_examples[orig_img][0])
+ temp = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB)
+ else:
+ temp = orig_img.copy()
+ sel_pix.pop()
+ # online show seg mask
+ if len(sel_pix) !=0:
+ temp, output_mask = segmentation(temp, sel_pix)
+ return temp.astype(np.uint8), output_mask
+ else:
+ gr.Error("Nothing to Undo")
+
+ undo_button.click(
+ undo_points,
+ [original_image, selected_points],
+ [input_image, original_mask]
+ )
+
+ def upload_to_img_gallery(img, res, re_id, flag):
+ if flag:
+ gr.Info("Image uploading")
+ if isinstance(img, int):
+ image_name = image_examples[img][0]
+ img = cv2.imread(image_name)
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ _ = upload_np_2_oss(img, name=re_id+"_ori.jpg", gallery=True)
+ for idx, r in enumerate(res):
+ r = cv2.imread(r['name'])
+ r = cv2.cvtColor(r, cv2.COLOR_BGR2RGB)
+ _ = upload_np_2_oss(r, name=re_id+f"_res_{idx}.jpg", gallery=True)
+ flag=False
+ gr.Info("Images have beend uploaded and are under check")
+ else:
+ gr.Info("Nothing to to")
+ return flag
+
+ # recommend.click(
+ # upload_to_img_gallery,
+ # [original_image, result_gallery, request_id, gallery_flag],
+ # [gallery_flag]
+ # )
+ # ips=[input_image, original_image, original_mask, input_mask, selected_points, prompt,negative_prompt,guidance_scale,steps,strength,scheduler]
+ ips=[original_image, original_mask, input_mask, selected_points, prompt,negative_prompt,guidance_scale,steps,strength,scheduler]
+ run_button.click(fn=process, inputs=ips, outputs=[result_gallery, request_id, gallery_flag])
+
+
+block.launch(share=True)
diff --git a/css/0.png b/css/0.png
new file mode 100644
index 0000000000000000000000000000000000000000..1f71aacffbb3adda52c1e86b0b3064533e6fb65b
Binary files /dev/null and b/css/0.png differ
diff --git a/css/style.css b/css/style.css
new file mode 100644
index 0000000000000000000000000000000000000000..e21f344b481a57cd5e50a87e6b111caaea9b7222
--- /dev/null
+++ b/css/style.css
@@ -0,0 +1,59 @@
+
+.baselayout{
+ background: url('https://img.alicdn.com/imgextra/i1/O1CN016hd0V91ilWY5Xr24B_!!6000000004453-2-tps-2882-256.png') no-repeat;
+}
+#btn {
+ background-color: #336699;
+ color: white;
+}
+#recBut {
+ background-color: #bb5252;
+ color: white;
+ width: 30%;
+ margin: auto;
+}
+#btnSEG {
+ background-color: #D5F3F4;
+ color: black;
+}
+#btnCHAT {
+ background-color: #B6DBF2;
+ color: black;
+}
+#accordion {
+ background-color: transparent;
+}
+#accordion1 {
+ background-color: #ecedee;
+}
+.feedback button.selected{
+ background-color: #6699CC;
+ color: white !important;
+}
+.feedback1 button.selected{
+ background-color: #839ab2;
+ color: white !important;
+}
+.Tab button.selected{
+ color: red;
+ font-weight: bold;
+}
+#Image {
+ width: 80%;
+ margin:auto;
+}
+#ShowCase {
+ width: 30%;
+ flex:none !important;
+}
+
+#Input {
+ border-style:solid;
+ border-width:1px;
+ border-color:#000000
+}
+#Seg {
+ min-width: min(100px, 100%) !important;
+ width: 100%;
+ margin:auto;
+}
diff --git a/image_gallery/00.png b/image_gallery/00.png
new file mode 100644
index 0000000000000000000000000000000000000000..5085f085aebca32adca69602cc44c5616652bdc8
Binary files /dev/null and b/image_gallery/00.png differ
diff --git a/image_gallery/01.png b/image_gallery/01.png
new file mode 100644
index 0000000000000000000000000000000000000000..3bc5574ca6f5a700a20b23fe81b59298fc289629
Binary files /dev/null and b/image_gallery/01.png differ
diff --git a/image_gallery/02.png b/image_gallery/02.png
new file mode 100644
index 0000000000000000000000000000000000000000..b5579a57d6d25c9897ff04c2bdc887b36127727e
Binary files /dev/null and b/image_gallery/02.png differ
diff --git a/image_gallery/a001.jpg b/image_gallery/a001.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..0958da232bc2adc083183cb8e08aab910cc5eac3
Binary files /dev/null and b/image_gallery/a001.jpg differ
diff --git a/image_gallery/a002.jpg b/image_gallery/a002.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..95abdbac66d0bcf2ca6739c0fce703838311b313
Binary files /dev/null and b/image_gallery/a002.jpg differ
diff --git a/image_gallery/a003.jpg b/image_gallery/a003.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..cc4b86d937f0e8d735d6f5a8494bdc3063b518c5
Binary files /dev/null and b/image_gallery/a003.jpg differ
diff --git a/image_gallery/a004.jpg b/image_gallery/a004.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..06697167bd9f1338347007951aa78612871f71c1
Binary files /dev/null and b/image_gallery/a004.jpg differ
diff --git a/image_gallery/a005.jpg b/image_gallery/a005.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..c267cebb616865728e6eed74a4ff723c34a82c11
Binary files /dev/null and b/image_gallery/a005.jpg differ
diff --git a/image_gallery/a006.jpg b/image_gallery/a006.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..21a3762d1f1b60c8dabed2632c1dfafada56ad68
Binary files /dev/null and b/image_gallery/a006.jpg differ
diff --git a/image_gallery/a007.jpg b/image_gallery/a007.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..ba3204af588e6b1be7dce38d5bfff816ccc528dc
Binary files /dev/null and b/image_gallery/a007.jpg differ
diff --git a/image_gallery/a009.jpg b/image_gallery/a009.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..8b973739cb5677c783e8a572fe79a0dd49b88f7d
Binary files /dev/null and b/image_gallery/a009.jpg differ
diff --git a/image_gallery/bg_001.jpg b/image_gallery/bg_001.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..c05697e0c5b0649fb70ee44066c8d31123adb7ae
Binary files /dev/null and b/image_gallery/bg_001.jpg differ
diff --git a/image_gallery/bg_002.jpg b/image_gallery/bg_002.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..e3bc03054ca5f1f1afa48c277d5f86640839da77
Binary files /dev/null and b/image_gallery/bg_002.jpg differ
diff --git a/image_gallery/bg_003.jpg b/image_gallery/bg_003.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..51f375bb74521625bea9d1b44de2fcfbc53520f1
Binary files /dev/null and b/image_gallery/bg_003.jpg differ
diff --git a/image_gallery/bg_004.jpg b/image_gallery/bg_004.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..8d72dbd040c0f22ca7dca5501906f3ae83d2d82f
Binary files /dev/null and b/image_gallery/bg_004.jpg differ
diff --git a/image_gallery/bg_005.jpg b/image_gallery/bg_005.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..d86a9c922e484c608d8287e77478689a3e74c7d6
Binary files /dev/null and b/image_gallery/bg_005.jpg differ
diff --git a/image_gallery/bg_006.jpg b/image_gallery/bg_006.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..9694d3ae04de57ac65fb05b9af59982923012194
Binary files /dev/null and b/image_gallery/bg_006.jpg differ
diff --git a/image_gallery/bg_007.jpg b/image_gallery/bg_007.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..d5e8c0a39420bb04c1d32360eabd29d0d8298468
Binary files /dev/null and b/image_gallery/bg_007.jpg differ
diff --git a/image_gallery/bg_008.jpg b/image_gallery/bg_008.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..b335a97afd7edff664219a3debc540180a9752f0
Binary files /dev/null and b/image_gallery/bg_008.jpg differ
diff --git a/image_gallery/bg_009.jpg b/image_gallery/bg_009.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..924dbac8340420e35fe428de31069179dfb816d3
Binary files /dev/null and b/image_gallery/bg_009.jpg differ
diff --git a/image_gallery/bg_010.jpg b/image_gallery/bg_010.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..1b0035c9f14004a3efb73289c41b787e330707ce
Binary files /dev/null and b/image_gallery/bg_010.jpg differ
diff --git a/image_gallery/bg_012.jpg b/image_gallery/bg_012.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..77ceb469a7991e6fcbcdaaf96b7600735558c1c3
Binary files /dev/null and b/image_gallery/bg_012.jpg differ
diff --git a/imgs/000.jpg b/imgs/000.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..47d05a3266c2d87452d0b589bab3082a611d9785
Binary files /dev/null and b/imgs/000.jpg differ
diff --git a/imgs/001.jpg b/imgs/001.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..c32bfa118314574f6c2c90acd3ffd00e0708f0c5
Binary files /dev/null and b/imgs/001.jpg differ
diff --git a/imgs/002.png b/imgs/002.png
new file mode 100644
index 0000000000000000000000000000000000000000..3c35d0b5230c1875a1ff2720dcd5a44cd7ca3da1
Binary files /dev/null and b/imgs/002.png differ
diff --git a/imgs/002_bg.png b/imgs/002_bg.png
new file mode 100644
index 0000000000000000000000000000000000000000..a15feedee1e36555163c729f1d653ff08dd510ea
Binary files /dev/null and b/imgs/002_bg.png differ
diff --git a/imgs/003.png b/imgs/003.png
new file mode 100644
index 0000000000000000000000000000000000000000..9903793c03fb477d5432df658c1342515a706d8a
Binary files /dev/null and b/imgs/003.png differ
diff --git a/imgs/003_bg.jpg b/imgs/003_bg.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..a6443f1076b0fc4631f30f1b29220dc713247f47
Binary files /dev/null and b/imgs/003_bg.jpg differ
diff --git a/imgs/bg_gen/base_imgs/1cdb9b1e6daea6a1b85236595d3e43d6.png b/imgs/bg_gen/base_imgs/1cdb9b1e6daea6a1b85236595d3e43d6.png
new file mode 100644
index 0000000000000000000000000000000000000000..5ea82ebddcae135dccedcef526a73cd2f24c0e47
Binary files /dev/null and b/imgs/bg_gen/base_imgs/1cdb9b1e6daea6a1b85236595d3e43d6.png differ
diff --git a/imgs/bg_gen/base_imgs/IMG_2941.png b/imgs/bg_gen/base_imgs/IMG_2941.png
new file mode 100644
index 0000000000000000000000000000000000000000..37f84d452ae2f175b5e4f9412bd057bfdfa9e8d7
Binary files /dev/null and b/imgs/bg_gen/base_imgs/IMG_2941.png differ
diff --git a/imgs/bg_gen/base_imgs/b2b1ed243364473e49d2e478e4f24413.png b/imgs/bg_gen/base_imgs/b2b1ed243364473e49d2e478e4f24413.png
new file mode 100644
index 0000000000000000000000000000000000000000..db679530a008829ae5d412ebd313bddb7858599b
Binary files /dev/null and b/imgs/bg_gen/base_imgs/b2b1ed243364473e49d2e478e4f24413.png differ
diff --git a/imgs/bg_gen/ref_imgs/df9a93ac2bca12696a9166182c4bf02ad9679aa5.jpg b/imgs/bg_gen/ref_imgs/df9a93ac2bca12696a9166182c4bf02ad9679aa5.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..87389876dd825fd733ebb1cb822707da23cf2388
Binary files /dev/null and b/imgs/bg_gen/ref_imgs/df9a93ac2bca12696a9166182c4bf02ad9679aa5.jpg differ
diff --git a/models/DOWNLOAD_MODEL_HERE.txt b/models/DOWNLOAD_MODEL_HERE.txt
new file mode 100644
index 0000000000000000000000000000000000000000..af74612ff809a40bf66785523e513702961ccb7b
--- /dev/null
+++ b/models/DOWNLOAD_MODEL_HERE.txt
@@ -0,0 +1,2 @@
+模型链接
+https://vision-poster.oss-cn-shanghai.aliyuncs.com/ashui/sam_vit_h_4b8939.pth?OSSAccessKeyId=LTAI5tSPYbksBzcmooNHCYif&Expires=3599001703148669&Signature=TYznO77DKFjGNn92SnR9RbucOlU%3D
\ No newline at end of file
diff --git a/models/sam_vit_h_4b8939.pth b/models/sam_vit_h_4b8939.pth
new file mode 100644
index 0000000000000000000000000000000000000000..8523acce9ddab1cf7e355628a08b1aab8ce08a72
--- /dev/null
+++ b/models/sam_vit_h_4b8939.pth
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e
+size 2564550879
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..5510fc3924454682a00f8c19624699cc1b8bfc8a
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,23 @@
+dashscope
+sympy
+Pillow==9.5.0
+gradio==3.50.0
+opencv-python
+omegaconf
+sentencepiece
+easydict
+scikit-image
+git+https://github.com/facebookresearch/segment-anything.git
+torch
+torchvision
+oss2==2.17.0
+--extra-index-url https://download.pytorch.org/whl/cu118
+torch
+git+https://github.com/huggingface/diffusers.git
+transformers
+accelerate
+ftfy
+numpy
+matplotlib
+uuid
+opencv-python
diff --git a/sdxl.txt b/sdxl.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e0faa98a67af475254c38577f67cd0441e7dd1a1
--- /dev/null
+++ b/sdxl.txt
@@ -0,0 +1,10 @@
+--extra-index-url https://download.pytorch.org/whl/cu118
+torch
+git+https://github.com/huggingface/diffusers.git
+transformers
+accelerate
+ftfy
+numpy
+matplotlib
+uuid
+opencv-python
\ No newline at end of file
diff --git a/src/__init__.py b/src/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/__pycache__/__init__.cpython-38.pyc b/src/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8dc73cc49c64b072034ceca2a70e452d9cddb19b
Binary files /dev/null and b/src/__pycache__/__init__.cpython-38.pyc differ
diff --git a/src/__pycache__/background_generation.cpython-38.pyc b/src/__pycache__/background_generation.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..00bb8b82e0c95b352bddf1364aa4338cd8efcc9c
Binary files /dev/null and b/src/__pycache__/background_generation.cpython-38.pyc differ
diff --git a/src/__pycache__/log.cpython-38.pyc b/src/__pycache__/log.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a0c1693fb0ca520a26e24e96bbb177b2d504f363
Binary files /dev/null and b/src/__pycache__/log.cpython-38.pyc differ
diff --git a/src/__pycache__/person_detect.cpython-38.pyc b/src/__pycache__/person_detect.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..0ef32d9362deafc1ee2a477a23709461d7194970
Binary files /dev/null and b/src/__pycache__/person_detect.cpython-38.pyc differ
diff --git a/src/__pycache__/util.cpython-38.pyc b/src/__pycache__/util.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d2a5a9d5022bfa6ee781b4f08d172fba5b571c2b
Binary files /dev/null and b/src/__pycache__/util.cpython-38.pyc differ
diff --git a/src/__pycache__/virtualmodel.cpython-38.pyc b/src/__pycache__/virtualmodel.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ef33df1a408ac8fc709ee7ec445437b0b9dc81eb
Binary files /dev/null and b/src/__pycache__/virtualmodel.cpython-38.pyc differ
diff --git a/src/background_generation.py b/src/background_generation.py
new file mode 100644
index 0000000000000000000000000000000000000000..387202d1cd77a2304b81af1dd1296b219c029e81
--- /dev/null
+++ b/src/background_generation.py
@@ -0,0 +1,76 @@
+import os
+import numpy
+from PIL import Image
+import requests
+import urllib.request
+from http import HTTPStatus
+from datetime import datetime
+import json
+from .log import logger
+import time
+import gradio as gr
+from .util import download_images
+
+def call_bg_genration(base_image, ref_img, prompt,ref_prompt_weight=0.5):
+ API_KEY = os.getenv("API_KEY_BG_GENERATION")
+ BATCH_SIZE=4
+ headers = {
+ "Content-Type": "application/json",
+ "Accept": "application/json",
+ "Authorization": f"Bearer {API_KEY}",
+ "X-DashScope-Async": "enable",
+ }
+ data = {
+ "model": "wanx-background-generation-v2",
+ "input":{
+ "base_image_url": base_image,
+ 'ref_image_url':ref_img,
+ "ref_prompt": prompt,
+ },
+ "parameters": {
+ "ref_prompt_weight": ref_prompt_weight,
+ "n": BATCH_SIZE
+ }
+ }
+ url_create_task = 'https://dashscope.aliyuncs.com/api/v1/services/aigc/background-generation/generation'
+ res_ = requests.post(url_create_task, data=json.dumps(data), headers=headers)
+
+ respose_code = res_.status_code
+ if 200 == respose_code:
+ res = json.loads(res_.content.decode())
+ request_id = res['request_id']
+ task_id = res['output']['task_id']
+ logger.info(f"task_id: {task_id}: Create Background Generation request success. Params: {data}")
+
+ # 异步查询
+ is_running = True
+ while is_running:
+ url_query = f'https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}'
+ res_ = requests.post(url_query, headers=headers)
+ respose_code = res_.status_code
+ if 200 == respose_code:
+ res = json.loads(res_.content.decode())
+ if "SUCCEEDED" == res['output']['task_status']:
+ logger.info(f"task_id: {task_id}: Background generation task query success.")
+ results = res['output']['results']
+ img_urls = [x['url'] for x in results]
+ logger.info(f"task_id: {task_id}: {res}")
+ break
+ elif "FAILED" != res['output']['task_status']:
+ logger.debug(f"task_id: {task_id}: query result...")
+ time.sleep(1)
+ else:
+ raise gr.Error('Fail to get results from Background Generation task.')
+
+ else:
+ logger.error(f'task_id: {task_id}: Fail to query task result: {res_.content}')
+ raise gr.Error("Fail to query task result.")
+
+ logger.info(f"task_id: {task_id}: download generated images.")
+ img_data = download_images(img_urls, BATCH_SIZE)
+ logger.info(f"task_id: {task_id}: Generate done.")
+ return img_data
+ else:
+ logger.error(f'Fail to create Background Generation task: {res_.content}')
+ raise gr.Error("Fail to create Background Generation task.")
+
diff --git a/src/log.py b/src/log.py
new file mode 100644
index 0000000000000000000000000000000000000000..a02a89fbb45fd99be98de8c5f0a17a51aa5588cf
--- /dev/null
+++ b/src/log.py
@@ -0,0 +1,18 @@
+import logging
+from logging.handlers import RotatingFileHandler
+import os
+
+log_file_name = "workdir/log_replaceAnything.log"
+os.makedirs(os.path.dirname(log_file_name), exist_ok=True)
+
+format = '[%(levelname)s] %(asctime)s "%(filename)s", line %(lineno)d, %(message)s'
+logging.basicConfig(
+ format=format,
+ datefmt="%Y-%m-%d %H:%M:%S",
+ level=logging.INFO)
+logger = logging.getLogger(name="WordArt_Studio")
+
+fh = RotatingFileHandler(log_file_name, maxBytes=20000000, backupCount=3)
+formatter = logging.Formatter(format, datefmt="%Y-%m-%d %H:%M:%S")
+fh.setFormatter(formatter)
+logger.addHandler(fh)
\ No newline at end of file
diff --git a/src/person_detect.py b/src/person_detect.py
new file mode 100644
index 0000000000000000000000000000000000000000..d88c7f1a0311568d6ccbf42dffe888df4295a40d
--- /dev/null
+++ b/src/person_detect.py
@@ -0,0 +1,39 @@
+import os
+import numpy
+from PIL import Image
+import requests
+import urllib.request
+from http import HTTPStatus
+from datetime import datetime
+import json
+from .log import logger
+import time
+import gradio as gr
+from .util import download_images
+
+API_KEY = os.getenv("API_KEY_VIRTUALMODEL")
+
+def call_person_detect(input_image_url):
+ headers = {
+ "Content-Type": "application/json",
+ "Accept": "application/json",
+ "Authorization": f"Bearer {API_KEY}",
+ "X-DashScope-DataInspection": "enable",
+ }
+ data = {
+ "model": "body-detection",
+ "input":{
+ "image_url": input_image_url,
+ },
+ "parameters": {
+ "score_threshold": 0.6,
+ }
+ }
+ url_create_task = 'https://dashscope.aliyuncs.com/api/v1/services/vision/bodydetection/detect'
+ res_ = requests.post(url_create_task, data=json.dumps(data), headers=headers)
+
+
+ res = json.loads(res_.content.decode())
+ request_id = res['request_id']
+ results = res['output']['results']
+ return results
\ No newline at end of file
diff --git a/src/util.py b/src/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..650cf92de7a2755a2961c0936ff7ffd83c2e51a9
--- /dev/null
+++ b/src/util.py
@@ -0,0 +1,177 @@
+import random
+
+import numpy as np
+import cv2
+import os
+import io
+import oss2
+from PIL import Image
+
+import dashscope
+from dashscope import MultiModalConversation
+
+from http import HTTPStatus
+import re
+import requests
+from .log import logger
+import concurrent.futures
+
+# dashscope.api_key = os.getenv("API_KEY_QW")
+# oss
+# access_key_id = os.getenv("ACCESS_KEY_ID")
+# access_key_secret = os.getenv("ACCESS_KEY_SECRET")
+# bucket_name = os.getenv("BUCKET_NAME")
+# endpoint = os.getenv("ENDPOINT")
+
+# bucket = oss2.Bucket(oss2.Auth(access_key_id, access_key_secret), endpoint, bucket_name)
+oss_path = "ashui"
+oss_path_img_gallery = "ashui_img_gallery"
+
+def download_img_pil(index, img_url):
+ # print(img_url)
+ r = requests.get(img_url, stream=True)
+ if r.status_code == 200:
+ img = Image.open(io.BytesIO(r.content))
+ return (index, img)
+ else:
+ logger.error(f"Fail to download: {img_url}")
+
+
+def download_images(img_urls, batch_size):
+ imgs_pil = [None] * batch_size
+ # worker_results = []
+ with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
+ to_do = []
+ for i, url in enumerate(img_urls):
+ future = executor.submit(download_img_pil, i, url)
+ to_do.append(future)
+
+ for future in concurrent.futures.as_completed(to_do):
+ ret = future.result()
+ # worker_results.append(ret)
+ index, img_pil = ret
+ imgs_pil[index] = img_pil # 按顺序排列url,后续下载关联的图片或者svg需要使用
+
+ return imgs_pil
+
+def upload_np_2_oss(input_image, name="cache.png", gallery=False):
+ imgByteArr = io.BytesIO()
+ Image.fromarray(input_image).save(imgByteArr, format="PNG")
+ imgByteArr = imgByteArr.getvalue()
+
+ if gallery:
+ path = oss_path_img_gallery
+ else:
+ path = oss_path
+
+ # bucket.put_object(path+"/"+name, imgByteArr) # data为数据,可以是图片
+ # ret = bucket.sign_url('GET', path+"/"+name, 60*60*24) # 返回值为链接,参数依次为,方法/oss上文件路径/过期时间(s)
+ del imgByteArr
+ return ""
+
+
+def call_with_messages(prompt):
+ messages = [
+ {'role': 'user', 'content': prompt}]
+ response = dashscope.Generation.call(
+ 'qwen-14b-chat',
+ messages=messages,
+ result_format='message', # set the result is message format.
+ )
+ if response.status_code == HTTPStatus.OK:
+ return response['output']["choices"][0]["message"]['content']
+ else:
+ print('Request id: %s, Status code: %s, error code: %s, error message: %s' % (
+ response.request_id, response.status_code,
+ response.code, response.message
+ ))
+ return None
+
+def HWC3(x):
+ assert x.dtype == np.uint8
+ if x.ndim == 2:
+ x = x[:, :, None]
+ assert x.ndim == 3
+ H, W, C = x.shape
+ assert C == 1 or C == 3 or C == 4
+ if C == 3:
+ return x
+ if C == 1:
+ return np.concatenate([x, x, x], axis=2)
+ if C == 4:
+ color = x[:, :, 0:3].astype(np.float32)
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
+ y = color * alpha + 255.0 * (1.0 - alpha)
+ y = y.clip(0, 255).astype(np.uint8)
+ return y
+
+
+def resize_image(input_image, resolution):
+ H, W, C = input_image.shape
+ H = float(H)
+ W = float(W)
+ k = float(resolution) / min(H, W)
+ H *= k
+ W *= k
+ H = int(np.round(H / 64.0)) * 64
+ W = int(np.round(W / 64.0)) * 64
+ img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
+ return img
+
+
+def nms(x, t, s):
+ x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s)
+
+ f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8)
+ f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8)
+ f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8)
+ f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8)
+
+ y = np.zeros_like(x)
+
+ for f in [f1, f2, f3, f4]:
+ np.putmask(y, cv2.dilate(x, kernel=f) == x, x)
+
+ z = np.zeros_like(y, dtype=np.uint8)
+ z[y > t] = 255
+ return z
+
+
+def make_noise_disk(H, W, C, F):
+ noise = np.random.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C))
+ noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC)
+ noise = noise[F: F + H, F: F + W]
+ noise -= np.min(noise)
+ noise /= np.max(noise)
+ if C == 1:
+ noise = noise[:, :, None]
+ return noise
+
+
+def min_max_norm(x):
+ x -= np.min(x)
+ x /= np.maximum(np.max(x), 1e-5)
+ return x
+
+
+def safe_step(x, step=2):
+ y = x.astype(np.float32) * float(step + 1)
+ y = y.astype(np.int32).astype(np.float32) / float(step)
+ return y
+
+
+def img2mask(img, H, W, low=10, high=90):
+ assert img.ndim == 3 or img.ndim == 2
+ assert img.dtype == np.uint8
+
+ if img.ndim == 3:
+ y = img[:, :, random.randrange(0, img.shape[2])]
+ else:
+ y = img
+
+ y = cv2.resize(y, (W, H), interpolation=cv2.INTER_CUBIC)
+
+ if random.uniform(0, 1) < 0.5:
+ y = 255 - y
+
+ return y < np.percentile(y, random.randrange(low, high))
diff --git a/src/virtualmodel.py b/src/virtualmodel.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9a111fda6e4e5fe8cd85b55dbec11a74648780f
--- /dev/null
+++ b/src/virtualmodel.py
@@ -0,0 +1,81 @@
+import os
+import numpy
+from PIL import Image
+import requests
+import urllib.request
+from http import HTTPStatus
+from datetime import datetime
+import json
+from .log import logger
+import time
+import gradio as gr
+from .util import download_images
+
+API_KEY = os.getenv("API_KEY_VIRTUALMODEL")
+
+def call_virtualmodel(input_image_url, input_mask_url, source_background_url, prompt, face_prompt):
+ BATCH_SIZE=4
+ headers = {
+ "Content-Type": "application/json",
+ "Accept": "application/json",
+ "Authorization": f"Bearer {API_KEY}",
+ "X-DashScope-Async": "enable",
+ }
+ data = {
+ "model": "wanx-virtualmodel",
+ "input":{
+ "base_image_url": input_image_url,
+ "mask_image_url": input_mask_url,
+ "prompt": prompt,
+ "face_prompt": face_prompt,
+ "background_image_url": source_background_url,
+ },
+ "parameters": {
+ "short_side_size": "512",
+ "n": BATCH_SIZE
+ }
+ }
+ url_create_task = 'https://dashscope.aliyuncs.com/api/v1/services/aigc/virtualmodel/generation'
+ res_ = requests.post(url_create_task, data=json.dumps(data), headers=headers)
+
+ respose_code = res_.status_code
+ if 200 == respose_code:
+ res = json.loads(res_.content.decode())
+ request_id = res['request_id']
+ task_id = res['output']['task_id']
+ logger.info(f"task_id: {task_id}: Create VirtualModel request success. Params: {data}")
+
+ # 异步查询
+ is_running = True
+ while is_running:
+ url_query = f'https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}'
+ res_ = requests.post(url_query, headers=headers)
+ respose_code = res_.status_code
+ if 200 == respose_code:
+ res = json.loads(res_.content.decode())
+ if "SUCCEEDED" == res['output']['task_status']:
+ logger.info(f"task_id: {task_id}: VirtualModel generation task query success.")
+ results = res['output']['results']
+ img_urls = []
+ for x in results:
+ if "url" in x:
+ img_urls.append(x['url'])
+ logger.info(f"task_id: {task_id}: {res}")
+ break
+ elif "FAILED" != res['output']['task_status']:
+ logger.debug(f"task_id: {task_id}: query result...")
+ time.sleep(1)
+ else:
+ raise gr.Error('Fail to get results from VirtualModel task.')
+
+ else:
+ logger.error(f'task_id: {task_id}: Fail to query task result: {res_.content}')
+ raise gr.Error("Fail to query task result.")
+
+ logger.info(f"task_id: {task_id}: download generated images.")
+ img_data = download_images(img_urls, len(img_urls)) if len(img_urls) > 0 else []
+ logger.info(f"task_id: {task_id}: Generate done.")
+ return img_data
+ else:
+ logger.error(f'Fail to create VirtualModel task: {res_.content}')
+ raise gr.Error("Fail to create VirtualModel task.")
\ No newline at end of file