picekl commited on
Commit
cba4c98
·
1 Parent(s): 443dd28

fix: fixing problem with HF hub

Browse files
Files changed (3) hide show
  1. .gitattributes +1 -0
  2. pytorch_model.bin +3 -0
  3. script.py +26 -12
.gitattributes CHANGED
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  FungiCLEF2024_TestMetadata.csv filter=lfs diff=lfs merge=lfs -text
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  FungiCLEF2024_TestMetadata.csv filter=lfs diff=lfs merge=lfs -text
37
+ pytorch_model.bin filter=lfs diff=lfs merge=lfs -text
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:25dcfebee82c8b14a9a43b4934173becdc434745344257046a4f3eb4fe94dc6f
3
+ size 34696073
script.py CHANGED
@@ -16,20 +16,32 @@ def is_gpu_available():
16
  class PytorchWorker:
17
  """Run inference using ONNX runtime."""
18
 
19
- def __init__(self, onnx_path: str):
20
- print("Setting up Pytorch Model")
21
 
22
- self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
23
 
24
- print(f"Using devide: {self.device}")
25
- self.model = timm.create_model("hf-hub:BVRA/tf_efficientnet_b3.in1k_ft_df20_224", pretrained=True)
26
- self.model = self.model.eval()
27
- self.model.to(self.device)
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  self.transforms = T.Compose([T.Resize((224, 224)),
30
  T.ToTensor(),
31
  T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
32
 
 
33
  def predict_image(self, image: np.ndarray) -> list():
34
  """Run inference using ONNX runtime.
35
 
@@ -42,10 +54,10 @@ class PytorchWorker:
42
  return logits.tolist()
43
 
44
 
45
- def make_submission(test_metadata, model_path, output_csv_path="./submission.csv", images_root_path="/tmp/data/private_testset"):
46
  """Make submission with given """
47
 
48
- model = PytorchWorker(model_path)
49
 
50
  predictions = []
51
 
@@ -71,12 +83,14 @@ if __name__ == "__main__":
71
  with zipfile.ZipFile("/tmp/data/private_testset.zip", 'r') as zip_ref:
72
  zip_ref.extractall("/tmp/data")
73
 
74
- HFHUB_MODEL_PATH = "hf-hub:BVRA/tf_efficientnet_b3.in1k_ft_df20_224"
 
75
 
76
  metadata_file_path = "./FungiCLEF2024_TestMetadata.csv"
77
- test_metadata = pd.read_csv(metadata_file_path)
78
 
79
  make_submission(
80
  test_metadata=test_metadata,
81
- model_path=HFHUB_MODEL_PATH,
 
82
  )
 
16
  class PytorchWorker:
17
  """Run inference using ONNX runtime."""
18
 
19
+ def __init__(self, model_path: str, model_name: str, number_of_categories: int = 1604):
 
20
 
21
+ def _load_model(model_name, model_path):
22
 
23
+ print("Setting up Pytorch Model")
24
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
25
+ print(f"Using devide: {device}")
26
+
27
+ model = timm.create_model(model_name, num_classes=number_of_categories, pretrained=False)
28
+
29
+ if not torch.cuda.is_available():
30
+ model_ckpt = torch.load(model_path, map_location=torch.device("cpu"))
31
+ else:
32
+ model_ckpt = torch.load(model_path)["model"]
33
+
34
+ model.load_state_dict(model_ckpt)
35
+
36
+ return model.to(device).eval()
37
+
38
+ self.model = _load_model(model_name, model_path)
39
 
40
  self.transforms = T.Compose([T.Resize((224, 224)),
41
  T.ToTensor(),
42
  T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
43
 
44
+
45
  def predict_image(self, image: np.ndarray) -> list():
46
  """Run inference using ONNX runtime.
47
 
 
54
  return logits.tolist()
55
 
56
 
57
+ def make_submission(test_metadata, model_path, model_name, output_csv_path="./submission.csv", images_root_path="/tmp/data/private_testset"):
58
  """Make submission with given """
59
 
60
+ model = PytorchWorker(model_path, model_name)
61
 
62
  predictions = []
63
 
 
83
  with zipfile.ZipFile("/tmp/data/private_testset.zip", 'r') as zip_ref:
84
  zip_ref.extractall("/tmp/data")
85
 
86
+ MODEL_PATH = "pytorch_model.bin"
87
+ MODEL_NAME = "tf_efficientnet_b1.ap_in1k"
88
 
89
  metadata_file_path = "./FungiCLEF2024_TestMetadata.csv"
90
+ test_metadata = pd.read_csv(metadata_file_path)[:100]
91
 
92
  make_submission(
93
  test_metadata=test_metadata,
94
+ model_path=MODEL_PATH,
95
+ model_name=MODEL_NAME
96
  )