import streamlit as st import torch from prediction_sinhala import MDFEND, TokenizerFromPreTrained # Set constants for model and tokenizer paths MODEL_SAVE_PATH = "models/last-epoch-model-2024-03-08-15_34_03_6.pth" BERT_MODEL_NAME = 'sinhala-nlp/sinbert-sold-si' DOMAIN_NUM = 3 MAX_LEN = 160 BATCH_SIZE = 100 # Load model and tokenizer @st.cache(allow_output_mutation=True) def load_model(): # Load the tokenizer from the pre-trained model name tokenizer = TokenizerFromPreTrained(MAX_LEN, BERT_MODEL_NAME) # Initialize and load the custom model from saved state model = MDFEND(BERT_MODEL_NAME, DOMAIN_NUM, expert_num=18, mlp_dims=[5080, 4020, 3010, 2024, 1012, 606, 400]) model.load_state_dict(torch.load(MODEL_SAVE_PATH, map_location=torch.device('cpu'))) model.eval() # Set the model to evaluation mode return model, tokenizer model, tokenizer = load_model() # User input text_input = st.text_area("Enter text here:") # Prediction if st.button("Predict"): if text_input: # Check if input is not empty # Process the input text through the custom tokenizer inputs = tokenizer.tokenize(text_input) # Convert to tensor, add batch dimension, and send to same device as model inputs = torch.tensor(inputs).unsqueeze(0).to(model.device) with torch.no_grad(): # No gradient computation # Get model prediction output_prob = model.predict(inputs) # Interpret the output probability prediction = 1 if output_prob >= 0.5 else 0 result = "offensive" if prediction == 1 else "not offensive" st.write(f"Prediction: {result}") else: st.error("Please enter some text to predict.")