Sandy0909 commited on
Commit
8ecc4e0
·
1 Parent(s): ef4ba0f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -5
app.py CHANGED
@@ -14,13 +14,12 @@ class FinancialBERT(torch.nn.Module):
14
  super(FinancialBERT, self).__init__()
15
  self.bert = BertForSequenceClassification.from_pretrained(Config.MODEL_PATH, num_labels=3, hidden_dropout_prob=0.5)
16
 
17
- def forward(self, input_ids, attention_mask, labels=None):
18
- output = self.bert(input_ids, attention_mask=attention_mask, labels=labels)
19
  return output.loss, output.logits
20
 
21
  # Load model
22
  model = FinancialBERT()
23
-
24
  model.eval()
25
 
26
  # Streamlit App
@@ -30,9 +29,9 @@ if st.button("Predict"):
30
  tokenizer = Config.TOKENIZER
31
  inputs = tokenizer([sentence], return_tensors="pt", truncation=True, padding=True, max_length=Config.MAX_LEN)
32
  with torch.no_grad():
33
- logits = model(**inputs)[1]
34
  probs = torch.nn.functional.softmax(logits, dim=-1)
35
  predictions = torch.argmax(probs, dim=-1)
36
  sentiment = ['negative', 'neutral', 'positive'][predictions[0].item()]
37
 
38
- st.write(f"The predicted sentiment is: {sentiment}")
 
14
  super(FinancialBERT, self).__init__()
15
  self.bert = BertForSequenceClassification.from_pretrained(Config.MODEL_PATH, num_labels=3, hidden_dropout_prob=0.5)
16
 
17
+ def forward(self, input_ids, attention_mask, token_type_ids, labels=None):
18
+ output = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, labels=labels)
19
  return output.loss, output.logits
20
 
21
  # Load model
22
  model = FinancialBERT()
 
23
  model.eval()
24
 
25
  # Streamlit App
 
29
  tokenizer = Config.TOKENIZER
30
  inputs = tokenizer([sentence], return_tensors="pt", truncation=True, padding=True, max_length=Config.MAX_LEN)
31
  with torch.no_grad():
32
+ logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], token_type_ids=inputs.get('token_type_ids'))[1]
33
  probs = torch.nn.functional.softmax(logits, dim=-1)
34
  predictions = torch.argmax(probs, dim=-1)
35
  sentiment = ['negative', 'neutral', 'positive'][predictions[0].item()]
36
 
37
+ st.write(f"The predicted sentiment is: {sentiment}")