CLIPSeg / app.py
sigyllly's picture
Update app.py
fa8c892 verified
from flask import Flask, request, jsonify, render_template
from PIL import Image
import base64
from io import BytesIO
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
import torch
import numpy as np
import matplotlib.pyplot as plt
import cv2
app = Flask(__name__)
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
def process_image(image, prompt, threshold, alpha_value, draw_rectangles):
inputs = processor(
text=prompt, images=image, padding="max_length", return_tensors="pt"
)
# predict
with torch.no_grad():
outputs = model(**inputs)
preds = outputs.logits
pred = torch.sigmoid(preds)
mat = pred.cpu().numpy()
# Ensure we are working with a single-channel 2D mask
mat = np.squeeze(mat, axis=0) # Remove batch dimension if it exists
mask = Image.fromarray(np.uint8(mat * 255), "L")
mask = mask.convert("RGB")
mask = mask.resize(image.size)
mask = np.array(mask)[:, :, 0]
# normalize the mask
mask_min = mask.min()
mask_max = mask.max()
mask = (mask - mask_min) / (mask_max - mask_min)
# threshold the mask
bmask = mask > threshold
mask[mask < threshold] = 0
fig, ax = plt.subplots()
ax.imshow(image)
ax.imshow(mask, alpha=alpha_value, cmap="jet")
if draw_rectangles:
contours, hierarchy = cv2.findContours(
bmask.astype(np.uint8), cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
)
for contour in contours:
x, y, w, h = cv2.boundingRect(contour)
rect = plt.Rectangle(
(x, y), w, h, fill=False, edgecolor="yellow", linewidth=2
)
ax.add_patch(rect)
ax.axis("off")
plt.tight_layout()
bmask = Image.fromarray(bmask.astype(np.uint8) * 255, "L")
output_image = Image.new("RGBA", image.size, (0, 0, 0, 0))
output_image.paste(image, mask=bmask)
# Convert mask to base64
buffered_mask = BytesIO()
bmask.save(buffered_mask, format="PNG")
result_mask = base64.b64encode(buffered_mask.getvalue()).decode('utf-8')
# Convert output image to base64
buffered_output = BytesIO()
output_image.save(buffered_output, format="PNG")
result_output = base64.b64encode(buffered_output.getvalue()).decode('utf-8')
return fig, result_mask, result_output
# Existing process_image function, copy it here
# ...
@app.route('/')
def index():
return render_template('index.html')
@app.route('/api/mask_image', methods=['POST'])
def mask_image_api():
data = request.get_json()
base64_image = data.get('base64_image', '')
prompt = data.get('prompt', '')
threshold = data.get('threshold', 0.4)
alpha_value = data.get('alpha_value', 0.5)
draw_rectangles = data.get('draw_rectangles', False)
# Decode base64 image
image_data = base64.b64decode(base64_image.split(',')[1])
image = Image.open(BytesIO(image_data))
# Process the image
_, result_mask, result_output = process_image(image, prompt, threshold, alpha_value, draw_rectangles)
return jsonify({'result_mask': result_mask, 'result_output': result_output})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=7860, debug=True)