import os import random import shutil import pandas as pd from sklearn.model_selection import train_test_split import argparse def load_and_preprocess_data(metadata_file, data_dir): df = pd.read_csv(metadata_file, skiprows=0) if 'Unnamed: 0' in df.columns: del df['Unnamed: 0'] # Filter and map classes to 0 and 1 classified_df = df[df['Class'].isin([1, 3])] classified_df['Class'] = classified_df['Class'].map({1: 1, 3: 0}) df = classified_df.set_index('PatientID') # Filter for patients that have corresponding WSI data available_patients = set(os.listdir(data_dir)) df = df.loc[df.index.intersection(available_patients)] df = df.sample(frac=1) return df def create_data_splits(df): class1 = list(df[df['Class'] == 1].index) class0 = list(df[df['Class'] == 0].index) C1_X_train, C1_X_test = train_test_split(class1, test_size=0.3) C0_X_train, C0_X_test = train_test_split(class0, test_size=0.2) C1_X_validate, C1_X_test = train_test_split(C1_X_test, test_size=0.6) C0_X_validate, C0_X_test = train_test_split(C0_X_test, test_size=0.5) X_train = []; X_train.extend(C1_X_train); X_train.extend(C0_X_train) X_test = []; X_test.extend(C1_X_test); X_test.extend(C0_X_test) X_validate = []; X_validate.extend(C1_X_validate); X_validate.extend(C0_X_validate) random.shuffle(X_train) random.shuffle(X_test) random.shuffle(X_validate) data_info = {'train': X_train, 'test': X_test, 'validate': X_validate} print(" C0 - Train : {} , Validate : {} , Test : {} ".format(len(C0_X_train), len(C0_X_test), len(C0_X_validate))) print(" C1 - Train : {} , Validate : {} , Test : {} ".format(len(C1_X_train), len(C1_X_test), len(C1_X_validate))) return data_info def copy_tiles(patient_ids, dest_folder, source_dir, num_tiles_per_patient): for pID in patient_ids: flp = os.path.join(source_dir, pID) if os.path.exists(flp): tiles = os.listdir(flp) selected_tiles = random.sample(tiles, min(num_tiles_per_patient, len(tiles))) for tile in selected_tiles: tile_p = os.path.join(flp, tile) new_p = os.path.join(dest_folder, tile) shutil.copy(tile_p, new_p) else: print(f"Folder not found for patient {pID}") def process_cohorts(primary_metadata, secondary_metadata, source_dir, dataset_dir, num_tiles_per_patient): # Create necessary directories if they don't exist os.makedirs(os.path.join(dataset_dir, 'train/class1/'), exist_ok=True) os.makedirs(os.path.join(dataset_dir, 'train/class0/'), exist_ok=True) os.makedirs(os.path.join(dataset_dir, 'test/class1/'), exist_ok=True) os.makedirs(os.path.join(dataset_dir, 'test/class0/'), exist_ok=True) os.makedirs(os.path.join(dataset_dir, 'validate/class1/'), exist_ok=True) os.makedirs(os.path.join(dataset_dir, 'validate/class0/'), exist_ok=True) os.makedirs(os.path.join(dataset_dir, 'test2/class1/'), exist_ok=True) os.makedirs(os.path.join(dataset_dir, 'test2/class0/'), exist_ok=True) # Load and preprocess primary cohort primary_df = load_and_preprocess_data(primary_metadata, source_dir) primary_data_info = create_data_splits(primary_df) # Load and preprocess secondary cohort secondary_df = load_and_preprocess_data(secondary_metadata, source_dir) secondary_data_info = {'test2': secondary_df.index.tolist()} # Copy tiles for the primary cohort copy_tiles(primary_data_info['train'], os.path.join(dataset_dir, 'train/class1/'), source_dir, num_tiles_per_patient) copy_tiles(primary_data_info['test'], os.path.join(dataset_dir, 'test/class1/'), source_dir, num_tiles_per_patient) copy_tiles(primary_data_info['validate'], os.path.join(dataset_dir, 'validate/class1/'), source_dir, num_tiles_per_patient) # Copy tiles for the secondary cohort copy_tiles(secondary_data_info['test2'], os.path.join(dataset_dir, 'test2/class1/'), source_dir, num_tiles_per_patient) print("Tiles copying completed for both cohorts.") if __name__ == "__main__": parser = argparse.ArgumentParser(description='Create dataset for benchmark models from primary and secondary cohorts.') parser.add_argument('--primary_metadata', type=str, required=True, help='Path to the primary cohort metadata CSV file.') parser.add_argument('--secondary_metadata', type=str, required=True, help='Path to the secondary cohort metadata CSV file.') parser.add_argument('--source_dir', type=str, required=True, help='Directory containing raw tissue tiles.') parser.add_argument('--dataset_dir', type=str, required=True, help='Directory to save the processed dataset.') parser.add_argument('--num_tiles_per_patient', type=int, default=595, help='Number of tiles to select per patient.') args = parser.parse_args() process_cohorts(args.primary_metadata, args.secondary_metadata, args.source_dir, args.dataset_dir, args.num_tiles_per_patient)