Fine Tuning on MR dataset

#2
by heeseongE - opened

Hello, mr. baibai, I'm a researcher interested in Multi-Modal Medical AI.

Now I'm trying to use your M3D-CLIP work for fine tuning on my dataset, which is constructed by stroke dwi(adc)-mr images and its radiologic reports.
This is a sample for my input report:
Diffusion restriction at left PVWM and left posterior globus pallidus/putamen~corona radiata.

  • with mild T2 signal change
    --> acute infarction, likely.

When I just did inference on my 20 test samples, it performed 0.1 accuracy.
And after I train it on 220 samples and evaluate with previsou 20 test samples, I found that it has same accuracy with all different similarity scores.
So I think there is better strategy with process my data or parameters etc. And I kindly ask do you have any advice about my process.

This is what i did for my train.
After that, I calculate cos similarity on image feature and report feature to finding real report.

Thanks a lot,
Heeseong Eom

import os
import torch
from transformers import AutoTokenizer, AutoModel
import numpy as np
import pandas as pd
from monai.transforms import (
Compose, LoadImaged, EnsureChannelFirstd, Orientationd, Spacingd,
Resized, NormalizeIntensityd, EnsureTyped
)
from sklearn.model_selection import train_test_split
from collections import Counter
import logging
import sys
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

attempt_num = input("Enter the attempt number: ")

best_model_dir = "/media/data4/hseum/py/rag_task/m3dclip_best_model"
os.makedirs(best_model_dir, exist_ok=True)
best_model_filename = f"best_model_attempt_{attempt_num}.pth"
best_model_path = os.path.join(best_model_dir, best_model_filename)

logging.basicConfig(stream=sys.stdout, level=logging.INFO)

internal_data_path = "/media/data4/hseum/stroke_mr/matched_image/BRMH"
internal_report_path = "/media/data4/hseum/report/matched_report/BRMH_mr_report.csv"

internal_df = pd.read_csv(internal_report_path)

all_labels = internal_df['ADP'].unique()
label_mapping = {label: idx for idx, label in enumerate(all_labels)}
print("# of Total Class: ", len(label_mapping))
print("Classes: ", all_labels, " to ", label_mapping)

internal_df['label'] = internal_df['ADP'].map(label_mapping)

def get_image_paths_and_labels(base_path, report_df):
image_paths = []
labels = []
summaries = []
for _, row in report_df.iterrows():
patient_id = str(row['PatientID']).strip().zfill(8)
date = str(row['Date'])
image_path = os.path.join(base_path, patient_id, date, 'mr.nii.gz')
if os.path.exists(image_path):
image_paths.append(image_path)
labels.append(row['label'])
summaries.append(row['summary'])
return image_paths, labels, summaries

internal_image_paths, internal_labels, internal_summaries = get_image_paths_and_labels(internal_data_path, internal_df)
print("Total internal image paths:", len(internal_image_paths))

class_counts = Counter(internal_df['label'])
total_count = sum(class_counts.values())
class_weights = {label: total_count / count for label, count in class_counts.items()}

internal_image_paths, internal_val_image_paths, internal_labels, internal_val_labels, internal_summaries, internal_val_summaries = train_test_split(
internal_image_paths, internal_labels, internal_summaries, test_size=20, random_state=42
)

def preprocess_image(image_path):
npy_path = image_path.replace('.nii.gz', '.npy')
if os.path.exists(npy_path):
return npy_path

transform = Compose([
    LoadImaged(keys=["img"]),
    EnsureChannelFirstd(keys=["img"]),
    Orientationd(keys=["img"], axcodes="RAS"),
    Spacingd(keys=["img"], pixdim=(0.9375, 0.9375, 4.0), mode='bilinear'),
    Resized(keys=["img"], spatial_size=(256, 256, 32)),
    NormalizeIntensityd(keys=["img"], nonzero=True, channel_wise=True),
    EnsureTyped(keys=["img"], data_type='tensor', track_meta=False)
])

data = {"img": image_path}
img_transformed = transform(data)["img"]

img_transformed = img_transformed[1]

img_min = img_transformed.min()
img_max = img_transformed.max()
img_normalized = (img_transformed - img_min) / (img_max - img_min)

img_numpy = img_normalized.numpy().astype(np.float32)
np.save(npy_path, img_numpy)
return npy_path

train_files = [{"img": preprocess_image(img_path), "text": text}
for img_path, text in zip(internal_image_paths, internal_summaries)]
internal_val_files = [{"img": preprocess_image(img_path), "text": text}
for img_path, text in zip(internal_val_image_paths, internal_val_summaries)]
print(np.shape(np.load(train_files[0]["img"])))
print("Train files:", len(train_files))
print("Internal validation files:", len(internal_val_files))

tokenizer = AutoTokenizer.from_pretrained("GoodBaiBai88/M3D-CLIP", model_max_length=512, padding_side="right", use_fast=False)
model = AutoModel.from_pretrained("GoodBaiBai88/M3D-CLIP", trust_remote_code=True)
model = model.to(device=device)

model.gather_loss = False

class CustomDataset(Dataset):
def init(self, data_files):
self.data_files = data_files

def __len__(self):
    return len(self.data_files)

def __getitem__(self, idx):
    image = torch.from_numpy(np.load(self.data_files[idx]["img"]).astype(np.float32))[None, ...]  # (32, 256, 256) -> (1, 32, 256, 256)
    image = image[:, 1, :, :, :]  # 두 번째 채널 선택 -> (3, 2, 256, 256, 32) -> (3, 1, 256, 256, 32)
    text_tensor = tokenizer(self.data_files[idx]["text"], max_length=512, truncation=True, padding="max_length", return_tensors="pt")
    input_id = text_tensor["input_ids"].squeeze(0)
    attention_mask = text_tensor["attention_mask"].squeeze(0)
    return image, input_id, attention_mask

train_dataset = CustomDataset(train_files)
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True)

optimizer = optim.Adam(model.parameters(), lr=5e-5, betas=(0.9, 0.98), eps=1e-6, weight_decay=0.2)

EPOCHS = 100
best_loss = float('inf')

for epoch in range(EPOCHS):
epoch_loss = 0.0
for batch in tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{EPOCHS}", leave=False):
optimizer.zero_grad()

    images, input_ids, attention_masks = batch

    images = images.to(device)
    input_ids = input_ids.to(device)
    attention_masks = attention_masks.to(device)

    labels = torch.arange(len(images), dtype=torch.long, device=device)

    ret = model(images, input_ids, attention_masks, labels)
    total_loss = ret["loss"]
    total_loss.backward()
    optimizer.step()

    epoch_loss += total_loss.item()

if epoch % 10==0:
    print(f"epoch_loss: {epoch_loss} best_loss: {best_loss}")

if epoch_loss < best_loss:
    best_loss = epoch_loss
    
    if os.path.exists(best_model_path):
        os.remove(best_model_path)
    torch.save(model.state_dict(), best_model_path)

print("Training finished.")

Hi,

If you use another dataset, especially MRI data, not CT, it is best to fine-tune this M3D-CLIP model on your dataset, reducing domain bias.
In data preprocessing, we should make the shape 32x256x256 (DxHxW) and do normalization to 0-1.

Best regards,
BAI Fan

Sign up or log in to comment