Spaces:
Running
Running
File size: 3,632 Bytes
9244e51 8fa0841 9244e51 8fa0841 9244e51 e7d1e40 9244e51 8fa0841 9244e51 e7d1e40 9244e51 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import gradio as gr
from PIL import Image
import torch
#import torch.nn.functional as F
#import torchvision
#import torchvision.transforms as T
from diffusers import StableDiffusionInpaintPipeline
import numpy as np
#import cv2
import os
import shutil
from gradio_client import Client, handle_file
# Load the model once globally to avoid repeated loading
def load_inpainting_model():
model_path = "uberRealisticPornMerge_v23Inpainting.safetensors"
model_path = "pornmasterFantasy_v4-inpainting.safetensors"
#model_path = "pornmasterAmateur_v6Vae-inpainting.safetensors"
device = "cpu" # Explicitly use CPU
pipe = StableDiffusionInpaintPipeline.from_single_file(
model_path,
torch_dtype=torch.float32, # Use float32 for CPU
safety_checker=None
).to(device)
return pipe
# Preload the model once
inpaint_pipeline = load_inpainting_model()
# Function to resize image (simpler interpolation method for speed)
def resize_to_match(input_image, output_image):
#torch_img = pil_to_torch(input_image)
#torch_img_scaled = F.interpolate(torch_img.unsqueeze(0),mode='trilinear').squeeze(0)
#output_image = torchvision.transforms.functional.to_pil_image(torch_img_scaled, mode=None)
#return output_image
return output_image.resize(input_image.size, Image.BICUBIC) # Use BILINEAR for faster resizing
# Function to generate the mask using Florence SAM Masking API (Replicate)
def generate_mask(image_path, text_prompt="clothing"):
client_sam = Client("SkalskiP/florence-sam-masking")
mask_result = client_sam.predict(
image_input=handle_file(image_path), # Provide your image path here
text_input=text_prompt, # Use "clothing" as the prompt
api_name="/process_image"
)
return mask_result # This is the local path to the generated mask
# Save the generated mask
def save_mask(mask_local_path, save_path="generated_mask.png"):
try:
shutil.copy(mask_local_path, save_path)
except Exception as e:
print(f"Failed to save the mask: {e}")
# Function to perform inpainting
def inpaint_image(input_image, mask_image):
prompt = "undress, naked"
result = inpaint_pipeline(prompt=prompt, image=input_image, mask_image=mask_image)
inpainted_image = result.images[0]
#inpainted_image = resize_to_match(input_image, inpainted_image)
return inpainted_image
# Function to process input image and mask
def process_image(input_image):
# Save the input image temporarily to process with Replicate
input_image_path = "temp_input_image.png"
input_image.save(input_image_path)
# Generate the mask using Florence SAM API
mask_local_path = generate_mask(image_path=input_image_path)
# Save the generated mask
mask_image_path = "generated_mask.png"
save_mask(mask_local_path, save_path=mask_image_path)
# Open the mask image and perform inpainting
mask_image = Image.open(mask_image_path)
result_image = inpaint_image(input_image, mask_image)
# Clean up temporary files
os.remove(input_image_path)
os.remove(mask_image_path)
return result_image
# Define Gradio interface using Blocks API
with gr.Blocks() as demo:
with gr.Row():
input_image = gr.Image(label="Upload Input Image", type="pil")
output_image = gr.Image(type="pil", label="Output Image")
# Button to trigger the process
with gr.Row():
btn = gr.Button("Run Inpainting")
# Function to run when button is clicked
btn.click(fn=process_image, inputs=[input_image], outputs=output_image)
# Launch the Gradio app
demo.launch(share=True)
|