import torch import torch.nn as nn from torch.utils import data import os from PIL import Image from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize try: from torchvision.transforms import InterpolationMode BICUBIC = InterpolationMode.BICUBIC except ImportError: BICUBIC = Image.BICUBIC import glob def image_transform(n_px): return Compose([ Resize(n_px, interpolation=BICUBIC), CenterCrop(n_px), ToTensor(), Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), ]) class Image_dataset(data.Dataset): def __init__(self,dataset_folder="/data1/haolin/datasets",categories=['03001627'],n_px=224): self.dataset_folder=dataset_folder self.image_folder=os.path.join(self.dataset_folder,'other_data') self.preprocess=image_transform(n_px) self.image_path=[] for cat in categories: subpath=os.path.join(self.image_folder,cat,"6_images") model_list=os.listdir(subpath) for folder in model_list: model_folder=os.path.join(subpath,folder) image_list=os.listdir(model_folder) for image_filename in image_list: image_filepath=os.path.join(model_folder,image_filename) self.image_path.append(image_filepath) def __len__(self): return len(self.image_path) def __getitem__(self,index): path=self.image_path[index] basename=os.path.basename(path)[:-4] model_id=path.split(os.sep)[-2] category=path.split(os.sep)[-4] image=Image.open(path) image_tensor=self.preprocess(image) return {"images":image_tensor,"image_name":basename,"model_id":model_id,"category":category} class Image_InTheWild_dataset(data.Dataset): def __init__(self,dataset_dir="/data1/haolin/data/real_scene_process_data",scene_id="letian-310",n_px=224): self.dataset_dir=dataset_dir self.preprocess = image_transform(n_px) self.image_path = [] if scene_id=="all": scene_list=os.listdir(self.dataset_dir) for id in scene_list: image_folder=os.path.join(self.dataset_dir,id,"6_images") self.image_path+=glob.glob(image_folder+"/*/*jpg") else: image_folder = os.path.join(self.dataset_dir, scene_id, "6_images") self.image_path += glob.glob(image_folder + "/*/*jpg") def __len__(self): return len(self.image_path) def __getitem__(self,index): path=self.image_path[index] basename=os.path.basename(path)[:-4] model_id=path.split(os.sep)[-2] scene_id=path.split(os.sep)[-4] image=Image.open(path) image_tensor=self.preprocess(image) return {"images":image_tensor,"image_name":basename,"model_id":model_id,"scene_id":scene_id}