Model Card for Acne Severity Classification
This model is designed to classify the severity of acne into four categories: mild, moderate, severe, and very severe. It uses a pre-trained EfficientNet-b0 architecture fine-tuned on a custom dataset.
Model Details
Model Description
The model classifies acne severity using an EfficientNet-b0 architecture, fine-tuned on a dataset of acne images. The dataset undergoes extensive preprocessing to enhance image quality, and the model is trained with a high accuracy on both validation and test sets.
- Developed by: Your Name
- Funded by [optional]: [More Information Needed]
- Shared by [optional]: [More Information Needed]
- Model type: Image Classification
- Language(s) (NLP): Not Applicable
- License: MIT License
- Finetuned from model [optional]: EfficientNet-b0
Model Sources [optional]
- Repository: [More Information Needed]
- Paper [optional]: [More Information Needed]
- Demo [optional]: [More Information Needed]
Uses
Direct Use
This model can be directly used to classify the severity of acne from images.
Downstream Use [optional]
[More Information Needed]
Out-of-Scope Use
This model should not be used for diagnosing other skin conditions or for any medical advice without further validation.
Bias, Risks, and Limitations
The model is trained on a specific dataset and may not generalize well to other datasets or real-world scenarios without further fine-tuning.
Recommendations
Users (both direct and downstream) should be made aware of the risks, biases, and limitations of the model. It is recommended to validate the model on additional datasets and consider further training for specific use cases.
How to Get Started with the Model
Use the code below to get started with the model.
import torch
from efficientnet_pytorch import EfficientNet
from PIL import Image
import torchvision.transforms as transforms
import cv2
import numpy as np
from PIL import ImageEnhance
def histogram_equalization(img):
img = np.array(img)
img_yuv = cv2.cvtColor(img, cv2.COLOR_RGB2YUV)
img_yuv[:, :, 0] = cv2.equalizeHist(img_yuv[:, :, 0])
img = cv2.cvtColor(img_yuv, cv2.COLOR_YUV2RGB)
return Image.fromarray(img)
def median_filter(img, kernel_size=3):
img = np.array(img)
img = cv2.medianBlur(img, kernel_size)
return Image.fromarray(img)
def enhance_image(img):
img = ImageEnhance.Color(img).enhance(1.2) # Adjust color balance
img = ImageEnhance.Contrast(img).enhance(1.2) # Adjust contrast
img = ImageEnhance.Sharpness(img).enhance(1.2) # Adjust sharpness
return img
# Load the model
model = EfficientNet.from_pretrained('efficientnet-b0')
num_classes = 4
model._fc = nn.Sequential(
nn.Dropout(p=0.5),
nn.Linear(model._fc.in_features, num_classes)
)
model.load_state_dict(torch.load('path_to_your_model.pth'))
model.eval()
# Define transforms
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.Lambda(histogram_equalization),
transforms.Lambda(median_filter),
transforms.Lambda(enhance_image),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# Load and preprocess the image
img = Image.open('path_to_your_image.jpg')
img = transform(img).unsqueeze(0)
# Predict
with torch.no_grad():
output = model(img)
_, predicted = torch.max(output, 1)
severity = predicted.item()
print(f'Predicted Acne Severity: {severity}')