Fangrui Liu
init repo
3f1124e
raw
history blame
No virus
4.75 kB
import time
import requests
from io import BytesIO
from os import path
from torch.utils.data import Dataset
from PIL import Image
class TestImageSetOnline(Dataset):
""" Test Image set with hugging face CLIP preprocess interface
Args:
Dataset (torch.utils.data.Dataset):
"""
def __init__(self, processor, image_list, timeout_base=0.5, timeout_mul=2):
"""
Args:
processor (CLIP preprocessor): process data to a CLIP digestable format
image_list (pandas.DataFrame): pandas.DataFrame that contains image metadata
timeout_base (float, optional): initial timeout parameter. Defaults to 0.5.
timeout_mul (int, optional): multiplier on timeout every time reqeust fails. Defaults to 2.
"""
self.image_list = image_list
self.processor = processor
self.timeout_base = timeout_base
self.timeout = self.timeout_base
self.timeout_mul = timeout_mul
def __getitem__(self, index):
row = self.image_list[index]
url = str(row['coco_url'])
_id = str(row['id'])
txt, img = None, None
flag = True
while flag:
try:
# Get images online
response = requests.get(url)
img = Image.open(BytesIO(response.content))
img_s = img.size
if img.mode in ['L', 'CMYK', 'RGBA']:
# L is grayscale, CMYK uses alternative color channels
img = img.convert('RGB')
# Preprocess image
ret = self.processor(text=txt, images=img, return_tensor='pt')
img = ret['pixel_values'][0]
# If success, then there will be no need to run this again
flag = False
# Relief the timeout param
if self.timeout > self.timeout_base:
self.timeout /= self.timeout_mul
except Exception as e:
print(f"{_id} {url}: {str(e)}")
if type(e) is KeyboardInterrupt:
raise e
time.sleep(self.timeout)
# Tension the timeout param and turn into a new request
self.timeout *= self.timeout_mul
return _id, url, img, img_s
def get(self, url):
_id = url
txt, img = None, None
flag = True
while flag:
try:
# Get images online
response = requests.get(url)
img = Image.open(BytesIO(response.content))
img_s = img.size
if img.mode in ['L', 'CMYK', 'RGBA']:
# L is grayscale, CMYK uses alternative color channels
img = img.convert('RGB')
# Preprocess image
ret = self.processor(text=txt, images=img, return_tensor='pt')
img = ret['pixel_values'][0]
# If success, then there will be no need to run this again
flag = False
# Relief the timeout param
if self.timeout > self.timeout_base:
self.timeout /= self.timeout_mul
except Exception as e:
print(f"{_id} {url}: {str(e)}")
if type(e) is KeyboardInterrupt:
raise e
time.sleep(self.timeout)
# Tension the timeout param and turn into a new request
self.timeout *= self.timeout_mul
return _id, url, img, img_s
def __len__(self,):
return len(self.image_list)
def __add__(self, other):
self.image_list += other.image_list
return self
class TestImageSet(TestImageSetOnline):
def __init__(self, droot, processor, image_list, timeout_base=0.5, timeout_mul=2):
super().__init__(processor, image_list, timeout_base, timeout_mul)
self.droot = droot
def __getitem__(self, index):
row = self.image_list[index]
url = str(row['coco_url'])
_id = '_'.join([url.split('/')[-2], str(row['id'])])
txt, img = None, None
# Get images online
img = Image.open(path.join(self.droot,
url.split('http://images.cocodataset.org/')[1]))
img_s = img.size
if img.mode in ['L', 'CMYK', 'RGBA']:
# L is grayscale, CMYK uses alternative color channels
img = img.convert('RGB')
# Preprocess image
ret = self.processor(text=txt, images=img, return_tensor='pt')
img = ret['pixel_values'][0]
# If success, then there will be no need to run this again
return _id, url, img, img_s