Spaces:
Runtime error
Runtime error
import os | |
import sys | |
import pandas as pd | |
import json | |
from tqdm import tqdm | |
from PIL import Image | |
import torch | |
from multiprocessing import Pool | |
import h5py | |
from transformers import logging | |
from transformers import CLIPFeatureExtractor, CLIPVisionModel | |
logging.set_verbosity_error() | |
data_dir = 'data/images/' | |
features_dir = 'features/' | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
encoder_name = 'openai/clip-vit-base-patch32' | |
feature_extractor = CLIPFeatureExtractor.from_pretrained(encoder_name) | |
clip_encoder = CLIPVisionModel.from_pretrained(encoder_name).to(device) | |
annotations = json.load(open('data/dataset_coco.json'))['images'] | |
def load_data(): | |
data = {'train': [], 'val': []} | |
for item in annotations: | |
file_name = item['filename'].split('_')[-1] | |
if item['split'] == 'train' or item['split'] == 'restval': | |
data['train'].append({'file_name': file_name, 'cocoid': item['cocoid']}) | |
elif item['split'] == 'val': | |
data['val'].append({'file_name': file_name, 'cocoid': item['cocoid']}) | |
return data | |
def encode_split(data, split): | |
df = pd.DataFrame(data[split]) | |
bs = 256 | |
h5py_file = h5py.File(features_dir + '{}.hdf5'.format(split), 'w') | |
for idx in tqdm(range(0, len(df), bs)): | |
cocoids = df['cocoid'][idx:idx + bs] | |
file_names = df['file_name'][idx:idx + bs] | |
images = [Image.open(data_dir + file_name).convert("RGB") for file_name in file_names] | |
with torch.no_grad(): | |
pixel_values = feature_extractor(images, return_tensors='pt').pixel_values.to(device) | |
encodings = clip_encoder(pixel_values=pixel_values).last_hidden_state.cpu().numpy() | |
for cocoid, encoding in zip(cocoids, encodings): | |
h5py_file.create_dataset(str(cocoid), (50, 768), data=encoding) | |
data = load_data() | |
encode_split(data, 'train') | |
encode_split(data, 'val') | |