Spaces:
Runtime error
Runtime error
harry-stark
commited on
Commit
•
3b5b50d
1
Parent(s):
bb514b9
Added Multilingual Model support
Browse files- app.py +6 -2
- hf_model.py +7 -1
app.py
CHANGED
@@ -1,9 +1,10 @@
|
|
1 |
import streamlit as st
|
2 |
-
from hf_model import classifier_zero,load_model
|
3 |
from utils import plot_result,examples_load
|
4 |
import json
|
5 |
|
6 |
-
|
|
|
7 |
ex_text,ex_labels=examples_load()
|
8 |
|
9 |
# with open("examples.json") as f:
|
@@ -13,12 +14,15 @@ if __name__ == '__main__':
|
|
13 |
st.header("Zero Shot Classification")
|
14 |
st.write("This app allows you to classify any text into any categories you are interested in.")
|
15 |
|
|
|
16 |
with st.form(key='my_form'):
|
17 |
text_input = st.text_area("Input Text",ex_text)
|
18 |
labels = st.text_input('Possible topics (separated by `,`)',ex_labels, max_chars=1000)
|
19 |
labels = list(set([x.strip() for x in labels.strip().split(',') if len(x.strip()) > 0]))
|
20 |
radio = st.radio("Select Multiclass",('Only one topic can be corect at a time','Multiple topics can be correct at a time'),)
|
|
|
21 |
multi_class= True if radio=="Multiple topics can be correct at a time" else False
|
|
|
22 |
submit_button = st.form_submit_button(label='Submit')
|
23 |
|
24 |
if submit_button:
|
|
|
1 |
import streamlit as st
|
2 |
+
from hf_model import classifier_zero,load_model,load_model_multil
|
3 |
from utils import plot_result,examples_load
|
4 |
import json
|
5 |
|
6 |
+
classifier_model0=load_model()
|
7 |
+
classifier_model1=load_model_multil()
|
8 |
ex_text,ex_labels=examples_load()
|
9 |
|
10 |
# with open("examples.json") as f:
|
|
|
14 |
st.header("Zero Shot Classification")
|
15 |
st.write("This app allows you to classify any text into any categories you are interested in.")
|
16 |
|
17 |
+
|
18 |
with st.form(key='my_form'):
|
19 |
text_input = st.text_area("Input Text",ex_text)
|
20 |
labels = st.text_input('Possible topics (separated by `,`)',ex_labels, max_chars=1000)
|
21 |
labels = list(set([x.strip() for x in labels.strip().split(',') if len(x.strip()) > 0]))
|
22 |
radio = st.radio("Select Multiclass",('Only one topic can be corect at a time','Multiple topics can be correct at a time'),)
|
23 |
+
radio1 = st.radio("Select Model",('Using English Language model','Using Multilingual Model'),)
|
24 |
multi_class= True if radio=="Multiple topics can be correct at a time" else False
|
25 |
+
classifier = classifier_model1 if radio1=="Using Multilingual Model" else classifier_model0
|
26 |
submit_button = st.form_submit_button(label='Submit')
|
27 |
|
28 |
if submit_button:
|
hf_model.py
CHANGED
@@ -2,13 +2,19 @@ from transformers import AutoTokenizer, AutoModelForSequenceClassification,pipel
|
|
2 |
import torch
|
3 |
|
4 |
def load_model():
|
5 |
-
|
6 |
model_name = "MoritzLaurer/DeBERTa-v3-base-mnli"
|
7 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
8 |
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
9 |
classifier = pipeline(task='zero-shot-classification', model=model, tokenizer=tokenizer, framework='pt')
|
10 |
return classifier
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
def classifier_zero(classifier,sequence:str,labels:list,multi_class:bool):
|
13 |
outputs=classifier(sequence, labels,multi_label=multi_class)
|
14 |
return outputs['labels'], outputs['scores']
|
|
|
2 |
import torch
|
3 |
|
4 |
def load_model():
|
|
|
5 |
model_name = "MoritzLaurer/DeBERTa-v3-base-mnli"
|
6 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
7 |
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
8 |
classifier = pipeline(task='zero-shot-classification', model=model, tokenizer=tokenizer, framework='pt')
|
9 |
return classifier
|
10 |
|
11 |
+
def load_model_multil():
|
12 |
+
model_name = "joeddav/xlm-roberta-large-xnli"
|
13 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
14 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_name)
|
15 |
+
classifier = pipeline(task='zero-shot-classification', model=model, tokenizer=tokenizer, framework='pt')
|
16 |
+
return classifier
|
17 |
+
|
18 |
def classifier_zero(classifier,sequence:str,labels:list,multi_class:bool):
|
19 |
outputs=classifier(sequence, labels,multi_label=multi_class)
|
20 |
return outputs['labels'], outputs['scores']
|