MedMNIST Active Learning Model

Overview

This model is designed for image classification tasks within the medical imaging domain, specifically targeting the MedMNIST dataset. It employs a ResNet-50 architecture tailored for 28x28 pixel images and incorporates active learning strategies to enhance performance with limited labeled data.

Model Architecture

  • Base Model: ResNet-50
  • Modifications:
    • Adjusted initial convolution layer to accommodate 28x28 input images.
    • Removed max pooling layer to preserve spatial dimensions.
    • Customized fully connected layer to output predictions for 9 classes.

Training Procedure

Training Hyperparameters

Hyperparameter Value
Batch Size 53
Initial Labeled Size 3559
Learning Rate 0.01332344940133225
MC Dropout Passes 6
Samples to Label 4430
Weight Decay 0.00021921795989143406

Optimizer Settings

The optimizer used during training was Stochastic Gradient Descent(SDG), with the following settings and a Learning Rate Scheduler of ReduceLROnPlateau:

  • learning_rate = 0.01332344940133225
  • momentum = 0.9
  • weight_decay = 0.00021921795989143406

The model was trained with float32 precision.

Dataset

PathMNIST

Data Augmentation

  • Random resized cropping
  • Horizontal flipping
  • Random rotations
  • Color jittering
  • Gaussian blur
  • RandAugment

Active Learning Strategy

The active learning process was based on a mixed sampling strategy:

  • Uncertainty Sampling: Monte Carlo (MC) dropout was used to estimate uncertainty.
  • Diversity Sampling: K-means clustering was employed to ensure diverse samples.

Evaluation

The model was evaluated on the validation set of PathMNIST. Key performance metrics include:

  • Accuracy: 94.72%
  • Loss: 0.2397
  • AUC: 99.73%

Graphs

The following plots illustrates the validation loss, validation accuracy, and validation auc over batches(number of iterations over the dataset) during the active learning process.

  • Validation Loss Validation Loss
  • Validation Accuracy Validation Accuracy
  • Validation AUC Validation AUC

Usage

All code for this model can be accessed in the following GitHub Repository: Allen Cheung Determined_AI_Hackathon

To utilize this model:

  1. Install Dependencies: Ensure the following Python packages are installed:

    • torch
    • torchvision
    • medmnist
    • scikit-learn
    • determined

    Install them using pip:

    pip install torch torchvision medmnist scikit-learn determined
    
  2. Load the Model:

    import torch
    from model import ResNet50_28
    
    model = ResNet50_28(num_classes=9)
    model.load_state_dict(torch.load('pytorch_model.bin'))
    model.eval()
    
  3. Inference:

    from torchvision import transforms
    from PIL import Image
    
    transform = transforms.Compose([
        transforms.Resize((28, 28)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5], std=[0.5])
    ])
    
    image = Image.open('path_to_image.jpg')
    input_tensor = transform(image).unsqueeze(0)
    output = model(input_tensor)
    prediction = output.argmax(dim=1).item()
    print(f"Predicted class: {prediction}")
    

License

This project is licensed under the MIT License.

Acknowledgements

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.