|
import numpy as np
|
|
import pandas as pd
|
|
import torch
|
|
import random
|
|
|
|
from os.path import join
|
|
from os.path import isfile
|
|
from PIL import Image
|
|
from sklearn.model_selection import train_test_split
|
|
from torch.utils.data import Dataset
|
|
from torchvision.transforms import (
|
|
Compose,
|
|
RandomCrop,
|
|
CenterCrop,
|
|
RandomHorizontalFlip,
|
|
ToTensor,
|
|
)
|
|
import time
|
|
from torchvision.transforms import GaussianBlur
|
|
from torchvision import transforms
|
|
|
|
def normalize(lat, lon):
|
|
"""Used to put all lat lon inside ±90 and ±180."""
|
|
lat = (lat + 90) % 360 - 90
|
|
if lat > 90:
|
|
lat = 180 - lat
|
|
lon += 180
|
|
lon = (lon + 180) % 360 - 180
|
|
return lat, lon
|
|
|
|
|
|
def collate_fn(batch):
|
|
"""Collate function for the dataloader.
|
|
Args:
|
|
batch (list): list of dictionaries with keys "img", "gps", "idx" and optionally "label"
|
|
Returns:
|
|
dict: dictionary with keys "img", "gps", "idx" and optionally "label"
|
|
"""
|
|
keys = list(batch[0].keys())
|
|
if "weight" in batch[0].keys():
|
|
keys.remove("weight")
|
|
output = {}
|
|
for key in [
|
|
"idx",
|
|
"unique_country",
|
|
"unique_region",
|
|
"unique_sub-region",
|
|
"unique_city",
|
|
"img_idx",
|
|
"text",
|
|
]:
|
|
if key in keys:
|
|
idx = [x[key] for x in batch]
|
|
output[key] = idx
|
|
keys.remove(key)
|
|
for key in keys:
|
|
if not ("text" in key):
|
|
output[key] = torch.stack([x[key] for x in batch])
|
|
return output
|
|
|
|
|
|
def collate_fn_streetclip(batch):
|
|
"""Collate function for the dataloader.
|
|
Args:
|
|
batch (list): list of dictionaries with keys "img", "gps", "idx" and optionally "label"
|
|
Returns:
|
|
dict: dictionary with keys "img", "gps", "idx" and optionally "label"
|
|
"""
|
|
keys = list(batch[0].keys())
|
|
if "weight" in batch[0].keys():
|
|
keys.remove("weight")
|
|
output = {}
|
|
for key in [
|
|
"idx",
|
|
"unique_country",
|
|
"unique_region",
|
|
"unique_sub-region",
|
|
"unique_city",
|
|
"img_idx",
|
|
"img",
|
|
"text",
|
|
]:
|
|
if key in keys:
|
|
idx = [x[key] for x in batch]
|
|
output[key] = idx
|
|
keys.remove(key)
|
|
for key in keys:
|
|
if not ("text" in key):
|
|
output[key] = torch.stack([x[key] for x in batch])
|
|
return output
|
|
|
|
|
|
def collate_fn_denstity(batch):
|
|
"""Collate function for the dataloader.
|
|
Args:
|
|
batch (list): list of dictionaries with keys "img", "gps", "idx" and optionally "label"
|
|
Returns:
|
|
dict: dictionary with keys "img", "gps", "idx" and optionally "label"
|
|
"""
|
|
keys = list(batch[0].keys())
|
|
if "weight" in batch[0].keys():
|
|
keys.remove("weight")
|
|
|
|
weights = np.array([x["weight"] for x in batch])
|
|
normalized_weights = weights / np.sum(weights)
|
|
sampled_indices = np.random.choice(
|
|
len(batch), size=len(batch), p=normalized_weights, replace=True
|
|
)
|
|
output = {}
|
|
for key in [
|
|
"idx",
|
|
"unique_country",
|
|
"unique_region",
|
|
"unique_sub-region",
|
|
"unique_city",
|
|
"img_idx",
|
|
"text",
|
|
]:
|
|
if key in keys:
|
|
idx = [batch[i][key] for i in sampled_indices]
|
|
output[key] = idx
|
|
keys.remove(key)
|
|
for key in keys:
|
|
if not ("text" in key):
|
|
output[key] = torch.stack([batch[i][key] for i in sampled_indices])
|
|
return output
|
|
|
|
|
|
def collate_fn_streetclip_denstity(batch):
|
|
"""Collate function for the dataloader.
|
|
Args:
|
|
batch (list): list of dictionaries with keys "img", "gps", "idx" and optionally "label"
|
|
Returns:
|
|
dict: dictionary with keys "img", "gps", "idx" and optionally "label"
|
|
"""
|
|
keys = list(batch[0].keys())
|
|
if "weight" in batch[0].keys():
|
|
keys.remove("weight")
|
|
|
|
weights = np.array([x["weight"] for x in batch])
|
|
normalized_weights = weights / np.sum(weights)
|
|
sampled_indices = np.random.choice(
|
|
len(batch), size=len(batch), p=normalized_weights, replace=True
|
|
)
|
|
output = {}
|
|
for key in [
|
|
"idx",
|
|
"unique_country",
|
|
"unique_region",
|
|
"unique_sub-region",
|
|
"unique_city",
|
|
"img_idx",
|
|
"img",
|
|
"text",
|
|
]:
|
|
if key in keys:
|
|
idx = [batch[i][key] for i in sampled_indices]
|
|
output[key] = idx
|
|
keys.remove(key)
|
|
for key in keys:
|
|
if not ("text" in key):
|
|
output[key] = torch.stack([batch[i][key] for i in sampled_indices])
|
|
return output
|
|
|
|
|
|
def collate_fn_contrastive(batch):
|
|
"""Collate function for the dataloader.
|
|
Args:
|
|
batch (list): list of dictionaries with keys "img", "gps", "idx" and optionally "label"
|
|
Returns:
|
|
dict: dictionary with keys "img", "gps", "idx" and optionally "label"
|
|
"""
|
|
output = collate_fn(batch)
|
|
pos_img = torch.stack([x["pos_img"] for x in batch])
|
|
output["pos_img"] = pos_img
|
|
return output
|
|
|
|
|
|
def collate_fn_contrastive_density(batch):
|
|
"""Collate function for the dataloader.
|
|
Args:
|
|
batch (list): list of dictionaries with keys "img", "gps", "idx" and optionally "label"
|
|
Returns:
|
|
dict: dictionary with keys "img", "gps", "idx" and optionally "label"
|
|
"""
|
|
keys = list(batch[0].keys())
|
|
if "weight" in batch[0].keys():
|
|
keys.remove("weight")
|
|
|
|
weights = np.array([x["weight"] for x in batch])
|
|
normalized_weights = weights / np.sum(weights)
|
|
sampled_indices = np.random.choice(
|
|
len(batch), size=len(batch), p=normalized_weights, replace=True
|
|
)
|
|
output = {}
|
|
for key in [
|
|
"idx",
|
|
"unique_country",
|
|
"unique_region",
|
|
"unique_sub-region",
|
|
"unique_city",
|
|
"img_idx",
|
|
]:
|
|
if key in keys:
|
|
idx = [batch[i][key] for i in sampled_indices]
|
|
output[key] = idx
|
|
keys.remove(key)
|
|
for key in keys:
|
|
if not ("text" in key):
|
|
output[key] = torch.stack([batch[i][key] for i in sampled_indices])
|
|
return output
|
|
|
|
|
|
class osv5m(Dataset):
|
|
csv_dtype = {"category": str, "country": str, "city": str}
|
|
|
|
def __init__(
|
|
self,
|
|
path,
|
|
transforms,
|
|
split="train",
|
|
class_name=None,
|
|
aux_data=[],
|
|
is_baseline=False,
|
|
areas=["country", "region", "sub-region", "city"],
|
|
streetclip=False,
|
|
suff="",
|
|
blur=False
|
|
):
|
|
"""Initializes the dataset.
|
|
Args:
|
|
path (str): path to the dataset
|
|
transforms (torchvision.transforms): transforms to apply to the images
|
|
split (str): split to use (train, val, test)
|
|
class_name (str): category to use (e.g. "city")
|
|
aux_data (list of str): auxilliary datas to use
|
|
areas (list of str): regions to perform accuracy
|
|
streetclip (bool): if the model is streetclip, do not use transform
|
|
suff (str): suffix of test csv
|
|
blur (bool): blur bottom of images or not
|
|
"""
|
|
self.suff = suff
|
|
self.path = path
|
|
self.aux = len(aux_data) > 0
|
|
self.aux_list = aux_data
|
|
self.split = split
|
|
if split == "select":
|
|
self.df = self.load_split(split)
|
|
split = "test"
|
|
else:
|
|
self.df = self.load_split(split)
|
|
self.split = split
|
|
self.image_folder = join(
|
|
path,
|
|
'images',
|
|
("train" if split == "val" else split),
|
|
)
|
|
|
|
self.dict_names = {}
|
|
for root, _, files in os.walk(self.image_folder):
|
|
for file in files:
|
|
self.dict_names[file] = os.path.join(root, file)
|
|
|
|
self.is_baseline = is_baseline
|
|
if self.aux:
|
|
self.aux_data = {}
|
|
for col in self.aux_list:
|
|
if col in ["land_cover", "climate", "soil"]:
|
|
self.aux_data[col] = pd.get_dummies(self.df[col], dtype=float)
|
|
if col == "climate":
|
|
for i in range(31):
|
|
if not (i in list(self.aux_data[col].columns)):
|
|
self.aux_data[col][i] = 0
|
|
desired_order = [i for i in range(31)]
|
|
desired_order.remove(20)
|
|
self.aux_data[col] = self.aux_data[col][desired_order]
|
|
else:
|
|
self.aux_data[col] = self.df[col].apply(lambda x: [x])
|
|
|
|
self.areas = ["_".join(["unique", area]) for area in areas]
|
|
if class_name is None:
|
|
self.class_name = class_name
|
|
elif "quadtree" in class_name:
|
|
self.class_name = class_name
|
|
else:
|
|
self.class_name = "_".join(["unique", class_name])
|
|
ex = self.extract_classes(self.class_name)
|
|
self.df = self.df[
|
|
["id", "latitude", "longitude", "weight"] + self.areas + ex
|
|
].fillna("NaN")
|
|
if self.class_name in self.areas:
|
|
self.df.columns = list(self.df.columns)[:-1] + [self.class_name + "_2"]
|
|
self.transforms = transforms
|
|
self.collate_fn = collate_fn
|
|
self.collate_fn_density = collate_fn_denstity
|
|
self.blur = blur
|
|
self.streetclip = streetclip
|
|
if self.streetclip:
|
|
self.collate_fn = collate_fn_streetclip
|
|
self.collate_fn_density = collate_fn_streetclip_denstity
|
|
|
|
def load_split(self, split):
|
|
"""Returns a new dataset with the given split."""
|
|
start_time = time.time()
|
|
if split == "test":
|
|
df = pd.read_csv(join(self.path, "test.csv"), dtype=self.csv_dtype)
|
|
|
|
longitude = df["longitude"].values
|
|
latitude = df["latitude"].values
|
|
|
|
num_bins = 100
|
|
lon_bins = np.linspace(longitude.min(), longitude.max(), num_bins)
|
|
lat_bins = np.linspace(latitude.min(), latitude.max(), num_bins)
|
|
|
|
hist, _, _ = np.histogram2d(longitude, latitude, bins=[lon_bins, lat_bins])
|
|
weights = 1.0 / np.power(hist[df["lon_bin"], df["lat_bin"]], 0.75)
|
|
normalized_weights = weights / np.sum(weights)
|
|
df["weight"] = normalized_weights
|
|
return df
|
|
elif split == "select":
|
|
df = pd.read_csv(
|
|
join(self.path, "select.csv"), dtype=self.csv_dtype
|
|
)
|
|
|
|
longitude = df["longitude"].values
|
|
latitude = df["latitude"].values
|
|
|
|
num_bins = 100
|
|
lon_bins = np.linspace(longitude.min(), longitude.max(), num_bins)
|
|
lat_bins = np.linspace(latitude.min(), latitude.max(), num_bins)
|
|
|
|
hist, _, _ = np.histogram2d(longitude, latitude, bins=[lon_bins, lat_bins])
|
|
weights = 1.0 / np.power(hist[df["lon_bin"], df["lat_bin"]], 0.75)
|
|
normalized_weights = weights / np.sum(weights)
|
|
df["weight"] = normalized_weights
|
|
return df
|
|
else:
|
|
if len(self.suff) == 0:
|
|
df = pd.read_csv(
|
|
join(self.path, "train.csv"), dtype=self.csv_dtype
|
|
)
|
|
else:
|
|
df = pd.read_csv(
|
|
join(self.path, "train" + "_" + self.suff + ".csv"),
|
|
dtype=self.csv_dtype,
|
|
)
|
|
|
|
|
|
longitude = df["longitude"].values
|
|
latitude = df["latitude"].values
|
|
|
|
num_bins = 100
|
|
lon_bins = np.linspace(longitude.min(), longitude.max(), num_bins)
|
|
lat_bins = np.linspace(latitude.min(), latitude.max(), num_bins)
|
|
|
|
hist, _, _ = np.histogram2d(longitude, latitude, bins=[lon_bins, lat_bins])
|
|
weights = 1.0 / np.power(hist[df["lon_bin"], df["lat_bin"]], 0.75)
|
|
normalized_weights = weights / np.sum(weights)
|
|
df["weight"] = normalized_weights
|
|
|
|
test_df = df.sample(
|
|
n=int(0.1 * len(df)),
|
|
weights=normalized_weights,
|
|
replace=False,
|
|
random_state=42,
|
|
)
|
|
|
|
end_time = time.time()
|
|
print(f"Loading {split} dataset took {(end_time - start_time):.2f} seconds")
|
|
|
|
if split == "val":
|
|
return test_df
|
|
else:
|
|
return df.drop(test_df.index)
|
|
|
|
def extract_classes(self, tag=None):
|
|
"""Extracts the categories from the dataset."""
|
|
if tag is None:
|
|
self.has_labels = False
|
|
return []
|
|
splits = ["train", "test"] if self.is_baseline else ["train"]
|
|
|
|
print(f"Loading categories from {splits}")
|
|
|
|
|
|
self.categories = sorted(
|
|
pd.concat(
|
|
[
|
|
pd.read_csv(join(self.path, f"{split}.csv"))[tag]
|
|
for split in splits
|
|
]
|
|
)
|
|
.fillna("NaN")
|
|
.unique()
|
|
.tolist()
|
|
)
|
|
|
|
if "NaN" in self.categories:
|
|
self.categories.remove("NaN")
|
|
if self.split != "test":
|
|
self.df = self.df.dropna(subset=[tag])
|
|
|
|
self.num_classes = len(self.categories)
|
|
|
|
|
|
self.category_to_index = {
|
|
category: i for i, category in enumerate(self.categories)
|
|
}
|
|
self.has_labels = True
|
|
return [tag]
|
|
|
|
def __getitem__(self, i):
|
|
"""Returns an item from the dataset.
|
|
Args:
|
|
i (int): index of the item
|
|
Returns:
|
|
dict: dictionary with keys "img", "gps", "idx" and optionally "label"
|
|
"""
|
|
x = list(self.df.iloc[i])
|
|
if self.streetclip:
|
|
img = Image.open(self.dict_names[f"{int(x[0])}.jpg"])
|
|
elif self.blur:
|
|
img = transforms.ToTensor()(Image.open(self.dict_names[f"{int(x[0])}.jpg"]))
|
|
u = GaussianBlur(kernel_size = 13, sigma=2.0)
|
|
bottom_part = img[:, -14:, :].unsqueeze(0)
|
|
blurred_bottom = u(bottom_part)
|
|
img[:, -14:, :] = blurred_bottom.squeeze()
|
|
img = self.transforms(transforms.ToPILImage()(img))
|
|
else:
|
|
img = self.transforms(
|
|
Image.open(self.dict_names[f"{int(x[0])}.jpg"])
|
|
)
|
|
|
|
lat, lon = normalize(x[1], x[2])
|
|
gps = torch.FloatTensor([np.radians(lat), np.radians(lon)]).squeeze(0)
|
|
|
|
output = {
|
|
"img": img,
|
|
"gps": gps,
|
|
"idx": i,
|
|
"img_idx": int(x[0]),
|
|
"weight": x[3],
|
|
}
|
|
|
|
for count, area in enumerate(self.areas):
|
|
output[area] = x[
|
|
count + 4
|
|
]
|
|
|
|
if self.has_labels:
|
|
if x[-1] in self.categories:
|
|
output["label"] = torch.LongTensor(
|
|
[self.category_to_index[x[-1]]]
|
|
).squeeze(-1)
|
|
else:
|
|
output["label"] = torch.LongTensor([-1]).squeeze(-1)
|
|
if self.aux:
|
|
for col in self.aux_list:
|
|
output[col] = torch.FloatTensor(self.aux_data[col].iloc[i])
|
|
return output
|
|
|
|
def __len__(self):
|
|
return len(self.df)
|
|
|
|
|
|
class Contrastiveosv5m(osv5m):
|
|
def __init__(
|
|
self,
|
|
path,
|
|
transforms,
|
|
split="train",
|
|
class_name=None,
|
|
aux_data=[],
|
|
class_name2=None,
|
|
blur=False,
|
|
):
|
|
"""
|
|
class_name2 (str): if not None, we do contrastive an other class than the one specified for classif
|
|
"""
|
|
super().__init__(
|
|
path,
|
|
transforms,
|
|
split=split,
|
|
class_name=class_name,
|
|
aux_data=aux_data,
|
|
blur=blur,
|
|
)
|
|
self.add_label = False
|
|
if not(class_name2 is None) and split != 'test' and split != 'select':
|
|
self.add_label = True
|
|
self.class_name = class_name2
|
|
self.extract_classes_contrastive(tag=class_name2)
|
|
self.df = self.df.reset_index(drop=True)
|
|
self.dict_classes = {
|
|
value: indices.tolist()
|
|
for value, indices in self.df.groupby(self.class_name).groups.items()
|
|
}
|
|
self.collate_fn = collate_fn_contrastive
|
|
self.random_crop = RandomCrop(224)
|
|
|
|
def sample_positive(self, i):
|
|
"""
|
|
sample positive image from the same city, country if it is available
|
|
otherwise, apply different crop to the image
|
|
"""
|
|
x = self.df.iloc[i]
|
|
class_name = x[self.class_name]
|
|
idxs = self.dict_classes[class_name]
|
|
idxs.remove(i)
|
|
|
|
if len(idxs) > 0:
|
|
idx = random.choice(idxs)
|
|
x = self.df.iloc[idx]
|
|
pos_img = self.transforms(
|
|
Image.open(self.dict_names[f"{int(x['id'])}.jpg"])
|
|
)
|
|
else:
|
|
pos_img = self.random_crop(
|
|
self.transforms(
|
|
Image.open(self.dict_names[f"{int(x['id'])}.jpg"])
|
|
)
|
|
)
|
|
return pos_img
|
|
|
|
def extract_classes_contrastive(self, tag=None):
|
|
"""Extracts the categories from the dataset."""
|
|
if tag is None:
|
|
self.has_labels = False
|
|
return []
|
|
splits = ["train", "test"] if self.is_baseline else ["train"]
|
|
|
|
print(f"Loading categories from {splits}")
|
|
|
|
|
|
categories = sorted(
|
|
pd.concat(
|
|
[
|
|
pd.read_csv(join(self.path, f"{split}.csv"))[tag]
|
|
for split in splits
|
|
]
|
|
)
|
|
.fillna("NaN")
|
|
.unique()
|
|
.tolist()
|
|
)
|
|
|
|
self.contrastive_category_to_index = {
|
|
category: i for i, category in enumerate(categories)
|
|
}
|
|
|
|
|
|
def __getitem__(self, i):
|
|
output = super().__getitem__(i)
|
|
pos_img = self.sample_positive(i)
|
|
output["pos_img"] = pos_img
|
|
if self.add_label:
|
|
output["label_contrastive"] = torch.LongTensor(
|
|
[self.contrastive_category_to_index[self.df[self.class_name].iloc[i]]]
|
|
).squeeze(-1)
|
|
return output
|
|
|
|
|
|
class TextContrastiveosv5m(osv5m):
|
|
def __init__(
|
|
self,
|
|
path,
|
|
transforms,
|
|
split="train",
|
|
class_name=None,
|
|
aux_data=[],
|
|
blur=False,
|
|
):
|
|
super().__init__(
|
|
path,
|
|
transforms,
|
|
split=split,
|
|
class_name=class_name,
|
|
aux_data=aux_data,
|
|
blur=blur,
|
|
)
|
|
self.df = self.df.reset_index(drop=True)
|
|
|
|
def get_text(self, i):
|
|
"""
|
|
sample positive image from the same city, country if it is available
|
|
otherwise, apply different crop to the image
|
|
"""
|
|
x = self.df.iloc[i]
|
|
l = [
|
|
name.split("_")[-1]
|
|
for name in [
|
|
x["unique_city"],
|
|
x["unique_sub-region"],
|
|
x["unique_region"],
|
|
x["unique_country"],
|
|
]
|
|
]
|
|
|
|
pre = False
|
|
sentence = "An image of "
|
|
if l[0] != "NaN":
|
|
sentence += "the city of "
|
|
sentence += l[0]
|
|
pre = True
|
|
|
|
if l[1] != "NaN":
|
|
if pre:
|
|
sentence += ", in "
|
|
sentence += "the area of "
|
|
sentence += l[1]
|
|
pre = True
|
|
|
|
if l[2] != "NaN":
|
|
if pre:
|
|
sentence += ", in "
|
|
sentence += "the region of "
|
|
sentence += l[2]
|
|
pre = True
|
|
|
|
if l[3] != "NaN":
|
|
if pre:
|
|
sentence += ", in "
|
|
sentence += l[3]
|
|
|
|
return sentence
|
|
|
|
def __getitem__(self, i):
|
|
output = super().__getitem__(i)
|
|
output["text"] = self.get_text(i)
|
|
return output
|
|
|
|
|
|
import os
|
|
import json
|
|
|
|
|
|
class Baseline(Dataset):
|
|
def __init__(
|
|
self,
|
|
path,
|
|
which,
|
|
transforms,
|
|
):
|
|
"""Initializes the dataset.
|
|
Args:
|
|
path (str): path to the dataset
|
|
which (str): which baseline to use (im2gps, im2gps3k)
|
|
transforms (torchvision.transforms): transforms to apply to the images
|
|
"""
|
|
baselines = {
|
|
"im2gps": self.load_im2gps,
|
|
"im2gps3k": self.load_im2gps,
|
|
"yfcc4k": self.load_yfcc4k,
|
|
}
|
|
self.path = path
|
|
self.samples = baselines[which]()
|
|
self.transforms = transforms
|
|
self.collate_fn = collate_fn
|
|
self.class_name = which
|
|
|
|
def load_im2gps(
|
|
self,
|
|
):
|
|
json_path = join(self.path, "info.json")
|
|
with open(json_path) as f:
|
|
data = json.load(f)
|
|
|
|
samples = []
|
|
for f in os.listdir(join(self.path, "images")):
|
|
if len(data[f]):
|
|
lat = float(data[f][-4].replace("latitude: ", ""))
|
|
lon = float(data[f][-3].replace("longitude: ", ""))
|
|
samples.append((f, lat, lon))
|
|
|
|
return samples
|
|
|
|
def load_yfcc4k(
|
|
self,
|
|
):
|
|
samples = []
|
|
with open(join(self.path, "info.txt")) as f:
|
|
lines = f.readlines()
|
|
for line in lines:
|
|
x = line.split("\t")
|
|
f, lon, lat = x[1], x[12], x[13]
|
|
samples.append((f + ".jpg", float(lat), float(lon)))
|
|
|
|
return samples
|
|
|
|
def __getitem__(self, i):
|
|
"""Returns an item from the dataset.
|
|
Args:
|
|
i (int): index of the item
|
|
Returns:
|
|
dict: dictionary with keys "img", "gps", "idx" and optionally "label"
|
|
"""
|
|
img_path, lat, lon = self.samples[i]
|
|
img = self.transforms(
|
|
Image.open(join(self.path, "images", img_path)).convert("RGB")
|
|
)
|
|
lat, lon = normalize(lat, lon)
|
|
gps = torch.FloatTensor([np.radians(lat), np.radians(lon)]).squeeze(0)
|
|
|
|
return {
|
|
"img": img,
|
|
"gps": gps,
|
|
"idx": i,
|
|
}
|
|
|
|
def __len__(self):
|
|
return len(self.samples) |