omics-plip-1 / make_train_data_for_risk_classification.py
VatsalPatel18's picture
Upload 19 files
70884da verified
import os
import numpy as np
import pandas as pd
import torch
import random
from sklearn.model_selection import train_test_split
import argparse
def prepare_data_with_padding(data, max_length=None):
if max_length is None:
max_length = max(len(bag) for bag in data)
padded_data = []
for bag in data:
if len(bag) < max_length:
padding = np.zeros((max_length - len(bag), bag.shape[1]))
padded_bag = np.vstack((bag, padding))
else:
padded_bag = bag
padded_data.append(padded_bag)
return np.array(padded_data), max_length
def create_bags(data_info, df13, data_dir):
data = {'train': {'X': [], 'Y': []}, 'test': {'X': [], 'Y': []}, 'validate': {'X': [], 'Y': []}}
for split in ['train', 'test', 'validate']:
for pID in data_info[split]:
fol_p = os.path.join(data_dir, pID)
tiles = os.listdir(fol_p)
tile_data = []
for tile in tiles:
tile_p = os.path.join(fol_p, tile)
np1 = torch.load(tile_p).numpy()
tile_data.append(np1)
patient_label = df13.loc[pID, 'Class']
bag = np.squeeze(tile_data, axis=1)
bag_label = 1 if patient_label == 1 else 0
data[split]['X'].append(bag)
data[split]['Y'].append(np.array([bag_label]))
data[split]['X'] = np.array(data[split]['X'])
data[split]['Y'] = np.array(data[split]['Y'])
print(f"Data[{split}]['X'] shape: {data[split]['X'].shape}, dtype: {data[split]['X'].dtype}")
return data
def process_and_save(data_dir, metadata_file, save_dir):
# Load and preprocess metadata
dff = pd.read_csv(metadata_file, skiprows=0)
del dff['Unnamed: 0']
classified_df = dff[dff['Class'].isin([1, 3])]
classified_df['Class'] = classified_df['Class'].map({1: 1, 3: 0})
df13 = classified_df.set_index('PatientID')
there = set(list(df13.index))
wsi_there = os.listdir(data_dir)
use = list(there.intersection(wsi_there))
df13 = df13.loc[use]
df13 = df13.sample(frac=1)
class1 = list(df13[df13['Class'] == 1].index)
class0 = list(df13[df13['Class'] == 0].index)
C1_X_train, C1_X_test = train_test_split(class1, test_size=0.3)
C0_X_train, C0_X_test = train_test_split(class0, test_size=0.2)
C1_X_validate, C1_X_test = train_test_split(C1_X_test, test_size=0.6)
C0_X_validate, C0_X_test = train_test_split(C0_X_test, test_size=0.5)
X_train = C1_X_train + C0_X_train
X_test = C1_X_test + C0_X_test
X_validate = C1_X_validate + C0_X_validate
random.shuffle(X_train)
random.shuffle(X_test)
random.shuffle(X_validate)
data_info = {'train': X_train, 'test': X_test, 'validate': X_validate}
print(" C0 - Train : {} , Validate : {} , Test : {} ".format(len(C0_X_train), len(C0_X_test), len(C0_X_validate)))
print(" C1 - Train : {} , Validate : {} , Test : {} ".format(len(C1_X_train), len(C1_X_test), len(C1_X_validate)))
# Create bags and prepare data with padding
data = create_bags(data_info, df13, data_dir)
train_X, _ = prepare_data_with_padding(data['train']['X'], 2000)
train_Y = np.array(data['train']['Y']).flatten()
validate_X, _ = prepare_data_with_padding(data['validate']['X'], 2000)
validate_Y = np.array(data['validate']['Y']).flatten()
test_X, _ = prepare_data_with_padding(data['test']['X'], 2000)
test_Y = np.array(data['test']['Y']).flatten()
# Save the processed arrays to a single file
np.savez_compressed(os.path.join(save_dir, 'training_risk_classifier_data.npz'),
train_X=train_X, train_Y=train_Y,
validate_X=validate_X, validate_Y=validate_Y,
test_X=test_X, test_Y=test_Y)
print("Data saved successfully in:", os.path.join(save_dir, 'training_risk_classifier_data.npz'))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Process, split, and save the data with padding.')
parser.add_argument('--data_dir', type=str, help='Directory containing the extracted features.')
parser.add_argument('--metadata_file', type=str, default='data/data1.hnsc.p3.csv', help='CSV file containing the metadata for the samples.')
parser.add_argument('--save_dir', type=str, default='Datasets', help='Directory to save the processed data.')
args = parser.parse_args()
if not os.path.exists(args.save_dir):
os.makedirs(args.save_dir)
process_and_save(args.data_dir, args.metadata_file, args.save_dir)