dsgt-snakeclef / evaluate /submission.py
Anthony Miyaguchi
Move everything into a single script
d41c4d4
raw
history blame contribute delete
No virus
2.85 kB
from pathlib import Path
import pandas as pd
import torch
from PIL import Image
from torch import nn
from torch.utils.data import DataLoader, Dataset
from transformers import AutoImageProcessor, AutoModel
import numpy as np
class ImageDataset(Dataset):
def __init__(self, metadata_path, images_root_path):
self.metadata_path = metadata_path
self.metadata = pd.read_csv(metadata_path)
self.images_root_path = images_root_path
def __len__(self):
return len(self.metadata)
def __getitem__(self, idx):
row = self.metadata.iloc[idx]
image_path = Path(self.images_root_path) / row.filename
img = Image.open(image_path).convert("RGB")
# convert to numpy array
img = torch.from_numpy(np.array(img))
# img = torch.tensor(img).permute(2, 0, 1).float() / 255.0
return {"features": img, "observation_id": row.observation_id}
class LinearClassifier(nn.Module):
def __init__(self, num_features, num_classes):
super().__init__()
self.num_features = num_features
self.num_classes = num_classes
self.model = nn.Linear(num_features, num_classes)
def forward(self, x):
return torch.log_softmax(self.model(x), dim=1)
class TransformDino:
def __init__(self, model_name="facebook/dinov2-base"):
self.processor = AutoImageProcessor.from_pretrained(model_name)
self.model = AutoModel.from_pretrained(model_name)
def forward(self, batch):
model_inputs = self.processor(images=batch["features"], return_tensors="pt")
with torch.no_grad():
outputs = self.model(**model_inputs)
last_hidden_states = outputs.last_hidden_state
# extract the cls token
batch["features"] = last_hidden_states[:, 0]
return batch
def make_submission(
test_metadata,
model_path,
output_csv_path="./submission.csv",
images_root_path="/tmp/data/private_testset",
):
checkpoint = torch.load(model_path)
hparams = checkpoint["hyper_parameters"]
model = LinearClassifier(hparams["num_features"], hparams["num_classes"])
model.load_state_dict(checkpoint["state_dict"])
transform = TransformDino()
dataloader = DataLoader(
ImageDataset(test_metadata, images_root_path), batch_size=32, num_workers=4
)
rows = []
for batch in dataloader:
batch = transform.forward(batch)
observation_ids = batch["observation_id"]
logits = model(batch["features"])
class_ids = torch.argmax(logits, dim=1)
for observation_id, class_id in zip(observation_ids, class_ids):
row = {"observation_id": int(observation_id), "class_id": int(class_id)}
rows.append(row)
submission_df = pd.DataFrame(rows)
submission_df.to_csv(output_csv_path, index=False)