|
import streamlit as st |
|
from transformers import DistilBertTokenizer, DistilBertModel |
|
import logging |
|
logging.basicConfig(level=logging.ERROR) |
|
import torch |
|
|
|
MAX_LEN = 100 |
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased', truncation=True, do_lower_case=True) |
|
|
|
class DistilBERTClass(torch.nn.Module): |
|
def __init__(self): |
|
super(DistilBERTClass, self).__init__() |
|
self.l1 = DistilBertModel.from_pretrained("distilbert-base-uncased") |
|
self.pre_classifier = torch.nn.Linear(768, 768) |
|
self.dropout = torch.nn.Dropout(0.1) |
|
self.classifier = torch.nn.Linear(768, 1) |
|
|
|
def forward(self, input_ids, attention_mask, token_type_ids): |
|
output_1 = self.l1(input_ids=input_ids, attention_mask=attention_mask) |
|
hidden_state = output_1[0] |
|
pooler = hidden_state[:, 0] |
|
pooler = self.pre_classifier(pooler) |
|
pooler = torch.nn.ReLU()(pooler) |
|
pooler = self.dropout(pooler) |
|
output = self.classifier(pooler) |
|
return output |
|
|
|
|
|
model_DB = DistilBERTClass() |
|
loaded_model_path = './model_DB_1.pt' |
|
model_DB.load_state_dict(torch.load(loaded_model_path, map_location=torch.device('cpu'))) |
|
model_DB.to(device) |
|
|
|
|
|
def sentiment_analysis_DB(input): |
|
inputs = tokenizer.encode_plus( |
|
input, |
|
None, |
|
add_special_tokens=True, |
|
max_length=100, |
|
pad_to_max_length=True, |
|
return_token_type_ids=True |
|
) |
|
ids = torch.tensor([inputs['input_ids']]) |
|
mask = torch.tensor([inputs['attention_mask']]) |
|
token_type_ids = torch.tensor([inputs["token_type_ids"]]) |
|
|
|
|
|
output = model_DB(ids, mask, token_type_ids) |
|
print('Raw output is ', output) |
|
|
|
sigmoid_output = torch.sigmoid(output) |
|
print('Sigmoid output is ', sigmoid_output) |
|
|
|
|
|
result = 1 if sigmoid_output.item() > 0.5 else 0 |
|
|
|
return result |
|
|
|
|
|
st.title("Sentiment Analysis App") |
|
|
|
|
|
user_input = st.text_area("Enter some text:") |
|
|
|
|
|
if st.button("Analyze Sentiment"): |
|
|
|
result = sentiment_analysis_DB(user_input) |
|
|
|
|
|
if result == 1: |
|
st.success("Positive sentiment detected!") |
|
else: |
|
st.error("Negative sentiment detected.") |