denoising / datasets.py
msong97's picture
remove model choices
3a575e4
raw
history blame
3.14 kB
from pathlib import Path
from typing import Callable, Optional
import os
import torch
from torch.utils.data import Dataset
from PIL import Image
class Preprocessed_fastMRI(torch.utils.data.Dataset):
"""FastMRI from preprocessed data for faster lading."""
def __init__(
self,
root: str,
transform: Optional[Callable] = None,
preprocess: bool = False,
) -> None:
self.root = root
self.transform = transform
self.preprocess = preprocess
# should contain all the information to load a data sample from the storage
self.sample_identifiers = []
# append all filenames in self.root ending with .pt
for root, _, files in os.walk(self.root):
for file in files:
if file.endswith(".pt"):
self.sample_identifiers.append(file)
def __len__(self) -> int:
return len(self.sample_identifiers)
def __getitem__(self, idx: int):
fname = self.sample_identifiers[idx]
tensor = torch.load(os.path.join(self.root, fname), weights_only=True)
img = tensor['data'].float()
if self.transform is not None:
img = self.transform(img)
if not self.preprocess:
return img
else:
# remove extension and prefix from filename
fname = Path(fname).stem
return img, fname
class Preprocessed_LIDCIDRI(torch.utils.data.Dataset):
"""FastMRI from preprocessed data for faster lading."""
def __init__(
self,
root: str,
transform: Optional[Callable] = None,
) -> None:
self.root = root
self.transform = transform
# should contain all the information to load a data sample from the storage
self.sample_identifiers = []
# append all filenames in self.root ending with .pt
for root, _, files in os.walk(self.root):
for file in files:
if file.endswith(".pt"):
self.sample_identifiers.append(file)
def __len__(self) -> int:
return len(self.sample_identifiers)
def __getitem__(self, idx: int):
fname = self.sample_identifiers[idx]
tensor = torch.load(os.path.join(self.root, fname), weights_only=True)
img = tensor['data'].float()
if self.transform is not None:
img = self.transform(img)
img = img.unsqueeze(0) # add channel dim
return img
class LsdirMiniDataset(torch.utils.data.Dataset):
def __init__(
self,
root: str,
transform: Optional[Callable] = None,
) -> None:
self.root = root
self.image_files = [f for f in os.listdir(self.root) if f.lower().endswith(('.png', '.jpeg'))]
self.transform = transform
def __len__(self) -> int:
return len(self.image_files)
def __getitem__(self, idx):
img_path = os.path.join(self.root, self.image_files[idx])
img = Image.open(img_path).convert("RGB") # Ensure consistent 3-channel format
if self.transform:
img = self.transform(img)
return img