File size: 2,944 Bytes
c40d23a
 
 
 
7c9b16b
 
c40d23a
 
 
 
7c9b16b
c40d23a
 
 
7c9b16b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c40d23a
 
 
 
7c9b16b
 
c40d23a
7c9b16b
 
 
 
c40d23a
7c9b16b
 
c40d23a
7c9b16b
 
c40d23a
 
7c9b16b
 
 
 
 
c40d23a
7c9b16b
c40d23a
7c9b16b
 
c40d23a
7c9b16b
 
 
c40d23a
7c9b16b
c40d23a
7c9b16b
c40d23a
 
 
 
 
7c9b16b
c40d23a
 
7c9b16b
 
 
 
c40d23a
 
 
 
 
 
7c9b16b
c40d23a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import pandas as pd
import numpy as np
import os
from tqdm import tqdm
import timm
import torchvision.transforms as T
from PIL import Image
import torch

def is_gpu_available():
    """Check if the python package `onnxruntime-gpu` is installed."""
    return torch.cuda.is_available()

class PytorchWorker:
    """Run inference using PyTorch."""

    def __init__(self, model_path: str, model_name: str, number_of_categories: int = 1784):

        def _load_model(model_name, model_path):

            print("Setting up Pytorch Model")
            self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
            print(f"Using device: {self.device}")

            model = timm.create_model(model_name, num_classes=number_of_categories, pretrained=False)

            # Load model weights
            model_ckpt = torch.load(model_path, map_location=self.device)
            model.load_state_dict(model_ckpt)

            return model.to(self.device).eval()

        self.model = _load_model(model_name, model_path)

        self.transforms = T.Compose([T.Resize((224, 224)),
                                     T.ToTensor(),
                                     T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])

    def predict_image(self, image: Image.Image) -> list:
        """Run inference using PyTorch.

        :param image: Input image as PIL Image.
        :return: A list with logits.
        """
        # Transform the image
        input_tensor = self.transforms(image).unsqueeze(0).to(self.device)
        
        # Get logits
        with torch.no_grad():
            logits = self.model(input_tensor)

        return logits.tolist()

def make_submission(test_metadata, model_path, model_name, output_csv_path="./submission.csv", images_root_path="/tmp/data/private_testset"):
    """Make submission with given """

    model = PytorchWorker(model_path, model_name)

    predictions = []

    for _, row in tqdm(test_metadata.iterrows(), total=len(test_metadata)):
        image_path = os.path.join(images_root_path, row.image_path)

        test_image = Image.open(image_path).convert("RGB")

        logits = model.predict_image(test_image)

        predictions.append(np.argmax(logits))

    test_metadata["class_id"] = predictions

    user_pred_df = test_metadata.drop_duplicates("observation_id", keep="first")
    user_pred_df[["observation_id", "class_id"]].to_csv(output_csv_path, index=None)

if __name__ == "__main__":
    import zipfile

    with zipfile.ZipFile("/tmp/data/private_testset.zip", 'r') as zip_ref:
        zip_ref.extractall("/tmp/data")

    MODEL_PATH = "resnet_classifier.pth"
    MODEL_NAME = "tf_efficientnet_b1.ap_in1k"

    metadata_file_path = "./SnakeCLEF2024-TestMetadata.csv"
    test_metadata = pd.read_csv(metadata_file_path)

    make_submission(
        test_metadata=test_metadata,
        model_path=MODEL_PATH,
        model_name=MODEL_NAME
    )