import streamlit as st import torch from transformers import BertTokenizer, BertModel from torch import nn from huggingface_hub import snapshot_download # Define the BERTClassifier class class BERTClassifier(nn.Module): def __init__(self, bert_model_name, num_classes): super(BERTClassifier, self).__init__() self.bert = BertModel.from_pretrained(bert_model_name) self.dropout = nn.Dropout(0.1) self.fc = nn.Linear(self.bert.config.hidden_size, num_classes) def forward(self, input_ids, attention_mask): outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) pooled_output = outputs.pooler_output x = self.dropout(pooled_output) logits = self.fc(x) return logits # Load the model and tokenizer from the repository repo_id = "Makima57/sentiment-model" model_path = snapshot_download(repo_id=repo_id) bert_model_name = 'bert-base-uncased' num_classes = 2 max_length = 128 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load saved model model = BERTClassifier(bert_model_name, num_classes) model.load_state_dict(torch.load(f"{model_path}/pytorch_model.bin", map_location=device)) model.to(device) model.eval() # Load tokenizer tokenizer = BertTokenizer.from_pretrained(f"{model_path}/tokenizer") # Define prediction function def predict_sentiment(text, model, tokenizer, device, max_length=128): encoding = tokenizer(text, return_tensors='pt', max_length=max_length, padding='max_length', truncation=True) input_ids = encoding['input_ids'].to(device) attention_mask = encoding['attention_mask'].to(device) with torch.no_grad(): outputs = model(input_ids=input_ids, attention_mask=attention_mask) _, preds = torch.max(outputs, dim=1) return "positive" if preds.item() == 1 else "negative" # Streamlit app interface st.title("IMDB Movie Review Sentiment Analyzer") # Text input from user user_input = st.text_area("Enter a movie review:", "") # Predict sentiment when the user submits the input if st.button("Analyze Sentiment"): if user_input.strip() != "": sentiment = predict_sentiment(user_input, model, tokenizer, device) st.write(f"The sentiment of the review is: **{sentiment}**") else: st.write("Please enter a valid review!")