CTO_TCP_V1 / app.py
ishworrsubedii's picture
update: revert to image upload
1ec0842
import torch
import os
import gradio as gr
import numpy as np
from PIL import Image
from PIL.ImageOps import grayscale
import gc
import spaces
import cv2
import base64
from io import BytesIO
from diffusers import StableDiffusionPipeline, StableDiffusionInpaintPipeline
model_id = "stabilityai/stable-diffusion-2-inpainting"
pipeline = StableDiffusionInpaintPipeline.from_pretrained(
model_id, torch_dtype=torch.float16
)
pipeline = pipeline.to("cuda")
def clear_func():
torch.cuda.empty_cache()
gc.collect()
def process_mask(mask):
mask = mask.convert("L")
mask = np.array(mask)
mask = np.where(mask > 128, 255, 0).astype(np.uint8)
return Image.fromarray(mask)
@spaces.GPU
def clothing_try_on(image, mask):
jewellery_mask = Image.fromarray(
np.bitwise_and(np.array(mask), np.array(image))
)
arr_orig = np.array(grayscale(mask))
image = cv2.inpaint(np.array(image), arr_orig, 15, cv2.INPAINT_TELEA)
image = Image.fromarray(image)
arr = arr_orig.copy()
mask_y = np.where(arr == arr[arr != 0][0])[0][0]
arr[mask_y:, :] = 255
new = Image.fromarray(arr)
mask = new.copy()
orig_size = image.size
image = image.resize((512, 512))
mask = mask.resize((512, 512))
results = []
prompt = f" South Indian Saree, properly worn, natural setting, elegant, natural look, neckline without jewellery, simple"
negative_prompt = "necklaces, jewellery, jewelry, necklace, neckpiece, garland, chain, neck wear, jewelled neck, jeweled neck, necklace on neck, jewellery on neck, accessories, watermark, text, changed background, wider body, narrower body, bad proportions, extra limbs, mutated hands, changed sizes, altered proportions, unnatural body proportions, blury, ugly"
output = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
image=image,
mask_image=mask,
strength=0.95,
guidance_score=9,
# generator = torch.Generator("cuda").manual_seed(42)
).images[0]
output = output.resize(orig_size)
temp_generated = np.bitwise_and(
np.array(output),
np.bitwise_not(np.array(Image.fromarray(arr_orig).convert("RGB"))),
)
results.append(temp_generated)
results = [
Image.fromarray(np.bitwise_or(x, np.array(jewellery_mask))) for x in results
]
clear_func()
return results[0]
def base64_to_image(base64_str):
image_data = base64.b64decode(base64_str)
image = Image.open(BytesIO(image_data))
return image
def image_to_base64(image):
buffered = BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
def clothing_try_on_base64(input_image_base64, mask_image_base64):
image = base64_to_image(input_image_base64)
mask = base64_to_image(mask_image_base64)
output_image = clothing_try_on(image, mask)
return image_to_base64(output_image)
def launch_interface_base64():
with gr.Blocks() as interface:
with gr.Row():
inputImage = gr.Textbox(label="Input Image (Base64)", lines=4)
maskImage = gr.Textbox(label="Input Mask (Base64)", lines=4)
outputOne = gr.Textbox(label="Output (Base64)", lines=4)
submit = gr.Button("Apply")
submit.click(fn=clothing_try_on_base64, inputs=[inputImage, maskImage], outputs=[outputOne])
interface.launch(debug=True)
if __name__ == "__main__":
launch_interface_base64()