|
import streamlit as st |
|
import torch |
|
from transformers import RobertaTokenizer, RobertaModel |
|
from prediction_sinhala import MDFEND |
|
|
|
|
|
@st.cache(allow_output_mutation=True) |
|
def load_model(): |
|
|
|
tokenizer = RobertaTokenizer.from_pretrained("./prediction_sinhala/") |
|
model = MDFEND.from_pretrained("./prediction_sinhala/") |
|
return model, tokenizer |
|
|
|
model, tokenizer = load_model() |
|
|
|
|
|
text_input = st.text_area("Enter text here:") |
|
|
|
|
|
if st.button("Predict"): |
|
inputs = tokenizer(text_input, return_tensors="pt") |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
prediction = outputs.logits.argmax(-1).item() |
|
st.write(f"Prediction: {prediction}") |
|
|