AioMedica / feature_extractor /build_graph_utils.py
chris1nexus
First commit
d60982d
raw
history blame
2.63 kB
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
import torchvision.models as models
import torchvision.transforms.functional as VF
from torchvision import transforms
import sys, argparse, os, glob
import pandas as pd
import numpy as np
from PIL import Image
from collections import OrderedDict
class ToPIL(object):
def __call__(self, sample):
img = sample
img = transforms.functional.to_pil_image(img)
return img
class BagDataset():
def __init__(self, csv_file, transform=None):
self.files_list = csv_file
self.transform = transform
def __len__(self):
return len(self.files_list)
def __getitem__(self, idx):
temp_path = self.files_list[idx]
img = os.path.join(temp_path)
img = Image.open(img)
img = img.resize((224, 224))
sample = {'input': img}
if self.transform:
sample = self.transform(sample)
return sample
class ToTensor(object):
def __call__(self, sample):
img = sample['input']
img = VF.to_tensor(img)
return {'input': img}
class Compose(object):
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, img):
for t in self.transforms:
img = t(img)
return img
def save_coords(txt_file, csv_file_path):
for path in csv_file_path:
x, y = path.split('/')[-1].split('.')[0].split('_')
txt_file.writelines(str(x) + '\t' + str(y) + '\n')
txt_file.close()
def adj_matrix(csv_file_path, output, device='cpu'):
total = len(csv_file_path)
adj_s = np.zeros((total, total))
for i in range(total-1):
path_i = csv_file_path[i]
x_i, y_i = path_i.split('/')[-1].split('.')[0].split('_')
for j in range(i+1, total):
# sptial
path_j = csv_file_path[j]
x_j, y_j = path_j.split('/')[-1].split('.')[0].split('_')
if abs(int(x_i)-int(x_j)) <=1 and abs(int(y_i)-int(y_j)) <= 1:
adj_s[i][j] = 1
adj_s[j][i] = 1
adj_s = torch.from_numpy(adj_s)
adj_s = adj_s.to(device)
return adj_s
def bag_dataset(args, csv_file_path):
transformed_dataset = BagDataset(csv_file=csv_file_path,
transform=Compose([
ToTensor()
]))
dataloader = DataLoader(transformed_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, drop_last=False)
return dataloader, len(transformed_dataset)