Plonk / scripts /preprocessing /nearest-neighbors.py
nicolas-dufour's picture
squash: merge all unpushed commits
c4c7cee
import sys, os
import json
from PIL import Image
from tqdm import tqdm
from os.path import dirname, join
sys.path.append(dirname(dirname(__file__)))
import torch
from transformers import AutoImageProcessor, AutoModel
from transformers import CLIPProcessor, CLIPModel
from transformers import pipeline
from data.data import osv5m
from json_stream import streamable_list
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_model_clip():
model = CLIPModel.from_pretrained("laion/CLIP-ViT-L-14-laion2B-s32B-b82K")
processor = CLIPProcessor.from_pretrained("laion/CLIP-ViT-L-14-laion2B-s32B-b82K")
return processor, model.to(DEVICE)
def load_model_dino():
model = AutoModel.from_pretrained("facebook/dinov2-base")
processor = AutoImageProcessor.from_pretrained("facebook/dinov2-base")
return processor, model.to(DEVICE)
def compute_dino(processor, model, x):
inputs = processor(images=x[0], return_tensors="pt", device=DEVICE).to(DEVICE)
outputs = model(**inputs)
last_hidden_states = outputs.last_hidden_state.cpu().numpy()
for i in range(len(x[0])):
yield [last_hidden_states[i].tolist(), x[1][i], x[2][i], x[3][i]]
def compute_clip(processor, model, x):
inputs = processor(images=x[0], return_tensors="pt", device=DEVICE).to(DEVICE)
features = model.get_image_features(**inputs)
features /= features.norm(dim=-1, keepdim=True)
features = features.cpu().numpy()
for i in range(len(x[0])):
yield [features[i].tolist(), x[1][i], x[2][i], x[3][i]]
def get_batch(dataset, batch_size):
data, lats, lons, ids = [], [], [], []
for i in range(len(dataset)):
id, lat, lon = dataset.df.iloc[i]
data.append(Image.open(join(dataset.image_folder, f"{int(id)}.jpg")))
lats.append(lat)
lons.append(lon)
ids.append(id)
if len(data) == batch_size:
yield data, lats, lons, ids
data, lats, lons, ids = [], [], [], []
if len(data) > 0:
yield data, lats, lons, ids
data, lats, lons, ids = [], [], [], []
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--batch_size", type=int, default=256)
parser.add_argument("--compute_features", action="store_true")
parser.add_argument("--compute_nearest", action="store_true")
parser.add_argument("--json_path", default="features")
parser.add_argument("--which", type=str, default="clip", choices=["clip", "dino"])
args = parser.parse_args()
json_path = join(args.json_path, args.which)
os.makedirs(json_path, exist_ok=True)
if args.compute_features:
processor, model = (
load_model_clip() if args.which == "clip" else load_model_dino()
)
compute_fn = compute_clip if args.which == "clip" else compute_dino
for split in ["test"]: #'train',
# open existing json and read as dictionary
json_path_ = join(json_path, f"{split}.json")
dataset = OSV5M(
"datasets/osv5m", transforms=None, split=split, dont_split=True
)
@torch.no_grad()
def compute(batch_size):
for data in tqdm(
get_batch(dataset, batch_size),
total=len(dataset) // batch_size,
desc=f"Computing {split} on {args.which}",
):
features = compute_fn(processor, model, data)
for feature, lat, lon, id in features:
yield feature, lat, lon, id
data = streamable_list(compute(args.batch_size))
json.dump(data, open(json_path_, "w"), indent=4)
if args.compute_nearest:
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
train, test = [
json.load(open(join(json_path, f"{split}.json"), "r"))
for split in ["train", "test"]
]
def get_neighbors(k=10):
for i, test_data in enumerate(tqdm(test)):
feature, lat, lon, id = test_data
features_train = np.stack(
[np.array(train_data[0]) for train_data in train]
)
cs = np.squeeze(
cosine_similarity(np.expand_dims(feature, axis=0), features_train),
axis=0,
)
i = np.argsort(cs)[-k:][::-1].tolist()
yield [
{n: x}
for idx in i
for n, x in zip(
["feature", "lat", "lon", "id", "distance"],
train[idx]
+ [
cs[idx],
],
)
]
data = streamable_list(get_neighbors())
json.dump(data, open(join(json_path, "nearest.json"), "w"), indent=4)