cifar10-outlier-low / prepare.py
MarkusStoll's picture
Duplicate from renumics/cifar10-outlier
713f16f
raw
history blame
1.34 kB
import pickle
import datasets
import os
import umap
if __name__ == "__main__":
cache_file = "dataset_cache.pkl"
if os.path.exists(cache_file):
# Load dataset from cache
with open(cache_file, "rb") as file:
dataset = pickle.load(file)
print("Dataset loaded from cache.")
else:
# Load dataset using datasets.load_dataset()
ds = datasets.load_dataset("renumics/cifar10-outlier", split="train")
print("Dataset loaded using datasets.load_dataset().")
df = ds.rename_columns({"img": "image", "label": "labels"}).to_pandas()
df["label_str"] = df["labels"].apply(lambda x: ds.features["label"].int2str(x))
# df = df[:1000]
# precalculate umap embeddings
df["embedding_ft_precalc"] = umap.UMAP(
n_neighbors=70, min_dist=0.5, random_state=42
).fit_transform(df["embedding_ft"].tolist()).tolist()
print("Umap for ft done")
df["embedding_foundation_precalc"] = umap.UMAP(
n_neighbors=70, min_dist=0.5, random_state=42
).fit_transform(df["embedding_foundation"].tolist()).tolist()
print("Umap for base done")
# Save dataset to cache
with open(cache_file, "wb") as file:
pickle.dump(df, file)
print("Dataset saved to cache.")