from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation from diffusers import StableDiffusionInpaintPipeline as StableDiffusionInpaintPipeline import gradio as gr from PIL import Image import torch from torch import autocast import matplotlib.pyplot as plt import numpy as np auth_token = os.environ.get("API_TOKEN") or True processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined") model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined") pipe = StableDiffusionInpaintPipeline.from_pretrained( "runwayml/stable-diffusion-inpainting", revision="fp16", torch_dtype=torch.float16, use_auth_token=auth_token, ) device = "cuda" if torch.cuda.is_available() else "cpu" pipe = pipe.to(device) def pad_image(image): w, h = image.size if w == h: return image elif w > h: new_image = Image.new(image.mode, (w, w), (0, 0, 0)) new_image.paste(image, (0, (w - h) // 2)) return new_image else: new_image = Image.new(image.mode, (h, h), (0, 0, 0)) new_image.paste(image, ((h - w) // 2, 0)) return new_image def process_image(image, prompt): inputs = processor( text=prompt, images=image, padding="max_length", return_tensors="pt" ) # predict with torch.no_grad(): outputs = model(**inputs) preds = outputs.logits pred = torch.sigmoid(preds) mat = pred.cpu().numpy() mask = Image.fromarray(np.uint8(mat * 255), "L") mask = mask.convert("RGB") mask = mask.resize(image.size) mask = np.array(mask)[:, :, 0] # normalize the mask mask_min = mask.min() mask_max = mask.max() mask = (mask - mask_min) / (mask_max - mask_min) return mask def get_masks(prompts, img, threhsold): prompts = prompts.split(",") masks = [] for prompt in prompts: mask = process_image(img, prompt) mask = mask > threhsold masks.append(mask) return masks def extract_image(img, pos_prompts, neg_prompts, threshold): positive_masks = get_masks(pos_prompts, img, threshold) negative_masks = get_masks(neg_prompts, img, threshold) # combine masks into one masks, logic OR pos_mask = np.any(np.stack(positive_masks), axis=0) neg_mask = np.any(np.stack(negative_masks), axis=0) final_mask = pos_mask & ~neg_mask # extract the final image final_mask = Image.fromarray(final_mask.astype(np.uint8) * 255, "L") inverse_mask = np.invert(final_mask, dtype=np.uint8) output_image = Image.new("RGBA", img.size, (0, 0, 0, 0)) output_image.paste(img, mask=final_mask) return output_image, final_mask, inverse_mask def inpaint_image(img, mask, prompt): img = pad_image(img).convert("RGB").resize((512, 512)) mask = Image.fromarray(mask*255) mask = pad_image(mask).convert("RGB").resize((512, 512)) with torch.cuda.amp.autocast(True): inpainted_image = pipe(prompt=prompt, image=img, mask_image=mask).images[0] return inpainted_image title = "Interactive demo: zero-shot image segmentation with CLIPSeg" description = "Demo for using CLIPSeg, a CLIP-based model for zero- and one-shot image segmentation. To use it, simply upload an image and add a text to mask (identify in the image), or use one of the examples below and click 'submit'. Results will show up in a few seconds." article = "

CLIPSeg: Image Segmentation Using Text and Image Prompts | HuggingFace docs

" with gr.Blocks() as demo: gr.Markdown("# CLIPSeg: Image Segmentation Using Text and Image Prompts") gr.Markdown(article) gr.Markdown(description) with gr.Row(): with gr.Column(): input_image = gr.Image(type="pil") positive_prompts = gr.Textbox( label="Please describe what you want to identify (comma separated)" ) negative_prompts = gr.Textbox( label="Please describe what you want to ignore (comma separated)" ) input_slider_T = gr.Slider( minimum=0, maximum=1, value=0.4, label="Threshold" ) btn_mask = gr.Button(label="Mask") with gr.Column(): output_image = gr.Image(label="Result") output_mask = gr.Image(label="Mask") inverse_mask = gr.Image(label="Inverse") with gr.Row(): with gr.Column(): prompt = gr.Textbox( label="Prompt" ) input_slider_S = gr.Slider( minimum=0, maximum=1, value=0.4, label="Image Strength" ) btn_run = gr.Button(label="Run") with gr.Column(): inpainted_image = gr.Image(label="Inpainted") btn_mask.click( extract_image, inputs=[ input_image, positive_prompts, negative_prompts, input_slider_T, ], outputs=[output_image, output_mask, inverse_mask], api_name="mask" ) btn_run.click( inpaint_image, inputs=[ input_image, inverse_mask, prompt, ], outputs=[inpainted_image], api_name="run" ) demo.launch()