from pathlib import Path import pandas as pd import torch from PIL import Image from torch import nn from torch.utils.data import DataLoader, Dataset from transformers import AutoImageProcessor, AutoModel import numpy as np class ImageDataset(Dataset): def __init__(self, metadata_path, images_root_path): self.metadata_path = metadata_path self.metadata = pd.read_csv(metadata_path) self.images_root_path = images_root_path def __len__(self): return len(self.metadata) def __getitem__(self, idx): row = self.metadata.iloc[idx] image_path = Path(self.images_root_path) / row.filename img = Image.open(image_path).convert("RGB") # convert to numpy array img = torch.from_numpy(np.array(img)) # img = torch.tensor(img).permute(2, 0, 1).float() / 255.0 return {"features": img, "observation_id": row.observation_id} class LinearClassifier(nn.Module): def __init__(self, num_features, num_classes): super().__init__() self.num_features = num_features self.num_classes = num_classes self.model = nn.Linear(num_features, num_classes) def forward(self, x): return torch.log_softmax(self.model(x), dim=1) class TransformDino: def __init__(self, model_name="facebook/dinov2-base"): self.processor = AutoImageProcessor.from_pretrained(model_name) self.model = AutoModel.from_pretrained(model_name) def forward(self, batch): model_inputs = self.processor(images=batch["features"], return_tensors="pt") with torch.no_grad(): outputs = self.model(**model_inputs) last_hidden_states = outputs.last_hidden_state # extract the cls token batch["features"] = last_hidden_states[:, 0] return batch def make_submission( test_metadata, model_path, output_csv_path="./submission.csv", images_root_path="/tmp/data/private_testset", ): checkpoint = torch.load(model_path) hparams = checkpoint["hyper_parameters"] model = LinearClassifier(hparams["num_features"], hparams["num_classes"]) model.load_state_dict(checkpoint["state_dict"]) transform = TransformDino() dataloader = DataLoader( ImageDataset(test_metadata, images_root_path), batch_size=32, num_workers=4 ) rows = [] for batch in dataloader: batch = transform.forward(batch) observation_ids = batch["observation_id"] logits = model(batch["features"]) class_ids = torch.argmax(logits, dim=1) for observation_id, class_id in zip(observation_ids, class_ids): row = {"observation_id": int(observation_id), "class_id": int(class_id)} rows.append(row) submission_df = pd.DataFrame(rows) submission_df.to_csv(output_csv_path, index=False)