zaidmehdi commited on
Commit
2cb1421
1 Parent(s): 1b150a5

removing unnecessary util functions

Browse files
Files changed (1) hide show
  1. src/utils.py +8 -15
src/utils.py CHANGED
@@ -1,6 +1,7 @@
 
 
1
  import matplotlib.pyplot as plt
2
  import seaborn as sns
3
- from sklearn.metrics import accuracy_score, f1_score
4
  from sklearn.metrics import confusion_matrix
5
  import torch
6
 
@@ -13,22 +14,14 @@ def extract_hidden_state(input_text, tokenizer, language_model):
13
  return outputs.last_hidden_state[:,0].numpy()
14
 
15
 
16
- def get_metrics(y_true, y_preds):
17
- accuracy = accuracy_score(y_true, y_preds)
18
- f1_macro = f1_score(y_true, y_preds, average="macro")
19
- f1_weighted = f1_score(y_true, y_preds, average="weighted")
20
- print(f"Accuracy: {accuracy}")
21
- print(f"F1 macro average: {f1_macro}")
22
- print(f"F1 weighted average: {f1_weighted}")
23
 
24
 
25
- def evaluate_predictions(model:str, train_preds, y_train, test_preds, y_test):
26
- print(model)
27
- print("\nTrain set:")
28
- get_metrics(y_train, train_preds)
29
- print("-"*50)
30
- print("Test set:")
31
- get_metrics(y_test, test_preds)
32
 
33
 
34
  def plot_confusion_matrix(y_true, y_preds):
 
1
+ import pickle
2
+
3
  import matplotlib.pyplot as plt
4
  import seaborn as sns
 
5
  from sklearn.metrics import confusion_matrix
6
  import torch
7
 
 
14
  return outputs.last_hidden_state[:,0].numpy()
15
 
16
 
17
+ def serialize_data(data, output_path:str):
18
+ with open(output_path, "wb") as f:
19
+ pickle.dump(data, f)
 
 
 
 
20
 
21
 
22
+ def load_data(input_path:str):
23
+ with open(input_path, "rb") as f:
24
+ return pickle.load(f)
 
 
 
 
25
 
26
 
27
  def plot_confusion_matrix(y_true, y_preds):