|
--- |
|
library_name: transformers |
|
--- |
|
# Fine-tuned Vision Transformer for Alzheimer's Detection |
|
|
|
This repository hosts a Vision Transformer (ViT) model fine-tuned on the OASIS MRI dataset for the classification of brain MRI images based on the progression of Alzheimer's disease. The model categorizes images into four classes: demented, very mild demented, mild demented, and non-demented. |
|
|
|
## Model Description |
|
The Vision Transformer has been adapted to tackle the challenging task of medical image analysis by leveraging its powerful attention mechanisms that capture complex patterns in image data. It has been fine-tuned to classify MRI images into stages of Alzheimer's disease, demonstrating the model's applicability to medical diagnostics. |
|
|
|
## Dataset |
|
The OASIS MRI dataset consists of 80,000 brain MRI images from 461 patients, formatted in Nifti (.nii) and converted to JPEG for model training. The images represent various stages of Alzheimer's disease as follows: |
|
- Non-Demented |
|
- Very Mild Demented |
|
- Mild Demented |
|
- Demented |
|
|
|
This dataset conversion involved standardizing image formats for machine learning applications, ensuring that each image is suitable for deep learning models. |
|
|
|
## Preprocessing Techniques |
|
During preprocessing: |
|
- MRI scans were converted from Nifti format to JPEG to simplify handling and reduce storage requirements. |
|
- Each image was resized to 128x128 pixels, ensuring uniformity across the dataset. |
|
- Pixel values were normalized to a [0, 1] scale to facilitate model training. |
|
|
|
## How to Use This Model |
|
You can use this model directly with a pipeline for image classification: |
|
|
|
\`\`\`python |
|
|
|
|
|
import torch |
|
from transformers import ViTForImageClassification |
|
from PIL import Image |
|
import numpy as np |
|
from torchvision.transforms import Compose, Resize, ToTensor, Normalize |
|
|
|
|
|
id2label = { |
|
0: "Non-Demented", |
|
1: "Very Mild Demented", |
|
2: "Mild Demented", |
|
3: "Demented" |
|
} |
|
|
|
import torch |
|
from transformers import ViTForImageClassification |
|
from PIL import Image |
|
import numpy as np |
|
from torchvision.transforms import Compose, Resize, ToTensor, Normalize |
|
import matplotlib.pyplot as plt |
|
|
|
# Set the device |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
# Load the model |
|
model = ViTForImageClassification.from_pretrained('fawadkhan/ViT_FineTuned_on_ImagesOASIS') |
|
model.to(device) |
|
model.eval() |
|
|
|
# Define the image path |
|
image_path = 'your image path.jpg' |
|
image = Image.open(image_path).convert("RGB") |
|
|
|
# Define the transformations |
|
transform = Compose([ |
|
Resize((224, 224)), # or the original input size of your model |
|
ToTensor(), |
|
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Standard normalization for ImageNet |
|
]) |
|
|
|
# Preprocess the image |
|
input_tensor = transform(image).unsqueeze(0) # Create a mini-batch as expected by the model |
|
input_tensor = input_tensor.to(device) |
|
|
|
# Predict |
|
with torch.no_grad(): |
|
outputs = model(input_tensor) |
|
_, predicted = torch.max(outputs.logits, 1) |
|
|
|
# Retrieve the class name |
|
predicted_class = id2label[predicted[0].item()] |
|
print("Predicted class:", predicted_class) |
|
|
|
# Plot the image and the prediction |
|
plt.imshow(image) |
|
plt.title(f'Predicted class: {predicted_class}') |
|
plt.axis('off') # Turn off axis numbers and ticks |
|
plt.show() |
|
|
|
|
|
\`\`\` |
|
|
|
## Training Procedure |
|
The model was trained using the AdamW optimizer with a learning rate of 5e-5 for 10 epochs, balancing the need for accuracy with the risk of overfitting. |
|
|
|
## Evaluation Results |
|
Upon evaluation on a validation set, the model achieved an accuracy of 99%, showcasing its effectiveness in identifying different stages of Alzheimer's disease based on MRI scans. |