Spaces:
Running
Running
import os, sys | |
# Ajouter le répertoire racine au chemin | |
root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../..")) | |
sys.path.append(root_dir) | |
import torch | |
from utils.image_processing import CenterCrop | |
from data.extract_embeddings.dataset_with_path import ImageWithPathDataset | |
import torch | |
from torchvision import transforms | |
from pathlib import Path | |
from tqdm import tqdm | |
import numpy as np | |
import argparse | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--number_of_splits", | |
type=int, | |
help="Number of splits to process", | |
default=1, | |
) | |
parser.add_argument( | |
"--split_index", | |
type=int, | |
help="Index of the split to process", | |
default=0, | |
) | |
parser.add_argument( | |
"--input_path", | |
type=str, | |
help="Path to the input dataset", | |
) | |
parser.add_argument( | |
"--output_path", | |
type=str, | |
help="Path to the output dataset", | |
) | |
args = parser.parse_args() | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitl14_reg") | |
model = torch.compile(model, mode="max-autotune") | |
model.eval() | |
model.to(device) | |
input_path = Path(args.input_path) | |
output_path = Path(args.output_path) | |
output_path.mkdir(exist_ok=True, parents=True) | |
augmentation = transforms.Compose( | |
[ | |
CenterCrop(ratio="1:1"), | |
transforms.Resize(336, interpolation=transforms.InterpolationMode.BICUBIC), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), | |
] | |
) | |
dataset = ImageWithPathDataset(input_path, output_path, transform=augmentation) | |
dataset = torch.utils.data.Subset( | |
dataset, | |
range( | |
args.split_index * len(dataset) // args.number_of_splits, | |
( | |
(args.split_index + 1) * len(dataset) // args.number_of_splits | |
if args.split_index != args.number_of_splits - 1 | |
else len(dataset) | |
), | |
), | |
) | |
batch_size = 128 | |
dataloader = torch.utils.data.DataLoader( | |
dataset, batch_size=batch_size, num_workers=16, collate_fn=lambda x: zip(*x) | |
) | |
for images, output_emb_paths in tqdm(dataloader): | |
images = torch.stack(images, dim=0).to(device) | |
with torch.no_grad(): | |
embeddings = model(images) | |
numpy_embeddings = embeddings.cpu().numpy() | |
for emb, output_emb_path in zip(numpy_embeddings, output_emb_paths): | |
np.save(f"{output_emb_path}.npy", emb) | |