Plonk / scripts /preprocessing /fix_namimbia.py
nicolas-dufour's picture
squash: merge all unpushed commits
c4c7cee
from os.path import join, dirname
import numpy as np
import pandas as pd
if __name__ == "__main__":
# Define the list of cities
cities = [
"Walvis Bay",
"Keetmanshoop",
"Warmbad",
"Rundu",
"Outapi",
"Karibib",
"Otjimbingwe",
"Ondangwa",
"Oranjemund",
"Maltahohe",
"Otavi",
"Outjo",
"Swakopmund",
"Gobabis",
"Karasburg",
"Opuwo",
"Hentiesbaai",
"Katima Mulilo",
"Oshikango",
"Bethanie",
"Ongandjera",
"Mariental",
"Bagani",
"Nkurenkuru",
"Usakos",
"Rehoboth",
"Aranos",
"Omaruru",
"Arandis",
"Windhoek",
"Khorixas",
"Okahandja",
"Grootfontein",
"Tsumeb",
]
csv_dtype = {"category": str, "country": str, "city": str}
for split in ["train", "test"]:
fp = join(
dirname(dirname(__file__)), "datasets", "osv5m", f"{split}.csv"
)
# Read the CSV file into a pandas DataFrame
df = pd.read_csv(fp, dtype=csv_dtype)
# Check if the "country" column contains any of the cities in the list
mask = df["city"].isin(cities)
# If a city is found, set the corresponding rows in the "country" column to 'NMB'
df.loc[mask, "country"] = "NMB"
assert all(map(lambda x: isinstance(x, str), df["country"].unique().tolist()))
# Drop the columns that are all NaN
df.dropna(subset=["id", "latitude", "longitude"], inplace=True)
# Save the modified DataFrame back to the CSV file
df.to_csv(fp, index=False)