File size: 2,800 Bytes
70884da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
import random
import shutil
from sklearn.model_selection import train_test_split
import argparse

def main(num_tiles_per_patient, source_dir, dataset_dir, test_val_size, val_size):
    # Create necessary directories if they don't exist
    os.makedirs(os.path.join(dataset_dir, 'train'), exist_ok=True)
    os.makedirs(os.path.join(dataset_dir, 'test'), exist_ok=True)
    os.makedirs(os.path.join(dataset_dir, 'validate'), exist_ok=True)
    

    with open('./data/tcga-hnscc-patients.json', 'r') as f:
        patient_data = json.load(f)

    # Separate the patients into study and restricted groups
    study_patients = patient_data['study']
    restricted_patients = patient_data['restricted']

    # List all patient directories in the source directory
    files = os.listdir(source_dir)

    # Filter files based on study patients
    study_files = [file for file in files if file in study_patients]

    # Split the data into train, test, and validation sets
    train, test_val = train_test_split(files, test_size=test_val_size)
    test, val = train_test_split(test_val, test_size=val_size)

    # Function to process and copy files
    def process_and_copy(file_list, type):
        for file in file_list:
            fol_p = os.path.join(source_dir, file)
            tiles = os.listdir(fol_p)
            selected_tiles = random.sample(tiles, min(num_tiles_per_patient, len(tiles)))
            for tile in selected_tiles:
                tile_p = os.path.join(fol_p, tile)
                new_p = os.path.join(dataset_dir, type, tile)
                shutil.copy(tile_p, new_p)

    # Process and copy files for each dataset
    process_and_copy(train, 'train')
    process_and_copy(test, 'test')
    process_and_copy(val, 'validate')

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Split data into train, test, and validation sets.')
    parser.add_argument('--num_tiles_per_patient', type=int, default=595,
                        help='Number of tiles to select per patient.')
    parser.add_argument('--source_dir', type=str, default='plip_preprocess',
                        help='Directory containing patient folders.')
    parser.add_argument('--dataset_dir', type=str, default='Datasets/train_03',
                        help='Root directory for the train, test, and validate directories.')
    parser.add_argument('--test_val_size', type=float, default=0.4,
                        help='Size of the test and validation sets combined.')
    parser.add_argument('--val_size', type=float, default=0.5,
                        help='Proportion of validation set in the test-validation split.')

    args = parser.parse_args()
    
    main(args.num_tiles_per_patient, args.source_dir, args.dataset_dir, args.test_val_size, args.val_size)