TestPaint_AI / app.py
ciover2024's picture
Update app.py
6b98e01 verified
raw
history blame
11.4 kB
import gradio as gr
from PIL import Image
import torch
#from diffusers import FluxControlNetModel
#from diffusers.pipelines import FluxControlNetPipeline
from diffusers import DiffusionPipeline
#from diffusers import FluxControlNetPipeline
#from diffusers import FluxControlNetModel #, FluxMultiControlNetModel
"""
from diffusers import DiffusionPipeline
pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-dev")
pipe.load_lora_weights("enhanceaiteam/Flux-Uncensored-V2")
prompt = "nsfw nude woman on beach, sunset, long flowing hair, sensual pose"
image = pipe(prompt).images[0]
"""
#import torch.nn.functional as F
#import torchvision
#import torchvision.transforms as T
#import cv2
from diffusers import StableDiffusionInpaintPipeline
import numpy as np
import os
import shutil
from gradio_client import Client, handle_file
# Load the model once globally to avoid repeated loading
"""
def load_inpainting_model():
# Load pipeline
#model_path = "urpmv13Inpainting.safetensors"
model_path = "uberRealisticPornMerge_v23Inpainting.safetensors"
#model_path = "pornmasterFantasy_v4-inpainting.safetensors"
#model_path = "pornmasterAmateur_v6Vae-inpainting.safetensors"
device = "cpu" # Explicitly use CPU
pipe = StableDiffusionInpaintPipeline.from_single_file(
model_path,
torch_dtype=torch.float32, # Use float32 for CPU
safety_checker=None
).to(device)
return pipe
"""
"""
# Load the model once globally to avoid repeated loading
def load_upscaling_model():
# Load pipeline
device = "cpu" # Explicitly use CPU
controlnet = FluxControlNetModel.from_pretrained(
"jasperai/Flux.1-dev-Controlnet-Upscaler",
torch_dtype=torch.float32
)
pipe = FluxControlNetPipeline.from_pretrained(
"black-forest-labs/FLUX.1-dev",
controlnet=controlnet,
torch_dtype=torch.float32
).to(device)
pipe = DiffusionPipeline.from_pretrained("jasperai/Flux.1-dev-Controlnet-Upscaler")
return pipe
"""
# Preload the model once
#inpaint_pipeline = load_inpainting_model()
# Preload the model once
#upscale_pipeline = load_upscaling_model()
def resize_image(orig_image):
aspect_ratio = orig_image.height / orig_image.width
old_width = orig_image.width
new_width = int(orig_image.width*1.2)
old_height = orig_image.height
new_height = int(new_width * aspect_ratio)
resized_image = orig_image.resize((new_width, new_height), Image.Resampling.LANCZOS)
left_crop = int((new_width - old_width)/2)
right_crop = new_width - int((new_width - old_width) / 2)
top_crop = int((new_height - old_height)/2)
bottom_crop = new_height - int((new_height - old_height) / 2)
cropped_image = resized_image.crop((left_crop,top_crop,right_crop,bottom_crop))
return cropped_image
# Function to resize image (simpler interpolation method for speed)
def resize_to_match(input_image, output_image):
#w, h = output_image.size
#control_image = output_image.resize((w * 4, h * 4))
"""
scaled_image = pipe(
prompt="",
control_image=control_image,
controlnet_conditioning_scale=0.6,
num_inference_steps=28,
guidance_scale=3.5,
height=control_image.size[1],
width=control_image.size[0]
).images[0]
"""
#return scaled_image
#torch_img = pil_to_torch(input_image)
#torch_img_scaled = F.interpolate(torch_img.unsqueeze(0),mode='trilinear').squeeze(0)
#output_image = torchvision.transforms.functional.to_pil_image(torch_img_scaled, mode=None)
return output_image.resize(input_image.size, Image.BICUBIC) # Use BILINEAR for faster resizing
def generate_image(image_path, mask_path, text_prompt="undress"):
result = client.predict(
"", # str in 'parameter_10' Textbox component
"", # str in 'Negative Prompt' Textbox component
["Fooocus V2","Fooocus Enhance","Fooocus Sharp"], # List[str] in 'Selected Styles' Checkboxgroup component
"Quality", # str in 'Performance' Radio component
'704×1408 <span style="color: grey;"> ∣ 1:2</span>', # str in 'Aspect Ratios' Radio component
1, # int | float (numeric value between 1 and 32) in 'Image Number' Slider component
"-1", # str in 'Seed' Textbox component
0, # int | float (numeric value between 0.0 and 30.0) in 'Image Sharpness' Slider component
1, # int | float (numeric value between 1.0 and 30.0) in 'Guidance Scale' Slider component
"juggernautXL_version6Rundiffusion.safetensors", # str (Option from: ['ACertainty.ckpt', 'ACertainty.safetensors', 'juggernautXL_version6Rundiffusion.safetensors']) in 'Base Model (SDXL only)' Dropdown component
"None", # str (Option from: ['None', 'ACertainty.ckpt', 'ACertainty.safetensors', 'juggernautXL_version6Rundiffusion.safetensors']) in 'Refiner (SDXL or SD 1.5)' Dropdown component
0.1, # int | float (numeric value between 0.1 and 1.0) in 'Refiner Switch At' Slider component
"None", # str (Option from: ['None', 'sdxl_lcm_lora.safetensors', 'sd_xl_offset_example-lora_1.0.safetensors']) in 'LoRA 1' Dropdown component
-2, # int | float (numeric value between -2 and 2) in 'Weight' Slider component
"None", # str (Option from: ['None', 'sdxl_lcm_lora.safetensors', 'sd_xl_offset_example-lora_1.0.safetensors']) in 'LoRA 2' Dropdown component
-2, # int | float (numeric value between -2 and 2) in 'Weight' Slider component
"None", # str (Option from: ['None', 'sdxl_lcm_lora.safetensors', 'sd_xl_offset_example-lora_1.0.safetensors']) in 'LoRA 3' Dropdown component
-2, # int | float (numeric value between -2 and 2) in 'Weight' Slider component
"None", # str (Option from: ['None', 'sdxl_lcm_lora.safetensors', 'sd_xl_offset_example-lora_1.0.safetensors']) in 'LoRA 4' Dropdown component
-2, # int | float (numeric value between -2 and 2) in 'Weight' Slider component
"None", # str (Option from: ['None', 'sdxl_lcm_lora.safetensors', 'sd_xl_offset_example-lora_1.0.safetensors']) in 'LoRA 5' Dropdown component
-2, # int | float (numeric value between -2 and 2) in 'Weight' Slider component
True, # bool in 'Input Image' Checkbox component
"", # str in 'parameter_85' Textbox component
"Disabled", # str in 'Upscale or Variation:' Radio component
None, # str (filepath or URL to image) in 'Drag above image to here' Image component
[], # List[str] in 'Outpaint Direction' Checkboxgroup component
image_path, # str (filepath or URL to image) in 'Drag inpaint or outpaint image to here' Image component
"", # str in 'Inpaint Additional Prompt' Textbox component
mask_path, # str (filepath or URL to image) in 'Mask Upload' Image component
image_path, # str (filepath or URL to image) in 'Image' Image component
0, # int | float (numeric value between 0.0 and 1.0) in 'Stop At' Slider component
0, # int | float (numeric value between 0.0 and 2.0) in 'Weight' Slider component
"ImagePrompt", # str in 'Type' Radio component
None, # str (filepath or URL to image) in 'Image' Image component
0, # int | float (numeric value between 0.0 and 1.0) in 'Stop At' Slider component
0, # int | float (numeric value between 0.0 and 2.0) in 'Weight' Slider component
"ImagePrompt", # str in 'Type' Radio component
None, # str (filepath or URL to image) in 'Image' Image component
0, # int | float (numeric value between 0.0 and 1.0) in 'Stop At' Slider component
0, # int | float (numeric value between 0.0 and 2.0) in 'Weight' Slider component
"ImagePrompt", # str in 'Type' Radio component
None, # str (filepath or URL to image) in 'Image' Image component
0, # int | float (numeric value between 0.0 and 1.0) in 'Stop At' Slider component
0, # int | float (numeric value between 0.0 and 2.0) in 'Weight' Slider component
"ImagePrompt", # str in 'Type' Radio component
fn_index=33
)
# Function to generate the mask using Florence SAM Masking API (Replicate)
def generate_mask(image_path, text_prompt="clothing"):
client_sam = Client("SkalskiP/florence-sam-masking")
mask_result = client_sam.predict(
#mode_dropdown = "open vocabulary detection + image masks",
image_input=handle_file(image_path), # Provide your image path here
text_input=text_prompt, # Use "clothing" as the prompt
api_name="/process_image"
)
print("mask_result=", mask_result)
return mask_result # This is the local path to the generated mask
# Save the generated mask
def save_mask(mask_local_path, save_path="generated_mask.png"):
try:
shutil.copy(mask_local_path, save_path)
except Exception as e:
print(f"Failed to save the mask: {e}")
# Function to perform inpainting
"""
def inpaint_image(input_image, mask_image):
prompt = "undress, naked, real skin, detailed nipples, erect nipples, detailed pussy, (detailed nipples), (detailed skin), (detailed pussy), accurate anatomy"
negative_prompt = "bad anatomy, deformed, ugly, disfigured, (extra arms), (extra legs), (extra hands), (extra feet), (extra finger)"
#IMAGE_SIZE = (1024,1024)
#initial_input_image = input_image.resize(IMAGE_SIZE)
#initial_mask_image = mask_image.resize(IMAGE_SIZE)
#blurred_mask_image = inpaint_pipeline.mask_processor.blur(initial_mask_image,blur_factor=10)
#result = inpaint_pipeline(prompt=prompt, negative_prompt=negative_prompt, height=IMAGE_SIZE[0], width=IMAGE_SIZE[0], image=initial_input_image, mask_image=blurred_mask_image, padding_mask_crop=32)
#blurred_mask_image = inpaint_pipeline.mask_processor.blur(mask_image,blur_factor=10)
result = inpaint_pipeline(prompt=prompt, negative_prompt=negative_prompt, image=input_image, mask_image=mask_image, padding_mask_crop=10)
inpainted_image = result.images[0]
#inpainted_image = resize_to_match(input_image, inpainted_image)
return inpainted_image
"""
# Function to process input image and mask
def process_image(input_image):
# Save the input image temporarily to process with Replicate
input_image_path = "temp_input_image.png"
input_image.save(input_image_path)
# Generate the mask using Florence SAM API
mask_local_path = generate_mask(image_path=input_image_path)
#mask_local_path1 = str(mask_local_path)#[0])
# Save the generated mask
mask_image_path = "generated_mask.png"
save_mask(mask_local_path, save_path=mask_image_path)
# Open the mask image and perform inpainting
mask_image = Image.open(mask_image_path)
result_image = resize_image(mask_image)
# Clean up temporary files
os.remove(input_image_path)
os.remove(mask_image_path)
return result_image
# Define Gradio interface using Blocks API
with gr.Blocks() as demo:
with gr.Row():
input_image = gr.Image(label="Upload Input Image", type="pil")
output_image = gr.Image(type="pil", label="Output Image")
# Button to trigger the process
with gr.Row():
btn = gr.Button("Run Inpainting")
# Function to run when button is clicked
btn.click(fn=process_image, inputs=[input_image], outputs=output_image)
# Launch the Gradio app
demo.launch(share=True)