CTO_TCP_V1 / app.py
ishworrsubedii's picture
Revert "update: png -> webp"
0ee7719
raw
history blame
3.51 kB
import torch
import os
import gradio as gr
import numpy as np
from PIL import Image
from PIL.ImageOps import grayscale
import gc
import spaces
import cv2
import base64
from io import BytesIO
import uvicorn
from fastapi import app
from diffusers import StableDiffusionPipeline, StableDiffusionInpaintPipeline
model_id = "stabilityai/stable-diffusion-2-inpainting"
pipeline = StableDiffusionInpaintPipeline.from_pretrained(
model_id, torch_dtype=torch.float16
)
pipeline = pipeline.to("cuda")
def clear_func():
torch.cuda.empty_cache()
gc.collect()
@spaces.GPU
def clothing_try_on(image, mask):
jewellery_mask = Image.fromarray(
np.bitwise_and(np.array(mask), np.array(image))
)
arr_orig = np.array(grayscale(mask))
image = cv2.inpaint(np.array(image), arr_orig, 15, cv2.INPAINT_TELEA)
image = Image.fromarray(image)
arr = arr_orig.copy()
mask_y = np.where(arr == arr[arr != 0][0])[0][0]
arr[mask_y:, :] = 255
new = Image.fromarray(arr)
mask = new.copy()
orig_size = image.size
image = image.resize((512, 512))
mask = mask.resize((512, 512))
results = []
prompt = f" South Indian Saree, properly worn, natural setting, elegant, natural look, neckline without jewellery, simple"
negative_prompt = "necklaces, jewellery, jewelry, necklace, neckpiece, garland, chain, neck wear, jewelled neck, jeweled neck, necklace on neck, jewellery on neck, accessories, watermark, text, changed background, wider body, narrower body, bad proportions, extra limbs, mutated hands, changed sizes, altered proportions, unnatural body proportions, blury, ugly"
output = pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
image=image,
mask_image=mask,
strength=0.95,
guidance_score=9,
# generator = torch.Generator("cuda").manual_seed(42)
).images[0]
output = output.resize(orig_size)
temp_generated = np.bitwise_and(
np.array(output),
np.bitwise_not(np.array(Image.fromarray(arr_orig).convert("RGB"))),
)
results.append(temp_generated)
results = [
Image.fromarray(np.bitwise_or(x, np.array(jewellery_mask))) for x in results
]
clear_func()
return results[0]
def base64_to_image(base64_str):
image_data = base64.b64decode(base64_str)
image = Image.open(BytesIO(image_data)).convert("RGB")
return image
def image_to_base64(image):
buffered = BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")
def clothing_try_on_base64(input_image_base64, mask_image_base64):
image = base64_to_image(input_image_base64)
mask = base64_to_image(mask_image_base64)
output_image = clothing_try_on(image, mask)
return image_to_base64(output_image)
def run_fastapi():
uvicorn.run(app, host="0.0.0.0", port=7860)
def launch_interface_base64():
with gr.Blocks() as interface:
with gr.Row():
inputImage = gr.Textbox(label="Input Image (Base64)", lines=4)
maskImage = gr.Textbox(label="Input Mask (Base64)", lines=4)
outputOne = gr.Textbox(label="Output (Base64)", lines=4)
submit = gr.Button("Apply")
submit.click(fn=clothing_try_on_base64, inputs=[inputImage, maskImage], outputs=[outputOne])
submit.click(fn=run_fastapi, inputs=[], outputs=[])
interface.launch(debug=True)
if __name__ == "__main__":
launch_interface_base64()