File size: 4,750 Bytes
3f1124e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import time
import requests
from io import BytesIO
from os import path
from torch.utils.data import Dataset
from PIL import Image

class TestImageSetOnline(Dataset):
    """ Test Image set with hugging face CLIP preprocess interface

    Args:
        Dataset (torch.utils.data.Dataset): 
    """
    def __init__(self, processor, image_list, timeout_base=0.5, timeout_mul=2):
        """
        Args:
            processor (CLIP preprocessor): process data to a CLIP digestable format
            image_list (pandas.DataFrame): pandas.DataFrame that contains image metadata
            timeout_base (float, optional): initial timeout parameter. Defaults to 0.5.
            timeout_mul (int, optional): multiplier on timeout every time reqeust fails. Defaults to 2.
        """
        self.image_list = image_list
        self.processor = processor
        self.timeout_base = timeout_base
        self.timeout = self.timeout_base
        self.timeout_mul = timeout_mul
    
    def __getitem__(self, index):
        row = self.image_list[index]
        url = str(row['coco_url'])
        _id = str(row['id'])
        txt, img = None, None
        flag = True
        while flag:
            try:
                # Get images online
                response = requests.get(url)
                img = Image.open(BytesIO(response.content))
                img_s = img.size
                if img.mode in ['L', 'CMYK', 'RGBA']:
                    # L is grayscale, CMYK uses alternative color channels
                    img = img.convert('RGB')
                # Preprocess image
                ret = self.processor(text=txt, images=img, return_tensor='pt')
                img = ret['pixel_values'][0]
                # If success, then there will be no need to run this again
                flag = False
                # Relief the timeout param
                if self.timeout > self.timeout_base:
                    self.timeout /= self.timeout_mul
            except Exception as e:
                print(f"{_id} {url}: {str(e)}")
                if type(e) is KeyboardInterrupt:
                    raise e
                time.sleep(self.timeout)
                # Tension the timeout param and turn into a new request
                self.timeout *= self.timeout_mul
        return _id, url, img, img_s
   
    def get(self, url):
        _id = url
        txt, img = None, None
        flag = True
        while flag:
            try:
                # Get images online
                response = requests.get(url)
                img = Image.open(BytesIO(response.content))
                img_s = img.size
                if img.mode in ['L', 'CMYK', 'RGBA']:
                    # L is grayscale, CMYK uses alternative color channels
                    img = img.convert('RGB')
                # Preprocess image
                ret = self.processor(text=txt, images=img, return_tensor='pt')
                img = ret['pixel_values'][0]
                # If success, then there will be no need to run this again
                flag = False
                # Relief the timeout param
                if self.timeout > self.timeout_base:
                    self.timeout /= self.timeout_mul
            except Exception as e:
                print(f"{_id} {url}: {str(e)}")
                if type(e) is KeyboardInterrupt:
                    raise e
                time.sleep(self.timeout)
                # Tension the timeout param and turn into a new request
                self.timeout *= self.timeout_mul
        return _id, url, img, img_s
    
    
    def __len__(self,):
        return len(self.image_list)
    
    def __add__(self, other):
        self.image_list += other.image_list
        return self
    
class TestImageSet(TestImageSetOnline):
    def __init__(self, droot, processor, image_list, timeout_base=0.5, timeout_mul=2):
        super().__init__(processor, image_list, timeout_base, timeout_mul)
        self.droot = droot
    
    def __getitem__(self, index):
        row = self.image_list[index]
        url = str(row['coco_url'])
        _id = '_'.join([url.split('/')[-2], str(row['id'])])
        txt, img = None, None
        # Get images online
        img = Image.open(path.join(self.droot,
                                   url.split('http://images.cocodataset.org/')[1]))
        img_s = img.size
        if img.mode in ['L', 'CMYK', 'RGBA']:
            # L is grayscale, CMYK uses alternative color channels
            img = img.convert('RGB')
        # Preprocess image
        ret = self.processor(text=txt, images=img, return_tensor='pt')
        img = ret['pixel_values'][0]
        # If success, then there will be no need to run this again
        return _id, url, img, img_s