Spaces:
Runtime error
Runtime error
harry-stark
commited on
Commit
•
6ec105b
1
Parent(s):
3b5b50d
Refactors:New model and dropped multilingual support
Browse files- app.py +8 -10
- examples.json +1 -1
- hf_model.py +1 -8
- utils.py +1 -9
app.py
CHANGED
@@ -1,14 +1,14 @@
|
|
|
|
|
|
1 |
import streamlit as st
|
2 |
-
from hf_model import classifier_zero,load_model
|
3 |
from utils import plot_result,examples_load
|
4 |
-
import json
|
5 |
|
6 |
-
|
7 |
-
|
|
|
8 |
ex_text,ex_labels=examples_load()
|
9 |
|
10 |
-
# with open("examples.json") as f:
|
11 |
-
# data=json.load(f)
|
12 |
|
13 |
if __name__ == '__main__':
|
14 |
st.header("Zero Shot Classification")
|
@@ -16,13 +16,11 @@ if __name__ == '__main__':
|
|
16 |
|
17 |
|
18 |
with st.form(key='my_form'):
|
19 |
-
text_input = st.text_area("Input Text",
|
20 |
-
labels = st.text_input('Possible topics (separated by `,`)',
|
21 |
labels = list(set([x.strip() for x in labels.strip().split(',') if len(x.strip()) > 0]))
|
22 |
radio = st.radio("Select Multiclass",('Only one topic can be corect at a time','Multiple topics can be correct at a time'),)
|
23 |
-
radio1 = st.radio("Select Model",('Using English Language model','Using Multilingual Model'),)
|
24 |
multi_class= True if radio=="Multiple topics can be correct at a time" else False
|
25 |
-
classifier = classifier_model1 if radio1=="Using Multilingual Model" else classifier_model0
|
26 |
submit_button = st.form_submit_button(label='Submit')
|
27 |
|
28 |
if submit_button:
|
|
|
1 |
+
from os import write
|
2 |
+
from typing import Sequence
|
3 |
import streamlit as st
|
4 |
+
from hf_model import classifier_zero,load_model
|
5 |
from utils import plot_result,examples_load
|
|
|
6 |
|
7 |
+
|
8 |
+
classifier=load_model()
|
9 |
+
|
10 |
ex_text,ex_labels=examples_load()
|
11 |
|
|
|
|
|
12 |
|
13 |
if __name__ == '__main__':
|
14 |
st.header("Zero Shot Classification")
|
|
|
16 |
|
17 |
|
18 |
with st.form(key='my_form'):
|
19 |
+
text_input = st.text_area("Input Text",data['text'])
|
20 |
+
labels = st.text_input('Possible topics (separated by `,`)',data["labels"], max_chars=1000)
|
21 |
labels = list(set([x.strip() for x in labels.strip().split(',') if len(x.strip()) > 0]))
|
22 |
radio = st.radio("Select Multiclass",('Only one topic can be corect at a time','Multiple topics can be correct at a time'),)
|
|
|
23 |
multi_class= True if radio=="Multiple topics can be correct at a time" else False
|
|
|
24 |
submit_button = st.form_submit_button(label='Submit')
|
25 |
|
26 |
if submit_button:
|
examples.json
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
{
|
2 |
"text":"The Democratic president had just signed into law the most significant infrastructure package in generations. And he had done it by bringing Democrats and Republicans together.",
|
3 |
-
"labels":"Economy,Politics,Environment,Entertainment"
|
4 |
}
|
|
|
1 |
{
|
2 |
"text":"The Democratic president had just signed into law the most significant infrastructure package in generations. And he had done it by bringing Democrats and Republicans together.",
|
3 |
+
"labels":["Economy", "Politics", "Environment", "Entertainment"]
|
4 |
}
|
hf_model.py
CHANGED
@@ -2,14 +2,7 @@ from transformers import AutoTokenizer, AutoModelForSequenceClassification,pipel
|
|
2 |
import torch
|
3 |
|
4 |
def load_model():
|
5 |
-
model_name = "MoritzLaurer/DeBERTa-v3-base-mnli"
|
6 |
-
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
7 |
-
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
8 |
-
classifier = pipeline(task='zero-shot-classification', model=model, tokenizer=tokenizer, framework='pt')
|
9 |
-
return classifier
|
10 |
-
|
11 |
-
def load_model_multil():
|
12 |
-
model_name = "joeddav/xlm-roberta-large-xnli"
|
13 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
14 |
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
15 |
classifier = pipeline(task='zero-shot-classification', model=model, tokenizer=tokenizer, framework='pt')
|
|
|
2 |
import torch
|
3 |
|
4 |
def load_model():
|
5 |
+
model_name = "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
7 |
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
8 |
classifier = pipeline(task='zero-shot-classification', model=model, tokenizer=tokenizer, framework='pt')
|
utils.py
CHANGED
@@ -1,8 +1,6 @@
|
|
1 |
import streamlit as st
|
2 |
import numpy as np
|
3 |
import plotly.express as px
|
4 |
-
import json
|
5 |
-
|
6 |
def plot_result(top_topics, scores):
|
7 |
top_topics = np.array(top_topics)
|
8 |
scores = np.array(scores)
|
@@ -16,10 +14,4 @@ def plot_result(top_topics, scores):
|
|
16 |
color_continuous_scale='GnBu')
|
17 |
fig.update(layout_coloraxis_showscale=False)
|
18 |
fig.update_traces(texttemplate='%{text:0.1f}%', textposition='outside')
|
19 |
-
st.plotly_chart(fig)
|
20 |
-
|
21 |
-
def examples_load():
|
22 |
-
with open("examples.json") as f:
|
23 |
-
data=json.load(f)
|
24 |
-
|
25 |
-
return data['text'],data['labels']
|
|
|
1 |
import streamlit as st
|
2 |
import numpy as np
|
3 |
import plotly.express as px
|
|
|
|
|
4 |
def plot_result(top_topics, scores):
|
5 |
top_topics = np.array(top_topics)
|
6 |
scores = np.array(scores)
|
|
|
14 |
color_continuous_scale='GnBu')
|
15 |
fig.update(layout_coloraxis_showscale=False)
|
16 |
fig.update_traces(texttemplate='%{text:0.1f}%', textposition='outside')
|
17 |
+
st.plotly_chart(fig)
|
|
|
|
|
|
|
|
|
|
|
|