Anupam251272 commited on
Commit
ea2329d
·
verified ·
1 Parent(s): 85e37ff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +482 -0
app.py CHANGED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ import gradio as gr
4
+ import numpy as np
5
+ import pandas as pd
6
+ from PIL import Image
7
+ import torch.nn as nn
8
+ from pathlib import Path
9
+ import cv2
10
+ from torchvision import transforms
11
+ from efficientnet_pytorch import EfficientNet
12
+ import logging
13
+ import warnings
14
+ from sklearn.preprocessing import StandardScaler
15
+ from typing import Optional, Dict, Any, Tuple
16
+ import json
17
+ import os
18
+ from datetime import datetime
19
+ import albumentations as A
20
+ from transformers import MarianMTModel, MarianTokenizer
21
+ import matplotlib.pyplot as plt
22
+ import seaborn as sns
23
+ import smtplib
24
+ from email.mime.text import MIMEText
25
+ from email.mime.multipart import MIMEMultipart
26
+ warnings.filterwarnings('ignore')
27
+
28
+ # Set up logging with more detailed configuration
29
+ logging.basicConfig(
30
+ level=logging.INFO,
31
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
32
+ handlers=[
33
+ logging.FileHandler('skin_diagnostic.log'),
34
+ logging.StreamHandler()
35
+ ]
36
+ )
37
+ logger = logging.getLogger(__name__)
38
+
39
+ class ImageValidator:
40
+ """Class for image validation and quality checking"""
41
+
42
+ @staticmethod
43
+ def validate_image(image: np.ndarray) -> Tuple[bool, str]:
44
+ """
45
+ Validate image quality and characteristics
46
+ Returns: (is_valid, message)
47
+ """
48
+ try:
49
+ # Check image dimensions
50
+ if image.shape[0] < 224 or image.shape[1] < 224:
51
+ return False, "Image resolution too low. Minimum 224x224 required."
52
+
53
+ # Check if image is too dark or too bright
54
+ brightness = np.mean(image)
55
+ if brightness < 30:
56
+ return False, "Image too dark. Please capture in better lighting."
57
+ if brightness > 240:
58
+ return False, "Image too bright. Please reduce exposure."
59
+
60
+ # Check for blur
61
+ laplacian_var = cv2.Laplacian(cv2.cvtColor(image, cv2.COLOR_RGB2GRAY), cv2.CV_64F).var()
62
+ if laplacian_var < 100:
63
+ return False, "Image is too blurry. Please provide a clearer image."
64
+
65
+ # Check for color consistency
66
+ color_std = np.std(image, axis=(0,1))
67
+ if np.mean(color_std) < 20:
68
+ return False, "Image lacks color variation. Please ensure proper lighting."
69
+
70
+ return True, "Image validation successful"
71
+
72
+ except Exception as e:
73
+ logger.error(f"Image validation error: {str(e)}")
74
+ return False, "Error during image validation"
75
+
76
+ class AdvancedImageAnalysis:
77
+ """Class for sophisticated image analysis techniques"""
78
+
79
+ def __init__(self):
80
+ self.scaler = StandardScaler()
81
+
82
+ def analyze_lesion(self, image: np.ndarray) -> Dict[str, float]:
83
+ """
84
+ Perform advanced analysis of skin lesion characteristics
85
+ """
86
+ try:
87
+ # Convert to different color spaces
88
+ hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV)
89
+ lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB)
90
+
91
+ # Extract features
92
+ features = {
93
+ 'asymmetry': self._calculate_asymmetry(image),
94
+ 'border_irregularity': self._analyze_border(image),
95
+ 'color_variation': self._analyze_color(hsv),
96
+ 'diameter': self._estimate_diameter(image),
97
+ 'texture': self._analyze_texture(lab),
98
+ 'vascularity': self._analyze_vascularity(image),
99
+ }
100
+
101
+ return features
102
+
103
+ except Exception as e:
104
+ logger.error(f"Error in lesion analysis: {str(e)}")
105
+ return {}
106
+
107
+ def _calculate_asymmetry(self, image: np.ndarray) -> float:
108
+ """Calculate asymmetry score of the lesion"""
109
+ gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
110
+ _, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
111
+
112
+ # Find contours
113
+ contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
114
+ if not contours:
115
+ return 0.0
116
+
117
+ # Get largest contour
118
+ largest_contour = max(contours, key=cv2.contourArea)
119
+
120
+ # Calculate moments
121
+ moments = cv2.moments(largest_contour)
122
+ if moments['m00'] == 0:
123
+ return 0.0
124
+
125
+ # Calculate center of mass
126
+ cx = moments['m10'] / moments['m00']
127
+ cy = moments['m01'] / moments['m00']
128
+
129
+ return float(cv2.matchShapes(largest_contour, cv2.flip(largest_contour, 1), 1, 0.0))
130
+
131
+ def _analyze_border(self, image: np.ndarray) -> float:
132
+ """Analyze border irregularity"""
133
+ gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
134
+ _, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
135
+
136
+ contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
137
+ if not contours:
138
+ return 0.0
139
+
140
+ largest_contour = max(contours, key=cv2.contourArea)
141
+ perimeter = cv2.arcLength(largest_contour, True)
142
+ area = cv2.contourArea(largest_contour)
143
+
144
+ if area == 0:
145
+ return 0.0
146
+
147
+ circularity = 4 * np.pi * area / (perimeter * perimeter)
148
+ return 1 - circularity
149
+
150
+ def _analyze_color(self, hsv: np.ndarray) -> float:
151
+ """Analyze color variation in the lesion"""
152
+ return float(np.std(hsv[:,:,0]))
153
+
154
+ def _estimate_diameter(self, image: np.ndarray) -> float:
155
+ """Estimate lesion diameter"""
156
+ gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
157
+ _, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
158
+
159
+ contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
160
+ if not contours:
161
+ return 0.0
162
+
163
+ largest_contour = max(contours, key=cv2.contourArea)
164
+ _, _, w, h = cv2.boundingRect(largest_contour)
165
+ return max(w, h)
166
+
167
+ def _analyze_texture(self, lab: np.ndarray) -> float:
168
+ """Analyze texture patterns"""
169
+ gray = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
170
+ gray = cv2.cvtColor(gray, cv2.COLOR_BGR2GRAY)
171
+
172
+ # Calculate GLCM features
173
+ glcm = cv2.calcHist([gray], [0], None, [16], [0,256])
174
+ glcm = glcm.flatten() / glcm.sum()
175
+
176
+ # Calculate entropy
177
+ entropy = -np.sum(glcm * np.log2(glcm + 1e-7))
178
+ return float(entropy)
179
+
180
+ def _analyze_vascularity(self, image: np.ndarray) -> float:
181
+ """Analyze vascular patterns"""
182
+ # Extract red channel
183
+ red_channel = image[:,:,0]
184
+ return float(np.percentile(red_channel, 95) - np.percentile(red_channel, 5))
185
+
186
+ class SkinDiagnosticSystem:
187
+ def __init__(self, model_path: Optional[str] = None):
188
+ # Define classes and risk levels
189
+ self.classes = [
190
+ 'Melanocytic nevi',
191
+ 'Melanoma',
192
+ 'Benign keratosis-like lesions',
193
+ 'Basal cell carcinoma',
194
+ 'Actinic keratoses',
195
+ 'Vascular lesions',
196
+ 'Dermatofibroma'
197
+ ]
198
+
199
+ self.risk_levels = {
200
+ 'Melanoma': 'High',
201
+ 'Basal cell carcinoma': 'High',
202
+ 'Actinic keratoses': 'Moderate',
203
+ 'Vascular lesions': 'Low to Moderate',
204
+ 'Benign keratosis-like lesions': 'Low',
205
+ 'Melanocytic nevi': 'Low',
206
+ 'Dermatofibroma': 'Low'
207
+ }
208
+
209
+ # Initialize components
210
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
211
+ self.image_validator = ImageValidator()
212
+ self.image_analyzer = AdvancedImageAnalysis()
213
+
214
+ # Load model
215
+ self.model = self._load_model(model_path)
216
+ self.transform = self._get_transforms()
217
+
218
+ # Load medical context
219
+ self.medical_context = self._load_medical_context()
220
+
221
+ def _load_model(self, model_path: Optional[str]) -> nn.Module:
222
+ """Load model with checkpointing support"""
223
+ try:
224
+ model = EfficientNet.from_pretrained('efficientnet-b4')
225
+ num_ftrs = model._fc.in_features
226
+ model._fc = nn.Sequential(
227
+ nn.Linear(num_ftrs, 512),
228
+ nn.ReLU(),
229
+ nn.Dropout(0.2),
230
+ nn.Linear(512, len(self.classes))
231
+ )
232
+
233
+ if model_path and os.path.exists(model_path):
234
+ logger.info(f"Loading model checkpoint from {model_path}")
235
+ checkpoint = torch.load(model_path, map_location=self.device)
236
+ model.load_state_dict(checkpoint['model_state_dict'])
237
+ logger.info(f"Model checkpoint loaded. Epoch: {checkpoint['epoch']}")
238
+
239
+ model = model.to(self.device)
240
+ model.eval()
241
+ return model
242
+
243
+ except Exception as e:
244
+ logger.error(f"Error loading model: {str(e)}")
245
+ raise
246
+
247
+ def _get_transforms(self) -> transforms.Compose:
248
+ """Get image transformations"""
249
+ return transforms.Compose([
250
+ transforms.Resize((224, 224)),
251
+ transforms.ToTensor(),
252
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
253
+ std=[0.229, 0.224, 0.225])
254
+ ])
255
+
256
+ def _load_medical_context(self) -> Dict[str, Any]:
257
+ """Load medical context and warnings"""
258
+ return {
259
+ 'Melanoma': {
260
+ 'description': 'A serious form of skin cancer that begins in melanocytes.',
261
+ 'warning': 'URGENT: Immediate medical attention required. This is a potentially serious condition.',
262
+ 'risk_factors': [
263
+ 'UV exposure',
264
+ 'Fair skin',
265
+ 'Family history',
266
+ 'Multiple moles'
267
+ ],
268
+ 'follow_up': 'Immediate dermatologist consultation required'
269
+ },
270
+ 'Basal cell carcinoma': {
271
+ 'description': 'The most common type of skin cancer.',
272
+ 'warning': 'Medical attention required. While typically slow-growing, treatment is necessary.',
273
+ 'risk_factors': [
274
+ 'Sun exposure',
275
+ 'Fair skin',
276
+ 'Age over 50',
277
+ 'Prior radiation therapy'
278
+ ],
279
+ 'follow_up': 'Schedule dermatologist appointment within 1-2 weeks'
280
+ },
281
+ # Add entries for other conditions...
282
+ }
283
+
284
+ def save_checkpoint(self, epoch: int, optimizer: torch.optim.Optimizer, loss: float) -> None:
285
+ """Save model checkpoint"""
286
+ checkpoint_dir = Path('checkpoints')
287
+ checkpoint_dir.mkdir(exist_ok=True)
288
+
289
+ checkpoint_path = checkpoint_dir / f'model_checkpoint_epoch_{epoch}.pth'
290
+ torch.save({
291
+ 'epoch': epoch,
292
+ 'model_state_dict': self.model.state_dict(),
293
+ 'optimizer_state_dict': optimizer.state_dict(),
294
+ 'loss': loss,
295
+ }, checkpoint_path)
296
+
297
+ logger.info(f"Checkpoint saved: {checkpoint_path}")
298
+
299
+ def analyze_image(self, image: np.ndarray) -> Dict[str, Any]:
300
+ """Main analysis function with validation and advanced analysis"""
301
+ try:
302
+ # Validate image
303
+ is_valid, validation_message = self.image_validator.validate_image(image)
304
+ if not is_valid:
305
+ return {'error': validation_message}
306
+
307
+ # Convert to PIL Image
308
+ pil_image = Image.fromarray(image)
309
+
310
+ # Prepare image for model
311
+ img_tensor = self.transform(pil_image).unsqueeze(0).to(self.device)
312
+
313
+ # Get model predictions
314
+ with torch.no_grad():
315
+ outputs = self.model(img_tensor)
316
+ probs = torch.nn.functional.softmax(outputs, dim=1)
317
+
318
+ # Get predicted class and probability
319
+ pred_prob, pred_idx = torch.max(probs, 1)
320
+ condition = self.classes[pred_idx]
321
+ confidence = pred_prob.item() * 100
322
+
323
+ # Perform advanced image analysis
324
+ analysis_results = self.image_analyzer.analyze_lesion(image)
325
+
326
+ # Get medical context
327
+ medical_info = self.medical_context.get(condition, {})
328
+
329
+ # Prepare response
330
+ response = {
331
+ 'condition': condition,
332
+ 'confidence': confidence,
333
+ 'risk_level': self.risk_levels.get(condition, 'Unknown'),
334
+ 'analysis': analysis_results,
335
+ 'medical_context': medical_info,
336
+ 'warning': medical_info.get('warning', ''),
337
+ 'timestamp': datetime.now().isoformat()
338
+ }
339
+
340
+ # Log analysis results
341
+ logger.info(f"Analysis completed for condition: {condition} (confidence: {confidence:.2f}%)")
342
+
343
+ return response
344
+
345
+ except Exception as e:
346
+ logger.error(f"Error in image analysis: {str(e)}")
347
+ return {'error': 'Analysis failed. Please try again.'}
348
+
349
+ def create_gradio_interface():
350
+ system = SkinDiagnosticSystem()
351
+
352
+ # Load translation models
353
+ translation_models = {
354
+ 'hi': ('Helsinki-NLP/opus-mt-en-hi', MarianTokenizer, MarianMTModel),
355
+ 'ta': ('Helsinki-NLP/opus-mt-en-ta', MarianTokenizer, MarianMTModel),
356
+ 'te': ('Helsinki-NLP/opus-mt-en-te', MarianTokenizer, MarianMTModel),
357
+ 'bn': ('Helsinki-NLP/opus-mt-en-bn', MarianTokenizer, MarianMTModel),
358
+ 'mr': ('Helsinki-NLP/opus-mt-en-mr', MarianTokenizer, MarianMTModel),
359
+ 'pa': ('Helsinki-NLP/opus-mt-en-pa', MarianTokenizer, MarianMTModel),
360
+ 'gu': ('Helsinki-NLP/opus-mt-en-gu', MarianTokenizer, MarianMTModel),
361
+ 'kn': ('Helsinki-NLP/opus-mt-en-kn', MarianTokenizer, MarianMTModel),
362
+ 'ml': ('Helsinki-NLP/opus-mt-en-ml', MarianTokenizer, MarianMTModel),
363
+ }
364
+
365
+ def process_image(image, language, email=None):
366
+ result = system.analyze_image(image)
367
+
368
+ if 'error' in result:
369
+ return f"Error: {result['error']}"
370
+
371
+ # Format detailed output
372
+ output = "ANALYSIS RESULTS\n" + "="*50 + "\n\n"
373
+
374
+ # Condition and Risk Level
375
+ output += f"Detected Condition: {result['condition']}\n"
376
+ output += f"Confidence: {result['confidence']:.2f}%\n"
377
+ output += f"Risk Level: {result['risk_level']}\n\n"
378
+
379
+ # Warning (if any)
380
+ if result['warning']:
381
+ output += f"⚠️ WARNING ⚠️\n{result['warning']}\n\n"
382
+
383
+ # Detailed Analysis
384
+ output += "Detailed Analysis:\n" + "-"*20 + "\n"
385
+ for metric, value in result['analysis'].items():
386
+ output += f"{metric}: {value:.2f}\n"
387
+
388
+ # Medical Context
389
+ if 'medical_context' in result and result['medical_context']:
390
+ output += "\nMedical Context:\n" + "-"*20 + "\n"
391
+ context = result['medical_context']
392
+ output += f"Description: {context.get('description', 'N/A')}\n"
393
+
394
+ if 'risk_factors' in context:
395
+ output += "\nRisk Factors:\n"
396
+ for factor in context['risk_factors']:
397
+ output += f"- {factor}\n"
398
+
399
+ if 'follow_up' in context:
400
+ output += f"\nRecommended Follow-up:\n{context['follow_up']}\n"
401
+
402
+ # Timestamp
403
+ output += f"\nAnalysis Timestamp: {result['timestamp']}\n"
404
+
405
+ # Disclaimer
406
+ output += "\n" + "="*50 + "\n"
407
+ output += "DISCLAIMER: This analysis is for informational purposes only and should not replace professional medical advice. Please consult a qualified healthcare provider for proper diagnosis and treatment."
408
+
409
+ # Translate output to the selected language
410
+ if language != 'en':
411
+ model_name, tokenizer_class, model_class = translation_models[language]
412
+ tokenizer = tokenizer_class.from_pretrained(model_name)
413
+ model = model_class.from_pretrained(model_name)
414
+ inputs = tokenizer(output, return_tensors="pt", padding=True, truncation=True)
415
+ translated = model.generate(**inputs)
416
+ translated_output = tokenizer.decode(translated[0], skip_special_tokens=True)
417
+ else:
418
+ translated_output = output
419
+
420
+ # Send email if provided
421
+ if email:
422
+ send_email(email, translated_output)
423
+
424
+ return translated_output
425
+
426
+ def send_email(to_email, message):
427
+ from_email = "your_email@example.com"
428
+ password = "your_password"
429
+
430
+ msg = MIMEMultipart()
431
+ msg['From'] = from_email
432
+ msg['To'] = to_email
433
+ msg['Subject'] = "Skin Lesion Analysis Results"
434
+
435
+ msg.attach(MIMEText(message, 'plain'))
436
+
437
+ server = smtplib.SMTP('smtp.example.com', 587)
438
+ server.starttls()
439
+ server.login(from_email, password)
440
+ server.sendmail(from_email, to_email, msg.as_string())
441
+ server.quit()
442
+
443
+ # Create enhanced Gradio interface with additional features
444
+ iface = gr.Interface(
445
+ fn=process_image,
446
+ inputs=[
447
+ gr.Image(type="numpy", label="Upload Skin Image"),
448
+ gr.Dropdown(choices=["en", "hi", "ta", "te", "bn", "mr", "pa", "gu", "kn", "ml"], label="Select Language"),
449
+ gr.Textbox(label="Email (optional)", placeholder="Enter your email to receive results")
450
+ ],
451
+ outputs=[
452
+ gr.Textbox(label="Analysis Results", lines=20)
453
+ ],
454
+ title="Advanced Skin Lesion Analysis System",
455
+ description="""
456
+ This system analyzes skin lesions using advanced computer vision and deep learning techniques.
457
+
458
+ Key Features:
459
+ - Lesion classification based on the HAM10000 dataset
460
+ - Advanced image quality validation
461
+ - Detailed analysis of lesion characteristics
462
+ - Medical context and risk assessment
463
+ - Option to receive results via email
464
+
465
+ Important: This tool is for educational purposes only and should not replace professional medical diagnosis.
466
+ """,
467
+ examples=[
468
+ ["example_melanoma.jpg", "en", ""],
469
+ ["example_nevus.jpg", "hi", ""],
470
+ ["example_bcc.jpg", "ta", ""]
471
+ ],
472
+ analytics_enabled=False,
473
+ )
474
+
475
+ return iface
476
+
477
+ iface = create_gradio_interface()
478
+ iface.launch(
479
+ server_name="0.0.0.0",
480
+ server_port=7860,
481
+ share=True,
482
+ )