Spaces:
Running
Running
import traceback | |
import os | |
import sys | |
import PIL | |
import json | |
import torch | |
import numpy as np | |
import pandas as pd | |
import operator | |
import joblib | |
import reverse_geocoder | |
from PIL import Image | |
from itertools import cycle | |
from tqdm.auto import tqdm, trange | |
from os.path import join | |
from PIL import Image | |
from tqdm import tqdm | |
from collections import Counter | |
from transformers import CLIPProcessor, CLIPModel | |
from torch.utils.data import Dataset, DataLoader | |
from torch.nn import functional as F | |
from utils import haversine | |
class GeoDataset(Dataset): | |
def __init__(self, image_folder, annotation_file, tag="image_id"): | |
self.image_folder = image_folder | |
gt = pd.read_csv(annotation_file, dtype={tag: str}) | |
files = set([f.replace(".jpg", "") for f in os.listdir(image_folder)]) | |
gt = gt[gt[tag].isin(files)] | |
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
self.gt = [ | |
(g[1][tag], g[1]["latitude"], g[1]["longitude"]) for g in gt.iterrows() | |
] | |
self.tag = tag | |
def fid(self, i): | |
return self.gt[i][0] | |
def latlon(self, i): | |
return self.gt[i][1] | |
def __len__(self): | |
return len(self.gt) | |
def __getitem__(self, idx): | |
fp = join(self.image_folder, self.gt[idx][0] + ".jpg") | |
pil = PIL.Image.open(fp) | |
proc = self.processor(images=pil, return_tensors="pt") | |
proc["image_id"] = self.gt[idx][0] | |
return proc | |
def compute_features_clip(img, model): | |
image_ids = img.data.pop("image_id") | |
image_input = img.to(model.device) | |
image_input["pixel_values"] = image_input["pixel_values"].squeeze(1) | |
features = model.get_image_features(**image_input) | |
features /= features.norm(dim=-1, keepdim=True) | |
return image_ids, features.cpu() | |
def get_prompts(country, region, sub_region, city): | |
a = country if country != "" else None | |
b, c, d = None, None, None | |
if a is not None: | |
b = country + ", " + region if region != "" else None | |
if b is not None: | |
c = ( | |
country + ", " + region + ", " + sub_region | |
if sub_region != "" | |
else None | |
) | |
d = ( | |
country + ", " + region + ", " + sub_region + ", " + city | |
if city != "" | |
else None | |
) | |
return a, b, c, d | |
if __name__ == "__main__": | |
# make a train/eval argparser | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--annotation_file", type=str, required=False, default="train.csv" | |
) | |
parser.add_argument( | |
"--features_parent", type=str, default="/home/isig/gaia-v2/faiss/street-clip" | |
) | |
parser.add_argument( | |
"--data_parent", type=str, default="/home/isig/gaia-v2/loic-data/" | |
) | |
args = parser.parse_args() | |
test_path_csv = join(args.data_parent, "test.csv") | |
test_image_dir = join(args.data_parent, "test") | |
save_path = join(args.features_parent, "indexes/test.index") | |
test_features_dir = join(args.features_parent, "indexes/features-test") | |
processor = CLIPProcessor.from_pretrained("geolocal/StreetCLIP") | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = CLIPModel.from_pretrained("geolocal/StreetCLIP").to(device) | |
def compute_text_features_clip(text): | |
text_pt = processor(text=text, return_tensors="pt").to(device) | |
features = model.get_text_features(**text_pt) | |
features /= features.norm(dim=-1, keepdim=True) | |
return features.cpu().squeeze(0).numpy() | |
import country_converter as coco | |
if not os.path.isfile("text_street-clip-features.pkl"): | |
if not os.path.isfile("rg_cities1000.csv"): | |
os.system( | |
"wget https://raw.githubusercontent.com/thampiman/reverse-geocoder/master/reverse_geocoder/rg_cities1000.csv" | |
) | |
cities = pd.read_csv("rg_cities1000.csv") | |
cities = cities[["lat", "lon", "name", "admin1", "admin2", "cc"]] | |
reprs = {0: {}, 1: {}, 2: {}, 3: {}} | |
for line in tqdm( | |
cities.iterrows(), total=len(cities), desc="Creating hierarchy" | |
): | |
lat, lon, city, region, sub_region, cc = line[1] | |
try: | |
city, region, sub_region, cc = [ | |
("" if pd.isna(x) else x) | |
for x in [ | |
city, | |
region, | |
sub_region, | |
coco.convert(cc, to="name_short"), | |
] | |
] | |
a, b, c, d = get_prompts(cc, region, sub_region, city) | |
if a is not None: | |
if a not in reprs[0]: | |
reprs[0][a] = { | |
"gps": {(lat, lon)}, | |
"embedding": compute_text_features_clip(a), | |
} | |
else: | |
reprs[0][a]["gps"].add((lat, lon)) | |
if b is not None: | |
if b not in reprs[1]: | |
reprs[1][b] = { | |
"gps": {(lat, lon)}, | |
"embedding": compute_text_features_clip(b), | |
} | |
else: | |
reprs[1][b]["gps"].add((lat, lon)) | |
if c is not None: | |
if c not in reprs[2]: | |
reprs[2][c] = { | |
"gps": {(lat, lon)}, | |
"embedding": compute_text_features_clip(c), | |
} | |
else: | |
reprs[2][c]["gps"].add((lat, lon)) | |
if d is not None: | |
if d not in reprs[3]: | |
reprs[3][d] = { | |
"gps": {(lat, lon)}, | |
"embedding": compute_text_features_clip( | |
d.replace(", , ", ", ") | |
), | |
} | |
else: | |
reprs[3][d]["gps"].add((lat, lon)) | |
except Exception as e: | |
# print stack trace into file log.txt | |
with open("log.txt", "a") as f: | |
print(traceback.format_exc(), file=f) | |
reprs[-1] = {"": {"gps": (0, 0), "embedding": compute_text_features_clip("")}} | |
# compute mean for gps of all 'a' and 'b' and 'c' and 'd' | |
for i in range(4): | |
for k in reprs[i].keys(): | |
reprs[i][k]["gps"] = tuple( | |
np.array(list(reprs[i][k]["gps"])).mean(axis=0).tolist() | |
) | |
joblib.dump(reprs, "text_street-clip-features.pkl") | |
else: | |
reprs = joblib.load("text_street-clip-features.pkl") | |
def get_loc(x): | |
location = reverse_geocoder.search(x[0].tolist())[0] | |
country = coco.convert(names=location["cc"], to="name_short") | |
region = location.get("admin1", "") | |
sub_region = location.get("admin2", "") | |
city = location.get("name", "") | |
a, b, c, d = get_prompts(country, region, sub_region, city) | |
return a, b, c, d | |
def matches(embed, repr, control, gt, sw=None): | |
first_max = max( | |
( | |
(k, embed.dot(v["embedding"])) | |
for k, v in repr.items() | |
if sw is None or k.startswith(sw) | |
), | |
key=operator.itemgetter(1), | |
) | |
if first_max[1] > embed.dot(control["embedding"]): | |
return repr[first_max[0]]["gps"], gt == first_max[0] | |
else: | |
return control["gps"], False | |
def get_match_values(gt, embed, N, pos): | |
xa, xb, xc, xd = get_loc(gt) | |
if xa is not None: | |
N["country"] += 1 | |
gps, flag = matches(embed, reprs[0], reprs[-1][""], xa) | |
if flag: | |
pos["country"] += 1 | |
if xb is not None: | |
N["region"] += 1 | |
gps, flag = matches(embed, reprs[1], reprs[0][xa], xb, sw=xa) | |
if flag: | |
pos["region"] += 1 | |
if xc is not None: | |
N["sub-region"] += 1 | |
gps, flag = matches( | |
embed, reprs[2], reprs[1][xb], xc, sw=xb | |
) | |
if flag: | |
pos["sub-region"] += 1 | |
if xd is not None: | |
N["city"] += 1 | |
gps, flag = matches( | |
embed, reprs[3], reprs[2][xc], xd, sw=xc | |
) | |
if flag: | |
pos["city"] += 1 | |
else: | |
if xd is not None: | |
N["city"] += 1 | |
gps, flag = matches( | |
embed, reprs[3], reprs[1][xb], xd, sw=xb + ", " | |
) | |
if flag: | |
pos["city"] += 1 | |
haversine(np.array(gps)[None, :], np.array(gt), N, pos) | |
def compute_print_accuracy(N, pos): | |
for k in N.keys(): | |
pos[k] /= N[k] | |
# pretty-print accuracy in percentage with 2 floating points | |
print( | |
f'Accuracy: {pos["country"]*100.0:.2f} (country), {pos["region"]*100.0:.2f} (region), {pos["sub-region"]*100.0:.2f} (sub-region), {pos["city"]*100.0:.2f} (city)' | |
) | |
print( | |
f'Haversine: {pos["haversine"]:.2f} (haversine), {pos["geoguessr"]:.2f} (geoguessr)' | |
) | |
import joblib | |
data = GeoDataset(test_image_dir, test_path_csv, tag="id") | |
test_gt = pd.read_csv(test_path_csv, dtype={"id": str})[ | |
["id", "latitude", "longitude"] | |
] | |
test_gt = { | |
g[1]["id"]: np.array([g[1]["latitude"], g[1]["longitude"]]) | |
for g in tqdm(test_gt.iterrows(), total=len(test_gt), desc="Loading test_gt") | |
} | |
with open("/home/isig/gaia-v2/loic/plonk/test3_indices.txt", "r") as f: | |
# read lines | |
lines = f.readlines() | |
# remove whitespace characters like `\n` at the end of each line | |
lines = [l.strip() for l in lines] | |
# and convert to set | |
lines = set(lines) | |
train_test = [] | |
N, pos = Counter(), Counter() | |
for f in tqdm(os.listdir(test_features_dir)): | |
if f.replace(".npy", "") not in lines: | |
continue | |
query_vector = np.squeeze(np.load(join(test_features_dir, f))) | |
test_gps = test_gt[f.replace(".npy", "")][None, :] | |
get_match_values(test_gps, query_vector, N, pos) | |
compute_print_accuracy(N, pos) | |