harry-stark commited on
Commit
6ec105b
1 Parent(s): 3b5b50d

Refactors:New model and dropped multilingual support

Browse files
Files changed (4) hide show
  1. app.py +8 -10
  2. examples.json +1 -1
  3. hf_model.py +1 -8
  4. utils.py +1 -9
app.py CHANGED
@@ -1,14 +1,14 @@
 
 
1
  import streamlit as st
2
- from hf_model import classifier_zero,load_model,load_model_multil
3
  from utils import plot_result,examples_load
4
- import json
5
 
6
- classifier_model0=load_model()
7
- classifier_model1=load_model_multil()
 
8
  ex_text,ex_labels=examples_load()
9
 
10
- # with open("examples.json") as f:
11
- # data=json.load(f)
12
 
13
  if __name__ == '__main__':
14
  st.header("Zero Shot Classification")
@@ -16,13 +16,11 @@ if __name__ == '__main__':
16
 
17
 
18
  with st.form(key='my_form'):
19
- text_input = st.text_area("Input Text",ex_text)
20
- labels = st.text_input('Possible topics (separated by `,`)',ex_labels, max_chars=1000)
21
  labels = list(set([x.strip() for x in labels.strip().split(',') if len(x.strip()) > 0]))
22
  radio = st.radio("Select Multiclass",('Only one topic can be corect at a time','Multiple topics can be correct at a time'),)
23
- radio1 = st.radio("Select Model",('Using English Language model','Using Multilingual Model'),)
24
  multi_class= True if radio=="Multiple topics can be correct at a time" else False
25
- classifier = classifier_model1 if radio1=="Using Multilingual Model" else classifier_model0
26
  submit_button = st.form_submit_button(label='Submit')
27
 
28
  if submit_button:
 
1
+ from os import write
2
+ from typing import Sequence
3
  import streamlit as st
4
+ from hf_model import classifier_zero,load_model
5
  from utils import plot_result,examples_load
 
6
 
7
+
8
+ classifier=load_model()
9
+
10
  ex_text,ex_labels=examples_load()
11
 
 
 
12
 
13
  if __name__ == '__main__':
14
  st.header("Zero Shot Classification")
 
16
 
17
 
18
  with st.form(key='my_form'):
19
+ text_input = st.text_area("Input Text",data['text'])
20
+ labels = st.text_input('Possible topics (separated by `,`)',data["labels"], max_chars=1000)
21
  labels = list(set([x.strip() for x in labels.strip().split(',') if len(x.strip()) > 0]))
22
  radio = st.radio("Select Multiclass",('Only one topic can be corect at a time','Multiple topics can be correct at a time'),)
 
23
  multi_class= True if radio=="Multiple topics can be correct at a time" else False
 
24
  submit_button = st.form_submit_button(label='Submit')
25
 
26
  if submit_button:
examples.json CHANGED
@@ -1,4 +1,4 @@
1
  {
2
  "text":"The Democratic president had just signed into law the most significant infrastructure package in generations. And he had done it by bringing Democrats and Republicans together.",
3
- "labels":"Economy,Politics,Environment,Entertainment"
4
  }
 
1
  {
2
  "text":"The Democratic president had just signed into law the most significant infrastructure package in generations. And he had done it by bringing Democrats and Republicans together.",
3
+ "labels":["Economy", "Politics", "Environment", "Entertainment"]
4
  }
hf_model.py CHANGED
@@ -2,14 +2,7 @@ from transformers import AutoTokenizer, AutoModelForSequenceClassification,pipel
2
  import torch
3
 
4
  def load_model():
5
- model_name = "MoritzLaurer/DeBERTa-v3-base-mnli"
6
- tokenizer = AutoTokenizer.from_pretrained(model_name)
7
- model = AutoModelForSequenceClassification.from_pretrained(model_name)
8
- classifier = pipeline(task='zero-shot-classification', model=model, tokenizer=tokenizer, framework='pt')
9
- return classifier
10
-
11
- def load_model_multil():
12
- model_name = "joeddav/xlm-roberta-large-xnli"
13
  tokenizer = AutoTokenizer.from_pretrained(model_name)
14
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
15
  classifier = pipeline(task='zero-shot-classification', model=model, tokenizer=tokenizer, framework='pt')
 
2
  import torch
3
 
4
  def load_model():
5
+ model_name = "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli"
 
 
 
 
 
 
 
6
  tokenizer = AutoTokenizer.from_pretrained(model_name)
7
  model = AutoModelForSequenceClassification.from_pretrained(model_name)
8
  classifier = pipeline(task='zero-shot-classification', model=model, tokenizer=tokenizer, framework='pt')
utils.py CHANGED
@@ -1,8 +1,6 @@
1
  import streamlit as st
2
  import numpy as np
3
  import plotly.express as px
4
- import json
5
-
6
  def plot_result(top_topics, scores):
7
  top_topics = np.array(top_topics)
8
  scores = np.array(scores)
@@ -16,10 +14,4 @@ def plot_result(top_topics, scores):
16
  color_continuous_scale='GnBu')
17
  fig.update(layout_coloraxis_showscale=False)
18
  fig.update_traces(texttemplate='%{text:0.1f}%', textposition='outside')
19
- st.plotly_chart(fig)
20
-
21
- def examples_load():
22
- with open("examples.json") as f:
23
- data=json.load(f)
24
-
25
- return data['text'],data['labels']
 
1
  import streamlit as st
2
  import numpy as np
3
  import plotly.express as px
 
 
4
  def plot_result(top_topics, scores):
5
  top_topics = np.array(top_topics)
6
  scores = np.array(scores)
 
14
  color_continuous_scale='GnBu')
15
  fig.update(layout_coloraxis_showscale=False)
16
  fig.update_traces(texttemplate='%{text:0.1f}%', textposition='outside')
17
+ st.plotly_chart(fig)