Spaces:
Runtime error
Runtime error
import streamlit as st | |
from transformers import BertTokenizer, BertForSequenceClassification | |
import torch | |
# Config class | |
class Config: | |
TOKENIZER_PATH = "ahmedrachid/FinancialBERT" # Use tokenizer from the original model | |
MODEL_PATH = "Sandy0909/finance_sentiment" | |
MAX_LEN = 512 | |
TOKENIZER = BertTokenizer.from_pretrained(TOKENIZER_PATH) | |
class FinancialBERT(torch.nn.Module): | |
def __init__(self): | |
super(FinancialBERT, self).__init__() | |
self.bert = BertForSequenceClassification.from_pretrained(Config.MODEL_PATH, num_labels=3, hidden_dropout_prob=0.5) | |
def forward(self, input_ids, attention_mask, token_type_ids, labels=None): | |
output = self.bert(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, labels=labels) | |
return output.loss, output.logits | |
# Load model | |
model = FinancialBERT() | |
model.eval() | |
# Streamlit App | |
# Set title and an image/banner if you have one | |
st.title("Financial Sentiment Analysis") | |
# st.image("path_to_your_image.jpg", use_column_width=True) | |
# Description | |
st.write(""" | |
This application predicts the sentiment of financial sentences using a state-of-the-art model. Enter a financial sentence below and click 'Predict' to get its sentiment. | |
""") | |
sentence = st.text_area("Enter a financial sentence:", "") | |
if st.button("Predict"): | |
tokenizer = Config.TOKENIZER | |
inputs = tokenizer([sentence], return_tensors="pt", truncation=True, padding=True, max_length=Config.MAX_LEN) | |
with torch.no_grad(): | |
logits = model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'], token_type_ids=inputs.get('token_type_ids'))[1] | |
probs = torch.nn.functional.softmax(logits, dim=-1) | |
predictions = torch.argmax(probs, dim=-1) | |
sentiment = ['negative', 'neutral', 'positive'][predictions[0].item()] | |
# Output visualization | |
st.subheader('Predicted Sentiment:') | |
st.write(f"The sentiment is: **{sentiment.capitalize()}**") | |
# Show Confidence levels as a bar chart | |
st.subheader('Model Confidence Levels:') | |
st.bar_chart(probs[0].numpy(), use_container_width=True) | |
# Sidebar: Documentation/Help | |
st.sidebar.header('About') | |
st.sidebar.text(""" | |
This application uses a BERT-based model trained specifically for financial sentences. The model can predict if the sentiment of a sentence is positive, negative, or neutral. | |
""") | |