CASY85 commited on
Commit
ccb91a2
·
verified ·
1 Parent(s): af7d307

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -9
app.py CHANGED
@@ -1,16 +1,17 @@
1
  import streamlit as st
2
  import joblib
3
  from sentence_transformers import SentenceTransformer
 
4
 
5
  # Load the pre-trained embedding model
6
  @st.cache_resource # Cache the embedding model to save loading time
7
  def load_embedding_model():
8
  return SentenceTransformer('neuml/pubmedbert-base-embeddings')
9
 
10
- # Load the MLP model
11
  @st.cache_resource # Cache the loaded model
12
- def load_mlp_model():
13
- with open("MLP.pkl", "rb") as file:
14
  return joblib.load(file)
15
 
16
  # Embed text
@@ -19,9 +20,15 @@ def get_embeddings(title, abstract, embedding_model):
19
  combined_text = title + " " + abstract
20
  return embedding_model.encode(combined_text)
21
 
 
 
 
 
 
 
22
  # Main Streamlit app
23
  def main():
24
- st.title("MLP Predictor for Titles and Abstracts")
25
 
26
  # Input fields
27
  title = st.text_input("Enter the Title:")
@@ -29,10 +36,10 @@ def main():
29
 
30
  # Load models
31
  embedding_model = load_embedding_model()
32
- mlp_model = load_mlp_model()
33
 
34
  # Predict button
35
- if st.button("Predict Label"):
36
  if title.strip() == "" or abstract.strip() == "":
37
  st.error("Both Title and Abstract are required!")
38
  else:
@@ -40,10 +47,17 @@ def main():
40
  embeddings = get_embeddings(title, abstract, embedding_model)
41
 
42
  # Make prediction
43
- prediction = mlp_model.predict([embeddings])[0] # Input should be a 2D array
44
 
45
- # Display result
46
- st.success(f"The predicted label is: {prediction}")
 
 
 
 
 
 
47
 
48
  if __name__ == "__main__":
49
  main()
 
 
1
  import streamlit as st
2
  import joblib
3
  from sentence_transformers import SentenceTransformer
4
+ import numpy as np
5
 
6
  # Load the pre-trained embedding model
7
  @st.cache_resource # Cache the embedding model to save loading time
8
  def load_embedding_model():
9
  return SentenceTransformer('neuml/pubmedbert-base-embeddings')
10
 
11
+ # Load the multilabel classification model
12
  @st.cache_resource # Cache the loaded model
13
+ def load_multilabel_model():
14
+ with open("multilabel_model.pkl", "rb") as file:
15
  return joblib.load(file)
16
 
17
  # Embed text
 
20
  combined_text = title + " " + abstract
21
  return embedding_model.encode(combined_text)
22
 
23
+ # Map predicted binary outputs to labels
24
+ LABELS = ["device", "screening", "drug", "surgery", "imaging", "telemedicine"]
25
+
26
+ def decode_predictions(predictions):
27
+ return [label for label, pred in zip(LABELS, predictions) if pred == 1]
28
+
29
  # Main Streamlit app
30
  def main():
31
+ st.title("Multilabel Classifier for Titles and Abstracts")
32
 
33
  # Input fields
34
  title = st.text_input("Enter the Title:")
 
36
 
37
  # Load models
38
  embedding_model = load_embedding_model()
39
+ multilabel_model = load_multilabel_model()
40
 
41
  # Predict button
42
+ if st.button("Predict Labels"):
43
  if title.strip() == "" or abstract.strip() == "":
44
  st.error("Both Title and Abstract are required!")
45
  else:
 
47
  embeddings = get_embeddings(title, abstract, embedding_model)
48
 
49
  # Make prediction
50
+ predictions = multilabel_model.predict([embeddings])[0] # Input should be a 2D array
51
 
52
+ # Decode predictions
53
+ predicted_labels = decode_predictions(predictions)
54
+
55
+ # Display results
56
+ if predicted_labels:
57
+ st.success(f"The predicted labels are: {', '.join(predicted_labels)}")
58
+ else:
59
+ st.warning("No relevant labels were predicted.")
60
 
61
  if __name__ == "__main__":
62
  main()
63
+