File size: 2,921 Bytes
3df3a47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
# Luke Melas-Kyriazi's code. (https://twitter.com/lukemelas)

#%% 
import sys
import os
from datetime import datetime
import pandas as pd
import contexttimer
from urllib.request import urlopen
import requests
from PIL import Image
import torch
from torchvision.transforms import functional as TF
from multiprocessing import Pool
from tqdm import tqdm
import logging

# Setup
logging.basicConfig(filename='download.log', filemode='w', level=logging.INFO)
requests.packages.urllib3.disable_warnings(requests.packages.urllib3.exceptions.InsecureRequestWarning)


# # For downloading SVG images (I can't get this to work)
# from io import BytesIO
# import cairosvg

#%% 
# Load data
print(f'Starting to load at {datetime.now().isoformat(timespec="minutes")}')
with contexttimer.Timer(prefix="Loading from tsv"):
    df = pd.read_csv('./cc12m.tsv', delimiter='\t', header=None)

url_to_idx_map = {url: index for index, url, caption in df.itertuples()}
print(f'Loaded {len(url_to_idx_map)} urls')

#%% 
df.head()

#%% 

# Note: it seems that there are no SVG images
df.sample(10000)[1].str.contains('.svg').sum()

#%% 
# Resize function
def resize(img):
    max_size_of_short_side = 512
    if min(img.size) > max_size_of_short_side:
        img = TF.resize(img, size=max_size_of_short_side, interpolation=Image.LANCZOS) 
    return img

base_dir = os.path.join(os.getcwd(), 'images')

def process(item):
    url, image_id = item
    try:
        base_url = os.path.basename(url)  # extract base url
        stem, ext = os.path.splitext(base_url)  # split into stem and extension
        filename = f'{image_id:08d}---{stem}.jpg'  # create filename
        filepath = os.path.join(base_dir, filename)  # concat to get filepath
        if not os.path.isfile(filepath):
            # if filepath.endswith('.svg'):
            #     raise NotImplementedError()
            #     image_bytes = BytesIO()  # create a bytestream
            #     cairosvg.svg2png(url=url, write_to=image_bytes)  # convert svg into image
            # else:
            req = requests.get(url, stream=True, timeout=1, verify=False).raw
            image = Image.open(req).convert('RGB')
            if min(image.size) > 512:
                image = TF.resize(image, size=512, interpolation=Image.LANCZOS)
            # image = resize(image)  # resize PIL image
            image.save(filepath)  # save PIL image
    except Exception as e:
        logging.info(" ".join(repr(e).splitlines()))
        logging.error(url)

#%% 
#for i, item in enumerate(tqdm(url_to_idx_map.items(), total=len(url_to_idx_map))):
#    process(item)
#    if i > 100:
#        break

# Use multiprocessing for speed
list_of_items = list(url_to_idx_map.items())
print(len(list_of_items))
list_of_items = list_of_items[10_000_000:]
print(len(list_of_items))
with Pool(128) as p:
    r = list(tqdm(p.imap(process, list_of_items), total=len(list_of_items)))
    print('DONE')