Anthony Miyaguchi commited on
Commit
d41c4d4
1 Parent(s): a0583df

Move everything into a single script

Browse files
Files changed (2) hide show
  1. evaluate/submission.py +41 -5
  2. script.py +80 -1
evaluate/submission.py CHANGED
@@ -1,10 +1,31 @@
 
 
1
  import pandas as pd
2
  import torch
 
3
  from torch import nn
4
- from torch.utils.data import DataLoader
5
- from torchvision.transforms import v2
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- from .data import ImageDataset, TransformDino
 
 
 
 
 
 
 
8
 
9
 
10
  class LinearClassifier(nn.Module):
@@ -18,6 +39,21 @@ class LinearClassifier(nn.Module):
18
  return torch.log_softmax(self.model(x), dim=1)
19
 
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  def make_submission(
22
  test_metadata,
23
  model_path,
@@ -29,13 +65,13 @@ def make_submission(
29
  model = LinearClassifier(hparams["num_features"], hparams["num_classes"])
30
  model.load_state_dict(checkpoint["state_dict"])
31
 
32
- transform = v2.Compose([TransformDino("facebook/dinov2-base")])
33
  dataloader = DataLoader(
34
  ImageDataset(test_metadata, images_root_path), batch_size=32, num_workers=4
35
  )
36
  rows = []
37
  for batch in dataloader:
38
- batch = transform(batch)
39
  observation_ids = batch["observation_id"]
40
  logits = model(batch["features"])
41
  class_ids = torch.argmax(logits, dim=1)
 
1
+ from pathlib import Path
2
+
3
  import pandas as pd
4
  import torch
5
+ from PIL import Image
6
  from torch import nn
7
+ from torch.utils.data import DataLoader, Dataset
8
+ from transformers import AutoImageProcessor, AutoModel
9
+ import numpy as np
10
+
11
+
12
+ class ImageDataset(Dataset):
13
+ def __init__(self, metadata_path, images_root_path):
14
+ self.metadata_path = metadata_path
15
+ self.metadata = pd.read_csv(metadata_path)
16
+ self.images_root_path = images_root_path
17
+
18
+ def __len__(self):
19
+ return len(self.metadata)
20
 
21
+ def __getitem__(self, idx):
22
+ row = self.metadata.iloc[idx]
23
+ image_path = Path(self.images_root_path) / row.filename
24
+ img = Image.open(image_path).convert("RGB")
25
+ # convert to numpy array
26
+ img = torch.from_numpy(np.array(img))
27
+ # img = torch.tensor(img).permute(2, 0, 1).float() / 255.0
28
+ return {"features": img, "observation_id": row.observation_id}
29
 
30
 
31
  class LinearClassifier(nn.Module):
 
39
  return torch.log_softmax(self.model(x), dim=1)
40
 
41
 
42
+ class TransformDino:
43
+ def __init__(self, model_name="facebook/dinov2-base"):
44
+ self.processor = AutoImageProcessor.from_pretrained(model_name)
45
+ self.model = AutoModel.from_pretrained(model_name)
46
+
47
+ def forward(self, batch):
48
+ model_inputs = self.processor(images=batch["features"], return_tensors="pt")
49
+ with torch.no_grad():
50
+ outputs = self.model(**model_inputs)
51
+ last_hidden_states = outputs.last_hidden_state
52
+ # extract the cls token
53
+ batch["features"] = last_hidden_states[:, 0]
54
+ return batch
55
+
56
+
57
  def make_submission(
58
  test_metadata,
59
  model_path,
 
65
  model = LinearClassifier(hparams["num_features"], hparams["num_classes"])
66
  model.load_state_dict(checkpoint["state_dict"])
67
 
68
+ transform = TransformDino()
69
  dataloader = DataLoader(
70
  ImageDataset(test_metadata, images_root_path), batch_size=32, num_workers=4
71
  )
72
  rows = []
73
  for batch in dataloader:
74
+ batch = transform.forward(batch)
75
  observation_ids = batch["observation_id"]
76
  logits = model(batch["features"])
77
  class_ids = torch.argmax(logits, dim=1)
script.py CHANGED
@@ -1,7 +1,86 @@
1
  #!/usr/bin/env python
2
  import zipfile
3
- from evaluate.submission import make_submission
4
  from argparse import ArgumentParser
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
 
7
  def parse_args():
 
1
  #!/usr/bin/env python
2
  import zipfile
 
3
  from argparse import ArgumentParser
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ import torch
9
+ from PIL import Image
10
+ from torch import nn
11
+ from torch.utils.data import DataLoader, Dataset
12
+ from transformers import AutoImageProcessor, AutoModel
13
+
14
+
15
+ class ImageDataset(Dataset):
16
+ def __init__(self, metadata_path, images_root_path):
17
+ self.metadata_path = metadata_path
18
+ self.metadata = pd.read_csv(metadata_path)
19
+ self.images_root_path = images_root_path
20
+
21
+ def __len__(self):
22
+ return len(self.metadata)
23
+
24
+ def __getitem__(self, idx):
25
+ row = self.metadata.iloc[idx]
26
+ image_path = Path(self.images_root_path) / row.filename
27
+ img = Image.open(image_path)
28
+ img = torch.from_numpy(np.array(img))
29
+ return {"features": img, "observation_id": row.observation_id}
30
+
31
+
32
+ class LinearClassifier(nn.Module):
33
+ def __init__(self, num_features, num_classes):
34
+ super().__init__()
35
+ self.num_features = num_features
36
+ self.num_classes = num_classes
37
+ self.model = nn.Linear(num_features, num_classes)
38
+
39
+ def forward(self, x):
40
+ return torch.log_softmax(self.model(x), dim=1)
41
+
42
+
43
+ class TransformDino:
44
+ def __init__(self, model_name="facebook/dinov2-base"):
45
+ self.processor = AutoImageProcessor.from_pretrained(model_name)
46
+ self.model = AutoModel.from_pretrained(model_name)
47
+
48
+ def forward(self, batch):
49
+ model_inputs = self.processor(images=batch["features"], return_tensors="pt")
50
+ with torch.no_grad():
51
+ outputs = self.model(**model_inputs)
52
+ last_hidden_states = outputs.last_hidden_state
53
+ # extract the cls token
54
+ batch["features"] = last_hidden_states[:, 0]
55
+ return batch
56
+
57
+
58
+ def make_submission(
59
+ test_metadata,
60
+ model_path,
61
+ output_csv_path="./submission.csv",
62
+ images_root_path="/tmp/data/private_testset",
63
+ ):
64
+ checkpoint = torch.load(model_path)
65
+ hparams = checkpoint["hyper_parameters"]
66
+ model = LinearClassifier(hparams["num_features"], hparams["num_classes"])
67
+ model.load_state_dict(checkpoint["state_dict"])
68
+
69
+ transform = TransformDino()
70
+ dataloader = DataLoader(
71
+ ImageDataset(test_metadata, images_root_path), batch_size=32, num_workers=4
72
+ )
73
+ rows = []
74
+ for batch in dataloader:
75
+ batch = transform.forward(batch)
76
+ observation_ids = batch["observation_id"]
77
+ logits = model(batch["features"])
78
+ class_ids = torch.argmax(logits, dim=1)
79
+ for observation_id, class_id in zip(observation_ids, class_ids):
80
+ row = {"observation_id": int(observation_id), "class_id": int(class_id)}
81
+ rows.append(row)
82
+ submission_df = pd.DataFrame(rows)
83
+ submission_df.to_csv(output_csv_path, index=False)
84
 
85
 
86
  def parse_args():