|
import os |
|
import torch |
|
from torch.utils.data import Dataset |
|
from pathlib import Path |
|
import argparse |
|
from scripts.genomic_plip_model import GenomicPLIPModel |
|
from transformers import CLIPVisionModel |
|
|
|
class PatientTileDataset(Dataset): |
|
def __init__(self, data_dir, model, save_dir): |
|
super().__init__() |
|
self.data_dir = data_dir |
|
self.model = model |
|
self.save_dir = Path(save_dir) |
|
self.files = [] |
|
for patient_id in os.listdir(data_dir): |
|
patient_dir = os.path.join(data_dir, patient_id) |
|
if os.path.isdir(patient_dir): |
|
for f in os.listdir(patient_dir): |
|
if f.endswith('.pt'): |
|
self.files.append((os.path.join(patient_dir, f), patient_id)) |
|
|
|
def __len__(self): |
|
return len(self.files) |
|
|
|
def __getitem__(self, idx): |
|
file_path, patient_id = self.files[idx] |
|
data = torch.load(file_path) |
|
tile_data = torch.from_numpy(data['tile_data'][0]).unsqueeze(0) |
|
with torch.no_grad(): |
|
vision_features, _ = self.model(pixel_values=tile_data, score_vector=torch.zeros(1, 4)) |
|
feature_path = self.save_dir / patient_id / os.path.basename(file_path) |
|
feature_path.parent.mkdir(parents=True, exist_ok=True) |
|
torch.save(vision_features, feature_path) |
|
return feature_path |
|
|
|
def extract_features(data_dir, save_dir, model_path): |
|
original_model = CLIPVisionModel.from_pretrained("./plip/") |
|
custom_model = GenomicPLIPModel(original_model) |
|
custom_model.load_state_dict(torch.load(model_path)) |
|
custom_model.eval() |
|
|
|
dataset = PatientTileDataset(data_dir=data_dir, model=custom_model, save_dir=save_dir) |
|
for _ in dataset: |
|
pass |
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser(description="Extract features from genomic aligned tiles.") |
|
parser.add_argument('--data_dir', type=str, default='plip_preprocess/', help='Directory containing the pre processed patient data.') |
|
parser.add_argument('--save_dir', type=str, default='omics_align_features/', help='Directory to save the extracted features.') |
|
parser.add_argument('--model_path', type=str, default='./save_model/omics_plip.pth', help='Path to the trained model file.') |
|
|
|
args = parser.parse_args() |
|
|
|
extract_features(args.data_dir, args.save_dir, args.model_path) |
|
|