harry-stark commited on
Commit
3b5b50d
1 Parent(s): bb514b9

Added Multilingual Model support

Browse files
Files changed (2) hide show
  1. app.py +6 -2
  2. hf_model.py +7 -1
app.py CHANGED
@@ -1,9 +1,10 @@
1
  import streamlit as st
2
- from hf_model import classifier_zero,load_model
3
  from utils import plot_result,examples_load
4
  import json
5
 
6
- classifier=load_model()
 
7
  ex_text,ex_labels=examples_load()
8
 
9
  # with open("examples.json") as f:
@@ -13,12 +14,15 @@ if __name__ == '__main__':
13
  st.header("Zero Shot Classification")
14
  st.write("This app allows you to classify any text into any categories you are interested in.")
15
 
 
16
  with st.form(key='my_form'):
17
  text_input = st.text_area("Input Text",ex_text)
18
  labels = st.text_input('Possible topics (separated by `,`)',ex_labels, max_chars=1000)
19
  labels = list(set([x.strip() for x in labels.strip().split(',') if len(x.strip()) > 0]))
20
  radio = st.radio("Select Multiclass",('Only one topic can be corect at a time','Multiple topics can be correct at a time'),)
 
21
  multi_class= True if radio=="Multiple topics can be correct at a time" else False
 
22
  submit_button = st.form_submit_button(label='Submit')
23
 
24
  if submit_button:
 
1
  import streamlit as st
2
+ from hf_model import classifier_zero,load_model,load_model_multil
3
  from utils import plot_result,examples_load
4
  import json
5
 
6
+ classifier_model0=load_model()
7
+ classifier_model1=load_model_multil()
8
  ex_text,ex_labels=examples_load()
9
 
10
  # with open("examples.json") as f:
 
14
  st.header("Zero Shot Classification")
15
  st.write("This app allows you to classify any text into any categories you are interested in.")
16
 
17
+
18
  with st.form(key='my_form'):
19
  text_input = st.text_area("Input Text",ex_text)
20
  labels = st.text_input('Possible topics (separated by `,`)',ex_labels, max_chars=1000)
21
  labels = list(set([x.strip() for x in labels.strip().split(',') if len(x.strip()) > 0]))
22
  radio = st.radio("Select Multiclass",('Only one topic can be corect at a time','Multiple topics can be correct at a time'),)
23
+ radio1 = st.radio("Select Model",('Using English Language model','Using Multilingual Model'),)
24
  multi_class= True if radio=="Multiple topics can be correct at a time" else False
25
+ classifier = classifier_model1 if radio1=="Using Multilingual Model" else classifier_model0
26
  submit_button = st.form_submit_button(label='Submit')
27
 
28
  if submit_button:
hf_model.py CHANGED
@@ -2,13 +2,19 @@ from transformers import AutoTokenizer, AutoModelForSequenceClassification,pipel
2
  import torch
3
 
4
  def load_model():
5
-
6
  model_name = "MoritzLaurer/DeBERTa-v3-base-mnli"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
9
  classifier = pipeline(task='zero-shot-classification', model=model, tokenizer=tokenizer, framework='pt')
10
  return classifier
11
 
 
 
 
 
 
 
 
12
  def classifier_zero(classifier,sequence:str,labels:list,multi_class:bool):
13
  outputs=classifier(sequence, labels,multi_label=multi_class)
14
  return outputs['labels'], outputs['scores']
 
2
  import torch
3
 
4
  def load_model():
 
5
  model_name = "MoritzLaurer/DeBERTa-v3-base-mnli"
6
  tokenizer = AutoTokenizer.from_pretrained(model_name)
7
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
8
  classifier = pipeline(task='zero-shot-classification', model=model, tokenizer=tokenizer, framework='pt')
9
  return classifier
10
 
11
+ def load_model_multil():
12
+ model_name = "joeddav/xlm-roberta-large-xnli"
13
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
14
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
15
+ classifier = pipeline(task='zero-shot-classification', model=model, tokenizer=tokenizer, framework='pt')
16
+ return classifier
17
+
18
  def classifier_zero(classifier,sequence:str,labels:list,multi_class:bool):
19
  outputs=classifier(sequence, labels,multi_label=multi_class)
20
  return outputs['labels'], outputs['scores']