Fangrui Liu
init repo
3f1124e
import time
import requests
from io import BytesIO
from os import path
from torch.utils.data import Dataset
from PIL import Image
class TestImageSetOnline(Dataset):
""" Test Image set with hugging face CLIP preprocess interface
Args:
Dataset (torch.utils.data.Dataset):
"""
def __init__(self, processor, image_list, timeout_base=0.5, timeout_mul=2):
"""
Args:
processor (CLIP preprocessor): process data to a CLIP digestable format
image_list (pandas.DataFrame): pandas.DataFrame that contains image metadata
timeout_base (float, optional): initial timeout parameter. Defaults to 0.5.
timeout_mul (int, optional): multiplier on timeout every time reqeust fails. Defaults to 2.
"""
self.image_list = image_list
self.processor = processor
self.timeout_base = timeout_base
self.timeout = self.timeout_base
self.timeout_mul = timeout_mul
def __getitem__(self, index):
row = self.image_list[index]
url = str(row['coco_url'])
_id = str(row['id'])
txt, img = None, None
flag = True
while flag:
try:
# Get images online
response = requests.get(url)
img = Image.open(BytesIO(response.content))
img_s = img.size
if img.mode in ['L', 'CMYK', 'RGBA']:
# L is grayscale, CMYK uses alternative color channels
img = img.convert('RGB')
# Preprocess image
ret = self.processor(text=txt, images=img, return_tensor='pt')
img = ret['pixel_values'][0]
# If success, then there will be no need to run this again
flag = False
# Relief the timeout param
if self.timeout > self.timeout_base:
self.timeout /= self.timeout_mul
except Exception as e:
print(f"{_id} {url}: {str(e)}")
if type(e) is KeyboardInterrupt:
raise e
time.sleep(self.timeout)
# Tension the timeout param and turn into a new request
self.timeout *= self.timeout_mul
return _id, url, img, img_s
def get(self, url):
_id = url
txt, img = None, None
flag = True
while flag:
try:
# Get images online
response = requests.get(url)
img = Image.open(BytesIO(response.content))
img_s = img.size
if img.mode in ['L', 'CMYK', 'RGBA']:
# L is grayscale, CMYK uses alternative color channels
img = img.convert('RGB')
# Preprocess image
ret = self.processor(text=txt, images=img, return_tensor='pt')
img = ret['pixel_values'][0]
# If success, then there will be no need to run this again
flag = False
# Relief the timeout param
if self.timeout > self.timeout_base:
self.timeout /= self.timeout_mul
except Exception as e:
print(f"{_id} {url}: {str(e)}")
if type(e) is KeyboardInterrupt:
raise e
time.sleep(self.timeout)
# Tension the timeout param and turn into a new request
self.timeout *= self.timeout_mul
return _id, url, img, img_s
def __len__(self,):
return len(self.image_list)
def __add__(self, other):
self.image_list += other.image_list
return self
class TestImageSet(TestImageSetOnline):
def __init__(self, droot, processor, image_list, timeout_base=0.5, timeout_mul=2):
super().__init__(processor, image_list, timeout_base, timeout_mul)
self.droot = droot
def __getitem__(self, index):
row = self.image_list[index]
url = str(row['coco_url'])
_id = '_'.join([url.split('/')[-2], str(row['id'])])
txt, img = None, None
# Get images online
img = Image.open(path.join(self.droot,
url.split('http://images.cocodataset.org/')[1]))
img_s = img.size
if img.mode in ['L', 'CMYK', 'RGBA']:
# L is grayscale, CMYK uses alternative color channels
img = img.convert('RGB')
# Preprocess image
ret = self.processor(text=txt, images=img, return_tensor='pt')
img = ret['pixel_values'][0]
# If success, then there will be no need to run this again
return _id, url, img, img_s