mawairon commited on
Commit
65391f8
1 Parent(s): f7eb5ec

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -2
app.py CHANGED
@@ -8,6 +8,9 @@ import matplotlib.pyplot as plt
8
  import io
9
  import base64
10
  import os
 
 
 
11
 
12
  # Load label mapping
13
  label_to_int = pd.read_pickle('label_to_int.pkl')
@@ -47,8 +50,17 @@ def load_model():
47
  input_size = 768 + metadata_features
48
  log_reg = LogisticRegressionTorch(input_dim=input_size, output_dim=N_UNIQUE_CLASSES)
49
 
50
- model_weights_path = os.getenv('MODEL_PATH')
51
- weights = torch.load(model_weights_path, map_location=torch.device('cpu'))
 
 
 
 
 
 
 
 
 
52
 
53
  base_model.load_state_dict(weights['model_state_dict'])
54
  log_reg.load_state_dict(weights['log_reg_state_dict'])
 
8
  import io
9
  import base64
10
  import os
11
+ import huggingface_hub
12
+ from huggingface_hub import hf_hub_download
13
+
14
 
15
  # Load label mapping
16
  label_to_int = pd.read_pickle('label_to_int.pkl')
 
50
  input_size = 768 + metadata_features
51
  log_reg = LogisticRegressionTorch(input_dim=input_size, output_dim=N_UNIQUE_CLASSES)
52
 
53
+ #model_weights_path = os.getenv('MODEL_PATH')
54
+ #weights = torch.load(model_weights_path, map_location=torch.device('cpu'))
55
+
56
+
57
+ # Download the file
58
+ file_path = hf_hub_download(repo_id="mawairon/noo_test", filename="/gena-blastln-bs33-lr4e-05-S168.pth")
59
+ ")
60
+
61
+ # Now you can use the file_path to load your .pth file
62
+ weights = torch.load(file_path)
63
+
64
 
65
  base_model.load_state_dict(weights['model_state_dict'])
66
  log_reg.load_state_dict(weights['log_reg_state_dict'])