import os import torch from torch.utils.data import Dataset from pathlib import Path import argparse from scripts.genomic_plip_model import GenomicPLIPModel from transformers import CLIPVisionModel class PatientTileDataset(Dataset): def __init__(self, data_dir, model, save_dir): super().__init__() self.data_dir = data_dir self.model = model self.save_dir = Path(save_dir) self.files = [] for patient_id in os.listdir(data_dir): patient_dir = os.path.join(data_dir, patient_id) if os.path.isdir(patient_dir): for f in os.listdir(patient_dir): if f.endswith('.pt'): self.files.append((os.path.join(patient_dir, f), patient_id)) def __len__(self): return len(self.files) def __getitem__(self, idx): file_path, patient_id = self.files[idx] data = torch.load(file_path) tile_data = torch.from_numpy(data['tile_data'][0]).unsqueeze(0) # Add batch dimension with torch.no_grad(): vision_features, _ = self.model(pixel_values=tile_data, score_vector=torch.zeros(1, 4)) feature_path = self.save_dir / patient_id / os.path.basename(file_path) feature_path.parent.mkdir(parents=True, exist_ok=True) torch.save(vision_features, feature_path) return feature_path def extract_features(data_dir, save_dir, model_path): original_model = CLIPVisionModel.from_pretrained("./plip/") custom_model = GenomicPLIPModel(original_model) custom_model.load_state_dict(torch.load(model_path)) custom_model.eval() dataset = PatientTileDataset(data_dir=data_dir, model=custom_model, save_dir=save_dir) for _ in dataset: pass if __name__ == "__main__": parser = argparse.ArgumentParser(description="Extract features from genomic aligned tiles.") parser.add_argument('--data_dir', type=str, default='plip_preprocess/', help='Directory containing the pre processed patient data.') parser.add_argument('--save_dir', type=str, default='omics_align_features/', help='Directory to save the extracted features.') parser.add_argument('--model_path', type=str, default='./save_model/omics_plip.pth', help='Path to the trained model file.') args = parser.parse_args() extract_features(args.data_dir, args.save_dir, args.model_path)