|
import os |
|
import numpy as np |
|
import pandas as pd |
|
import torch |
|
import random |
|
from sklearn.model_selection import train_test_split |
|
import argparse |
|
|
|
def prepare_data_with_padding(data, max_length=None): |
|
if max_length is None: |
|
max_length = max(len(bag) for bag in data) |
|
padded_data = [] |
|
for bag in data: |
|
if len(bag) < max_length: |
|
padding = np.zeros((max_length - len(bag), bag.shape[1])) |
|
padded_bag = np.vstack((bag, padding)) |
|
else: |
|
padded_bag = bag |
|
padded_data.append(padded_bag) |
|
return np.array(padded_data), max_length |
|
|
|
def create_bags(data_info, df13, data_dir): |
|
data = {'train': {'X': [], 'Y': []}, 'test': {'X': [], 'Y': []}, 'validate': {'X': [], 'Y': []}} |
|
for split in ['train', 'test', 'validate']: |
|
for pID in data_info[split]: |
|
fol_p = os.path.join(data_dir, pID) |
|
tiles = os.listdir(fol_p) |
|
tile_data = [] |
|
for tile in tiles: |
|
tile_p = os.path.join(fol_p, tile) |
|
np1 = torch.load(tile_p).numpy() |
|
tile_data.append(np1) |
|
patient_label = df13.loc[pID, 'Class'] |
|
bag = np.squeeze(tile_data, axis=1) |
|
bag_label = 1 if patient_label == 1 else 0 |
|
data[split]['X'].append(bag) |
|
data[split]['Y'].append(np.array([bag_label])) |
|
data[split]['X'] = np.array(data[split]['X']) |
|
data[split]['Y'] = np.array(data[split]['Y']) |
|
print(f"Data[{split}]['X'] shape: {data[split]['X'].shape}, dtype: {data[split]['X'].dtype}") |
|
return data |
|
|
|
def process_and_save(data_dir, metadata_file, save_dir): |
|
|
|
dff = pd.read_csv(metadata_file, skiprows=0) |
|
del dff['Unnamed: 0'] |
|
|
|
classified_df = dff[dff['Class'].isin([1, 3])] |
|
classified_df['Class'] = classified_df['Class'].map({1: 1, 3: 0}) |
|
df13 = classified_df.set_index('PatientID') |
|
|
|
there = set(list(df13.index)) |
|
wsi_there = os.listdir(data_dir) |
|
use = list(there.intersection(wsi_there)) |
|
df13 = df13.loc[use] |
|
df13 = df13.sample(frac=1) |
|
|
|
class1 = list(df13[df13['Class'] == 1].index) |
|
class0 = list(df13[df13['Class'] == 0].index) |
|
|
|
C1_X_train, C1_X_test = train_test_split(class1, test_size=0.3) |
|
C0_X_train, C0_X_test = train_test_split(class0, test_size=0.2) |
|
C1_X_validate, C1_X_test = train_test_split(C1_X_test, test_size=0.6) |
|
C0_X_validate, C0_X_test = train_test_split(C0_X_test, test_size=0.5) |
|
|
|
X_train = C1_X_train + C0_X_train |
|
X_test = C1_X_test + C0_X_test |
|
X_validate = C1_X_validate + C0_X_validate |
|
|
|
random.shuffle(X_train) |
|
random.shuffle(X_test) |
|
random.shuffle(X_validate) |
|
|
|
data_info = {'train': X_train, 'test': X_test, 'validate': X_validate} |
|
|
|
print(" C0 - Train : {} , Validate : {} , Test : {} ".format(len(C0_X_train), len(C0_X_test), len(C0_X_validate))) |
|
print(" C1 - Train : {} , Validate : {} , Test : {} ".format(len(C1_X_train), len(C1_X_test), len(C1_X_validate))) |
|
|
|
|
|
data = create_bags(data_info, df13, data_dir) |
|
train_X, _ = prepare_data_with_padding(data['train']['X'], 2000) |
|
train_Y = np.array(data['train']['Y']).flatten() |
|
validate_X, _ = prepare_data_with_padding(data['validate']['X'], 2000) |
|
validate_Y = np.array(data['validate']['Y']).flatten() |
|
test_X, _ = prepare_data_with_padding(data['test']['X'], 2000) |
|
test_Y = np.array(data['test']['Y']).flatten() |
|
|
|
|
|
np.savez_compressed(os.path.join(save_dir, 'training_risk_classifier_data.npz'), |
|
train_X=train_X, train_Y=train_Y, |
|
validate_X=validate_X, validate_Y=validate_Y, |
|
test_X=test_X, test_Y=test_Y) |
|
|
|
print("Data saved successfully in:", os.path.join(save_dir, 'training_risk_classifier_data.npz')) |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description='Process, split, and save the data with padding.') |
|
parser.add_argument('--data_dir', type=str, help='Directory containing the extracted features.') |
|
parser.add_argument('--metadata_file', type=str, default='data/data1.hnsc.p3.csv', help='CSV file containing the metadata for the samples.') |
|
parser.add_argument('--save_dir', type=str, default='Datasets', help='Directory to save the processed data.') |
|
|
|
args = parser.parse_args() |
|
|
|
if not os.path.exists(args.save_dir): |
|
os.makedirs(args.save_dir) |
|
|
|
process_and_save(args.data_dir, args.metadata_file, args.save_dir) |
|
|