Virtual-Try-On / src /background_processor.py
parokshsaxena
using remove bg to add background to the output image
cbe97f0
raw
history blame
9.6 kB
import os
import requests
import logging
from PIL import Image, ImageEnhance
import cv2
import numpy as np
from preprocess.humanparsing.run_parsing import Parsing
from src.image_format_convertor import ImageFormatConvertor
REMOVE_BG_KEY = os.getenv('REMOVE_BG_KEY', "8XHtXvvhWFBpAA6jVt3yzVmh")
parsing_model = Parsing(0)
class BackgroundProcessor:
DeprecationWarning("Created only for testing. Not in use")
@classmethod
def add_background(cls, human_img: Image, background_img: Image):
human_img = human_img.convert("RGB")
width = human_img.width
height = human_img.height
# Create mask image
parsed_img, _ = parsing_model(human_img)
mask_img = parsed_img.convert("L")
mask_img = mask_img.resize((width, height))
background_img = background_img.convert("RGB")
background_img = background_img.resize((width, height))
# Convert to numpy arrays
human_np = np.array(human_img)
mask_np = np.array(mask_img)
background_np = np.array(background_img)
# Ensure mask is 3-channel (RGB) for compatibility
mask_np = np.stack((mask_np,) * 3, axis=-1)
# Apply the mask to human_img
human_with_background = np.where(mask_np > 0, human_np, background_np)
# Convert back to PIL Image
result_img = Image.fromarray(human_with_background.astype('uint8'))
# Return or save the result
return result_img
DeprecationWarning("Created only for testing. Not in use")
@classmethod
def add_background_v3(cls, foreground_pil: Image, background_pil: Image):
foreground_pil= foreground_pil.convert("RGB")
width = foreground_pil.width
height = foreground_pil.height
# Create mask image
parsed_img, _ = parsing_model(foreground_pil)
mask_pil = parsed_img.convert("L")
# Apply a threshold to convert to binary image
# mask_pil = mask_pil.point(lambda p: 1 if p > 127 else 0, mode='1')
mask_pil = mask_pil.resize((width, height))
# Resize background image
background_pil = background_pil.convert("RGB")
background_pil = background_pil.resize((width, height))
# Load the images using PIL
#foreground_pil = Image.open(human_img_path).convert("RGB") # The segmented person image
#background_pil = Image.open(background_img_path).convert("RGB") # The new background image
#mask_pil = Image.open(mask_img_path).convert('L') # The mask image from the human parser model
# Resize the background to match the size of the foreground
#background_pil = background_pil.resize(foreground_pil.size)
# Resize mask
#mask_pil = mask_pil.resize(foreground_pil.size)
# Convert PIL images to OpenCV format
foreground_cv2 = ImageFormatConvertor.pil_to_cv2(foreground_pil)
background_cv2 = ImageFormatConvertor.pil_to_cv2(background_pil)
#mask_cv2 = pil_to_cv2(mask_pil)
mask_cv2 = np.array(mask_pil) # Directly convert to NumPy array without color conversion
# Ensure the mask is a single channel image
if len(mask_cv2.shape) == 3:
mask_cv2 = cv2.cvtColor(mask_cv2, cv2.COLOR_BGR2GRAY)
# Threshold the mask to convert it to pure black and white
_, mask_cv2 = cv2.threshold(mask_cv2, 0, 255, cv2.THRESH_BINARY)
# Ensure the mask is a single channel image
#if len(mask_cv2.shape) == 3:
# mask_cv2 = cv2.cvtColor(mask_cv2, cv2.COLOR_BGR2GRAY)
# Create an inverted mask
mask_inv_cv2 = cv2.bitwise_not(mask_cv2)
# Convert mask to 3 channels
mask_3ch_cv2 = cv2.cvtColor(mask_cv2, cv2.COLOR_GRAY2BGR)
mask_inv_3ch_cv2 = cv2.cvtColor(mask_inv_cv2, cv2.COLOR_GRAY2BGR)
# Extract the person from the foreground image using the mask
person_cv2 = cv2.bitwise_and(foreground_cv2, mask_3ch_cv2)
# Extract the background where the person is not present
background_extracted_cv2 = cv2.bitwise_and(background_cv2, mask_inv_3ch_cv2)
# Combine the person and the new background
combined_cv2 = cv2.add(person_cv2, background_extracted_cv2)
# Refine edges using Gaussian Blur (feathering technique)
blurred_combined_cv2 = cv2.GaussianBlur(combined_cv2, (5, 5), 0)
# Convert the result back to PIL format
combined_pil = ImageFormatConvertor.cv2_to_pil(blurred_combined_cv2)
"""
# Post-processing: Adjust brightness, contrast, etc. (optional)
enhancer = ImageEnhance.Contrast(combined_pil)
post_processed_pil = enhancer.enhance(1.2) # Adjust contrast
enhancer = ImageEnhance.Brightness(post_processed_pil)
post_processed_pil = enhancer.enhance(1.2) # Adjust brightness
"""
# Save the final image
# post_processed_pil.save('path_to_save_final_image_1.png')
# Display the images (optional)
#foreground_pil.show(title="Foreground")
#background_pil.show(title="Background")
#mask_pil.show(title="Mask")
#combined_pil.show(title="Combined")
# post_processed_pil.show(title="Post Processed")
return combined_pil
DeprecationWarning("Created only for testing. Not in use")
@classmethod
def replace_background(cls, foreground_img_path: str, background_img_path: str):
# Load the input image (with alpha channel) and the background image
#input_image = cv2.imread(foreground_img_path, cv2.IMREAD_UNCHANGED)
# background_image = cv2.imread(background_img_path)
foreground_img_pil = Image.open(foreground_img_path)
width = foreground_img_pil.width
height = foreground_img_pil.height
background_image_pil = Image.open(background_img_path)
background_image_pil = background_image_pil.resize((width, height))
input_image = ImageFormatConvertor.pil_to_cv2(foreground_img_pil)
background_image = ImageFormatConvertor.pil_to_cv2(background_image_pil)
# Ensure the input image has an alpha channel
if input_image.shape[2] != 4:
raise ValueError("Input image must have an alpha channel")
# Extract the alpha channel
alpha_channel = input_image[:, :, 3]
# Resize the background image to match the input image dimensions
background_image = cv2.resize(background_image, (input_image.shape[1], input_image.shape[0]))
# Convert alpha channel to 3 channels
alpha_channel_3ch = cv2.cvtColor(alpha_channel, cv2.COLOR_GRAY2BGR)
alpha_channel_3ch = alpha_channel_3ch / 255.0 # Normalize to 0-1
# Extract the BGR channels of the input image
input_bgr = input_image[:, :, :3]
background_bgr = background_image[:,:,:3]
# Blend the images using the alpha channel
foreground = cv2.multiply(alpha_channel_3ch, input_bgr.astype(float))
background = cv2.multiply(1.0 - alpha_channel_3ch, background_bgr.astype(float))
combined_image = cv2.add(foreground, background).astype(np.uint8)
# Save and display the result
cv2.imwrite('path_to_save_combined_image.png', combined_image)
cv2.imshow('Combined Image', combined_image)
cv2.waitKey(0)
cv2.destroyAllWindows()
@classmethod
def replace_background_with_removebg(cls, foreground_img_pil: Image, background_image_pil: Image):
foreground_img_pil= foreground_img_pil.convert("RGB")
width = foreground_img_pil.width
height = foreground_img_pil.height
# Resize background image
background_image_pil = background_image_pil.convert("RGB")
background_image_pil = background_image_pil.resize((width, height))
#foreground_img_pil = Image.open(foreground_img_path)
#width = foreground_img_pil.width
#height = foreground_img_pil.height
#background_image_pil = Image.open(background_img_path)
#background_image_pil = background_image_pil.resize((width, height))
foreground_binary = ImageFormatConvertor.pil_image_to_binary_data(foreground_img_pil)
background_binary = ImageFormatConvertor.pil_image_to_binary_data(background_image_pil)
combined_img_pil = cls.remove_bg(foreground_binary, background_binary)
combined_img_pil.show()
return combined_img_pil
@classmethod
def remove_bg(cls, foreground_binary: str, background_binary: str):
# ref: https://www.remove.bg/api#api-reference
url = "https://api.remove.bg/v1.0/removebg"
# using form-data as passing binary data is not supported in application/json
files = {
"image_file": ('foreground.png', foreground_binary, 'image/png'),
"bg_image_file": ('background.png', background_binary, 'image/png')
}
headers = {
"accept": "image/*",
'X-Api-Key': REMOVE_BG_KEY
}
remove_bg_request = requests.post(url, files=files,headers=headers, timeout=20)
if remove_bg_request.status_code == 200:
image_content = remove_bg_request.content
pil_image = ImageFormatConvertor.binary_data_to_pil_image(image_content)
return pil_image
logging.error(f"failed to use remove bg. Status: {remove_bg_request.status_code}. Resp: {remove_bg_request.content}")
return None