File size: 2,849 Bytes
d41c4d4
 
2076935
 
d41c4d4
a0583df
d41c4d4
 
 
 
 
 
 
 
 
 
 
 
 
a0583df
d41c4d4
 
 
 
 
 
 
 
a0583df
2076935
a0583df
 
 
 
 
 
 
 
 
2076935
 
d41c4d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2076935
 
 
 
 
 
a0583df
 
 
 
 
d41c4d4
a0583df
 
2076935
 
a0583df
d41c4d4
a0583df
 
 
 
2076935
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from pathlib import Path

import pandas as pd
import torch
from PIL import Image
from torch import nn
from torch.utils.data import DataLoader, Dataset
from transformers import AutoImageProcessor, AutoModel
import numpy as np


class ImageDataset(Dataset):
    def __init__(self, metadata_path, images_root_path):
        self.metadata_path = metadata_path
        self.metadata = pd.read_csv(metadata_path)
        self.images_root_path = images_root_path

    def __len__(self):
        return len(self.metadata)

    def __getitem__(self, idx):
        row = self.metadata.iloc[idx]
        image_path = Path(self.images_root_path) / row.filename
        img = Image.open(image_path).convert("RGB")
        # convert to numpy array
        img = torch.from_numpy(np.array(img))
        # img = torch.tensor(img).permute(2, 0, 1).float() / 255.0
        return {"features": img, "observation_id": row.observation_id}


class LinearClassifier(nn.Module):
    def __init__(self, num_features, num_classes):
        super().__init__()
        self.num_features = num_features
        self.num_classes = num_classes
        self.model = nn.Linear(num_features, num_classes)

    def forward(self, x):
        return torch.log_softmax(self.model(x), dim=1)


class TransformDino:
    def __init__(self, model_name="facebook/dinov2-base"):
        self.processor = AutoImageProcessor.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name)

    def forward(self, batch):
        model_inputs = self.processor(images=batch["features"], return_tensors="pt")
        with torch.no_grad():
            outputs = self.model(**model_inputs)
            last_hidden_states = outputs.last_hidden_state
        # extract the cls token
        batch["features"] = last_hidden_states[:, 0]
        return batch


def make_submission(
    test_metadata,
    model_path,
    output_csv_path="./submission.csv",
    images_root_path="/tmp/data/private_testset",
):
    checkpoint = torch.load(model_path)
    hparams = checkpoint["hyper_parameters"]
    model = LinearClassifier(hparams["num_features"], hparams["num_classes"])
    model.load_state_dict(checkpoint["state_dict"])

    transform = TransformDino()
    dataloader = DataLoader(
        ImageDataset(test_metadata, images_root_path), batch_size=32, num_workers=4
    )
    rows = []
    for batch in dataloader:
        batch = transform.forward(batch)
        observation_ids = batch["observation_id"]
        logits = model(batch["features"])
        class_ids = torch.argmax(logits, dim=1)
        for observation_id, class_id in zip(observation_ids, class_ids):
            row = {"observation_id": int(observation_id), "class_id": int(class_id)}
            rows.append(row)
    submission_df = pd.DataFrame(rows)
    submission_df.to_csv(output_csv_path, index=False)