LumiereIQ / Pages /imageBB.py
samcoding5854's picture
added Save image code
b2cb3e9
raw
history blame
No virus
7.21 kB
from segment_anything import sam_model_registry, SamPredictor
import torch
import streamlit as st
from Pages.streamlit_img_label import st_img_label
from Pages.streamlit_img_label.manage import ImageManager
import os
from PIL import Image
import cv2
import numpy as np
@st.cache_data
def get_masks(rect, img_path):
CHECKPOINT_PATH = os.path.join("weights", "sam_vit_h_4b8939.pth")
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
MODEL_TYPE = "vit_h"
sam = sam_model_registry[MODEL_TYPE](checkpoint=CHECKPOINT_PATH).to(device=DEVICE)
mask_predictor = SamPredictor(sam)
rect = np.array([
rect['left'],
rect['top'],
rect['left'] + rect['width'],
rect['top'] + rect['height']
])
image_bgr = cv2.imread(img_path)
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
mask_predictor.set_image(image_rgb)
masks, scores, logits = mask_predictor.predict(
box=rect,
multimask_output=False
)
return masks
def run(img_path):
st.set_option("deprecation.showfileUploaderEncoding", False)
im = ImageManager(img_path)
resized_img = im.resizing_img()
resized_rects = im.get_resized_rects()
if "rects" not in st.session_state:
st.session_state.rects = resized_rects
# Only display st_img_label if Save button hasn't been clicked
if not st.session_state.get("saved"):
rects = st_img_label(resized_img, box_color="red", rects=st.session_state.rects)
st.session_state.rects = rects
else:
st.image(resized_img, caption="Uploaded Image", width=300, use_column_width=True)
for rect in st.session_state.rects:
with st.spinner('Please wait while the product image is being extracted...'):
masks = get_masks(rect, img_path)
save_dir = "Assets/saved_masks"
if not os.path.exists(save_dir):
os.makedirs(save_dir)
for i, mask in enumerate(masks):
inverted_mask = 255 - (mask * 255).astype(np.uint8)
file_path = os.path.join(save_dir, f"inverted_mask_{i}.png")
cv2.imwrite(file_path, inverted_mask)
print(f"Inverted masks saved to directory: {save_dir}")
image_files = [f for f in os.listdir("Assets/bgImages") if os.path.isfile(os.path.join("Assets/bgImages", f))]
st.header("Template Selection")
# Create a dropdown with the list of image files
selected_image = st.selectbox("Select an image file", image_files)
if selected_image:
# Display the selected image
image_pathBG = os.path.join("Assets/bgImages", selected_image)
image = Image.open(image_pathBG)
st.image(image, width=300, caption=f"Selected image: {selected_image}")
if st.button("Create Image"):
st.session_state.create_image = True
if st.session_state.get("create_image"):
# Read the base image and background image
image_bgr = cv2.imread(img_path)
background_bgr = cv2.imread(image_pathBG)
# Resize the background image to match the size of image_bgr
background_bgr = cv2.resize(background_bgr, (image_bgr.shape[1], image_bgr.shape[0]))
# Convert the base image to RGB format for mask prediction if it's not already in RGB
if image_bgr.shape[2] == 3: # No alpha channel, standard BGR image
image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
else:
image_rgb = image_bgr[:, :, :3] # Drop alpha channel if it exists
# Assuming masks is a binary mask, convert it to uint8 format
mask = (masks[0] > 0).astype(np.uint8) * 255
# Apply a Gaussian blur to the mask to smooth the edges
mask = cv2.GaussianBlur(mask, (3, 3), 0)
# Ensure the image has an alpha channel
if image_bgr.shape[2] == 3: # If no alpha channel, add one
b, g, r = cv2.split(image_bgr)
alpha_channel = mask # Use the blurred mask as the alpha channel
image_bgra = cv2.merge((b, g, r, alpha_channel))
else:
image_bgra = image_bgr
# Get the dimensions of the images
masked_height, masked_width = image_bgra.shape[:2]
background_height, background_width = background_bgr.shape[:2]
# Calculate the coordinates to place the masked image in the center of the background image
x_offset = (background_width - masked_width) // 2
y_offset = (background_height - masked_height) // 2
# Resize the masked image if it is larger than the background area
if masked_width > background_width or masked_height > background_height:
scaling_factor = min(background_width / masked_width, background_height / masked_height)
new_size = (int(masked_width * scaling_factor), int(masked_height * scaling_factor))
image_bgra = cv2.resize(image_bgra, new_size, interpolation=cv2.INTER_AREA)
masked_height, masked_width = image_bgra.shape[:2]
x_offset = (background_width - masked_width) // 2
y_offset = (background_height - masked_height) // 2
# Create a copy of the background image and convert it to BGRA
background_bgra = cv2.cvtColor(background_bgr, cv2.COLOR_BGR2BGRA)
# Overlay the masked image onto the center of the background image
overlay_image = background_bgra.copy()
# Only update the region where the segmented image will be placed
overlay = np.zeros_like(background_bgra)
overlay[y_offset:y_offset + masked_height, x_offset:x_offset + masked_width] = image_bgra
# Create the alpha mask for blending
alpha_mask = overlay[:, :, 3] / 255.0
alpha_inv = 1.0 - alpha_mask
# Modify alpha channel for smoother blending
alpha_mask = alpha_mask ** 0.5 # Applying square root for smoother blending
# Blend the images
for c in range(0, 3):
overlay_image[:, :, c] = (alpha_mask * overlay[:, :, c] + alpha_inv * overlay_image[:, :, c])
# Set the alpha channel
overlay_image[:, :, 3] = np.clip(overlay[:, :, 3] + background_bgra[:, :, 3], 0, 255)
# Prompt user for the filename
filename = st.text_input("Enter a name to save the image:")
if filename and st.button("Save Image"):
output_path = f'Assets/output/images/{filename}.png'
cv2.imwrite(output_path, overlay_image)
# Display the overlay image
st.image(output_path, caption="Created Image", use_column_width=True, width=300)
def annotate():
st.session_state.saved = True
if st.session_state.rects:
st.button(label="Save", on_click=annotate)
# Example of calling the function
# run("path/to/your/image.jpg")