DeLoFi / data /data.py
tronskel's picture
init
71026d8 verified
import numpy as np
import pandas as pd
import torch
import random
from os.path import join
from os.path import isfile
from PIL import Image
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset
from torchvision.transforms import (
Compose,
RandomCrop,
CenterCrop,
RandomHorizontalFlip,
ToTensor,
)
import time
from torchvision.transforms import GaussianBlur
from torchvision import transforms
def normalize(lat, lon):
"""Used to put all lat lon inside ±90 and ±180."""
lat = (lat + 90) % 360 - 90
if lat > 90:
lat = 180 - lat
lon += 180
lon = (lon + 180) % 360 - 180
return lat, lon
def collate_fn(batch):
"""Collate function for the dataloader.
Args:
batch (list): list of dictionaries with keys "img", "gps", "idx" and optionally "label"
Returns:
dict: dictionary with keys "img", "gps", "idx" and optionally "label"
"""
keys = list(batch[0].keys())
if "weight" in batch[0].keys():
keys.remove("weight")
output = {}
for key in [
"idx",
"unique_country",
"unique_region",
"unique_sub-region",
"unique_city",
"img_idx",
"text",
]:
if key in keys:
idx = [x[key] for x in batch]
output[key] = idx
keys.remove(key)
for key in keys:
if not ("text" in key):
output[key] = torch.stack([x[key] for x in batch])
return output
def collate_fn_streetclip(batch):
"""Collate function for the dataloader.
Args:
batch (list): list of dictionaries with keys "img", "gps", "idx" and optionally "label"
Returns:
dict: dictionary with keys "img", "gps", "idx" and optionally "label"
"""
keys = list(batch[0].keys())
if "weight" in batch[0].keys():
keys.remove("weight")
output = {}
for key in [
"idx",
"unique_country",
"unique_region",
"unique_sub-region",
"unique_city",
"img_idx",
"img",
"text",
]:
if key in keys:
idx = [x[key] for x in batch]
output[key] = idx
keys.remove(key)
for key in keys:
if not ("text" in key):
output[key] = torch.stack([x[key] for x in batch])
return output
def collate_fn_denstity(batch):
"""Collate function for the dataloader.
Args:
batch (list): list of dictionaries with keys "img", "gps", "idx" and optionally "label"
Returns:
dict: dictionary with keys "img", "gps", "idx" and optionally "label"
"""
keys = list(batch[0].keys())
if "weight" in batch[0].keys():
keys.remove("weight")
# Sample indices based on the weights
weights = np.array([x["weight"] for x in batch])
normalized_weights = weights / np.sum(weights)
sampled_indices = np.random.choice(
len(batch), size=len(batch), p=normalized_weights, replace=True
)
output = {}
for key in [
"idx",
"unique_country",
"unique_region",
"unique_sub-region",
"unique_city",
"img_idx",
"text",
]:
if key in keys:
idx = [batch[i][key] for i in sampled_indices]
output[key] = idx
keys.remove(key)
for key in keys:
if not ("text" in key):
output[key] = torch.stack([batch[i][key] for i in sampled_indices])
return output
def collate_fn_streetclip_denstity(batch):
"""Collate function for the dataloader.
Args:
batch (list): list of dictionaries with keys "img", "gps", "idx" and optionally "label"
Returns:
dict: dictionary with keys "img", "gps", "idx" and optionally "label"
"""
keys = list(batch[0].keys())
if "weight" in batch[0].keys():
keys.remove("weight")
# Sample indices based on the weights
weights = np.array([x["weight"] for x in batch])
normalized_weights = weights / np.sum(weights)
sampled_indices = np.random.choice(
len(batch), size=len(batch), p=normalized_weights, replace=True
)
output = {}
for key in [
"idx",
"unique_country",
"unique_region",
"unique_sub-region",
"unique_city",
"img_idx",
"img",
"text",
]:
if key in keys:
idx = [batch[i][key] for i in sampled_indices]
output[key] = idx
keys.remove(key)
for key in keys:
if not ("text" in key):
output[key] = torch.stack([batch[i][key] for i in sampled_indices])
return output
def collate_fn_contrastive(batch):
"""Collate function for the dataloader.
Args:
batch (list): list of dictionaries with keys "img", "gps", "idx" and optionally "label"
Returns:
dict: dictionary with keys "img", "gps", "idx" and optionally "label"
"""
output = collate_fn(batch)
pos_img = torch.stack([x["pos_img"] for x in batch])
output["pos_img"] = pos_img
return output
def collate_fn_contrastive_density(batch):
"""Collate function for the dataloader.
Args:
batch (list): list of dictionaries with keys "img", "gps", "idx" and optionally "label"
Returns:
dict: dictionary with keys "img", "gps", "idx" and optionally "label"
"""
keys = list(batch[0].keys())
if "weight" in batch[0].keys():
keys.remove("weight")
# Sample indices based on the weights
weights = np.array([x["weight"] for x in batch])
normalized_weights = weights / np.sum(weights)
sampled_indices = np.random.choice(
len(batch), size=len(batch), p=normalized_weights, replace=True
)
output = {}
for key in [
"idx",
"unique_country",
"unique_region",
"unique_sub-region",
"unique_city",
"img_idx",
]:
if key in keys:
idx = [batch[i][key] for i in sampled_indices]
output[key] = idx
keys.remove(key)
for key in keys:
if not ("text" in key):
output[key] = torch.stack([batch[i][key] for i in sampled_indices])
return output
class osv5m(Dataset):
csv_dtype = {"category": str, "country": str, "city": str} # Don't remove.
def __init__(
self,
path,
transforms,
split="train",
class_name=None,
aux_data=[],
is_baseline=False,
areas=["country", "region", "sub-region", "city"],
streetclip=False,
suff="",
blur=False
):
"""Initializes the dataset.
Args:
path (str): path to the dataset
transforms (torchvision.transforms): transforms to apply to the images
split (str): split to use (train, val, test)
class_name (str): category to use (e.g. "city")
aux_data (list of str): auxilliary datas to use
areas (list of str): regions to perform accuracy
streetclip (bool): if the model is streetclip, do not use transform
suff (str): suffix of test csv
blur (bool): blur bottom of images or not
"""
self.suff = suff
self.path = path
self.aux = len(aux_data) > 0
self.aux_list = aux_data
self.split = split
if split == "select":
self.df = self.load_split(split)
split = "test"
else:
self.df = self.load_split(split)
self.split = split
self.image_folder = join(
path,
'images',
("train" if split == "val" else split),
)
self.dict_names = {}
for root, _, files in os.walk(self.image_folder):
for file in files:
self.dict_names[file] = os.path.join(root, file)
self.is_baseline = is_baseline
if self.aux:
self.aux_data = {}
for col in self.aux_list:
if col in ["land_cover", "climate", "soil"]:
self.aux_data[col] = pd.get_dummies(self.df[col], dtype=float)
if col == "climate":
for i in range(31):
if not (i in list(self.aux_data[col].columns)):
self.aux_data[col][i] = 0
desired_order = [i for i in range(31)]
desired_order.remove(20)
self.aux_data[col] = self.aux_data[col][desired_order]
else:
self.aux_data[col] = self.df[col].apply(lambda x: [x])
self.areas = ["_".join(["unique", area]) for area in areas]
if class_name is None:
self.class_name = class_name
elif "quadtree" in class_name:
self.class_name = class_name
else:
self.class_name = "_".join(["unique", class_name])
ex = self.extract_classes(self.class_name)
self.df = self.df[
["id", "latitude", "longitude", "weight"] + self.areas + ex
].fillna("NaN")
if self.class_name in self.areas:
self.df.columns = list(self.df.columns)[:-1] + [self.class_name + "_2"]
self.transforms = transforms
self.collate_fn = collate_fn
self.collate_fn_density = collate_fn_denstity
self.blur = blur
self.streetclip = streetclip
if self.streetclip:
self.collate_fn = collate_fn_streetclip
self.collate_fn_density = collate_fn_streetclip_denstity
def load_split(self, split):
"""Returns a new dataset with the given split."""
start_time = time.time()
if split == "test":
df = pd.read_csv(join(self.path, "test.csv"), dtype=self.csv_dtype)
# extract coord
longitude = df["longitude"].values
latitude = df["latitude"].values
# Create bins
num_bins = 100
lon_bins = np.linspace(longitude.min(), longitude.max(), num_bins)
lat_bins = np.linspace(latitude.min(), latitude.max(), num_bins)
# compute density and weights
hist, _, _ = np.histogram2d(longitude, latitude, bins=[lon_bins, lat_bins])
weights = 1.0 / np.power(hist[df["lon_bin"], df["lat_bin"]], 0.75)
normalized_weights = weights / np.sum(weights)
df["weight"] = normalized_weights
return df
elif split == "select":
df = pd.read_csv(
join(self.path, "select.csv"), dtype=self.csv_dtype
)
# extract coord
longitude = df["longitude"].values
latitude = df["latitude"].values
# Create bins
num_bins = 100
lon_bins = np.linspace(longitude.min(), longitude.max(), num_bins)
lat_bins = np.linspace(latitude.min(), latitude.max(), num_bins)
# compute density and weights
hist, _, _ = np.histogram2d(longitude, latitude, bins=[lon_bins, lat_bins])
weights = 1.0 / np.power(hist[df["lon_bin"], df["lat_bin"]], 0.75)
normalized_weights = weights / np.sum(weights)
df["weight"] = normalized_weights
return df
else:
if len(self.suff) == 0:
df = pd.read_csv(
join(self.path, "train.csv"), dtype=self.csv_dtype
)
else:
df = pd.read_csv(
join(self.path, "train" + "_" + self.suff + ".csv"),
dtype=self.csv_dtype,
)
# extract coord
longitude = df["longitude"].values
latitude = df["latitude"].values
# Create bins
num_bins = 100
lon_bins = np.linspace(longitude.min(), longitude.max(), num_bins)
lat_bins = np.linspace(latitude.min(), latitude.max(), num_bins)
# compute density and weights
hist, _, _ = np.histogram2d(longitude, latitude, bins=[lon_bins, lat_bins])
weights = 1.0 / np.power(hist[df["lon_bin"], df["lat_bin"]], 0.75)
normalized_weights = weights / np.sum(weights)
df["weight"] = normalized_weights
test_df = df.sample(
n=int(0.1 * len(df)),
weights=normalized_weights,
replace=False,
random_state=42,
)
end_time = time.time()
print(f"Loading {split} dataset took {(end_time - start_time):.2f} seconds")
if split == "val":
return test_df
else:
return df.drop(test_df.index)
def extract_classes(self, tag=None):
"""Extracts the categories from the dataset."""
if tag is None:
self.has_labels = False
return []
splits = ["train", "test"] if self.is_baseline else ["train"]
# splits = ["train", "test"]
print(f"Loading categories from {splits}")
# concatenate all categories from relevant splits to find the unique ones.
self.categories = sorted(
pd.concat(
[
pd.read_csv(join(self.path, f"{split}.csv"))[tag]
for split in splits
]
)
.fillna("NaN")
.unique()
.tolist()
)
if "NaN" in self.categories:
self.categories.remove("NaN")
if self.split != "test":
self.df = self.df.dropna(subset=[tag])
# compute the total number of categories - this name is fixed and will be used as a lookup during init
self.num_classes = len(self.categories)
# create a mapping from category to index
self.category_to_index = {
category: i for i, category in enumerate(self.categories)
}
self.has_labels = True
return [tag]
def __getitem__(self, i):
"""Returns an item from the dataset.
Args:
i (int): index of the item
Returns:
dict: dictionary with keys "img", "gps", "idx" and optionally "label"
"""
x = list(self.df.iloc[i]) # id, latitude, longitude, {category}
if self.streetclip:
img = Image.open(self.dict_names[f"{int(x[0])}.jpg"])
elif self.blur:
img = transforms.ToTensor()(Image.open(self.dict_names[f"{int(x[0])}.jpg"]))
u = GaussianBlur(kernel_size = 13, sigma=2.0)
bottom_part = img[:, -14:, :].unsqueeze(0)
blurred_bottom = u(bottom_part)
img[:, -14:, :] = blurred_bottom.squeeze()
img = self.transforms(transforms.ToPILImage()(img))
else:
img = self.transforms(
Image.open(self.dict_names[f"{int(x[0])}.jpg"])
)
lat, lon = normalize(x[1], x[2])
gps = torch.FloatTensor([np.radians(lat), np.radians(lon)]).squeeze(0)
output = {
"img": img,
"gps": gps,
"idx": i,
"img_idx": int(x[0]),
"weight": x[3],
}
for count, area in enumerate(self.areas):
output[area] = x[
count + 4
] #'country': x[3], 'region': x[4], 'sub-region': x[5], 'city': x[6]}
if self.has_labels:
if x[-1] in self.categories:
output["label"] = torch.LongTensor(
[self.category_to_index[x[-1]]]
).squeeze(-1)
else:
output["label"] = torch.LongTensor([-1]).squeeze(-1)
if self.aux:
for col in self.aux_list:
output[col] = torch.FloatTensor(self.aux_data[col].iloc[i])
return output
def __len__(self):
return len(self.df)
class Contrastiveosv5m(osv5m):
def __init__(
self,
path,
transforms,
split="train",
class_name=None,
aux_data=[],
class_name2=None,
blur=False,
):
"""
class_name2 (str): if not None, we do contrastive an other class than the one specified for classif
"""
super().__init__(
path,
transforms,
split=split,
class_name=class_name,
aux_data=aux_data,
blur=blur,
)
self.add_label = False
if not(class_name2 is None) and split != 'test' and split != 'select':
self.add_label = True
self.class_name = class_name2
self.extract_classes_contrastive(tag=class_name2)
self.df = self.df.reset_index(drop=True)
self.dict_classes = {
value: indices.tolist()
for value, indices in self.df.groupby(self.class_name).groups.items()
}
self.collate_fn = collate_fn_contrastive
self.random_crop = RandomCrop(224) # use when no positive image is available
def sample_positive(self, i):
"""
sample positive image from the same city, country if it is available
otherwise, apply different crop to the image
"""
x = self.df.iloc[i] # id, latitude, longitude, {category}
class_name = x[self.class_name]
idxs = self.dict_classes[class_name]
idxs.remove(i)
if len(idxs) > 0:
idx = random.choice(idxs)
x = self.df.iloc[idx]
pos_img = self.transforms(
Image.open(self.dict_names[f"{int(x['id'])}.jpg"])
)
else:
pos_img = self.random_crop(
self.transforms(
Image.open(self.dict_names[f"{int(x['id'])}.jpg"])
)
)
return pos_img
def extract_classes_contrastive(self, tag=None):
"""Extracts the categories from the dataset."""
if tag is None:
self.has_labels = False
return []
splits = ["train", "test"] if self.is_baseline else ["train"]
# splits = ["train", "test"]
print(f"Loading categories from {splits}")
# concatenate all categories from relevant splits to find the unique ones.
categories = sorted(
pd.concat(
[
pd.read_csv(join(self.path, f"{split}.csv"))[tag]
for split in splits
]
)
.fillna("NaN")
.unique()
.tolist()
)
# create a mapping from category to index
self.contrastive_category_to_index = {
category: i for i, category in enumerate(categories)
}
def __getitem__(self, i):
output = super().__getitem__(i)
pos_img = self.sample_positive(i)
output["pos_img"] = pos_img
if self.add_label:
output["label_contrastive"] = torch.LongTensor(
[self.contrastive_category_to_index[self.df[self.class_name].iloc[i]]]
).squeeze(-1)
return output
class TextContrastiveosv5m(osv5m):
def __init__(
self,
path,
transforms,
split="train",
class_name=None,
aux_data=[],
blur=False,
):
super().__init__(
path,
transforms,
split=split,
class_name=class_name,
aux_data=aux_data,
blur=blur,
)
self.df = self.df.reset_index(drop=True)
def get_text(self, i):
"""
sample positive image from the same city, country if it is available
otherwise, apply different crop to the image
"""
x = self.df.iloc[i] # id, latitude, longitude, {category}
l = [
name.split("_")[-1]
for name in [
x["unique_city"],
x["unique_sub-region"],
x["unique_region"],
x["unique_country"],
]
]
pre = False
sentence = "An image of "
if l[0] != "NaN":
sentence += "the city of "
sentence += l[0]
pre = True
if l[1] != "NaN":
if pre:
sentence += ", in "
sentence += "the area of "
sentence += l[1]
pre = True
if l[2] != "NaN":
if pre:
sentence += ", in "
sentence += "the region of "
sentence += l[2]
pre = True
if l[3] != "NaN":
if pre:
sentence += ", in "
sentence += l[3]
return sentence
def __getitem__(self, i):
output = super().__getitem__(i)
output["text"] = self.get_text(i)
return output
import os
import json
class Baseline(Dataset):
def __init__(
self,
path,
which,
transforms,
):
"""Initializes the dataset.
Args:
path (str): path to the dataset
which (str): which baseline to use (im2gps, im2gps3k)
transforms (torchvision.transforms): transforms to apply to the images
"""
baselines = {
"im2gps": self.load_im2gps,
"im2gps3k": self.load_im2gps,
"yfcc4k": self.load_yfcc4k,
}
self.path = path
self.samples = baselines[which]()
self.transforms = transforms
self.collate_fn = collate_fn
self.class_name = which
def load_im2gps(
self,
):
json_path = join(self.path, "info.json")
with open(json_path) as f:
data = json.load(f)
samples = []
for f in os.listdir(join(self.path, "images")):
if len(data[f]):
lat = float(data[f][-4].replace("latitude: ", ""))
lon = float(data[f][-3].replace("longitude: ", ""))
samples.append((f, lat, lon))
return samples
def load_yfcc4k(
self,
):
samples = []
with open(join(self.path, "info.txt")) as f:
lines = f.readlines()
for line in lines:
x = line.split("\t")
f, lon, lat = x[1], x[12], x[13]
samples.append((f + ".jpg", float(lat), float(lon)))
return samples
def __getitem__(self, i):
"""Returns an item from the dataset.
Args:
i (int): index of the item
Returns:
dict: dictionary with keys "img", "gps", "idx" and optionally "label"
"""
img_path, lat, lon = self.samples[i]
img = self.transforms(
Image.open(join(self.path, "images", img_path)).convert("RGB")
)
lat, lon = normalize(lat, lon)
gps = torch.FloatTensor([np.radians(lat), np.radians(lon)]).squeeze(0)
return {
"img": img,
"gps": gps,
"idx": i,
}
def __len__(self):
return len(self.samples)