IbrahimaThioye commited on
Commit
0578219
1 Parent(s): 4bb79b0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +604 -0
app.py ADDED
@@ -0,0 +1,604 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify, render_template, url_for
2
+ from flask_socketio import SocketIO
3
+ import threading
4
+ from ultralytics import YOLO
5
+ import numpy as np
6
+ import cv2
7
+ import matplotlib.pyplot as plt
8
+ import importlib
9
+ from segment_anything import sam_model_registry, SamPredictor
10
+ import os
11
+ from werkzeug.utils import secure_filename
12
+ import logging
13
+ import json
14
+ import shutil
15
+ import sys
16
+ from sam2.build_sam import build_sam2
17
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
18
+ app = Flask(__name__)
19
+ socketio = SocketIO(app)
20
+
21
+ # Configure logging
22
+ logging.basicConfig(level=logging.INFO)
23
+ logger = logging.getLogger(__name__)
24
+
25
+ # Configuration
26
+ class Config:
27
+ BASE_DIR = os.path.abspath(os.path.dirname(__file__))
28
+ UPLOAD_FOLDER = os.path.join(BASE_DIR, 'static', 'uploads')
29
+ SAM_RESULT_FOLDER = os.path.join(BASE_DIR, 'static', 'sam','sam_results')
30
+ YOLO_RESULT_FOLDER = os.path.join(BASE_DIR, 'static', 'yolo','yolo_results')
31
+ YOLO_TRAIN_IMAGE_FOLDER = os.path.join(BASE_DIR, 'static', 'yolo','dataset_yolo','train','images')
32
+ YOLO_TRAIN_LABEL_FOLDER = os.path.join(BASE_DIR, 'static', 'yolo','dataset_yolo','train','labels')
33
+ AREA_DATA_FOLDER = os.path.join(BASE_DIR, 'static', 'yolo','area_data')
34
+ ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}
35
+ MAX_CONTENT_LENGTH = 16 * 1024 * 1024 # 16MB max file size
36
+ SAM_CHECKPOINT = os.path.join(BASE_DIR, 'static', 'sam',"sam_vit_h_4b8939.pth")
37
+ SAM_2 = os.path.join(BASE_DIR, 'static', 'sam',"sam2.1_hiera_large.pt")
38
+ YOLO_PATH = os.path.join(BASE_DIR, 'static', 'yolo', "model_yolo.pt")
39
+ RETRAINED_MODEL_PATH = os.path.join(BASE_DIR, 'static', 'yolo', "model_retrained.pt")
40
+ DATA_PATH = os.path.join(BASE_DIR, 'static', 'yolo','dataset_yolo', "data.yaml")
41
+
42
+ app.config.from_object(Config)
43
+
44
+ # Ensure directories exist
45
+ os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
46
+ os.makedirs(app.config['SAM_RESULT_FOLDER'], exist_ok=True)
47
+ os.makedirs(app.config['YOLO_RESULT_FOLDER'], exist_ok=True)
48
+ os.makedirs(app.config['YOLO_TRAIN_IMAGE_FOLDER'], exist_ok=True)
49
+ os.makedirs(app.config['YOLO_TRAIN_LABEL_FOLDER'], exist_ok=True)
50
+ os.makedirs(app.config['AREA_DATA_FOLDER'], exist_ok=True)
51
+
52
+
53
+ # Initialize Yolo model
54
+ try:
55
+ model = YOLO(app.config['YOLO_PATH'])
56
+ except Exception as e:
57
+ logger.error(f"Failed to initialize YOLO model: {str(e)}")
58
+ raise
59
+
60
+ try:
61
+ sam2_checkpoint = app.config['SAM_2']
62
+ model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
63
+
64
+ sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cpu")
65
+ predictor = SAM2ImagePredictor(sam2_model)
66
+ except Exception as e:
67
+ logger.error(f"Failed to initialize SAM model: {str(e)}")
68
+ raise
69
+
70
+ def allowed_file(filename):
71
+ return '.' in filename and filename.rsplit('.', 1)[1].lower() in app.config['ALLOWED_EXTENSIONS']
72
+
73
+ def scale_coordinates(coords, original_dims, target_dims):
74
+ """
75
+ Scale coordinates from one dimension space to another.
76
+
77
+ Args:
78
+ coords: List of [x, y] coordinates
79
+ original_dims: Tuple of (width, height) of original space
80
+ target_dims: Tuple of (width, height) of target space
81
+
82
+ Returns:
83
+ Scaled coordinates
84
+ """
85
+ scale_x = target_dims[0] / original_dims[0]
86
+ scale_y = target_dims[1] / original_dims[1]
87
+
88
+ return [
89
+ [int(coord[0] * scale_x), int(coord[1] * scale_y)]
90
+ for coord in coords
91
+ ]
92
+
93
+ def scale_box(box, original_dims, target_dims):
94
+ """
95
+ Scale bounding box coordinates from one dimension space to another.
96
+
97
+ Args:
98
+ box: List of [x1, y1, x2, y2] coordinates
99
+ original_dims: Tuple of (width, height) of original space
100
+ target_dims: Tuple of (width, height) of target space
101
+
102
+ Returns:
103
+ Scaled box coordinates
104
+ """
105
+ scale_x = target_dims[0] / original_dims[0]
106
+ scale_y = target_dims[1] / original_dims[1]
107
+
108
+ return [
109
+ int(box[0] * scale_x), # x1
110
+ int(box[1] * scale_y), # y1
111
+ int(box[2] * scale_x), # x2
112
+ int(box[3] * scale_y) # y2
113
+ ]
114
+
115
+ def retrain_model_fn():
116
+ # Parameters for retraining
117
+ data_path = app.config['DATA_PATH']
118
+ epochs = 5
119
+ img_size = 640
120
+ batch_size = 8
121
+
122
+ # Start training with YOLO, using event listeners for epoch completion
123
+ for epoch in range(epochs):
124
+ # Train the model for one epoch, here we simulate with a loop
125
+ model.train(
126
+ data=data_path,
127
+ epochs=1, # Use 1 epoch per call to get individual progress
128
+ imgsz=img_size,
129
+ batch=batch_size,
130
+ device="cpu" # Adjust based on system capabilities
131
+ )
132
+
133
+ # Emit an update to the client after each epoch
134
+ socketio.emit('training_update', {
135
+ 'epoch': epoch + 1,
136
+ 'status': f"Epoch {epoch + 1} complete"
137
+ })
138
+
139
+ # Emit a message once training is complete
140
+ socketio.emit('training_complete', {'status': "Retraining complete"})
141
+ model.save(app.config['YOLO_PATH'])
142
+ logger.info("Model retrained successfully")
143
+
144
+ @app.route('/')
145
+ def index():
146
+ return render_template('index.html')
147
+
148
+ @app.route('/yolo')
149
+ def yolo():
150
+ return render_template('yolo.html')
151
+
152
+ @app.route('/upload_sam', methods=['POST'])
153
+ def upload_sam_file():
154
+ """
155
+ Handles SAM image upload and embeds the image into the predictor instance.
156
+
157
+ Returns:
158
+ JSON response with 'message', 'image_url', 'filename', and 'dimensions' keys
159
+ on success, or 'error' key with an appropriate error message on failure.
160
+ """
161
+
162
+ try:
163
+ if 'file' not in request.files:
164
+ return jsonify({'error': 'No file part'}), 400
165
+
166
+ file = request.files['file']
167
+ if file.filename == '':
168
+ return jsonify({'error': 'No selected file'}), 400
169
+
170
+ if not allowed_file(file.filename):
171
+ return jsonify({'error': 'Invalid file type. Allowed types: PNG, JPG, JPEG'}), 400
172
+
173
+ filename = secure_filename(file.filename)
174
+ filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
175
+ file.save(filepath)
176
+
177
+ # Set the image for predictor right after upload
178
+ image = cv2.imread(filepath)
179
+ if image is None:
180
+ return jsonify({'error': 'Failed to load uploaded image'}), 500
181
+
182
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
183
+ predictor.set_image(image)
184
+ logger.info("Image embedded successfully")
185
+
186
+ # Get image dimensions
187
+ height, width = image.shape[:2]
188
+
189
+ image_url = url_for('static', filename=f'uploads/{filename}')
190
+ logger.info(f"File uploaded successfully: {filepath}")
191
+
192
+ return jsonify({
193
+ 'message': 'File uploaded successfully',
194
+ 'image_url': image_url,
195
+ 'filename': filename,
196
+ 'dimensions': {
197
+ 'width': width,
198
+ 'height': height
199
+ }
200
+ })
201
+
202
+ except Exception as e:
203
+ logger.error(f"Upload error: {str(e)}")
204
+ return jsonify({'error': 'Server error during upload'}), 500
205
+
206
+ @app.route('/upload_yolo', methods=['POST'])
207
+ def upload_yolo_file():
208
+ """
209
+ Upload a YOLO image file
210
+
211
+ This endpoint allows a POST request containing a single image file. The file is
212
+ saved to the uploads folder and the image is embedded into the YOLO model.
213
+
214
+ Returns a JSON response with the following keys:
215
+ - message: a success message
216
+ - image_url: the URL of the uploaded image
217
+ - filename: the name of the uploaded file
218
+
219
+ If an error occurs, the JSON response will contain an 'error' key with a
220
+ descriptive error message.
221
+ """
222
+ try:
223
+ if 'file' not in request.files:
224
+ return jsonify({'error': 'No file part'}), 400
225
+
226
+ file = request.files['file']
227
+ if file.filename == '':
228
+ return jsonify({'error': 'No selected file'}), 400
229
+
230
+ if not allowed_file(file.filename):
231
+ return jsonify({'error': 'Invalid file type. Allowed types: PNG, JPG, JPEG'}), 400
232
+
233
+ filename = secure_filename(file.filename)
234
+ filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
235
+ file.save(filepath)
236
+
237
+
238
+ image_url = url_for('static', filename=f'uploads/{filename}')
239
+ logger.info(f"File uploaded successfully: {filepath}")
240
+
241
+ return jsonify({
242
+ 'message': 'File uploaded successfully',
243
+ 'image_url': image_url,
244
+ 'filename': filename,
245
+ })
246
+
247
+ except Exception as e:
248
+ logger.error(f"Upload error: {str(e)}")
249
+ return jsonify({'error': 'Server error during upload'}), 500
250
+
251
+ @app.route('/generate_mask', methods=['POST'])
252
+ def generate_mask():
253
+ """
254
+ Generate a mask for a given image using the YOLO model
255
+ @param data: a JSON object containing the following keys:
256
+ - filename: the name of the image file
257
+ - normalized_void_points: a list of normalized 2D points (x, y) representing the voids
258
+ - normalized_component_boxes: a list of normalized 2D bounding boxes (x, y, w, h) representing the components
259
+ @return: a JSON object containing the following keys:
260
+ - status: a string indicating the status of the request
261
+ - train_image_url: the URL of the saved train image
262
+ - result_path: the URL of the saved result image
263
+ """
264
+ try:
265
+ data = request.json
266
+ normalized_void_points = data.get('void_points', [])
267
+ normalized_component_boxes = data.get('component_boxes', [])
268
+ filename = data.get('filename', '')
269
+
270
+ if not filename:
271
+ return jsonify({'error': 'No filename provided'}), 400
272
+
273
+ image_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
274
+ if not os.path.exists(image_path):
275
+ return jsonify({'error': 'Image file not found'}), 404
276
+
277
+ # Read image
278
+ image = cv2.imread(image_path)
279
+ if image is None:
280
+ return jsonify({'error': 'Failed to load image'}), 500
281
+
282
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
283
+ image_height, image_width = image.shape[:2]
284
+
285
+ # Denormalize coordinates back to pixel values
286
+ void_points = [
287
+ [int(point[0] * image_width), int(point[1] * image_height)]
288
+ for point in normalized_void_points
289
+ ]
290
+ logger.info(f"Void points: {void_points}")
291
+
292
+ component_boxes = [
293
+ [
294
+ int(box[0] * image_width),
295
+ int(box[1] * image_height),
296
+ int(box[2] * image_width),
297
+ int(box[3] * image_height)
298
+ ]
299
+ for box in normalized_component_boxes
300
+ ]
301
+ logger.info(f"Void points: {void_points}")
302
+
303
+ # Create a list to store individual void masks
304
+ void_masks = []
305
+
306
+ # Process void points one by one
307
+ for point in void_points:
308
+ # Convert point to correct format: [N, 2] array
309
+ point_coord = np.array([[point[0], point[1]]])
310
+ point_label = np.array([1]) # Single label
311
+
312
+ masks, scores, _ = predictor.predict(
313
+ point_coords=point_coord,
314
+ point_labels=point_label,
315
+ multimask_output=True # Get multiple masks
316
+ )
317
+
318
+ if len(masks) > 0: # Check if any masks were generated
319
+ # Get the mask with highest score
320
+ best_mask_idx = np.argmax(scores)
321
+ void_masks.append(masks[best_mask_idx])
322
+ logger.info(f"Processed void point {point} with score {scores[best_mask_idx]}")
323
+
324
+ # Process component boxes
325
+ component_masks = []
326
+ if component_boxes:
327
+ for box in component_boxes:
328
+ # Convert box to correct format: [2, 2] array
329
+ box_np = np.array([[box[0], box[1]], [box[2], box[3]]])
330
+ masks, scores, _ = predictor.predict(
331
+ box=box_np,
332
+ multimask_output=True
333
+ )
334
+ if len(masks) > 0:
335
+ best_mask_idx = np.argmax(scores)
336
+ component_masks.append(masks[best_mask_idx])
337
+ logger.info(f"Processed component box {box}")
338
+
339
+ # Create visualization with different colors for each void
340
+ combined_image = image.copy()
341
+
342
+ # Font settings for labels
343
+ font = cv2.FONT_HERSHEY_SIMPLEX
344
+ font_scale = 0.6
345
+ font_color = (0,0,0) # White text color
346
+ font_thickness = 1
347
+ background_color = (255, 255, 255) # White background for text
348
+
349
+ # Helper function to get bounding box coordinates
350
+ def get_bounding_box(mask):
351
+ coords = np.column_stack(np.where(mask))
352
+ x_min, y_min = coords.min(axis=0)
353
+ x_max, y_max = coords.max(axis=0)
354
+ return (x_min, y_min, x_max, y_max)
355
+
356
+ # Helper function to add text with background
357
+ def put_text_with_background(img, text, pos):
358
+ # Calculate text size
359
+ (text_w, text_h), _ = cv2.getTextSize(text, font, font_scale, font_thickness)
360
+ # Define the rectangle coordinates for background
361
+ background_tl = (pos[0], pos[1] - text_h - 2)
362
+ background_br = (pos[0] + text_w, pos[1] + 2)
363
+ # Draw white rectangle as background
364
+ cv2.rectangle(img, background_tl, background_br, background_color, -1)
365
+ # Put the text over the background rectangle
366
+ cv2.putText(img, text, pos, font, font_scale, font_color, font_thickness, cv2.LINE_AA)
367
+
368
+ def get_safe_label_position(x_min, y_min, x_max, y_max, text_w, text_h, img_width, img_height):
369
+ # Default to top-right of bounding box
370
+ x_pos = min(y_max, img_width - text_w - 10) # Keep 10px margin from the right
371
+ y_pos = max(x_min + text_h + 5, text_h + 5) # Keep 5px margin from the top
372
+ return x_pos, y_pos
373
+
374
+
375
+ # Apply void masks with different colors
376
+ for mask in void_masks:
377
+ mask = mask.astype(bool)
378
+ combined_image[mask, 0] = np.clip(0.5 * image[mask, 0] + 0.5 * 255, 0, 255) # Red channel with transparency
379
+ combined_image[mask, 1] = np.clip(0.5 * image[mask, 1], 0, 255) # Green channel reduced
380
+ combined_image[mask, 2] = np.clip(0.5 * image[mask, 2], 0, 255)
381
+ logger.info("Mask Drawn")
382
+
383
+ # Apply component masks in green
384
+ for mask in component_masks:
385
+ mask = mask.astype(bool)
386
+ # Only apply green where there is no red overlay
387
+ non_red_area = mask & ~np.any([void_mask for void_mask in void_masks], axis=0)
388
+ combined_image[non_red_area, 0] = np.clip(0.5 * image[non_red_area, 0], 0, 255) # Reduced red channel
389
+ combined_image[non_red_area, 1] = np.clip(0.5 * image[non_red_area, 1] + 0.5 * 255, 0, 255) # Green channel
390
+ combined_image[non_red_area, 2] = np.clip(0.5 * image[non_red_area, 2], 0, 255)
391
+ logger.info("Mask Drawn")
392
+
393
+
394
+ # Add labels on top of masks
395
+ for i,mask in enumerate(void_masks):
396
+ x_min, y_min, x_max, y_max = get_bounding_box(mask)
397
+ (text_w, text_h), _ = cv2.getTextSize("Void", font, font_scale, font_thickness)
398
+ label_position = get_safe_label_position(x_min, y_min, x_max, y_max, text_w, text_h, combined_image.shape[1], combined_image.shape[0])
399
+ put_text_with_background(combined_image, f"Void {i+1}", label_position)
400
+
401
+ for i,mask in enumerate(component_masks):
402
+ i=i+1
403
+ x_min, y_min, x_max, y_max = get_bounding_box(mask)
404
+ (text_w, text_h), _ = cv2.getTextSize("Component", font, font_scale, font_thickness)
405
+ label_position = get_safe_label_position(x_min, y_min, x_max, y_max, text_w, text_h, combined_image.shape[1], combined_image.shape[0])
406
+ put_text_with_background(combined_image, f"Component {i}", label_position)
407
+
408
+ # Prepare an empty list to store the output in the required format
409
+ mask_coordinates = []
410
+
411
+ for mask in void_masks:
412
+ # Get contours from the mask
413
+ contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
414
+ # Image dimensions
415
+ height, width = mask.shape
416
+
417
+ # For each contour, extract the normalized coordinates
418
+ for contour in contours:
419
+ contour_points = contour.reshape(-1, 2) # Flatten to (N, 2) where N is the number of points
420
+ normalized_points = contour_points / [width, height] # Normalize to (0, 1)
421
+
422
+ class_id = 1 # 1 for voids
423
+ row = [class_id] + normalized_points.flatten().tolist() # Flatten and add the class
424
+ mask_coordinates.append(row)
425
+
426
+ for mask in component_masks:
427
+ # Get contours from the mask
428
+ contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
429
+ # Filter to keep only the largest contour
430
+ contours = sorted(contours, key=cv2.contourArea, reverse=True)
431
+ largest_contour = [contours[0]] if contours else []
432
+ # Image dimensions
433
+ height, width = mask.shape
434
+
435
+ # For each contour, extract the normalized coordinates
436
+ for contour in largest_contour:
437
+ contour_points = contour.reshape(-1, 2) # Flatten to (N, 2) where N is the number of points
438
+ normalized_points = contour_points / [width, height] # Normalize to (0, 1)
439
+
440
+ class_id = 0 # for components
441
+ row = [class_id] + normalized_points.flatten().tolist() # Flatten and add the class
442
+ mask_coordinates.append(row)
443
+
444
+ mask_coordinates_filename = f'{filename}.txt' # Create a unique filename
445
+ mask_coordinates_path = os.path.join(app.config['YOLO_TRAIN_LABEL_FOLDER'], mask_coordinates_filename)
446
+
447
+
448
+ with open(mask_coordinates_path, "w") as file:
449
+ for row in mask_coordinates:
450
+ # Join elements of the row into a string with spaces in between and write to the file
451
+ file.write(" ".join(map(str, row)) + "\n")
452
+
453
+ # Save train image
454
+ train_image_filepath = os.path.join(app.config['YOLO_TRAIN_IMAGE_FOLDER'], filename)
455
+ shutil.copy(image_path, train_image_filepath)
456
+ train_image_url = url_for('static', filename=f'yolo/dataset_yolo/train/images/{filename}')
457
+
458
+ # Save result
459
+ result_filename = f'segmented_{filename}'
460
+ result_path = os.path.join(app.config['SAM_RESULT_FOLDER'], result_filename)
461
+ plt.imsave(result_path, combined_image)
462
+ logger.info("Mask generation completed successfully")
463
+
464
+ return jsonify({
465
+ 'status': 'success',
466
+ 'train_image_url':train_image_url,
467
+ 'result_path': url_for('static', filename=f'sam/sam_results/{result_filename}')
468
+ })
469
+
470
+ except Exception as e:
471
+ logger.error(f"Mask generation error: {str(e)}")
472
+ return jsonify({'error': str(e)}), 500
473
+
474
+ @app.route('/classify', methods=['POST'])
475
+ def classify():
476
+ """
477
+ Classify an image and return the classification result, area data, and the annotated image.
478
+
479
+ Request body should contain a JSON object with a single key 'filename' specifying the image file to be classified.
480
+
481
+ Returns a JSON object with the following keys:
482
+
483
+ - status: 'success' if the classification is successful, 'error' if there is an error.
484
+ - result_path: URL of the annotated image.
485
+ - area_data: a list of dictionaries containing the area and overlap statistics for each component.
486
+ - area_data_path: URL of the JSON file containing the area data.
487
+
488
+ If there is an error, returns a JSON object with a single key 'error' containing the error message.
489
+ """
490
+
491
+ try:
492
+ data = request.json
493
+ filename = data.get('filename', '')
494
+ if not filename:
495
+ return jsonify({'error': 'No filename provided'}), 400
496
+
497
+ image_path = os.path.join(app.config['UPLOAD_FOLDER'], filename)
498
+ if not os.path.exists(image_path):
499
+ return jsonify({'error': 'Image file not found'}), 404
500
+
501
+ # Read image
502
+ image = cv2.imread(image_path)
503
+ if image is None:
504
+ return jsonify({'error': 'Failed to load image'}), 500
505
+
506
+ results = model(image)
507
+ result = results[0]
508
+
509
+ component_masks = []
510
+ void_masks = []
511
+
512
+ # Extract masks and labels from results
513
+ for mask, label in zip(result.masks.data, result.boxes.cls):
514
+ mask_array = mask.cpu().numpy().astype(bool) # Convert to a binary mask (boolean array)
515
+ if label == 1: # Assuming label '1' represents void
516
+ void_masks.append(mask_array)
517
+ elif label == 0: # Assuming label '0' represents component
518
+ component_masks.append(mask_array)
519
+
520
+ # Calculate area and overlap statistics
521
+ area_data = []
522
+ for i, component_mask in enumerate(component_masks):
523
+ component_area = np.sum(component_mask).item() # Total component area in pixels
524
+ void_area_within_component = 0
525
+ max_void_area_percentage = 0
526
+
527
+ # Calculate overlap of each void mask with the component mask
528
+ for void_mask in void_masks:
529
+ overlap_area = np.sum(void_mask & component_mask).item() # Overlapping area
530
+ void_area_within_component += overlap_area
531
+ void_area_percentage = (overlap_area / component_area) * 100 if component_area > 0 else 0
532
+ max_void_area_percentage = max(max_void_area_percentage, void_area_percentage)
533
+
534
+ # Append data for this component
535
+ area_data.append({
536
+ "Image": filename,
537
+ 'Component': f'Component {i+1}',
538
+ 'Area': component_area,
539
+ 'Void Area (pixels)': void_area_within_component,
540
+ 'Void Area %': void_area_within_component / component_area * 100 if component_area > 0 else 0,
541
+ 'Max Void Area %': max_void_area_percentage
542
+ })
543
+
544
+ area_data_filename = f'area_data_{filename.split("/")[-1]}.json' # Create a unique filename
545
+ area_data_path = os.path.join(app.config['AREA_DATA_FOLDER'], area_data_filename)
546
+
547
+ with open(area_data_path, 'w') as json_file:
548
+ json.dump(area_data, json_file, indent=4)
549
+
550
+ annotated_image = result.plot()
551
+
552
+ output_filename = f'output_{filename}'
553
+ output_image_path = os.path.join(app.config['YOLO_RESULT_FOLDER'], output_filename)
554
+ plt.imsave(output_image_path, annotated_image)
555
+ logger.info("Classification completed successfully")
556
+
557
+ return jsonify({
558
+ 'status': 'success',
559
+ 'result_path': url_for('static', filename=f'yolo/yolo_results/{output_filename}'),
560
+ 'area_data': area_data,
561
+ 'area_data_path': url_for('static', filename=f'yolo/area_data/{area_data_filename}')
562
+ })
563
+ except Exception as e:
564
+ logger.error(f"Classification error: {str(e)}")
565
+ return jsonify({'error': str(e)}), 500
566
+
567
+ retraining_status = {
568
+ 'status': 'idle',
569
+ 'progress': None,
570
+ 'message': None
571
+ }
572
+
573
+ @app.route('/start_retraining', methods=['GET', 'POST'])
574
+ def start_retraining():
575
+ """
576
+ Start the model retraining process.
577
+
578
+ If the request is a POST, start the model retraining process in a separate thread.
579
+ If the request is a GET, render the retraining page.
580
+
581
+ Returns:
582
+ A JSON response with the status of the retraining process, or a rendered HTML page.
583
+ """
584
+ if request.method == 'POST':
585
+ # Reset status
586
+ global retraining_status
587
+ retraining_status['status'] = 'in_progress'
588
+ retraining_status['progress'] = 'Initializing'
589
+
590
+ # Start retraining in a separate thread
591
+ threading.Thread(target=retrain_model_fn).start()
592
+ return jsonify({'status': 'started'})
593
+ else:
594
+ # GET request - render the retraining page
595
+ return render_template('retrain.html')
596
+
597
+ # Event handler for client connection
598
+ @socketio.on('connect')
599
+ def handle_connect():
600
+ print('Client connected')
601
+
602
+
603
+ if __name__ == '__main__':
604
+ app.run(port=5001, debug=True)