diff --git a/README.md b/README.md index 5c0c69382c0d85e57152b7553610e0635d999ca5..f117aed0ee105bd3675f5891962582744b302b3e 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,13 @@ --- -title: SAM SDXL Inpainting -emoji: 🦀 -colorFrom: red -colorTo: gray +title: ReplaceAnything Using SAM + SDXL Inpainting +emoji: 📚 +colorFrom: yellow +colorTo: blue sdk: gradio -sdk_version: 4.14.0 +sdk_version: 3.50.2 app_file: app.py pinned: false +license: apache-2.0 --- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..8a04fad64b1dcbccce7371c3d257a2187274813c --- /dev/null +++ b/app.py @@ -0,0 +1,311 @@ +##!/usr/bin/python3 +# -*- coding: utf-8 -*- +# @Time : 2023-06-01 +# @Author : ashui(Binghui Chen) +from sympy import im +import time +import cv2 +import gradio as gr +import numpy as np +import random +import math +import uuid +import torch +from torch import autocast + +from src.util import resize_image, upload_np_2_oss +from diffusers import AutoPipelineForInpainting, UNet2DConditionModel +import diffusers +import sys, os + +from PIL import Image, ImageFilter, ImageOps, ImageDraw + +from segment_anything import SamPredictor, sam_model_registry + + +device = "cuda" if torch.cuda.is_available() else "cpu" +pipe = AutoPipelineForInpainting.from_pretrained("diffusers/stable-diffusion-xl-1.0-inpainting-0.1", torch_dtype=torch.float16, variant="fp16").to(device) + +mobile_sam = sam_model_registry['vit_h'](checkpoint='models/sam_vit_h_4b8939.pth').to("cuda") +mobile_sam.eval() +mobile_predictor = SamPredictor(mobile_sam) +colors = [(255, 0, 0), (0, 255, 0)] +markers = [1, 5] + +# - - - - - examples - - - - - # +# 输入图地址, 文本, 背景图地址, index, [] +image_examples = [ + ["imgs/000.jpg", "A young woman in short sleeves shows off a mobile phone", None, 0, []], + ["imgs/001.jpg", "A young woman wears short sleeves, her hand is holding a bottle.", None, 1, []], + ["imgs/003.png", "A woman is wearing a black suit against a blue background", "imgs/003_bg.jpg", 2, []], + ["imgs/002.png", "A young woman poses in a dress, she stands in front of a blue background", "imgs/002_bg.png", 3, []], + ["imgs/bg_gen/base_imgs/1cdb9b1e6daea6a1b85236595d3e43d6.png", "water splash", None, 4, []], + ["imgs/bg_gen/base_imgs/1cdb9b1e6daea6a1b85236595d3e43d6.png", "", "imgs/bg_gen/ref_imgs/df9a93ac2bca12696a9166182c4bf02ad9679aa5.jpg", 5, []], + ["imgs/bg_gen/base_imgs/IMG_2941.png", "On the desert floor", None, 6, []], + ["imgs/bg_gen/base_imgs/b2b1ed243364473e49d2e478e4f24413.png","White ground, white background, light coming in, Canon",None,7,[]], + ] + +img = "image_gallery/" +files = os.listdir(img) +files = sorted(files) +showcases = [] +for idx, name in enumerate(files): + temp = os.path.join(os.path.dirname(__file__), img, name) + showcases.append(temp) + +def process(original_image, original_mask, input_mask, selected_points, prompt,negative_prompt,guidance_scale,steps,strength,scheduler): + if original_image.shape[0]>original_image.shape[1]: + original_image=cv2.resize(original_image,(int(original_image.shape[1]*1000/original_image.shape[0]),1000)) + if original_mask.shape[0]>original_mask.shape[1]: + original_mask=cv2.resize(original_mask,(int(original_mask.shape[1]*1000/original_mask.shape[0]),1000)) + if original_image is None: + raise gr.Error('Please upload the input image') + if (original_mask is None or len(selected_points)==0) and input_mask is None: + raise gr.Error("Please click the region where you want to keep unchanged, or upload a white-black Mask image where white color indicates region to be retained.") + + # load example image + if isinstance(original_image, int): + image_name = image_examples[original_image][0] + original_image = cv2.imread(image_name) + original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2RGB) + + if input_mask is not None: + H,W=original_image.shape[:2] + original_mask = cv2.resize(input_mask, (W, H)) + else: + original_mask = np.clip(255 - original_mask, 0, 255).astype(np.uint8) + + request_id = str(uuid.uuid4()) + # input_image_url = upload_np_2_oss(original_image, request_id+".png") + # input_mask_url = upload_np_2_oss(original_mask, request_id+"_mask.png") + # source_background_url = "" if source_background is None else upload_np_2_oss(source_background, request_id+"_bg.png") + if negative_prompt == "": + negative_prompt = None + scheduler_class_name = scheduler.split("-")[0] + + add_kwargs = {} + if len(scheduler.split("-")) > 1: + add_kwargs["use_karras"] = True + if len(scheduler.split("-")) > 2: + add_kwargs["algorithm_type"] = "sde-dpmsolver++" + + scheduler = getattr(diffusers, scheduler_class_name) + pipe.scheduler = scheduler.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="scheduler", **add_kwargs) + + # Image.fromarray(original_mask).save("original_mask.png") + init_image = Image.fromarray(original_image).convert("RGB") + mask = Image.fromarray(original_mask).convert("RGB") + output = pipe(prompt = prompt, negative_prompt=negative_prompt, image=init_image, mask_image=mask, guidance_scale=guidance_scale, num_inference_steps=int(steps), strength=strength) + # person detect: [[x1,y1,x2,y2,score],] + # det_res = call_person_detect(input_image_url) + + res = [] + # if len(det_res)>0: + # if len(prompt)==0: + # raise gr.Error('Please input the prompt') + # # res = call_virtualmodel(input_image_url, input_mask_url, source_background_url, prompt, face_prompt) + # else: + # ### + # if len(prompt)==0: + # prompt=None + # ref_image_url=None if source_background_url =='' else source_background_url + # original_mask=original_mask[:,:,:1] + # base_image=np.concatenate([original_image, original_mask],axis=2) + # base_image_url=upload_np_2_oss(base_image, request_id+"_base.png") + # res=call_bg_genration(base_image_url,ref_image_url,prompt,ref_prompt_weight=0.5) + # Image.fromarray(input_mask).save("input_mask.png") + res= output.images[0] + res = res.convert("RGB") + #resize the output image to original image size + res = res.resize((original_image.shape[1],original_image.shape[0]), Image.LANCZOS) + return [res], request_id, True + +block = gr.Blocks( + css="css/style.css", + theme=gr.themes.Soft( + radius_size=gr.themes.sizes.radius_none, + text_size=gr.themes.sizes.text_md + ) + ).queue(concurrency_count=2) +with block: + with gr.Row(): + with gr.Column(): + gr.HTML(f""" +
+
+

