efeperro commited on
Commit
9a45018
β€’
1 Parent(s): 9bd634e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -5
app.py CHANGED
@@ -29,10 +29,21 @@ def load_cnn():
29
 
30
  return model
31
 
32
- def predict_sentiment(text, model):
33
- processor.transform(text)
34
- prediction = model.predict([text])
35
- return prediction
 
 
 
 
 
 
 
 
 
 
 
36
 
37
 
38
  model_1 = load_model()
@@ -62,7 +73,7 @@ with st.expander("Model 2: CNN Sentiment analysis"):
62
  user_input = st.text_area("Enter text here...", key='model2_input')
63
  if st.button('Analyze', key='model2_button'):
64
  # Displaying output
65
- result = predict_sentiment(user_input, model_2)
66
  if result >= 0.5:
67
  st.write('The sentiment is: Positive πŸ˜€', key='model2_poswrite')
68
  else:
 
29
 
30
  return model
31
 
32
+ def predict_sentiment(text, model, torch=False):
33
+ if torch == True:
34
+ processed_text = processor.transform(text)
35
+ with torch.no_grad(): # Ensure no gradients are computed
36
+ prediction = model(processed_text) # Get raw model output
37
+ # Convert output to probabilities
38
+ probs = torch.softmax(prediction, dim=1)
39
+ # Get the predicted class
40
+ pred_class = torch.argmax(probs, dim=1)
41
+ return pred_class.item() # Return the predicted class as a Python int
42
+ else:
43
+ processor.transform(text)
44
+ prediction = model.predict([text])
45
+ return prediction
46
+
47
 
48
 
49
  model_1 = load_model()
 
73
  user_input = st.text_area("Enter text here...", key='model2_input')
74
  if st.button('Analyze', key='model2_button'):
75
  # Displaying output
76
+ result = predict_sentiment(user_input, model_2, torch=True)
77
  if result >= 0.5:
78
  st.write('The sentiment is: Positive πŸ˜€', key='model2_poswrite')
79
  else: