--- library_name: transformers --- # Fine-tuned Vision Transformer for Alzheimer's Detection This repository hosts a Vision Transformer (ViT) model fine-tuned on the OASIS MRI dataset for the classification of brain MRI images based on the progression of Alzheimer's disease. The model categorizes images into four classes: demented, very mild demented, mild demented, and non-demented. ## Model Description The Vision Transformer has been adapted to tackle the challenging task of medical image analysis by leveraging its powerful attention mechanisms that capture complex patterns in image data. It has been fine-tuned to classify MRI images into stages of Alzheimer's disease, demonstrating the model's applicability to medical diagnostics. ## Dataset The OASIS MRI dataset consists of 80,000 brain MRI images from 461 patients, formatted in Nifti (.nii) and converted to JPEG for model training. The images represent various stages of Alzheimer's disease as follows: - Non-Demented - Very Mild Demented - Mild Demented - Demented This dataset conversion involved standardizing image formats for machine learning applications, ensuring that each image is suitable for deep learning models. ## Preprocessing Techniques During preprocessing: - MRI scans were converted from Nifti format to JPEG to simplify handling and reduce storage requirements. - Each image was resized to 128x128 pixels, ensuring uniformity across the dataset. - Pixel values were normalized to a [0, 1] scale to facilitate model training. ## How to Use This Model You can use this model directly with a pipeline for image classification: \`\`\`python import torch from transformers import ViTForImageClassification from PIL import Image import numpy as np from torchvision.transforms import Compose, Resize, ToTensor, Normalize id2label = { 0: "Mild Dementia", 1: "Moderate Dementia", 2: "Non Demented", 3: "Very mild Dementia" } import torch from transformers import ViTForImageClassification from PIL import Image import numpy as np from torchvision.transforms import Compose, Resize, ToTensor, Normalize import matplotlib.pyplot as plt # Set the device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load the model model = ViTForImageClassification.from_pretrained('fawadkhan/ViT_FineTuned_on_ImagesOASIS') model.to(device) model.eval() # Define the image path image_path = 'your image path.jpg' image = Image.open(image_path).convert("RGB") # Define the transformations transform = Compose([ Resize((224, 224)), # or the original input size of your model ToTensor(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Standard normalization for ImageNet ]) # Preprocess the image input_tensor = transform(image).unsqueeze(0) # Create a mini-batch as expected by the model input_tensor = input_tensor.to(device) # Predict with torch.no_grad(): outputs = model(input_tensor) _, predicted = torch.max(outputs.logits, 1) # Retrieve the class name predicted_class = id2label[predicted[0].item()] print("Predicted class:", predicted_class) # Plot the image and the prediction plt.imshow(image) plt.title(f'Predicted class: {predicted_class}') plt.axis('off') # Turn off axis numbers and ticks plt.show() \`\`\` ## Training Procedure The model was trained using the AdamW optimizer with a learning rate of 5e-5 for 10 epochs, balancing the need for accuracy with the risk of overfitting. ## Evaluation Results Upon evaluation on a validation set, the model achieved an accuracy of 99%, showcasing its effectiveness in identifying different stages of Alzheimer's disease based on MRI scans.