Spaces:
Running
Running
IbrahimaThioye
commited on
Commit
•
0578219
1
Parent(s):
4bb79b0
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,604 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from flask import Flask, request, jsonify, render_template, url_for
|
2 |
+
from flask_socketio import SocketIO
|
3 |
+
import threading
|
4 |
+
from ultralytics import YOLO
|
5 |
+
import numpy as np
|
6 |
+
import cv2
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
import importlib
|
9 |
+
from segment_anything import sam_model_registry, SamPredictor
|
10 |
+
import os
|
11 |
+
from werkzeug.utils import secure_filename
|
12 |
+
import logging
|
13 |
+
import json
|
14 |
+
import shutil
|
15 |
+
import sys
|
16 |
+
from sam2.build_sam import build_sam2
|
17 |
+
from sam2.sam2_image_predictor import SAM2ImagePredictor
|
18 |
+
app = Flask(__name__)
|
19 |
+
socketio = SocketIO(app)
|
20 |
+
|
21 |
+
# Configure logging
|
22 |
+
logging.basicConfig(level=logging.INFO)
|
23 |
+
logger = logging.getLogger(__name__)
|
24 |
+
|
25 |
+
# Configuration
|
26 |
+
class Config:
|
27 |
+
BASE_DIR = os.path.abspath(os.path.dirname(__file__))
|
28 |
+
UPLOAD_FOLDER = os.path.join(BASE_DIR, 'static', 'uploads')
|
29 |
+
SAM_RESULT_FOLDER = os.path.join(BASE_DIR, 'static', 'sam','sam_results')
|
30 |
+
YOLO_RESULT_FOLDER = os.path.join(BASE_DIR, 'static', 'yolo','yolo_results')
|
31 |
+
YOLO_TRAIN_IMAGE_FOLDER = os.path.join(BASE_DIR, 'static', 'yolo','dataset_yolo','train','images')
|
32 |
+
YOLO_TRAIN_LABEL_FOLDER = os.path.join(BASE_DIR, 'static', 'yolo','dataset_yolo','train','labels')
|
33 |
+
AREA_DATA_FOLDER = os.path.join(BASE_DIR, 'static', 'yolo','area_data')
|
34 |
+
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}
|
35 |
+
MAX_CONTENT_LENGTH = 16 * 1024 * 1024 # 16MB max file size
|
36 |
+
SAM_CHECKPOINT = os.path.join(BASE_DIR, 'static', 'sam',"sam_vit_h_4b8939.pth")
|
37 |
+
SAM_2 = os.path.join(BASE_DIR, 'static', 'sam',"sam2.1_hiera_large.pt")
|
38 |
+
YOLO_PATH = os.path.join(BASE_DIR, 'static', 'yolo', "model_yolo.pt")
|
39 |
+
RETRAINED_MODEL_PATH = os.path.join(BASE_DIR, 'static', 'yolo', "model_retrained.pt")
|
40 |
+
DATA_PATH = os.path.join(BASE_DIR, 'static', 'yolo','dataset_yolo', "data.yaml")
|
41 |
+
|
42 |
+
app.config.from_object(Config)
|
43 |
+
|
44 |
+
# Ensure directories exist
|
45 |
+
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
|
46 |
+
os.makedirs(app.config['SAM_RESULT_FOLDER'], exist_ok=True)
|
47 |
+
os.makedirs(app.config['YOLO_RESULT_FOLDER'], exist_ok=True)
|
48 |
+
os.makedirs(app.config['YOLO_TRAIN_IMAGE_FOLDER'], exist_ok=True)
|
49 |
+
os.makedirs(app.config['YOLO_TRAIN_LABEL_FOLDER'], exist_ok=True)
|
50 |
+
os.makedirs(app.config['AREA_DATA_FOLDER'], exist_ok=True)
|
51 |
+
|
52 |
+
|
53 |
+
# Initialize Yolo model
|
54 |
+
try:
|
55 |
+
model = YOLO(app.config['YOLO_PATH'])
|
56 |
+
except Exception as e:
|
57 |
+
logger.error(f"Failed to initialize YOLO model: {str(e)}")
|
58 |
+
raise
|
59 |
+
|
60 |
+
try:
|
61 |
+
sam2_checkpoint = app.config['SAM_2']
|
62 |
+
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
|
63 |
+
|
64 |
+
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cpu")
|
65 |
+
predictor = SAM2ImagePredictor(sam2_model)
|
66 |
+
except Exception as e:
|
67 |
+
logger.error(f"Failed to initialize SAM model: {str(e)}")
|
68 |
+
raise
|
69 |
+
|
70 |
+
def allowed_file(filename):
|
71 |
+
return '.' in filename and filename.rsplit('.', 1)[1].lower() in app.config['ALLOWED_EXTENSIONS']
|
72 |
+
|
73 |
+
def scale_coordinates(coords, original_dims, target_dims):
|
74 |
+
"""
|
75 |
+
Scale coordinates from one dimension space to another.
|
76 |
+
|
77 |
+
Args:
|
78 |
+
coords: List of [x, y] coordinates
|
79 |
+
original_dims: Tuple of (width, height) of original space
|
80 |
+
target_dims: Tuple of (width, height) of target space
|
81 |
+
|
82 |
+
Returns:
|
83 |
+
Scaled coordinates
|
84 |
+
"""
|
85 |
+
scale_x = target_dims[0] / original_dims[0]
|
86 |
+
scale_y = target_dims[1] / original_dims[1]
|
87 |
+
|
88 |
+
return [
|
89 |
+
[int(coord[0] * scale_x), int(coord[1] * scale_y)]
|
90 |
+
for coord in coords
|
91 |
+
]
|
92 |
+
|
93 |
+
def scale_box(box, original_dims, target_dims):
|
94 |
+
"""
|
95 |
+
Scale bounding box coordinates from one dimension space to another.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
box: List of [x1, y1, x2, y2] coordinates
|
99 |
+
original_dims: Tuple of (width, height) of original space
|
100 |
+
target_dims: Tuple of (width, height) of target space
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
Scaled box coordinates
|
104 |
+
"""
|
105 |
+
scale_x = target_dims[0] / original_dims[0]
|
106 |
+
scale_y = target_dims[1] / original_dims[1]
|
107 |
+
|
108 |
+
return [
|
109 |
+
int(box[0] * scale_x), # x1
|
110 |
+
int(box[1] * scale_y), # y1
|
111 |
+
int(box[2] * scale_x), # x2
|
112 |
+
int(box[3] * scale_y) # y2
|
113 |
+
]
|
114 |
+
|
115 |
+
def retrain_model_fn():
|
116 |
+
# Parameters for retraining
|
117 |
+
data_path = app.config['DATA_PATH']
|
118 |
+
epochs = 5
|
119 |
+
img_size = 640
|
120 |
+
batch_size = 8
|
121 |
+
|
122 |
+
# Start training with YOLO, using event listeners for epoch completion
|
123 |
+
for epoch in range(epochs):
|
124 |
+
# Train the model for one epoch, here we simulate with a loop
|
125 |
+
model.train(
|
126 |
+
data=data_path,
|
127 |
+
epochs=1, # Use 1 epoch per call to get individual progress
|
128 |
+
imgsz=img_size,
|
129 |
+
batch=batch_size,
|
130 |
+
device="cpu" # Adjust based on system capabilities
|
131 |
+
)
|
132 |
+
|
133 |
+
# Emit an update to the client after each epoch
|
134 |
+
socketio.emit('training_update', {
|
135 |
+
'epoch': epoch + 1,
|
136 |
+
'status': f"Epoch {epoch + 1} complete"
|
137 |
+
})
|
138 |
+
|
139 |
+
# Emit a message once training is complete
|
140 |
+
socketio.emit('training_complete', {'status': "Retraining complete"})
|
141 |
+
model.save(app.config['YOLO_PATH'])
|
142 |
+
logger.info("Model retrained successfully")
|
143 |
+
|
144 |
+
@app.route('/')
|
145 |
+
def index():
|
146 |
+
return render_template('index.html')
|
147 |
+
|
148 |
+
@app.route('/yolo')
|
149 |
+
def yolo():
|
150 |
+
return render_template('yolo.html')
|
151 |
+
|
152 |
+
@app.route('/upload_sam', methods=['POST'])
|
153 |
+
def upload_sam_file():
|
154 |
+
"""
|
155 |
+
Handles SAM image upload and embeds the image into the predictor instance.
|
156 |
+
|
157 |
+
Returns:
|
158 |
+
JSON response with 'message', 'image_url', 'filename', and 'dimensions' keys
|
159 |
+
on success, or 'error' key with an appropriate error message on failure.
|
160 |
+
"""
|
161 |
+
|
162 |
+
try:
|
163 |
+
if 'file' not in request.files:
|
164 |
+
return jsonify({'error': 'No file part'}), 400
|
165 |
+
|
166 |
+
file = request.files['file']
|
167 |
+
if file.filename == '':
|
168 |
+
return jsonify({'error': 'No selected file'}), 400
|
169 |
+
|
170 |
+
if not allowed_file(file.filename):
|
171 |
+
return jsonify({'error': 'Invalid file type. Allowed types: PNG, JPG, JPEG'}), 400
|
172 |
+
|
173 |
+
filename = secure_filename(file.filename)
|
174 |
+
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
|
175 |
+
file.save(filepath)
|
176 |
+
|
177 |
+
# Set the image for predictor right after upload
|
178 |
+
image = cv2.imread(filepath)
|
179 |
+
if image is None:
|
180 |
+
return jsonify({'error': 'Failed to load uploaded image'}), 500
|
181 |
+
|
182 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
183 |
+
predictor.set_image(image)
|
184 |
+
logger.info("Image embedded successfully")
|
185 |
+
|
186 |
+
# Get image dimensions
|
187 |
+
height, width = image.shape[:2]
|
188 |
+
|
189 |
+
image_url = url_for('static', filename=f'uploads/{filename}')
|
190 |
+
logger.info(f"File uploaded successfully: {filepath}")
|
191 |
+
|
192 |
+
return jsonify({
|
193 |
+
'message': 'File uploaded successfully',
|
194 |
+
'image_url': image_url,
|
195 |
+
'filename': filename,
|
196 |
+
'dimensions': {
|
197 |
+
'width': width,
|
198 |
+
'height': height
|
199 |
+
}
|
200 |
+
})
|
201 |
+
|
202 |
+
except Exception as e:
|
203 |
+
logger.error(f"Upload error: {str(e)}")
|
204 |
+
return jsonify({'error': 'Server error during upload'}), 500
|
205 |
+
|
206 |
+
@app.route('/upload_yolo', methods=['POST'])
|
207 |
+
def upload_yolo_file():
|
208 |
+
"""
|
209 |
+
Upload a YOLO image file
|
210 |
+
|
211 |
+
This endpoint allows a POST request containing a single image file. The file is
|
212 |
+
saved to the uploads folder and the image is embedded into the YOLO model.
|
213 |
+
|
214 |
+
Returns a JSON response with the following keys:
|
215 |
+
- message: a success message
|
216 |
+
- image_url: the URL of the uploaded image
|
217 |
+
- filename: the name of the uploaded file
|
218 |
+
|
219 |
+
If an error occurs, the JSON response will contain an 'error' key with a
|
220 |
+
descriptive error message.
|
221 |
+
"""
|
222 |
+
try:
|
223 |
+
if 'file' not in request.files:
|
224 |
+
return jsonify({'error': 'No file part'}), 400
|
225 |
+
|
226 |
+
file = request.files['file']
|
227 |
+
if file.filename == '':
|
228 |
+
return jsonify({'error': 'No selected file'}), 400
|
229 |
+
|
230 |
+
if not allowed_file(file.filename):
|
231 |
+
return jsonify({'error': 'Invalid file type. Allowed types: PNG, JPG, JPEG'}), 400
|
232 |
+
|
233 |
+
filename = secure_filename(file.filename)
|
234 |
+
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
|
235 |
+
file.save(filepath)
|
236 |
+
|
237 |
+
|
238 |
+
image_url = url_for('static', filename=f'uploads/{filename}')
|
239 |
+
logger.info(f"File uploaded successfully: {filepath}")
|
240 |
+
|
241 |
+
return jsonify({
|
242 |
+
'message': 'File uploaded successfully',
|
243 |
+
'image_url': image_url,
|
244 |
+
'filename': filename,
|
245 |
+
})
|
246 |
+
|
247 |
+
except Exception as e:
|
248 |
+
logger.error(f"Upload error: {str(e)}")
|
249 |
+
return jsonify({'error': 'Server error during upload'}), 500
|
250 |
+
|
251 |
+
@app.route('/generate_mask', methods=['POST'])
|
252 |
+
def generate_mask():
|
253 |
+
"""
|
254 |
+
Generate a mask for a given image using the YOLO model
|
255 |
+
@param data: a JSON object containing the following keys:
|
256 |
+
- filename: the name of the image file
|
257 |
+
- normalized_void_points: a list of normalized 2D points (x, y) representing the voids
|
258 |
+
- normalized_component_boxes: a list of normalized 2D bounding boxes (x, y, w, h) representing the components
|
259 |
+
@return: a JSON object containing the following keys:
|
260 |
+
- status: a string indicating the status of the request
|
261 |
+
- train_image_url: the URL of the saved train image
|
262 |
+
- result_path: the URL of the saved result image
|
263 |
+
"""
|
264 |
+
try:
|
265 |
+
data = request.json
|
266 |
+
normalized_void_points = data.get('void_points', [])
|
267 |
+
normalized_component_boxes = data.get('component_boxes', [])
|
268 |
+
filename = data.get('filename', '')
|
269 |
+
|
270 |
+
if not filename:
|
271 |
+
return jsonify({'error': 'No filename provided'}), 400
|
272 |
+
|
273 |
+
image_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
|
274 |
+
if not os.path.exists(image_path):
|
275 |
+
return jsonify({'error': 'Image file not found'}), 404
|
276 |
+
|
277 |
+
# Read image
|
278 |
+
image = cv2.imread(image_path)
|
279 |
+
if image is None:
|
280 |
+
return jsonify({'error': 'Failed to load image'}), 500
|
281 |
+
|
282 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
283 |
+
image_height, image_width = image.shape[:2]
|
284 |
+
|
285 |
+
# Denormalize coordinates back to pixel values
|
286 |
+
void_points = [
|
287 |
+
[int(point[0] * image_width), int(point[1] * image_height)]
|
288 |
+
for point in normalized_void_points
|
289 |
+
]
|
290 |
+
logger.info(f"Void points: {void_points}")
|
291 |
+
|
292 |
+
component_boxes = [
|
293 |
+
[
|
294 |
+
int(box[0] * image_width),
|
295 |
+
int(box[1] * image_height),
|
296 |
+
int(box[2] * image_width),
|
297 |
+
int(box[3] * image_height)
|
298 |
+
]
|
299 |
+
for box in normalized_component_boxes
|
300 |
+
]
|
301 |
+
logger.info(f"Void points: {void_points}")
|
302 |
+
|
303 |
+
# Create a list to store individual void masks
|
304 |
+
void_masks = []
|
305 |
+
|
306 |
+
# Process void points one by one
|
307 |
+
for point in void_points:
|
308 |
+
# Convert point to correct format: [N, 2] array
|
309 |
+
point_coord = np.array([[point[0], point[1]]])
|
310 |
+
point_label = np.array([1]) # Single label
|
311 |
+
|
312 |
+
masks, scores, _ = predictor.predict(
|
313 |
+
point_coords=point_coord,
|
314 |
+
point_labels=point_label,
|
315 |
+
multimask_output=True # Get multiple masks
|
316 |
+
)
|
317 |
+
|
318 |
+
if len(masks) > 0: # Check if any masks were generated
|
319 |
+
# Get the mask with highest score
|
320 |
+
best_mask_idx = np.argmax(scores)
|
321 |
+
void_masks.append(masks[best_mask_idx])
|
322 |
+
logger.info(f"Processed void point {point} with score {scores[best_mask_idx]}")
|
323 |
+
|
324 |
+
# Process component boxes
|
325 |
+
component_masks = []
|
326 |
+
if component_boxes:
|
327 |
+
for box in component_boxes:
|
328 |
+
# Convert box to correct format: [2, 2] array
|
329 |
+
box_np = np.array([[box[0], box[1]], [box[2], box[3]]])
|
330 |
+
masks, scores, _ = predictor.predict(
|
331 |
+
box=box_np,
|
332 |
+
multimask_output=True
|
333 |
+
)
|
334 |
+
if len(masks) > 0:
|
335 |
+
best_mask_idx = np.argmax(scores)
|
336 |
+
component_masks.append(masks[best_mask_idx])
|
337 |
+
logger.info(f"Processed component box {box}")
|
338 |
+
|
339 |
+
# Create visualization with different colors for each void
|
340 |
+
combined_image = image.copy()
|
341 |
+
|
342 |
+
# Font settings for labels
|
343 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
344 |
+
font_scale = 0.6
|
345 |
+
font_color = (0,0,0) # White text color
|
346 |
+
font_thickness = 1
|
347 |
+
background_color = (255, 255, 255) # White background for text
|
348 |
+
|
349 |
+
# Helper function to get bounding box coordinates
|
350 |
+
def get_bounding_box(mask):
|
351 |
+
coords = np.column_stack(np.where(mask))
|
352 |
+
x_min, y_min = coords.min(axis=0)
|
353 |
+
x_max, y_max = coords.max(axis=0)
|
354 |
+
return (x_min, y_min, x_max, y_max)
|
355 |
+
|
356 |
+
# Helper function to add text with background
|
357 |
+
def put_text_with_background(img, text, pos):
|
358 |
+
# Calculate text size
|
359 |
+
(text_w, text_h), _ = cv2.getTextSize(text, font, font_scale, font_thickness)
|
360 |
+
# Define the rectangle coordinates for background
|
361 |
+
background_tl = (pos[0], pos[1] - text_h - 2)
|
362 |
+
background_br = (pos[0] + text_w, pos[1] + 2)
|
363 |
+
# Draw white rectangle as background
|
364 |
+
cv2.rectangle(img, background_tl, background_br, background_color, -1)
|
365 |
+
# Put the text over the background rectangle
|
366 |
+
cv2.putText(img, text, pos, font, font_scale, font_color, font_thickness, cv2.LINE_AA)
|
367 |
+
|
368 |
+
def get_safe_label_position(x_min, y_min, x_max, y_max, text_w, text_h, img_width, img_height):
|
369 |
+
# Default to top-right of bounding box
|
370 |
+
x_pos = min(y_max, img_width - text_w - 10) # Keep 10px margin from the right
|
371 |
+
y_pos = max(x_min + text_h + 5, text_h + 5) # Keep 5px margin from the top
|
372 |
+
return x_pos, y_pos
|
373 |
+
|
374 |
+
|
375 |
+
# Apply void masks with different colors
|
376 |
+
for mask in void_masks:
|
377 |
+
mask = mask.astype(bool)
|
378 |
+
combined_image[mask, 0] = np.clip(0.5 * image[mask, 0] + 0.5 * 255, 0, 255) # Red channel with transparency
|
379 |
+
combined_image[mask, 1] = np.clip(0.5 * image[mask, 1], 0, 255) # Green channel reduced
|
380 |
+
combined_image[mask, 2] = np.clip(0.5 * image[mask, 2], 0, 255)
|
381 |
+
logger.info("Mask Drawn")
|
382 |
+
|
383 |
+
# Apply component masks in green
|
384 |
+
for mask in component_masks:
|
385 |
+
mask = mask.astype(bool)
|
386 |
+
# Only apply green where there is no red overlay
|
387 |
+
non_red_area = mask & ~np.any([void_mask for void_mask in void_masks], axis=0)
|
388 |
+
combined_image[non_red_area, 0] = np.clip(0.5 * image[non_red_area, 0], 0, 255) # Reduced red channel
|
389 |
+
combined_image[non_red_area, 1] = np.clip(0.5 * image[non_red_area, 1] + 0.5 * 255, 0, 255) # Green channel
|
390 |
+
combined_image[non_red_area, 2] = np.clip(0.5 * image[non_red_area, 2], 0, 255)
|
391 |
+
logger.info("Mask Drawn")
|
392 |
+
|
393 |
+
|
394 |
+
# Add labels on top of masks
|
395 |
+
for i,mask in enumerate(void_masks):
|
396 |
+
x_min, y_min, x_max, y_max = get_bounding_box(mask)
|
397 |
+
(text_w, text_h), _ = cv2.getTextSize("Void", font, font_scale, font_thickness)
|
398 |
+
label_position = get_safe_label_position(x_min, y_min, x_max, y_max, text_w, text_h, combined_image.shape[1], combined_image.shape[0])
|
399 |
+
put_text_with_background(combined_image, f"Void {i+1}", label_position)
|
400 |
+
|
401 |
+
for i,mask in enumerate(component_masks):
|
402 |
+
i=i+1
|
403 |
+
x_min, y_min, x_max, y_max = get_bounding_box(mask)
|
404 |
+
(text_w, text_h), _ = cv2.getTextSize("Component", font, font_scale, font_thickness)
|
405 |
+
label_position = get_safe_label_position(x_min, y_min, x_max, y_max, text_w, text_h, combined_image.shape[1], combined_image.shape[0])
|
406 |
+
put_text_with_background(combined_image, f"Component {i}", label_position)
|
407 |
+
|
408 |
+
# Prepare an empty list to store the output in the required format
|
409 |
+
mask_coordinates = []
|
410 |
+
|
411 |
+
for mask in void_masks:
|
412 |
+
# Get contours from the mask
|
413 |
+
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
414 |
+
# Image dimensions
|
415 |
+
height, width = mask.shape
|
416 |
+
|
417 |
+
# For each contour, extract the normalized coordinates
|
418 |
+
for contour in contours:
|
419 |
+
contour_points = contour.reshape(-1, 2) # Flatten to (N, 2) where N is the number of points
|
420 |
+
normalized_points = contour_points / [width, height] # Normalize to (0, 1)
|
421 |
+
|
422 |
+
class_id = 1 # 1 for voids
|
423 |
+
row = [class_id] + normalized_points.flatten().tolist() # Flatten and add the class
|
424 |
+
mask_coordinates.append(row)
|
425 |
+
|
426 |
+
for mask in component_masks:
|
427 |
+
# Get contours from the mask
|
428 |
+
contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
429 |
+
# Filter to keep only the largest contour
|
430 |
+
contours = sorted(contours, key=cv2.contourArea, reverse=True)
|
431 |
+
largest_contour = [contours[0]] if contours else []
|
432 |
+
# Image dimensions
|
433 |
+
height, width = mask.shape
|
434 |
+
|
435 |
+
# For each contour, extract the normalized coordinates
|
436 |
+
for contour in largest_contour:
|
437 |
+
contour_points = contour.reshape(-1, 2) # Flatten to (N, 2) where N is the number of points
|
438 |
+
normalized_points = contour_points / [width, height] # Normalize to (0, 1)
|
439 |
+
|
440 |
+
class_id = 0 # for components
|
441 |
+
row = [class_id] + normalized_points.flatten().tolist() # Flatten and add the class
|
442 |
+
mask_coordinates.append(row)
|
443 |
+
|
444 |
+
mask_coordinates_filename = f'{filename}.txt' # Create a unique filename
|
445 |
+
mask_coordinates_path = os.path.join(app.config['YOLO_TRAIN_LABEL_FOLDER'], mask_coordinates_filename)
|
446 |
+
|
447 |
+
|
448 |
+
with open(mask_coordinates_path, "w") as file:
|
449 |
+
for row in mask_coordinates:
|
450 |
+
# Join elements of the row into a string with spaces in between and write to the file
|
451 |
+
file.write(" ".join(map(str, row)) + "\n")
|
452 |
+
|
453 |
+
# Save train image
|
454 |
+
train_image_filepath = os.path.join(app.config['YOLO_TRAIN_IMAGE_FOLDER'], filename)
|
455 |
+
shutil.copy(image_path, train_image_filepath)
|
456 |
+
train_image_url = url_for('static', filename=f'yolo/dataset_yolo/train/images/{filename}')
|
457 |
+
|
458 |
+
# Save result
|
459 |
+
result_filename = f'segmented_{filename}'
|
460 |
+
result_path = os.path.join(app.config['SAM_RESULT_FOLDER'], result_filename)
|
461 |
+
plt.imsave(result_path, combined_image)
|
462 |
+
logger.info("Mask generation completed successfully")
|
463 |
+
|
464 |
+
return jsonify({
|
465 |
+
'status': 'success',
|
466 |
+
'train_image_url':train_image_url,
|
467 |
+
'result_path': url_for('static', filename=f'sam/sam_results/{result_filename}')
|
468 |
+
})
|
469 |
+
|
470 |
+
except Exception as e:
|
471 |
+
logger.error(f"Mask generation error: {str(e)}")
|
472 |
+
return jsonify({'error': str(e)}), 500
|
473 |
+
|
474 |
+
@app.route('/classify', methods=['POST'])
|
475 |
+
def classify():
|
476 |
+
"""
|
477 |
+
Classify an image and return the classification result, area data, and the annotated image.
|
478 |
+
|
479 |
+
Request body should contain a JSON object with a single key 'filename' specifying the image file to be classified.
|
480 |
+
|
481 |
+
Returns a JSON object with the following keys:
|
482 |
+
|
483 |
+
- status: 'success' if the classification is successful, 'error' if there is an error.
|
484 |
+
- result_path: URL of the annotated image.
|
485 |
+
- area_data: a list of dictionaries containing the area and overlap statistics for each component.
|
486 |
+
- area_data_path: URL of the JSON file containing the area data.
|
487 |
+
|
488 |
+
If there is an error, returns a JSON object with a single key 'error' containing the error message.
|
489 |
+
"""
|
490 |
+
|
491 |
+
try:
|
492 |
+
data = request.json
|
493 |
+
filename = data.get('filename', '')
|
494 |
+
if not filename:
|
495 |
+
return jsonify({'error': 'No filename provided'}), 400
|
496 |
+
|
497 |
+
image_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
|
498 |
+
if not os.path.exists(image_path):
|
499 |
+
return jsonify({'error': 'Image file not found'}), 404
|
500 |
+
|
501 |
+
# Read image
|
502 |
+
image = cv2.imread(image_path)
|
503 |
+
if image is None:
|
504 |
+
return jsonify({'error': 'Failed to load image'}), 500
|
505 |
+
|
506 |
+
results = model(image)
|
507 |
+
result = results[0]
|
508 |
+
|
509 |
+
component_masks = []
|
510 |
+
void_masks = []
|
511 |
+
|
512 |
+
# Extract masks and labels from results
|
513 |
+
for mask, label in zip(result.masks.data, result.boxes.cls):
|
514 |
+
mask_array = mask.cpu().numpy().astype(bool) # Convert to a binary mask (boolean array)
|
515 |
+
if label == 1: # Assuming label '1' represents void
|
516 |
+
void_masks.append(mask_array)
|
517 |
+
elif label == 0: # Assuming label '0' represents component
|
518 |
+
component_masks.append(mask_array)
|
519 |
+
|
520 |
+
# Calculate area and overlap statistics
|
521 |
+
area_data = []
|
522 |
+
for i, component_mask in enumerate(component_masks):
|
523 |
+
component_area = np.sum(component_mask).item() # Total component area in pixels
|
524 |
+
void_area_within_component = 0
|
525 |
+
max_void_area_percentage = 0
|
526 |
+
|
527 |
+
# Calculate overlap of each void mask with the component mask
|
528 |
+
for void_mask in void_masks:
|
529 |
+
overlap_area = np.sum(void_mask & component_mask).item() # Overlapping area
|
530 |
+
void_area_within_component += overlap_area
|
531 |
+
void_area_percentage = (overlap_area / component_area) * 100 if component_area > 0 else 0
|
532 |
+
max_void_area_percentage = max(max_void_area_percentage, void_area_percentage)
|
533 |
+
|
534 |
+
# Append data for this component
|
535 |
+
area_data.append({
|
536 |
+
"Image": filename,
|
537 |
+
'Component': f'Component {i+1}',
|
538 |
+
'Area': component_area,
|
539 |
+
'Void Area (pixels)': void_area_within_component,
|
540 |
+
'Void Area %': void_area_within_component / component_area * 100 if component_area > 0 else 0,
|
541 |
+
'Max Void Area %': max_void_area_percentage
|
542 |
+
})
|
543 |
+
|
544 |
+
area_data_filename = f'area_data_{filename.split("/")[-1]}.json' # Create a unique filename
|
545 |
+
area_data_path = os.path.join(app.config['AREA_DATA_FOLDER'], area_data_filename)
|
546 |
+
|
547 |
+
with open(area_data_path, 'w') as json_file:
|
548 |
+
json.dump(area_data, json_file, indent=4)
|
549 |
+
|
550 |
+
annotated_image = result.plot()
|
551 |
+
|
552 |
+
output_filename = f'output_{filename}'
|
553 |
+
output_image_path = os.path.join(app.config['YOLO_RESULT_FOLDER'], output_filename)
|
554 |
+
plt.imsave(output_image_path, annotated_image)
|
555 |
+
logger.info("Classification completed successfully")
|
556 |
+
|
557 |
+
return jsonify({
|
558 |
+
'status': 'success',
|
559 |
+
'result_path': url_for('static', filename=f'yolo/yolo_results/{output_filename}'),
|
560 |
+
'area_data': area_data,
|
561 |
+
'area_data_path': url_for('static', filename=f'yolo/area_data/{area_data_filename}')
|
562 |
+
})
|
563 |
+
except Exception as e:
|
564 |
+
logger.error(f"Classification error: {str(e)}")
|
565 |
+
return jsonify({'error': str(e)}), 500
|
566 |
+
|
567 |
+
retraining_status = {
|
568 |
+
'status': 'idle',
|
569 |
+
'progress': None,
|
570 |
+
'message': None
|
571 |
+
}
|
572 |
+
|
573 |
+
@app.route('/start_retraining', methods=['GET', 'POST'])
|
574 |
+
def start_retraining():
|
575 |
+
"""
|
576 |
+
Start the model retraining process.
|
577 |
+
|
578 |
+
If the request is a POST, start the model retraining process in a separate thread.
|
579 |
+
If the request is a GET, render the retraining page.
|
580 |
+
|
581 |
+
Returns:
|
582 |
+
A JSON response with the status of the retraining process, or a rendered HTML page.
|
583 |
+
"""
|
584 |
+
if request.method == 'POST':
|
585 |
+
# Reset status
|
586 |
+
global retraining_status
|
587 |
+
retraining_status['status'] = 'in_progress'
|
588 |
+
retraining_status['progress'] = 'Initializing'
|
589 |
+
|
590 |
+
# Start retraining in a separate thread
|
591 |
+
threading.Thread(target=retrain_model_fn).start()
|
592 |
+
return jsonify({'status': 'started'})
|
593 |
+
else:
|
594 |
+
# GET request - render the retraining page
|
595 |
+
return render_template('retrain.html')
|
596 |
+
|
597 |
+
# Event handler for client connection
|
598 |
+
@socketio.on('connect')
|
599 |
+
def handle_connect():
|
600 |
+
print('Client connected')
|
601 |
+
|
602 |
+
|
603 |
+
if __name__ == '__main__':
|
604 |
+
app.run(port=5001, debug=True)
|