# Data exploration

In [None]:
import os
import matplotlib.pyplot as plt
import cv2
import numpy as np
import pandas as pd
from glob import glob
import matplotlib.gridspec as gridspec
import seaborn as sns
from sklearn.utils import class_weight
import random
random.seed(123)

In [None]:
path1 ="/kaggle/input/kermany2018/OCT2017 /train"
path2 ="/kaggle/input/kermany2018/OCT2017 /test"
path3 ="/kaggle/input/kermany2018/OCT2017 /val"

In [None]:
myList = os.listdir(path1)
print("Total Number of Classes Detected :",len(myList))


In [None]:
random.seed(123)
noOfclasses= len(myList)
print(myList)

In [None]:
main_dir = os.listdir('/kaggle/input/kermany2018/OCT2017 /')
print(main_dir)

for i in main_dir:
    data_dir_list =  os.listdir('/kaggle/input/kermany2018/OCT2017 /'+ str(i) )
    print(i, data_dir_list)

In [None]:
normal_len = len(os.listdir("../input/kermany2018/OCT2017 /train/NORMAL"))
drusen_len = len(os.listdir("../input/kermany2018/OCT2017 /train/DRUSEN"))
cnv_len = len(os.listdir("../input/kermany2018/OCT2017 /train/CNV"))
dme_len = len(os.listdir("../input/kermany2018/OCT2017 /train/DME"))

print("length of images with drusen = ",drusen_len)
print("length of images with cnv = ",cnv_len)
print("length of images with dme = ",dme_len)
print("length of images with normal = ",normal_len)

In [None]:
print("Normal")
multipleImages = glob('../input/kermany2018/oct2017/OCT2017 /train/NORMAL/**')
i_ = 0
plt.rcParams['figure.figsize'] = (15.0, 15.0)
plt.subplots_adjust(wspace=0, hspace=0)
for l in multipleImages[:25]:
    im = cv2.imread(l)
    im = cv2.resize(im, (128, 128)) 
    plt.subplot(5, 5, i_+1) #.set_title(l)
    plt.imshow(cv2.cvtColor(im, cv2.COLOR_BGR2RGB)); plt.axis('off')
    i_ += 1

In [None]:
print("DME")
multipleImages = glob('../input/kermany2018/oct2017/OCT2017 /train/DME/**')
i_ = 0
plt.rcParams['figure.figsize']  =  (15.0, 15.0)
plt.subplots_adjust(wspace=0, hspace=0)
for l in multipleImages[:15]:
    im = cv2.imread(l)
    im = cv2.resize(im, (128, 128)) 
    plt.subplot(5, 5, i_+1) #.set_title(l)
    plt.imshow(cv2.cvtColor(im, cv2.COLOR_BGR2RGB)); plt.axis('off')
    i_ += 1

# Model checkpoint

In [None]:
model_checkpoint = "microsoft/swinv2-tiny-patch4-window8-256"
batch_size = 64

In [None]:
!pip install -q datasets transformers

In [None]:
from huggingface_hub import notebook_login

notebook_login()

# Loading complete data

In [None]:
from datasets import load_dataset 
dataset = load_dataset("imagefolder", data_dir="/kaggle/input/kermany2018/")

In [None]:
dataset

In [None]:
from datasets import load_metric

metric = load_metric("accuracy")


In [None]:
labels = dataset["train"].features["label"].names
label2id, id2label = dict(), dict()
for i, label in enumerate(labels):
    label2id[label] = i
    id2label[i] = label

# Data preprocessing

In [None]:
from transformers import AutoImageProcessor

image_processor  = AutoImageProcessor.from_pretrained(model_checkpoint)
image_processor 

In [None]:
from torchvision.transforms import (
    Compose,
    Normalize,
    RandomHorizontalFlip,
    RandomVerticalFlip,  
    Resize,
    ToTensor,
    ColorJitter,
)

normalize = Normalize(mean=image_processor.image_mean, std=image_processor.image_std)

