|
import os |
|
import random |
|
import torch |
|
from PIL import Image |
|
from concurrent.futures import ThreadPoolExecutor |
|
|
|
class PlipDataProcess(torch.utils.data.Dataset): |
|
def __init__(self, root_dir, files, df, img_processor=None, num_tiles_per_patient=128, max_workers=64, save_dir='processed_tile_data'): |
|
self.root_dir = root_dir |
|
self.files = files |
|
self.df = df |
|
self.img_processor = img_processor |
|
self.num_tiles_per_patient = num_tiles_per_patient |
|
self.max_workers = max_workers |
|
self.save_dir = save_dir |
|
if not os.path.exists(self.save_dir): |
|
os.makedirs(self.save_dir) |
|
|
|
def __len__(self): |
|
return len(self.files) |
|
|
|
def load_and_process_image(self, tile_path): |
|
image = Image.open(tile_path) |
|
return self.img_processor.preprocess(image)['pixel_values'] |
|
|
|
def save_individual_tile_data(self, tile_data, file_data, file_name, tile_name): |
|
save_path = os.path.join(self.save_dir, file_name, f"{tile_name}.pt") |
|
os.makedirs(os.path.dirname(save_path), exist_ok=True) |
|
torch.save({'tile_data': tile_data, 'file_data': file_data}, save_path) |
|
|
|
def __getitem__(self, idx): |
|
file = self.files[idx] |
|
tiles_path = os.path.join(self.root_dir, file,) |
|
tiles = [tile for tile in os.listdir(tiles_path) if tile != '.ipynb_checkpoints'] |
|
selected_tiles = random.sample(tiles, min(self.num_tiles_per_patient, len(tiles))) |
|
|
|
|
|
|
|
try: |
|
file_data = torch.tensor(self.df.loc[f'{file}-01'].values, dtype=torch.float32) |
|
except KeyError: |
|
|
|
|
|
num_features = self.df.shape[1] |
|
file_data = torch.zeros(num_features, dtype=torch.float32) |
|
|
|
with ThreadPoolExecutor(max_workers=self.max_workers) as executor: |
|
for tile_name in selected_tiles: |
|
tile_path = os.path.join(tiles_path, tile_name) |
|
executor.submit(self.process_and_save_tile, tile_path, file_data, file, tile_name) |
|
|
|
return idx |
|
|
|
def process_and_save_tile(self, tile_path, file_data, file_name, tile_name): |
|
tile_data = self.load_and_process_image(tile_path) |
|
self.save_individual_tile_data(tile_data, file_data, file_name, tile_name) |
|
|