Spaces:
Sleeping
Sleeping
from flask import Flask, request, send_file, Response, jsonify | |
from flask_cors import CORS | |
import numpy as np | |
import io | |
import torch | |
import cv2 | |
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator | |
from PIL import Image | |
import zipfile | |
app = Flask(__name__) | |
CORS(app) | |
cudaOrNah = "cuda" if torch.cuda.is_available() else "cpu" | |
print(cudaOrNah) | |
# Global model setup | |
# running out of memory adjusted | |
# checkpoint = "sam_vit_h_4b8939.pth" | |
# model_type = "vit_h" | |
checkpoint = "sam_vit_l_0b3195.pth" | |
model_type = "vit_l" | |
sam = sam_model_registry[model_type](checkpoint=checkpoint) | |
sam.to(device=cudaOrNah) | |
mask_generator = SamAutomaticMaskGenerator( | |
model=sam, | |
min_mask_region_area=0.0015 # Adjust this value as needed | |
) | |
print('Setup SAM model') | |
def hello(): | |
return {"hei": "Shredded to peices"} | |
def health_check(): | |
# Simple health check endpoint | |
return jsonify({"status": "ok"}), 200 | |
def get_masks(): | |
try: | |
print('received image from frontend') | |
# Get the image file from the request | |
if 'image' not in request.files: | |
return jsonify({"error": "No image file provided"}), 400 | |
image_file = request.files['image'] | |
if image_file.filename == '': | |
return jsonify({"error": "No image file provided"}), 400 | |
# Read image file using OpenCV-style approach (similar to cv2.imread)s | |
# Convert the image file to a NumPy array using OpenCV | |
file_bytes = np.fromstring(image_file.read(), np.uint8) | |
image = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR) | |
# Convert BGR to RGB using OpenCV (similar to cv2.cvtColor) | |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) | |
if image is None: | |
raise ValueError("Image not found or unable to read.") | |
if cudaOrNah == "cuda": | |
torch.cuda.empty_cache() | |
masks = mask_generator.generate(image) | |
if cudaOrNah == "cuda": | |
torch.cuda.empty_cache() | |
masks = sorted(masks, key=(lambda x: x['area']), reverse=True) | |
def is_background(segmentation): | |
val = (segmentation[10, 10] or segmentation[-10, 10] or | |
segmentation[10, -10] or segmentation[-10, -10]) | |
return val | |
masks = [mask for mask in masks if not is_background(mask['segmentation'])] | |
for i in range(0, len(masks) - 1)[::-1]: | |
large_mask = masks[i]['segmentation'] | |
for j in range(i+1, len(masks)): | |
not_small_mask = np.logical_not(masks[j]['segmentation']) | |
masks[i]['segmentation'] = np.logical_and(large_mask, not_small_mask) | |
masks[i]['area'] = masks[i]['segmentation'].sum() | |
large_mask = masks[i]['segmentation'] | |
def sum_under_threshold(segmentation, threshold): | |
return segmentation.sum() / segmentation.size < 0.0015 | |
masks = [mask for mask in masks if not sum_under_threshold(mask['segmentation'], 100)] | |
masks = sorted(masks, key=(lambda x: x['area']), reverse=True) | |
# Create a zip file in memory | |
zip_buffer = io.BytesIO() | |
with zipfile.ZipFile(zip_buffer, 'w', zipfile.ZIP_DEFLATED) as zip_file: | |
for idx, mask in enumerate(masks): | |
alpha = mask['segmentation'].astype('uint8') * 255 | |
mask_image = Image.fromarray(alpha) | |
mask_io = io.BytesIO() | |
mask_image.save(mask_io, format="PNG") | |
mask_io.seek(0) | |
zip_file.writestr(f'mask_{idx+1}.png', mask_io.read()) | |
zip_buffer.seek(0) | |
return send_file(zip_buffer, mimetype='application/zip', as_attachment=True, download_name='masks.zip') | |
except Exception as e: | |
# Log the error message if needed | |
print(f"Error processing the image: {e}") | |
# Return a JSON response with the error message and a 400 Bad Request status | |
return jsonify({"error": "Error processing the image", "details": str(e)}), 400 | |
if __name__ == '__main__': | |
app.run(debug=True) | |