BlurryScenes / app.py
Niraj70194's picture
Update app.py
9498f43 verified
import gradio as gr
import numpy as np
from transformers import BeitImageProcessor, BeitForSemanticSegmentation, DPTImageProcessor, DPTForDepthEstimation
from PIL import Image, ImageFilter
import torch
import cv2
# Load the updated BeitImageProcessor for segmentation
segmentation_processor = BeitImageProcessor.from_pretrained("microsoft/beit-base-finetuned-ade-640-640")
segmentation_model = BeitForSemanticSegmentation.from_pretrained("microsoft/beit-base-finetuned-ade-640-640")
# Load the updated DPTImageProcessor for depth estimation
depth_feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-large")
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large")
def apply_gaussian_blur(image):
# Resize and preprocess the image
image = image.resize((512, 512)).convert("RGB")
inputs = segmentation_processor(image, return_tensors="pt")
# Perform semantic segmentation using the model
with torch.no_grad():
outputs = segmentation_model(**inputs)
logits = outputs.logits
# Get the predicted class for each pixel
segmentation = torch.argmax(logits, dim=1)[0] # Shape: [height, width]
# Create a binary mask for the 'person' class
person_index = 12 # Assuming 12 is the 'person' class index
binary_mask = (segmentation == person_index).numpy().astype(np.uint8) * 255 # Convert to 0 and 255
# Resize the mask to match the image size (512x512)
binary_mask = cv2.resize(binary_mask, (512, 512), interpolation=cv2.INTER_NEAREST)
# Convert the original image to a numpy array
image_np = np.array(image)
# Apply Gaussian blur to the entire image
blurred_image = cv2.GaussianBlur(image_np, (0, 0), sigmaX=15, sigmaY=15)
# Normalize the mask to range between 0 and 1
normalized_mask = binary_mask / 255.0
normalized_mask = np.expand_dims(normalized_mask, axis=-1) # Add channel dimension
# Create the composite image with the blurred background
final_image = (image_np * normalized_mask + blurred_image * (1 - normalized_mask)).astype(np.uint8)
# Convert back to PIL Image
final_image_pil = Image.fromarray(final_image)
return final_image_pil
def apply_lens_blur(image):
# Resize and preprocess the image
image = image.resize((512, 512)).convert("RGB")
depth_inputs = depth_feature_extractor(images=image, return_tensors="pt")
# Perform depth estimation
with torch.no_grad():
depth_outputs = depth_model(**depth_inputs)
predicted_depth = depth_outputs.predicted_depth[0].cpu().numpy()
# Normalize and invert the depth map
min_depth = predicted_depth.min()
max_depth = predicted_depth.max()
normalized_depth = (predicted_depth - min_depth) / (max_depth - min_depth)
inverted_depth = 1 - normalized_depth
# Resize the depth map to match the original image size
depth_weight_resized = Image.fromarray((inverted_depth * 255).astype(np.uint8)).resize((512, 512))
depth_weight_resized = np.array(depth_weight_resized) / 255.0
depth_weight_resized = depth_weight_resized[:, :, np.newaxis]
# Apply maximum Gaussian blur to the original image
blurred_image = image.filter(ImageFilter.GaussianBlur(radius=15))
# Convert images to numpy arrays
original_np = np.array(image).astype(np.float32)
blurred_np = np.array(blurred_image).astype(np.float32)
# Blend the images based on the resized depth map
output_np = (1 - depth_weight_resized) * original_np + depth_weight_resized * blurred_np
# Convert back to uint8
output_np = np.clip(output_np, 0, 255).astype(np.uint8)
return Image.fromarray(output_np)
# Define a function to call the correct blurring function based on user selection
def apply_blur(effect, image):
if effect == "Gaussian Blur":
return apply_gaussian_blur(image)
elif effect == "Lens Blur":
return apply_lens_blur(image)
# Define the Gradio interface
interface = gr.Interface(
fn=apply_blur,
inputs=[
gr.Dropdown(choices=["Gaussian Blur", "Lens Blur"], label="Select Blur Effect"),
gr.Image(type="pil")
],
outputs=gr.Image(type="pil"),
title="Blur Effects with Hugging Face",
description="Apply Gaussian Blur or Lens Blur to images using semantic segmentation or depth estimation."
)
# Launch the Gradio interface
interface.launch()