Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -7,7 +7,7 @@ import json
|
|
7 |
@st.cache()
|
8 |
def get_model():
|
9 |
model = AutoModelForSequenceClassification.from_pretrained("siebert/sentiment-roberta-large-english", num_labels=2)
|
10 |
-
model.load_state_dict(torch.load('
|
11 |
return model
|
12 |
|
13 |
@st.cache()
|
@@ -15,9 +15,11 @@ def get_tokenizer():
|
|
15 |
tokenizer = AutoTokenizer.from_pretrained("siebert/sentiment-roberta-large-english")
|
16 |
return tokenizer
|
17 |
|
18 |
-
def make_prediction():
|
19 |
model = get_model()
|
20 |
tokenizer = tokenizer()
|
|
|
|
|
21 |
|
22 |
|
23 |
|
@@ -41,7 +43,8 @@ with st.form(key='input_form'):
|
|
41 |
button = st.form_submit_button(label='Classify')
|
42 |
if button:
|
43 |
if to_analyze:
|
44 |
-
make_prediction(to_analyze)
|
|
|
45 |
else:
|
46 |
st.markdown("Empty request. Please resubmit")
|
47 |
|
|
|
7 |
@st.cache()
|
8 |
def get_model():
|
9 |
model = AutoModelForSequenceClassification.from_pretrained("siebert/sentiment-roberta-large-english", num_labels=2)
|
10 |
+
model.load_state_dict(torch.load('cached_model.pth'))
|
11 |
return model
|
12 |
|
13 |
@st.cache()
|
|
|
15 |
tokenizer = AutoTokenizer.from_pretrained("siebert/sentiment-roberta-large-english")
|
16 |
return tokenizer
|
17 |
|
18 |
+
def make_prediction(to_analyze):
|
19 |
model = get_model()
|
20 |
tokenizer = tokenizer()
|
21 |
+
to_return = model(**tokenizer(to_anayze))
|
22 |
+
return to_return
|
23 |
|
24 |
|
25 |
|
|
|
43 |
button = st.form_submit_button(label='Classify')
|
44 |
if button:
|
45 |
if to_analyze:
|
46 |
+
pred = make_prediction(to_analyze)
|
47 |
+
st.markdown(pred)
|
48 |
else:
|
49 |
st.markdown("Empty request. Please resubmit")
|
50 |
|