SAM + SDXL Inpainting

+
+
+
+
+

ReplaceAnything using SAM + SDXL Inpainting as you want: Ultra-high quality content replacement

+
+ """) + + with gr.Tabs(elem_classes=["Tab"]): + with gr.TabItem("Image Create"): + with gr.Accordion(label="🧭 Instructions:", open=True, elem_id="accordion"): + with gr.Row(equal_height=True): + gr.Markdown(""" + - ⭐️ step1:Upload or select one image from Example + - ⭐️ step2:Click on Input-image to select the object to be retained (or upload a white-black Mask image, in which white color indicates the region you want to keep unchanged) + - ⭐️ step3:Input prompt or reference image (highly-recommended) for generating new contents + - ⭐️ step4:Click Run button + """) + with gr.Row(): + with gr.Column(): + with gr.Column(elem_id="Input"): + with gr.Row(): + with gr.Tabs(elem_classes=["feedback"]): + with gr.TabItem("Input Image"): + input_image = gr.Image(type="numpy", label="input",scale=2) + original_image = gr.State(value=None,label="index") + original_mask = gr.State(value=None) + selected_points = gr.State([],label="click points") + with gr.Row(elem_id="Seg"): + radio = gr.Radio(['foreground', 'background'], label='Click to seg: ', value='foreground',scale=2) + undo_button = gr.Button('Undo seg', elem_id="btnSEG",scale=1) + input_mask = gr.Image(type="numpy", label="Mask Image") + prompt = gr.Textbox(label="Prompt", placeholder="Please input your prompt",value='',lines=1) + negative_prompt = gr.Textbox(label="Negative Prompt", placeholder="Please input your prompt",value='hand,blur,face,bad',lines=1) + guidance_scale = gr.Number(value=7.5, minimum=1.0, maximum=20.0, step=0.1, label="guidance_scale") + steps = gr.Number(value=20, minimum=10, maximum=30, step=1, label="steps") + strength = gr.Number(value=0.99, minimum=0.01, maximum=1.0, step=0.01, label="strength") + with gr.Row(mobile_collapse=False, equal_height=True): + schedulers = ["DEISMultistepScheduler", "HeunDiscreteScheduler", "EulerDiscreteScheduler", "DPMSolverMultistepScheduler", "DPMSolverMultistepScheduler-Karras", "DPMSolverMultistepScheduler-Karras-SDE"] + scheduler = gr.Dropdown(label="Schedulers", choices=schedulers, value="EulerDiscreteScheduler") + + run_button = gr.Button("Run",elem_id="btn") + + with gr.Column(): + with gr.Tabs(elem_classes=["feedback"]): + with gr.TabItem("Outputs"): + result_gallery = gr.Gallery(label='Output', show_label=False, elem_id="gallery", preview=True) + # recommend=gr.Button("Recommend results to Image Gallery",elem_id="recBut") + request_id=gr.State(value="") + gallery_flag=gr.State(value=False) + + # once user upload an image, the original image is stored in `original_image` + def store_img(img): + # image upload is too slow + # if min(img.shape[0], img.shape[1]) > 896: + # img = resize_image(img, 896) + # if max(img.shape[0], img.shape[1])*1.0/min(img.shape[0], img.shape[1])>2.0: + # raise gr.Error('image aspect ratio cannot be larger than 2.0') + return img, img, [], None # when new image is uploaded, `selected_points` should be empty + + input_image.upload( + store_img, + [input_image], + [input_image, original_image, selected_points] + ) + + # user click the image to get points, and show the points on the image + def segmentation(img, sel_pix): + print("segmentation") + # online show seg mask + points = [] + labels = [] + for p, l in sel_pix: + points.append(p) + labels.append(l) + mobile_predictor.set_image(img if isinstance(img, np.ndarray) else np.array(img)) + with torch.no_grad(): + with autocast("cuda"): + masks, _, _ = mobile_predictor.predict(point_coords=np.array(points), point_labels=np.array(labels), multimask_output=False) + + output_mask = np.ones((masks.shape[1], masks.shape[2], 3))*255 + for i in range(3): + output_mask[masks[0] == True, i] = 0.0 + + mask_all = np.ones((masks.shape[1], masks.shape[2], 3)) + color_mask = np.random.random((1, 3)).tolist()[0] + for i in range(3): + mask_all[masks[0] == True, i] = color_mask[i] + masked_img = img / 255 * 0.3 + mask_all * 0.7 + masked_img = masked_img*255 + ## draw points + for point, label in sel_pix: + cv2.drawMarker(masked_img, point, colors[label], markerType=markers[label], markerSize=20, thickness=5) + return masked_img, output_mask + + def get_point(img, sel_pix, point_type, evt: gr.SelectData): + + if point_type == 'foreground': + sel_pix.append((evt.index, 1)) # append the foreground_point + elif point_type == 'background': + sel_pix.append((evt.index, 0)) # append the background_point + else: + sel_pix.append((evt.index, 1)) # default foreground_point + + if isinstance(img, int): + image_name = image_examples[img][0] + img = cv2.imread(image_name) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + + # online show seg mask + if img.shape[0]>img.shape[1]: + img=cv2.resize(img,(int(img.shape[1]*1000/img.shape[0]),1000)) + masked_img, output_mask = segmentation(img, sel_pix) + + return masked_img.astype(np.uint8), output_mask + + input_image.select( + get_point, + [original_image, selected_points, radio], + [input_image, original_mask], + ) + + # undo the selected point + def undo_points(orig_img, sel_pix): + # draw points + output_mask = None + if len(sel_pix) != 0: + if isinstance(orig_img, int): # if orig_img is int, the image if select from examples + temp = cv2.imread(image_examples[orig_img][0]) + temp = cv2.cvtColor(temp, cv2.COLOR_BGR2RGB) + else: + temp = orig_img.copy() + sel_pix.pop() + # online show seg mask + if len(sel_pix) !=0: + temp, output_mask = segmentation(temp, sel_pix) + return temp.astype(np.uint8), output_mask + else: + gr.Error("Nothing to Undo") + + undo_button.click( + undo_points, + [original_image, selected_points], + [input_image, original_mask] + ) + + def upload_to_img_gallery(img, res, re_id, flag): + if flag: + gr.Info("Image uploading") + if isinstance(img, int): + image_name = image_examples[img][0] + img = cv2.imread(image_name) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + _ = upload_np_2_oss(img, name=re_id+"_ori.jpg", gallery=True) + for idx, r in enumerate(res): + r = cv2.imread(r['name']) + r = cv2.cvtColor(r, cv2.COLOR_BGR2RGB) + _ = upload_np_2_oss(r, name=re_id+f"_res_{idx}.jpg", gallery=True) + flag=False + gr.Info("Images have beend uploaded and are under check") + else: + gr.Info("Nothing to to") + return flag + + # recommend.click( + # upload_to_img_gallery, + # [original_image, result_gallery, request_id, gallery_flag], + # [gallery_flag] + # ) + # ips=[input_image, original_image, original_mask, input_mask, selected_points, prompt,negative_prompt,guidance_scale,steps,strength,scheduler] + ips=[original_image, original_mask, input_mask, selected_points, prompt,negative_prompt,guidance_scale,steps,strength,scheduler] + run_button.click(fn=process, inputs=ips, outputs=[result_gallery, request_id, gallery_flag]) + + +block.launch(share=True) diff --git a/css/0.png b/css/0.png new file mode 100644 index 0000000000000000000000000000000000000000..1f71aacffbb3adda52c1e86b0b3064533e6fb65b Binary files /dev/null and b/css/0.png differ diff --git a/css/style.css b/css/style.css new file mode 100644 index 0000000000000000000000000000000000000000..e21f344b481a57cd5e50a87e6b111caaea9b7222 --- /dev/null +++ b/css/style.css @@ -0,0 +1,59 @@ + +.baselayout{ + background: url('https://img.alicdn.com/imgextra/i1/O1CN016hd0V91ilWY5Xr24B_!!6000000004453-2-tps-2882-256.png') no-repeat; +} +#btn { + background-color: #336699; + color: white; +} +#recBut { + background-color: #bb5252; + color: white; + width: 30%; + margin: auto; +} +#btnSEG { + background-color: #D5F3F4; + color: black; +} +#btnCHAT { + background-color: #B6DBF2; + color: black; +} +#accordion { + background-color: transparent; +} +#accordion1 { + background-color: #ecedee; +} +.feedback button.selected{ + background-color: #6699CC; + color: white !important; +} +.feedback1 button.selected{ + background-color: #839ab2; + color: white !important; +} +.Tab button.selected{ + color: red; + font-weight: bold; +} +#Image { + width: 80%; + margin:auto; +} +#ShowCase { + width: 30%; + flex:none !important; +} + +#Input { + border-style:solid; + border-width:1px; + border-color:#000000 +} +#Seg { + min-width: min(100px, 100%) !important; + width: 100%; + margin:auto; +} diff --git a/image_gallery/00.png b/image_gallery/00.png new file mode 100644 index 0000000000000000000000000000000000000000..5085f085aebca32adca69602cc44c5616652bdc8 Binary files /dev/null and b/image_gallery/00.png differ diff --git a/image_gallery/01.png b/image_gallery/01.png new file mode 100644 index 0000000000000000000000000000000000000000..3bc5574ca6f5a700a20b23fe81b59298fc289629 Binary files /dev/null and b/image_gallery/01.png differ diff --git a/image_gallery/02.png b/image_gallery/02.png new file mode 100644 index 0000000000000000000000000000000000000000..b5579a57d6d25c9897ff04c2bdc887b36127727e Binary files /dev/null and b/image_gallery/02.png differ diff --git a/image_gallery/a001.jpg b/image_gallery/a001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..0958da232bc2adc083183cb8e08aab910cc5eac3 Binary files /dev/null and b/image_gallery/a001.jpg differ diff --git a/image_gallery/a002.jpg b/image_gallery/a002.jpg new file mode 100644 index 0000000000000000000000000000000000000000..95abdbac66d0bcf2ca6739c0fce703838311b313 Binary files /dev/null and b/image_gallery/a002.jpg differ diff --git a/image_gallery/a003.jpg b/image_gallery/a003.jpg new file mode 100644 index 0000000000000000000000000000000000000000..cc4b86d937f0e8d735d6f5a8494bdc3063b518c5 Binary files /dev/null and b/image_gallery/a003.jpg differ diff --git a/image_gallery/a004.jpg b/image_gallery/a004.jpg new file mode 100644 index 0000000000000000000000000000000000000000..06697167bd9f1338347007951aa78612871f71c1 Binary files /dev/null and b/image_gallery/a004.jpg differ diff --git a/image_gallery/a005.jpg b/image_gallery/a005.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c267cebb616865728e6eed74a4ff723c34a82c11 Binary files /dev/null and b/image_gallery/a005.jpg differ diff --git a/image_gallery/a006.jpg b/image_gallery/a006.jpg new file mode 100644 index 0000000000000000000000000000000000000000..21a3762d1f1b60c8dabed2632c1dfafada56ad68 Binary files /dev/null and b/image_gallery/a006.jpg differ diff --git a/image_gallery/a007.jpg b/image_gallery/a007.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ba3204af588e6b1be7dce38d5bfff816ccc528dc Binary files /dev/null and b/image_gallery/a007.jpg differ diff --git a/image_gallery/a009.jpg b/image_gallery/a009.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8b973739cb5677c783e8a572fe79a0dd49b88f7d Binary files /dev/null and b/image_gallery/a009.jpg differ diff --git a/image_gallery/bg_001.jpg b/image_gallery/bg_001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c05697e0c5b0649fb70ee44066c8d31123adb7ae Binary files /dev/null and b/image_gallery/bg_001.jpg differ diff --git a/image_gallery/bg_002.jpg b/image_gallery/bg_002.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e3bc03054ca5f1f1afa48c277d5f86640839da77 Binary files /dev/null and b/image_gallery/bg_002.jpg differ diff --git a/image_gallery/bg_003.jpg b/image_gallery/bg_003.jpg new file mode 100644 index 0000000000000000000000000000000000000000..51f375bb74521625bea9d1b44de2fcfbc53520f1 Binary files /dev/null and b/image_gallery/bg_003.jpg differ diff --git a/image_gallery/bg_004.jpg b/image_gallery/bg_004.jpg new file mode 100644 index 0000000000000000000000000000000000000000..8d72dbd040c0f22ca7dca5501906f3ae83d2d82f Binary files /dev/null and b/image_gallery/bg_004.jpg differ diff --git a/image_gallery/bg_005.jpg b/image_gallery/bg_005.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d86a9c922e484c608d8287e77478689a3e74c7d6 Binary files /dev/null and b/image_gallery/bg_005.jpg differ diff --git a/image_gallery/bg_006.jpg b/image_gallery/bg_006.jpg new file mode 100644 index 0000000000000000000000000000000000000000..9694d3ae04de57ac65fb05b9af59982923012194 Binary files /dev/null and b/image_gallery/bg_006.jpg differ diff --git a/image_gallery/bg_007.jpg b/image_gallery/bg_007.jpg new file mode 100644 index 0000000000000000000000000000000000000000..d5e8c0a39420bb04c1d32360eabd29d0d8298468 Binary files /dev/null and b/image_gallery/bg_007.jpg differ diff --git a/image_gallery/bg_008.jpg b/image_gallery/bg_008.jpg new file mode 100644 index 0000000000000000000000000000000000000000..b335a97afd7edff664219a3debc540180a9752f0 Binary files /dev/null and b/image_gallery/bg_008.jpg differ diff --git a/image_gallery/bg_009.jpg b/image_gallery/bg_009.jpg new file mode 100644 index 0000000000000000000000000000000000000000..924dbac8340420e35fe428de31069179dfb816d3 Binary files /dev/null and b/image_gallery/bg_009.jpg differ diff --git a/image_gallery/bg_010.jpg b/image_gallery/bg_010.jpg new file mode 100644 index 0000000000000000000000000000000000000000..1b0035c9f14004a3efb73289c41b787e330707ce Binary files /dev/null and b/image_gallery/bg_010.jpg differ diff --git a/image_gallery/bg_012.jpg b/image_gallery/bg_012.jpg new file mode 100644 index 0000000000000000000000000000000000000000..77ceb469a7991e6fcbcdaaf96b7600735558c1c3 Binary files /dev/null and b/image_gallery/bg_012.jpg differ diff --git a/imgs/000.jpg b/imgs/000.jpg new file mode 100644 index 0000000000000000000000000000000000000000..47d05a3266c2d87452d0b589bab3082a611d9785 Binary files /dev/null and b/imgs/000.jpg differ diff --git a/imgs/001.jpg b/imgs/001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..c32bfa118314574f6c2c90acd3ffd00e0708f0c5 Binary files /dev/null and b/imgs/001.jpg differ diff --git a/imgs/002.png b/imgs/002.png new file mode 100644 index 0000000000000000000000000000000000000000..3c35d0b5230c1875a1ff2720dcd5a44cd7ca3da1 Binary files /dev/null and b/imgs/002.png differ diff --git a/imgs/002_bg.png b/imgs/002_bg.png new file mode 100644 index 0000000000000000000000000000000000000000..a15feedee1e36555163c729f1d653ff08dd510ea Binary files /dev/null and b/imgs/002_bg.png differ diff --git a/imgs/003.png b/imgs/003.png new file mode 100644 index 0000000000000000000000000000000000000000..9903793c03fb477d5432df658c1342515a706d8a Binary files /dev/null and b/imgs/003.png differ diff --git a/imgs/003_bg.jpg b/imgs/003_bg.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a6443f1076b0fc4631f30f1b29220dc713247f47 Binary files /dev/null and b/imgs/003_bg.jpg differ diff --git a/imgs/bg_gen/base_imgs/1cdb9b1e6daea6a1b85236595d3e43d6.png b/imgs/bg_gen/base_imgs/1cdb9b1e6daea6a1b85236595d3e43d6.png new file mode 100644 index 0000000000000000000000000000000000000000..5ea82ebddcae135dccedcef526a73cd2f24c0e47 Binary files /dev/null and b/imgs/bg_gen/base_imgs/1cdb9b1e6daea6a1b85236595d3e43d6.png differ diff --git a/imgs/bg_gen/base_imgs/IMG_2941.png b/imgs/bg_gen/base_imgs/IMG_2941.png new file mode 100644 index 0000000000000000000000000000000000000000..37f84d452ae2f175b5e4f9412bd057bfdfa9e8d7 Binary files /dev/null and b/imgs/bg_gen/base_imgs/IMG_2941.png differ diff --git a/imgs/bg_gen/base_imgs/b2b1ed243364473e49d2e478e4f24413.png b/imgs/bg_gen/base_imgs/b2b1ed243364473e49d2e478e4f24413.png new file mode 100644 index 0000000000000000000000000000000000000000..db679530a008829ae5d412ebd313bddb7858599b Binary files /dev/null and b/imgs/bg_gen/base_imgs/b2b1ed243364473e49d2e478e4f24413.png differ diff --git a/imgs/bg_gen/ref_imgs/df9a93ac2bca12696a9166182c4bf02ad9679aa5.jpg b/imgs/bg_gen/ref_imgs/df9a93ac2bca12696a9166182c4bf02ad9679aa5.jpg new file mode 100644 index 0000000000000000000000000000000000000000..87389876dd825fd733ebb1cb822707da23cf2388 Binary files /dev/null and b/imgs/bg_gen/ref_imgs/df9a93ac2bca12696a9166182c4bf02ad9679aa5.jpg differ diff --git a/models/DOWNLOAD_MODEL_HERE.txt b/models/DOWNLOAD_MODEL_HERE.txt new file mode 100644 index 0000000000000000000000000000000000000000..af74612ff809a40bf66785523e513702961ccb7b --- /dev/null +++ b/models/DOWNLOAD_MODEL_HERE.txt @@ -0,0 +1,2 @@ +模型链接 +https://vision-poster.oss-cn-shanghai.aliyuncs.com/ashui/sam_vit_h_4b8939.pth?OSSAccessKeyId=LTAI5tSPYbksBzcmooNHCYif&Expires=3599001703148669&Signature=TYznO77DKFjGNn92SnR9RbucOlU%3D \ No newline at end of file diff --git a/models/sam_vit_h_4b8939.pth b/models/sam_vit_h_4b8939.pth new file mode 100644 index 0000000000000000000000000000000000000000..8523acce9ddab1cf7e355628a08b1aab8ce08a72 --- /dev/null +++ b/models/sam_vit_h_4b8939.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a7bf3b02f3ebf1267aba913ff637d9a2d5c33d3173bb679e46d9f338c26f262e +size 2564550879 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..5510fc3924454682a00f8c19624699cc1b8bfc8a --- /dev/null +++ b/requirements.txt @@ -0,0 +1,23 @@ +dashscope +sympy +Pillow==9.5.0 +gradio==3.50.0 +opencv-python +omegaconf +sentencepiece +easydict +scikit-image +git+https://github.com/facebookresearch/segment-anything.git +torch +torchvision +oss2==2.17.0 +--extra-index-url https://download.pytorch.org/whl/cu118 +torch +git+https://github.com/huggingface/diffusers.git +transformers +accelerate +ftfy +numpy +matplotlib +uuid +opencv-python diff --git a/sdxl.txt b/sdxl.txt new file mode 100644 index 0000000000000000000000000000000000000000..e0faa98a67af475254c38577f67cd0441e7dd1a1 --- /dev/null +++ b/sdxl.txt @@ -0,0 +1,10 @@ +--extra-index-url https://download.pytorch.org/whl/cu118 +torch +git+https://github.com/huggingface/diffusers.git +transformers +accelerate +ftfy +numpy +matplotlib +uuid +opencv-python \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/__pycache__/__init__.cpython-38.pyc b/src/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8dc73cc49c64b072034ceca2a70e452d9cddb19b Binary files /dev/null and b/src/__pycache__/__init__.cpython-38.pyc differ diff --git a/src/__pycache__/background_generation.cpython-38.pyc b/src/__pycache__/background_generation.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..00bb8b82e0c95b352bddf1364aa4338cd8efcc9c Binary files /dev/null and b/src/__pycache__/background_generation.cpython-38.pyc differ diff --git a/src/__pycache__/log.cpython-38.pyc b/src/__pycache__/log.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0c1693fb0ca520a26e24e96bbb177b2d504f363 Binary files /dev/null and b/src/__pycache__/log.cpython-38.pyc differ diff --git a/src/__pycache__/person_detect.cpython-38.pyc b/src/__pycache__/person_detect.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ef32d9362deafc1ee2a477a23709461d7194970 Binary files /dev/null and b/src/__pycache__/person_detect.cpython-38.pyc differ diff --git a/src/__pycache__/util.cpython-38.pyc b/src/__pycache__/util.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2a5a9d5022bfa6ee781b4f08d172fba5b571c2b Binary files /dev/null and b/src/__pycache__/util.cpython-38.pyc differ diff --git a/src/__pycache__/virtualmodel.cpython-38.pyc b/src/__pycache__/virtualmodel.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef33df1a408ac8fc709ee7ec445437b0b9dc81eb Binary files /dev/null and b/src/__pycache__/virtualmodel.cpython-38.pyc differ diff --git a/src/background_generation.py b/src/background_generation.py new file mode 100644 index 0000000000000000000000000000000000000000..387202d1cd77a2304b81af1dd1296b219c029e81 --- /dev/null +++ b/src/background_generation.py @@ -0,0 +1,76 @@ +import os +import numpy +from PIL import Image +import requests +import urllib.request +from http import HTTPStatus +from datetime import datetime +import json +from .log import logger +import time +import gradio as gr +from .util import download_images + +def call_bg_genration(base_image, ref_img, prompt,ref_prompt_weight=0.5): + API_KEY = os.getenv("API_KEY_BG_GENERATION") + BATCH_SIZE=4 + headers = { + "Content-Type": "application/json", + "Accept": "application/json", + "Authorization": f"Bearer {API_KEY}", + "X-DashScope-Async": "enable", + } + data = { + "model": "wanx-background-generation-v2", + "input":{ + "base_image_url": base_image, + 'ref_image_url':ref_img, + "ref_prompt": prompt, + }, + "parameters": { + "ref_prompt_weight": ref_prompt_weight, + "n": BATCH_SIZE + } + } + url_create_task = 'https://dashscope.aliyuncs.com/api/v1/services/aigc/background-generation/generation' + res_ = requests.post(url_create_task, data=json.dumps(data), headers=headers) + + respose_code = res_.status_code + if 200 == respose_code: + res = json.loads(res_.content.decode()) + request_id = res['request_id'] + task_id = res['output']['task_id'] + logger.info(f"task_id: {task_id}: Create Background Generation request success. Params: {data}") + + # 异步查询 + is_running = True + while is_running: + url_query = f'https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}' + res_ = requests.post(url_query, headers=headers) + respose_code = res_.status_code + if 200 == respose_code: + res = json.loads(res_.content.decode()) + if "SUCCEEDED" == res['output']['task_status']: + logger.info(f"task_id: {task_id}: Background generation task query success.") + results = res['output']['results'] + img_urls = [x['url'] for x in results] + logger.info(f"task_id: {task_id}: {res}") + break + elif "FAILED" != res['output']['task_status']: + logger.debug(f"task_id: {task_id}: query result...") + time.sleep(1) + else: + raise gr.Error('Fail to get results from Background Generation task.') + + else: + logger.error(f'task_id: {task_id}: Fail to query task result: {res_.content}') + raise gr.Error("Fail to query task result.") + + logger.info(f"task_id: {task_id}: download generated images.") + img_data = download_images(img_urls, BATCH_SIZE) + logger.info(f"task_id: {task_id}: Generate done.") + return img_data + else: + logger.error(f'Fail to create Background Generation task: {res_.content}') + raise gr.Error("Fail to create Background Generation task.") + diff --git a/src/log.py b/src/log.py new file mode 100644 index 0000000000000000000000000000000000000000..a02a89fbb45fd99be98de8c5f0a17a51aa5588cf --- /dev/null +++ b/src/log.py @@ -0,0 +1,18 @@ +import logging +from logging.handlers import RotatingFileHandler +import os + +log_file_name = "workdir/log_replaceAnything.log" +os.makedirs(os.path.dirname(log_file_name), exist_ok=True) + +format = '[%(levelname)s] %(asctime)s "%(filename)s", line %(lineno)d, %(message)s' +logging.basicConfig( + format=format, + datefmt="%Y-%m-%d %H:%M:%S", + level=logging.INFO) +logger = logging.getLogger(name="WordArt_Studio") + +fh = RotatingFileHandler(log_file_name, maxBytes=20000000, backupCount=3) +formatter = logging.Formatter(format, datefmt="%Y-%m-%d %H:%M:%S") +fh.setFormatter(formatter) +logger.addHandler(fh) \ No newline at end of file diff --git a/src/person_detect.py b/src/person_detect.py new file mode 100644 index 0000000000000000000000000000000000000000..d88c7f1a0311568d6ccbf42dffe888df4295a40d --- /dev/null +++ b/src/person_detect.py @@ -0,0 +1,39 @@ +import os +import numpy +from PIL import Image +import requests +import urllib.request +from http import HTTPStatus +from datetime import datetime +import json +from .log import logger +import time +import gradio as gr +from .util import download_images + +API_KEY = os.getenv("API_KEY_VIRTUALMODEL") + +def call_person_detect(input_image_url): + headers = { + "Content-Type": "application/json", + "Accept": "application/json", + "Authorization": f"Bearer {API_KEY}", + "X-DashScope-DataInspection": "enable", + } + data = { + "model": "body-detection", + "input":{ + "image_url": input_image_url, + }, + "parameters": { + "score_threshold": 0.6, + } + } + url_create_task = 'https://dashscope.aliyuncs.com/api/v1/services/vision/bodydetection/detect' + res_ = requests.post(url_create_task, data=json.dumps(data), headers=headers) + + + res = json.loads(res_.content.decode()) + request_id = res['request_id'] + results = res['output']['results'] + return results \ No newline at end of file diff --git a/src/util.py b/src/util.py new file mode 100644 index 0000000000000000000000000000000000000000..650cf92de7a2755a2961c0936ff7ffd83c2e51a9 --- /dev/null +++ b/src/util.py @@ -0,0 +1,177 @@ +import random + +import numpy as np +import cv2 +import os +import io +import oss2 +from PIL import Image + +import dashscope +from dashscope import MultiModalConversation + +from http import HTTPStatus +import re +import requests +from .log import logger +import concurrent.futures + +# dashscope.api_key = os.getenv("API_KEY_QW") +# oss +# access_key_id = os.getenv("ACCESS_KEY_ID") +# access_key_secret = os.getenv("ACCESS_KEY_SECRET") +# bucket_name = os.getenv("BUCKET_NAME") +# endpoint = os.getenv("ENDPOINT") + +# bucket = oss2.Bucket(oss2.Auth(access_key_id, access_key_secret), endpoint, bucket_name) +oss_path = "ashui" +oss_path_img_gallery = "ashui_img_gallery" + +def download_img_pil(index, img_url): + # print(img_url) + r = requests.get(img_url, stream=True) + if r.status_code == 200: + img = Image.open(io.BytesIO(r.content)) + return (index, img) + else: + logger.error(f"Fail to download: {img_url}") + + +def download_images(img_urls, batch_size): + imgs_pil = [None] * batch_size + # worker_results = [] + with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor: + to_do = [] + for i, url in enumerate(img_urls): + future = executor.submit(download_img_pil, i, url) + to_do.append(future) + + for future in concurrent.futures.as_completed(to_do): + ret = future.result() + # worker_results.append(ret) + index, img_pil = ret + imgs_pil[index] = img_pil # 按顺序排列url,后续下载关联的图片或者svg需要使用 + + return imgs_pil + +def upload_np_2_oss(input_image, name="cache.png", gallery=False): + imgByteArr = io.BytesIO() + Image.fromarray(input_image).save(imgByteArr, format="PNG") + imgByteArr = imgByteArr.getvalue() + + if gallery: + path = oss_path_img_gallery + else: + path = oss_path + + # bucket.put_object(path+"/"+name, imgByteArr) # data为数据,可以是图片 + # ret = bucket.sign_url('GET', path+"/"+name, 60*60*24) # 返回值为链接,参数依次为,方法/oss上文件路径/过期时间(s) + del imgByteArr + return "" + + +def call_with_messages(prompt): + messages = [ + {'role': 'user', 'content': prompt}] + response = dashscope.Generation.call( + 'qwen-14b-chat', + messages=messages, + result_format='message', # set the result is message format. + ) + if response.status_code == HTTPStatus.OK: + return response['output']["choices"][0]["message"]['content'] + else: + print('Request id: %s, Status code: %s, error code: %s, error message: %s' % ( + response.request_id, response.status_code, + response.code, response.message + )) + return None + +def HWC3(x): + assert x.dtype == np.uint8 + if x.ndim == 2: + x = x[:, :, None] + assert x.ndim == 3 + H, W, C = x.shape + assert C == 1 or C == 3 or C == 4 + if C == 3: + return x + if C == 1: + return np.concatenate([x, x, x], axis=2) + if C == 4: + color = x[:, :, 0:3].astype(np.float32) + alpha = x[:, :, 3:4].astype(np.float32) / 255.0 + y = color * alpha + 255.0 * (1.0 - alpha) + y = y.clip(0, 255).astype(np.uint8) + return y + + +def resize_image(input_image, resolution): + H, W, C = input_image.shape + H = float(H) + W = float(W) + k = float(resolution) / min(H, W) + H *= k + W *= k + H = int(np.round(H / 64.0)) * 64 + W = int(np.round(W / 64.0)) * 64 + img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA) + return img + + +def nms(x, t, s): + x = cv2.GaussianBlur(x.astype(np.float32), (0, 0), s) + + f1 = np.array([[0, 0, 0], [1, 1, 1], [0, 0, 0]], dtype=np.uint8) + f2 = np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0]], dtype=np.uint8) + f3 = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=np.uint8) + f4 = np.array([[0, 0, 1], [0, 1, 0], [1, 0, 0]], dtype=np.uint8) + + y = np.zeros_like(x) + + for f in [f1, f2, f3, f4]: + np.putmask(y, cv2.dilate(x, kernel=f) == x, x) + + z = np.zeros_like(y, dtype=np.uint8) + z[y > t] = 255 + return z + + +def make_noise_disk(H, W, C, F): + noise = np.random.uniform(low=0, high=1, size=((H // F) + 2, (W // F) + 2, C)) + noise = cv2.resize(noise, (W + 2 * F, H + 2 * F), interpolation=cv2.INTER_CUBIC) + noise = noise[F: F + H, F: F + W] + noise -= np.min(noise) + noise /= np.max(noise) + if C == 1: + noise = noise[:, :, None] + return noise + + +def min_max_norm(x): + x -= np.min(x) + x /= np.maximum(np.max(x), 1e-5) + return x + + +def safe_step(x, step=2): + y = x.astype(np.float32) * float(step + 1) + y = y.astype(np.int32).astype(np.float32) / float(step) + return y + + +def img2mask(img, H, W, low=10, high=90): + assert img.ndim == 3 or img.ndim == 2 + assert img.dtype == np.uint8 + + if img.ndim == 3: + y = img[:, :, random.randrange(0, img.shape[2])] + else: + y = img + + y = cv2.resize(y, (W, H), interpolation=cv2.INTER_CUBIC) + + if random.uniform(0, 1) < 0.5: + y = 255 - y + + return y < np.percentile(y, random.randrange(low, high)) diff --git a/src/virtualmodel.py b/src/virtualmodel.py new file mode 100644 index 0000000000000000000000000000000000000000..f9a111fda6e4e5fe8cd85b55dbec11a74648780f --- /dev/null +++ b/src/virtualmodel.py @@ -0,0 +1,81 @@ +import os +import numpy +from PIL import Image +import requests +import urllib.request +from http import HTTPStatus +from datetime import datetime +import json +from .log import logger +import time +import gradio as gr +from .util import download_images + +API_KEY = os.getenv("API_KEY_VIRTUALMODEL") + +def call_virtualmodel(input_image_url, input_mask_url, source_background_url, prompt, face_prompt): + BATCH_SIZE=4 + headers = { + "Content-Type": "application/json", + "Accept": "application/json", + "Authorization": f"Bearer {API_KEY}", + "X-DashScope-Async": "enable", + } + data = { + "model": "wanx-virtualmodel", + "input":{ + "base_image_url": input_image_url, + "mask_image_url": input_mask_url, + "prompt": prompt, + "face_prompt": face_prompt, + "background_image_url": source_background_url, + }, + "parameters": { + "short_side_size": "512", + "n": BATCH_SIZE + } + } + url_create_task = 'https://dashscope.aliyuncs.com/api/v1/services/aigc/virtualmodel/generation' + res_ = requests.post(url_create_task, data=json.dumps(data), headers=headers) + + respose_code = res_.status_code + if 200 == respose_code: + res = json.loads(res_.content.decode()) + request_id = res['request_id'] + task_id = res['output']['task_id'] + logger.info(f"task_id: {task_id}: Create VirtualModel request success. Params: {data}") + + # 异步查询 + is_running = True + while is_running: + url_query = f'https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}' + res_ = requests.post(url_query, headers=headers) + respose_code = res_.status_code + if 200 == respose_code: + res = json.loads(res_.content.decode()) + if "SUCCEEDED" == res['output']['task_status']: + logger.info(f"task_id: {task_id}: VirtualModel generation task query success.") + results = res['output']['results'] + img_urls = [] + for x in results: + if "url" in x: + img_urls.append(x['url']) + logger.info(f"task_id: {task_id}: {res}") + break + elif "FAILED" != res['output']['task_status']: + logger.debug(f"task_id: {task_id}: query result...") + time.sleep(1) + else: + raise gr.Error('Fail to get results from VirtualModel task.') + + else: + logger.error(f'task_id: {task_id}: Fail to query task result: {res_.content}') + raise gr.Error("Fail to query task result.") + + logger.info(f"task_id: {task_id}: download generated images.") + img_data = download_images(img_urls, len(img_urls)) if len(img_urls) > 0 else [] + logger.info(f"task_id: {task_id}: Generate done.") + return img_data + else: + logger.error(f'Fail to create VirtualModel task: {res_.content}') + raise gr.Error("Fail to create VirtualModel task.") \ No newline at end of file