File size: 2,386 Bytes
70884da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import os
import torch
from torch.utils.data import Dataset
from pathlib import Path
import argparse
from scripts.genomic_plip_model import GenomicPLIPModel
from transformers import CLIPVisionModel

class PatientTileDataset(Dataset):
    def __init__(self, data_dir, model, save_dir):
        super().__init__()
        self.data_dir = data_dir
        self.model = model
        self.save_dir = Path(save_dir)
        self.files = []
        for patient_id in os.listdir(data_dir):
            patient_dir = os.path.join(data_dir, patient_id)
            if os.path.isdir(patient_dir):
                for f in os.listdir(patient_dir):
                    if f.endswith('.pt'):
                        self.files.append((os.path.join(patient_dir, f), patient_id))

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        file_path, patient_id = self.files[idx]
        data = torch.load(file_path)
        tile_data = torch.from_numpy(data['tile_data'][0]).unsqueeze(0)  # Add batch dimension
        with torch.no_grad():
            vision_features, _ = self.model(pixel_values=tile_data, score_vector=torch.zeros(1, 4))
        feature_path = self.save_dir / patient_id / os.path.basename(file_path)
        feature_path.parent.mkdir(parents=True, exist_ok=True)
        torch.save(vision_features, feature_path)
        return feature_path

def extract_features(data_dir, save_dir, model_path):
    original_model = CLIPVisionModel.from_pretrained("./plip/")
    custom_model = GenomicPLIPModel(original_model)
    custom_model.load_state_dict(torch.load(model_path))
    custom_model.eval()

    dataset = PatientTileDataset(data_dir=data_dir, model=custom_model, save_dir=save_dir)
    for _ in dataset:
        pass

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Extract features from genomic aligned tiles.")
    parser.add_argument('--data_dir', type=str, default='plip_preprocess/', help='Directory containing the pre processed patient data.')
    parser.add_argument('--save_dir', type=str, default='omics_align_features/', help='Directory to save the extracted features.')
    parser.add_argument('--model_path', type=str, default='./save_model/omics_plip.pth', help='Path to the trained model file.')

    args = parser.parse_args()

    extract_features(args.data_dir, args.save_dir, args.model_path)