if "height" in image_processor.size:
    size = (image_processor.size["height"], image_processor.size["width"])
elif "shortest_edge" in image_processor.size:
    size = (image_processor.size["shortest_edge"], image_processor.size["shortest_edge"])

color_jitter = ColorJitter(brightness=0.2, contrast=0.2)

train_transforms = Compose(
        [
            RandomHorizontalFlip(),
            RandomVerticalFlip(),  
            color_jitter,
            Resize(size),
            ToTensor(),
            normalize,
        ]
    )

val_transforms = Compose(
        [
            Resize(size),
            ToTensor(),
            normalize,
        ]
    )

def preprocess_train(example_batch):
    example_batch["pixel_values"] = [
        train_transforms(image.convert("RGB")) for image in example_batch["image"]
    ]
    return example_batch

def preprocess_val(example_batch):
    example_batch["pixel_values"] = [val_transforms(image.convert("RGB")) for image in example_batch["image"]]
    return example_batch


In [None]:
splits_1 = dataset["train"].train_test_split(test_size=0.2)
train_temp = splits_1['train']
test_ds = splits_1['test']

# Now, split the train_temp (80% of original) into train (87.5% of 80%) and val (12.5% of 80%)
splits_2 = train_temp.train_test_split(test_size=0.125)
train_ds = splits_2['train']
val_ds = splits_2['test']

In [None]:
train_ds.set_transform(preprocess_train)
val_ds.set_transform(preprocess_val)
test_ds.set_transform(preprocess_val)

# Model training

In [None]:
from transformers import AutoModelForImageClassification, TrainingArguments, Trainer

model = AutoModelForImageClassification.from_pretrained(
    model_checkpoint, 
    label2id=label2id,
    id2label=id2label,
    ignore_mismatched_sizes = True
)


In [None]:
model_name = model_checkpoint.split("/")[-1]

args = TrainingArguments(
    f"SwinMark2",
    remove_unused_columns=False,
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=batch_size,
    gradient_accumulation_steps=4,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=3,
    warmup_ratio=0.1,
    logging_steps=10,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    push_to_hub=True,
)

In [None]:
import numpy as np
def compute_metrics(eval_pred):
    predictions = np.argmax(eval_pred.predictions, axis=1)
    return metric.compute(predictions=predictions, references=eval_pred.label_ids)

In [None]:
import torch

def collate_fn(examples):
    pixel_values = torch.stack([example["pixel_values"] for example in examples])
    labels = torch.tensor([example["label"] for example in examples])
    return {"pixel_values": pixel_values, "labels": labels}

In [None]:
trainer = Trainer(
    model,
    args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    tokenizer=image_processor,
    compute_metrics=compute_metrics,
    data_collator=collate_fn,
)

In [None]:
train_results = trainer.train()
trainer.save_model()
trainer.log_metrics("train", train_results.metrics)
trainer.save_metrics("train", train_results.metrics)
trainer.save_state()

In [None]:
metrics = trainer.evaluate()
trainer.log_metrics("eval", metrics)
trainer.save_metrics("eval", metrics)

In [None]:
outputs = trainer.predict(test_ds)

In [None]:
print(outputs.metrics)

In [None]:
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

y_true = outputs.label_ids
y_pred = outputs.predictions.argmax(1)

labels = train_ds.features['label'].names
cm = confusion_matrix(y_true, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=labels)
disp.plot(xticks_rotation=45)

In [None]:
trainer.push_to_hub()

In [None]:
from sklearn.metrics import classification_report
report = classification_report(y_true, y_pred, target_names=labels, digits=4)
print(report)

In [None]:
test_ds

# Inference

In [None]:
from torchvision.datasets import ImageFolder

image_path = "/kaggle/input/kermany2018/OCT2017 /test/CNV/CNV-1016042-4.jpeg"

from PIL import Image
image = Image.open(image_path)
image.show()


In [None]:
!pip install transformers
from transformers import pipeline
pipe = pipeline("image-classification", "Abhiram4/SwinMark2")

In [None]:
pipe(image)