import torch import torchvision import gradio as gr import numpy as np import pandas as pd from PIL import Image import torch.nn as nn from pathlib import Path import cv2 from torchvision import transforms from efficientnet_pytorch import EfficientNet import logging import warnings from sklearn.preprocessing import StandardScaler from typing import Optional, Dict, Any, Tuple import json import os from datetime import datetime import albumentations as A from transformers import MarianMTModel, MarianTokenizer import matplotlib.pyplot as plt import seaborn as sns import smtplib from email.mime.text import MIMEText from email.mime.multipart import MIMEMultipart warnings.filterwarnings('ignore') # Set up logging with more detailed configuration logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler('skin_diagnostic.log'), logging.StreamHandler() ] ) logger = logging.getLogger(__name__) class ImageValidator: """Class for image validation and quality checking""" @staticmethod def validate_image(image: np.ndarray) -> Tuple[bool, str]: """ Validate image quality and characteristics Returns: (is_valid, message) """ try: # Check image dimensions if image.shape[0] < 224 or image.shape[1] < 224: return False, "Image resolution too low. Minimum 224x224 required." # Check if image is too dark or too bright brightness = np.mean(image) if brightness < 30: return False, "Image too dark. Please capture in better lighting." if brightness > 240: return False, "Image too bright. Please reduce exposure." # Check for blur laplacian_var = cv2.Laplacian(cv2.cvtColor(image, cv2.COLOR_RGB2GRAY), cv2.CV_64F).var() if laplacian_var < 100: return False, "Image is too blurry. Please provide a clearer image." # Check for color consistency color_std = np.std(image, axis=(0,1)) if np.mean(color_std) < 20: return False, "Image lacks color variation. Please ensure proper lighting." return True, "Image validation successful" except Exception as e: logger.error(f"Image validation error: {str(e)}") return False, "Error during image validation" class AdvancedImageAnalysis: """Class for sophisticated image analysis techniques""" def __init__(self): self.scaler = StandardScaler() def analyze_lesion(self, image: np.ndarray) -> Dict[str, float]: """ Perform advanced analysis of skin lesion characteristics """ try: # Convert to different color spaces hsv = cv2.cvtColor(image, cv2.COLOR_RGB2HSV) lab = cv2.cvtColor(image, cv2.COLOR_RGB2LAB) # Extract features features = { 'asymmetry': self._calculate_asymmetry(image), 'border_irregularity': self._analyze_border(image), 'color_variation': self._analyze_color(hsv), 'diameter': self._estimate_diameter(image), 'texture': self._analyze_texture(lab), 'vascularity': self._analyze_vascularity(image), } return features except Exception as e: logger.error(f"Error in lesion analysis: {str(e)}") return {} def _calculate_asymmetry(self, image: np.ndarray) -> float: """Calculate asymmetry score of the lesion""" gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) _, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) # Find contours contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if not contours: return 0.0 # Get largest contour largest_contour = max(contours, key=cv2.contourArea) # Calculate moments moments = cv2.moments(largest_contour) if moments['m00'] == 0: return 0.0 # Calculate center of mass cx = moments['m10'] / moments['m00'] cy = moments['m01'] / moments['m00'] return float(cv2.matchShapes(largest_contour, cv2.flip(largest_contour, 1), 1, 0.0)) def _analyze_border(self, image: np.ndarray) -> float: """Analyze border irregularity""" gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) _, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if not contours: return 0.0 largest_contour = max(contours, key=cv2.contourArea) perimeter = cv2.arcLength(largest_contour, True) area = cv2.contourArea(largest_contour) if area == 0: return 0.0 circularity = 4 * np.pi * area / (perimeter * perimeter) return 1 - circularity def _analyze_color(self, hsv: np.ndarray) -> float: """Analyze color variation in the lesion""" return float(np.std(hsv[:,:,0])) def _estimate_diameter(self, image: np.ndarray) -> float: """Estimate lesion diameter""" gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) _, thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if not contours: return 0.0 largest_contour = max(contours, key=cv2.contourArea) _, _, w, h = cv2.boundingRect(largest_contour) return max(w, h) def _analyze_texture(self, lab: np.ndarray) -> float: """Analyze texture patterns""" gray = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR) gray = cv2.cvtColor(gray, cv2.COLOR_BGR2GRAY) # Calculate GLCM features glcm = cv2.calcHist([gray], [0], None, [16], [0,256]) glcm = glcm.flatten() / glcm.sum() # Calculate entropy entropy = -np.sum(glcm * np.log2(glcm + 1e-7)) return float(entropy) def _analyze_vascularity(self, image: np.ndarray) -> float: """Analyze vascular patterns""" # Extract red channel red_channel = image[:,:,0] return float(np.percentile(red_channel, 95) - np.percentile(red_channel, 5)) class SkinDiagnosticSystem: def __init__(self, model_path: Optional[str] = None): # Define classes and risk levels self.classes = [ 'Melanocytic nevi', 'Melanoma', 'Benign keratosis-like lesions', 'Basal cell carcinoma', 'Actinic keratoses', 'Vascular lesions', 'Dermatofibroma' ] self.risk_levels = { 'Melanoma': 'High', 'Basal cell carcinoma': 'High', 'Actinic keratoses': 'Moderate', 'Vascular lesions': 'Low to Moderate', 'Benign keratosis-like lesions': 'Low', 'Melanocytic nevi': 'Low', 'Dermatofibroma': 'Low' } # Initialize components self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.image_validator = ImageValidator() self.image_analyzer = AdvancedImageAnalysis() # Load model self.model = self._load_model(model_path) self.transform = self._get_transforms() # Load medical context self.medical_context = self._load_medical_context() def _load_model(self, model_path: Optional[str]) -> nn.Module: """Load model with checkpointing support""" try: model = EfficientNet.from_pretrained('efficientnet-b4') num_ftrs = model._fc.in_features model._fc = nn.Sequential( nn.Linear(num_ftrs, 512), nn.ReLU(), nn.Dropout(0.2), nn.Linear(512, len(self.classes)) ) if model_path and os.path.exists(model_path): logger.info(f"Loading model checkpoint from {model_path}") checkpoint = torch.load(model_path, map_location=self.device) model.load_state_dict(checkpoint['model_state_dict']) logger.info(f"Model checkpoint loaded. Epoch: {checkpoint['epoch']}") model = model.to(self.device) model.eval() return model except Exception as e: logger.error(f"Error loading model: {str(e)}") raise def _get_transforms(self) -> transforms.Compose: """Get image transformations""" return transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) def _load_medical_context(self) -> Dict[str, Any]: """Load medical context and warnings""" return { 'Melanoma': { 'description': 'A serious form of skin cancer that begins in melanocytes.', 'warning': 'URGENT: Immediate medical attention required. This is a potentially serious condition.', 'risk_factors': [ 'UV exposure', 'Fair skin', 'Family history', 'Multiple moles' ], 'follow_up': 'Immediate dermatologist consultation required' }, 'Basal cell carcinoma': { 'description': 'The most common type of skin cancer.', 'warning': 'Medical attention required. While typically slow-growing, treatment is necessary.', 'risk_factors': [ 'Sun exposure', 'Fair skin', 'Age over 50', 'Prior radiation therapy' ], 'follow_up': 'Schedule dermatologist appointment within 1-2 weeks' }, # Add entries for other conditions... } def save_checkpoint(self, epoch: int, optimizer: torch.optim.Optimizer, loss: float) -> None: """Save model checkpoint""" checkpoint_dir = Path('checkpoints') checkpoint_dir.mkdir(exist_ok=True) checkpoint_path = checkpoint_dir / f'model_checkpoint_epoch_{epoch}.pth' torch.save({ 'epoch': epoch, 'model_state_dict': self.model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss, }, checkpoint_path) logger.info(f"Checkpoint saved: {checkpoint_path}") def analyze_image(self, image: np.ndarray) -> Dict[str, Any]: """Main analysis function with validation and advanced analysis""" try: # Validate image is_valid, validation_message = self.image_validator.validate_image(image) if not is_valid: return {'error': validation_message} # Convert to PIL Image pil_image = Image.fromarray(image) # Prepare image for model img_tensor = self.transform(pil_image).unsqueeze(0).to(self.device) # Get model predictions with torch.no_grad(): outputs = self.model(img_tensor) probs = torch.nn.functional.softmax(outputs, dim=1) # Get predicted class and probability pred_prob, pred_idx = torch.max(probs, 1) condition = self.classes[pred_idx] confidence = pred_prob.item() * 100 # Perform advanced image analysis analysis_results = self.image_analyzer.analyze_lesion(image) # Get medical context medical_info = self.medical_context.get(condition, {}) # Prepare response response = { 'condition': condition, 'confidence': confidence, 'risk_level': self.risk_levels.get(condition, 'Unknown'), 'analysis': analysis_results, 'medical_context': medical_info, 'warning': medical_info.get('warning', ''), 'timestamp': datetime.now().isoformat() } # Log analysis results logger.info(f"Analysis completed for condition: {condition} (confidence: {confidence:.2f}%)") return response except Exception as e: logger.error(f"Error in image analysis: {str(e)}") return {'error': 'Analysis failed. Please try again.'} def create_gradio_interface(): system = SkinDiagnosticSystem() # Load translation models translation_models = { 'hi': ('Helsinki-NLP/opus-mt-en-hi', MarianTokenizer, MarianMTModel), 'ta': ('Helsinki-NLP/opus-mt-en-ta', MarianTokenizer, MarianMTModel), 'te': ('Helsinki-NLP/opus-mt-en-te', MarianTokenizer, MarianMTModel), 'bn': ('Helsinki-NLP/opus-mt-en-bn', MarianTokenizer, MarianMTModel), 'mr': ('Helsinki-NLP/opus-mt-en-mr', MarianTokenizer, MarianMTModel), 'pa': ('Helsinki-NLP/opus-mt-en-pa', MarianTokenizer, MarianMTModel), 'gu': ('Helsinki-NLP/opus-mt-en-gu', MarianTokenizer, MarianMTModel), 'kn': ('Helsinki-NLP/opus-mt-en-kn', MarianTokenizer, MarianMTModel), 'ml': ('Helsinki-NLP/opus-mt-en-ml', MarianTokenizer, MarianMTModel), } def process_image(image, language, email=None): result = system.analyze_image(image) if 'error' in result: return f"Error: {result['error']}" # Format detailed output output = "ANALYSIS RESULTS\n" + "="*50 + "\n\n" # Condition and Risk Level output += f"Detected Condition: {result['condition']}\n" output += f"Confidence: {result['confidence']:.2f}%\n" output += f"Risk Level: {result['risk_level']}\n\n" # Warning (if any) if result['warning']: output += f"⚠️ WARNING ⚠️\n{result['warning']}\n\n" # Detailed Analysis output += "Detailed Analysis:\n" + "-"*20 + "\n" for metric, value in result['analysis'].items(): output += f"{metric}: {value:.2f}\n" # Medical Context if 'medical_context' in result and result['medical_context']: output += "\nMedical Context:\n" + "-"*20 + "\n" context = result['medical_context'] output += f"Description: {context.get('description', 'N/A')}\n" if 'risk_factors' in context: output += "\nRisk Factors:\n" for factor in context['risk_factors']: output += f"- {factor}\n" if 'follow_up' in context: output += f"\nRecommended Follow-up:\n{context['follow_up']}\n" # Timestamp output += f"\nAnalysis Timestamp: {result['timestamp']}\n" # Disclaimer output += "\n" + "="*50 + "\n" output += "DISCLAIMER: This analysis is for informational purposes only and should not replace professional medical advice. Please consult a qualified healthcare provider for proper diagnosis and treatment." # Translate output to the selected language if language != 'en': model_name, tokenizer_class, model_class = translation_models[language] tokenizer = tokenizer_class.from_pretrained(model_name) model = model_class.from_pretrained(model_name) inputs = tokenizer(output, return_tensors="pt", padding=True, truncation=True) translated = model.generate(**inputs) translated_output = tokenizer.decode(translated[0], skip_special_tokens=True) else: translated_output = output # Send email if provided if email: send_email(email, translated_output) return translated_output def send_email(to_email, message): from_email = "your_email@example.com" password = "your_password" msg = MIMEMultipart() msg['From'] = from_email msg['To'] = to_email msg['Subject'] = "Skin Lesion Analysis Results" msg.attach(MIMEText(message, 'plain')) server = smtplib.SMTP('smtp.example.com', 587) server.starttls() server.login(from_email, password) server.sendmail(from_email, to_email, msg.as_string()) server.quit() # Create enhanced Gradio interface with additional features iface = gr.Interface( fn=process_image, inputs=[ gr.Image(type="numpy", label="Upload Skin Image"), gr.Dropdown(choices=["en", "hi", "ta", "te", "bn", "mr", "pa", "gu", "kn", "ml"], label="Select Language"), gr.Textbox(label="Email (optional)", placeholder="Enter your email to receive results") ], outputs=[ gr.Textbox(label="Analysis Results", lines=20) ], title="Advanced Skin Lesion Analysis System", description=""" This system analyzes skin lesions using advanced computer vision and deep learning techniques. Key Features: - Lesion classification based on the HAM10000 dataset - Advanced image quality validation - Detailed analysis of lesion characteristics - Medical context and risk assessment - Option to receive results via email Important: This tool is for educational purposes only and should not replace professional medical diagnosis. """, examples=[ ["example_melanoma.jpg", "en", ""], ["example_nevus.jpg", "hi", ""], ["example_bcc.jpg", "ta", ""] ], analytics_enabled=False, ) return iface iface = create_gradio_interface() iface.launch( server_name="0.0.0.0", server_port=7860, share=True, )