TechSSN commited on
Commit
7c9b16b
·
verified ·
1 Parent(s): c40d23a

Update script.py

Browse files
Files changed (1) hide show
  1. script.py +52 -40
script.py CHANGED
@@ -2,78 +2,90 @@ import pandas as pd
2
  import numpy as np
3
  import os
4
  from tqdm import tqdm
 
 
5
  from PIL import Image
6
  import torch
7
- import torch.nn as nn
8
- import torchvision.transforms as T
9
- from torchvision.models import resnet50
10
 
11
  def is_gpu_available():
 
12
  return torch.cuda.is_available()
13
 
14
- class ResNetClassifier(nn.Module):
15
- def __init__(self, num_classes, metadata_size):
16
- super(ResNetClassifier, self).__init__()
17
- self.resnet = resnet50(pretrained=True)
18
- self.resnet.fc = nn.Identity() # Remove the fully connected layer
19
- self.metadata_fc = nn.Linear(metadata_size, 128)
20
- self.classifier = nn.Linear(2048 + 128, num_classes) # 2048 is the output size of ResNet50
21
-
22
- def forward(self, x, metadata_features):
23
- resnet_features = self.resnet(x)
24
- metadata_features = self.metadata_fc(metadata_features)
25
- combined_features = torch.cat((resnet_features, metadata_features), dim=1)
26
- logits = self.classifier(combined_features)
27
- return logits
28
-
29
  class PytorchWorker:
30
- def __init__(self, model_path: str, num_classes: int, metadata_size: int):
31
- self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
32
- print(f"Using device: {self.device}")
33
- self.model = self._load_model(model_path, num_classes, metadata_size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  self.transforms = T.Compose([T.Resize((224, 224)),
35
  T.ToTensor(),
36
  T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
37
 
38
- def _load_model(self, model_path, num_classes, metadata_size):
39
- model = ResNetClassifier(num_classes, metadata_size)
40
- model.load_state_dict(torch.load(model_path, map_location=self.device))
41
- return model.to(self.device).eval()
42
 
43
- def predict_image(self, image: Image.Image, metadata_features: np.ndarray) -> list:
 
 
 
44
  input_tensor = self.transforms(image).unsqueeze(0).to(self.device)
45
- metadata_tensor = torch.tensor(metadata_features).unsqueeze(0).to(self.device)
 
46
  with torch.no_grad():
47
- logits = self.model(input_tensor, metadata_tensor)
 
48
  return logits.tolist()
49
 
50
- def make_submission(test_metadata, model_path, num_classes, metadata_size, output_csv_path="./submission.csv", images_root_path="/tmp/data/private_testset"):
51
- model = PytorchWorker(model_path, num_classes, metadata_size)
 
 
 
52
  predictions = []
 
53
  for _, row in tqdm(test_metadata.iterrows(), total=len(test_metadata)):
54
- image_path = os.path.join(images_root_path, row['image_path'])
 
55
  test_image = Image.open(image_path).convert("RGB")
56
- metadata_features = row.drop(['image_path', 'class_id']).values.astype(np.float32)
57
- logits = model.predict_image(test_image, metadata_features)
 
58
  predictions.append(np.argmax(logits))
 
59
  test_metadata["class_id"] = predictions
 
60
  user_pred_df = test_metadata.drop_duplicates("observation_id", keep="first")
61
  user_pred_df[["observation_id", "class_id"]].to_csv(output_csv_path, index=None)
62
 
63
  if __name__ == "__main__":
64
  import zipfile
 
65
  with zipfile.ZipFile("/tmp/data/private_testset.zip", 'r') as zip_ref:
66
  zip_ref.extractall("/tmp/data")
67
-
68
- MODEL_PATH = "pytorch_model.pth"
 
 
69
  metadata_file_path = "./SnakeCLEF2024-TestMetadata.csv"
70
  test_metadata = pd.read_csv(metadata_file_path)
71
- num_classes = 1784
72
- metadata_size = len(test_metadata.columns) - 2 # Excluding 'image_path' and 'class_id'
73
 
74
  make_submission(
75
  test_metadata=test_metadata,
76
  model_path=MODEL_PATH,
77
- num_classes=num_classes,
78
- metadata_size=metadata_size
79
  )
 
2
  import numpy as np
3
  import os
4
  from tqdm import tqdm
5
+ import timm
6
+ import torchvision.transforms as T
7
  from PIL import Image
8
  import torch
 
 
 
9
 
10
  def is_gpu_available():
11
+ """Check if the python package `onnxruntime-gpu` is installed."""
12
  return torch.cuda.is_available()
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  class PytorchWorker:
15
+ """Run inference using PyTorch."""
16
+
17
+ def __init__(self, model_path: str, model_name: str, number_of_categories: int = 1784):
18
+
19
+ def _load_model(model_name, model_path):
20
+
21
+ print("Setting up Pytorch Model")
22
+ self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
23
+ print(f"Using device: {self.device}")
24
+
25
+ model = timm.create_model(model_name, num_classes=number_of_categories, pretrained=False)
26
+
27
+ # Load model weights
28
+ model_ckpt = torch.load(model_path, map_location=self.device)
29
+ model.load_state_dict(model_ckpt)
30
+
31
+ return model.to(self.device).eval()
32
+
33
+ self.model = _load_model(model_name, model_path)
34
+
35
  self.transforms = T.Compose([T.Resize((224, 224)),
36
  T.ToTensor(),
37
  T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
38
 
39
+ def predict_image(self, image: Image.Image) -> list:
40
+ """Run inference using PyTorch.
 
 
41
 
42
+ :param image: Input image as PIL Image.
43
+ :return: A list with logits.
44
+ """
45
+ # Transform the image
46
  input_tensor = self.transforms(image).unsqueeze(0).to(self.device)
47
+
48
+ # Get logits
49
  with torch.no_grad():
50
+ logits = self.model(input_tensor)
51
+
52
  return logits.tolist()
53
 
54
+ def make_submission(test_metadata, model_path, model_name, output_csv_path="./submission.csv", images_root_path="/tmp/data/private_testset"):
55
+ """Make submission with given """
56
+
57
+ model = PytorchWorker(model_path, model_name)
58
+
59
  predictions = []
60
+
61
  for _, row in tqdm(test_metadata.iterrows(), total=len(test_metadata)):
62
+ image_path = os.path.join(images_root_path, row.image_path)
63
+
64
  test_image = Image.open(image_path).convert("RGB")
65
+
66
+ logits = model.predict_image(test_image)
67
+
68
  predictions.append(np.argmax(logits))
69
+
70
  test_metadata["class_id"] = predictions
71
+
72
  user_pred_df = test_metadata.drop_duplicates("observation_id", keep="first")
73
  user_pred_df[["observation_id", "class_id"]].to_csv(output_csv_path, index=None)
74
 
75
  if __name__ == "__main__":
76
  import zipfile
77
+
78
  with zipfile.ZipFile("/tmp/data/private_testset.zip", 'r') as zip_ref:
79
  zip_ref.extractall("/tmp/data")
80
+
81
+ MODEL_PATH = "resnet_classifier.pth"
82
+ MODEL_NAME = "tf_efficientnet_b1.ap_in1k"
83
+
84
  metadata_file_path = "./SnakeCLEF2024-TestMetadata.csv"
85
  test_metadata = pd.read_csv(metadata_file_path)
 
 
86
 
87
  make_submission(
88
  test_metadata=test_metadata,
89
  model_path=MODEL_PATH,
90
+ model_name=MODEL_NAME
 
91
  )