TestApp / app.py
menikev's picture
Update app.py
fdd0b8f verified
raw
history blame
761 Bytes
import streamlit as st
import torch
from transformers import RobertaTokenizer, RobertaModel
from prediction_sinhala import MDFEND
# Load model and tokenizer
@st.cache(allow_output_mutation=True)
def load_model():
tokenizer = RobertaTokenizer.from_pretrained("./prediction_sinhala/")
model = MDFEND.from_pretrained("./prediction_sinhala/")
return model, tokenizer
model, tokenizer = load_model()
# User input
text_input = st.text_area("Enter text here:")
# Prediction
if st.button("Predict"):
inputs = tokenizer(text_input, return_tensors="pt")
with torch.no_grad(): # Ensure no gradients are computed
outputs = model(**inputs)
prediction = outputs.logits.argmax(-1).item()
st.write(f"Prediction: {prediction}")