articles / app.py
andufkova's picture
code fixes for new model
cc7a4cf
import gradio as gr
import numpy as np
import pandas as pd
import pickle
import sklearn
import plotly.express as px
from sentence_transformers import SentenceTransformer
from sklearn.cluster import MiniBatchKMeans
from learn_multi_doc_model import Model
#css_code='body {background-image:url("https://picsum.photos/seed/picsum/200/300");} div.gradio-container {background: white;}, button#component-8{background-color: rgb(158,202,225);}'
css_code='button#component-8{background-color: rgb(158,202,225);}'
import __main__
setattr(__main__, "Model", Model)
categories = ["Censorship","Development","Digital Activism","Disaster","Economics & Business","Education","Environment","Governance","Health","History","Humanitarian Response","International Relations","Law","Media & Journalism","Migration & Immigration","Politics","Protest","Religion","Sport","Travel","War & Conflict","Technology + Science","Women & Gender + LGBTQ + Youth","Freedom of Speech + Human Rights","Literature + Arts & Culture"]
input_cvect_key_file = 'topic_discovery/cvects.key'
model_labse = SentenceTransformer('sentence-transformers/LaBSE')
with open('models/MLP_classifier_average_en.pkl', 'rb') as f:
classifier = pickle.load(f)
mul_model = None
with open('models/model_0.0001_100.pkl', 'rb') as f:
mul_model = pickle.load(f)
def get_embedding(text):
if text is None:
text = ""
return model_labse.encode(text)
def get_categories(y_pred):
indices = []
for idx, value in enumerate(y_pred):
if value == 1:
indices.append(idx)
cats = [categories[i] for i in indices]
return cats
def get_words(doc_emb):
# load countvectorizers
cvects = {}
vocab = {} # load vocabulary of words for each lang
with open(input_cvect_key_file, "r") as fpr:
for line in fpr:
#print(line)
lang, fpath = line.strip().split()
with open(fpath, "rb") as fpr:
#print(f"loading {fpath}")
cvects[lang] = pickle.load(fpr)
vocab[lang] = cvects[lang].get_feature_names()
#print(
# "Loaded CountVectorizer for lang",
# lang,
# "with vocab size:",
# len(vocab[lang]),
#)
topn = 10 # top N words per cluster
#print(vocab["en"])
#print("MODEL KEYS")
#print(mul_model.E.keys())
doc_emb = doc_emb.flatten()
words_dict = {}
for lang in mul_model.E.keys():
#print(lang, end=": ")
scores = mul_model.E[lang].detach().numpy() @ (doc_emb).T
k_ixs = np.argsort(scores)[::-1][:topn].squeeze() # sort them in descending order and pick topn
tmp = []
for i in k_ixs:
#print(vocab[lang][i], end=", ")
tmp.append(vocab[lang][i])
words_dict[lang] = tmp
#print()
return words_dict
def generate_output(article):
paragraphs = article.split("\n")
embdds = []
for par in paragraphs:
embdds.append(get_embedding(par))
embedding = np.average(embdds, axis=0)
#y_pred = classifier.predict_proba(embedding.reshape(1, 768))
reshaped = embedding.reshape(1, 768)
#y_pred = classifier.predict(reshaped)
#y_pred = y_pred.flatten()
y_prob = classifier.predict_proba(reshaped)
y_prob = y_prob.reshape(len(categories),1)
y_pred = [1 if x >= 0.5 else 0 for x in y_prob]
classes = get_categories(y_pred)
if len(classes) > 1:
classes_string = ', '.join(classes)
elif len(classes) == 1:
classes_string = classes[0]
else:
classes_string = 'No category was found.'
data = pd.DataFrame()
data['Category'] = categories
data['Probability'] = y_prob
fig = px.bar(data, x='Probability', y='Category', orientation='h', height=600)#, title="Category probability")
fig.update_xaxes(range=[0, 1])
fig.update_layout(margin=dict(l=5, r=5, t=20, b=5)) #paper_bgcolor="LightSteelBlue")
fig.update_traces(marker_color='rgb(158,202,225)')
#print(f"LEN Y_PROB {len(y_prob)}")
#print(f"LEN CAT {len(categories)}")
words_dict = get_words(reshaped)
words_string = ""
for lang, w in words_dict.items():
words_string += f"{lang}: "
words_string += ', '.join(w)
words_string += "\n"
return (classes_string, fig, words_string)
# demo = gr.Interface(fn=generate_output,
# inputs=gr.Textbox(lines=6, placeholder="Insert text of the article here...", label="Article"),
# outputs=[gr.Textbox(lines=1, label="Category"), gr.Plot(label="Category probability"), gr.Textbox(lines=5, label="Topic discovery")],
# title="Article classification & topic discovery demo",
# flagging_options=["Incorrect"],
# theme=gr.themes.Base())
#css=css_code)
demo = gr.Blocks(css=css_code, theme=gr.themes.Base(), title="Article classification & topic discovery demo")
with demo:
with gr.Row():
my_title = gr.HTML("<h1 align='center'>Article classification & topic discovery demo</h1>")
with gr.Row():
with gr.Column():
input_text = gr.Textbox(lines=22, placeholder="Insert text of the article here...", label="Article")
with gr.Row():
clear_button = gr.Button("Clear")
submit_button = gr.Button("Submit")
with gr.Column():
with gr.Tabs():
with gr.TabItem("Classification"):
category_text = gr.Textbox(lines=1, label="Category")
category_plot = gr.Plot()
with gr.TabItem("Topic discovery"):
topic_text = gr.Textbox(lines=22, label="The most representative words")
submit_button.click(generate_output, inputs=input_text, outputs=[category_text, category_plot, topic_text])
clear_button.click(lambda: None, None, input_text, queue=False)
demo.launch()