satia / app.py
stinoco's picture
Added classification models for subcategories
cef1d7f
raw
history blame contribute delete
No virus
6.07 kB
import gradio as gr
import numpy as np
from transformers import pipeline
from utils.tokenizer import Tokenizer
from utils.lstm import LSTM
from utils.load_model import load_model
from utils.production_model import ProductionModel
# Cargamos modelos
## Transformers
pipeline_clf = pipeline("text-classification", model = "stinoco/beto-sentiment-analysis-finetuned", return_all_scores = True)
pipeline_pos = pipeline("token-classification", model = "sagorsarker/codeswitch-spaeng-pos-lince")
## LSTM
clf_marketing = load_model('marketing')
clf_cliente = load_model('cliente')
clf_conforme = load_model('conforme')
clf_devoluciones = load_model('devoluciones')
clf_entrega = load_model('entrega')
clf_financiamiento = load_model('financiamiento')
clf_otros = load_model('otros')
clf_stock = load_model('stock')
clf_ventas = load_model('ventas')
# PREDICT
def predict(text):
# Text Classification
classes = pipeline_clf(text)[0]
macro_probas = {element['label']: element['score'] for element in classes}
macro_probas = dict(sorted(macro_probas.items(), key=lambda x: x[1], reverse = True)[:4])
macro_probas['Resto'] = 1 - sum(macro_probas.values())
macro_label = max(macro_probas, key = macro_probas.get)
macro_labels = macro_label.split(' - ')
output = {macro_output: macro_probas, cliente_component: None, conforme_component: None,
devoluciones_component: None, entrega_component: None, financiamiento_component: None,
otros_component: None, stock_component: None, marketing_component: None,
ventas_component: None, row_cliente: gr.update(visible = False),
row_conforme: gr.update(visible = False), row_devoluciones: gr.update(visible = False),
row_entrega: gr.update(visible = False), row_financiamiento: gr.update(visible = False),
row_otros: gr.update(visible = False), row_stock: gr.update(visible = False),
row_marketing: gr.update(visible = False), row_ventas: gr.update(visible = False),}
if 'Atención al cliente' in macro_labels:
output[row_cliente] = gr.update(visible = True)
output[cliente_component] = clf_cliente.predict([text])
if 'Conforme' in macro_labels:
output[row_conforme] = gr.update(visible = True)
output[conforme_component] = clf_conforme.predict([text])
if 'Devoluciones' in macro_labels:
output[row_devoluciones] = gr.update(visible = True)
output[devoluciones_component] = clf_devoluciones.predict([text])
if 'Entrega' in macro_labels:
output[row_entrega] = gr.update(visible = True)
output[entrega_component] = clf_entrega.predict([text])
if 'Financiamiento' in macro_labels:
output[row_financiamiento] = gr.update(visible = True)
output[financiamiento_component] = clf_financiamiento.predict([text])
if 'Otros' in macro_labels:
output[row_otros] = gr.update(visible = True)
output[otros_component] = clf_otros.predict([text])
if 'Stock' in macro_labels:
output[row_stock] = gr.update(visible = True)
output[stock_component] = clf_stock.predict([text])
if 'Trade Marketing' in macro_labels:
output[row_marketing] = gr.update(visible = True)
output[marketing_component] = clf_marketing.predict([text])
if 'Ventas' in macro_labels:
output[row_ventas] = gr.update(visible = True)
output[ventas_component] = clf_ventas.predict([text])
return output
# DEMO
with gr.Blocks(title = 'Modelo NPS') as demo:
gr.Markdown(
'''
# <center>Modelo de Clasificación NPS</center>
Este es un modelo para categorizar reclamos de NPS, prueba escribiendo reclamos abajo!
''')
with gr.Column() as text_col:
with gr.Row():
text_input = gr.Textbox(placeholder = "Ingresa el reclamo acá", label = 'Reclamo')
#macro_output = gr.outputs.Label(label = 'Categorías Generales')
with gr.Row():
macro_output = gr.outputs.Label(label = 'Categorías Generales')
with gr.Row():
#macro_output = gr.outputs.Label(label = 'Categorías Generales')
with gr.Row(visible = False) as row_cliente:
cliente_component = gr.outputs.Label(label = 'Categorías Atención al Cliente')
with gr.Row(visible = False) as row_conforme:
conforme_component = gr.outputs.Label(label = 'Categorías Conforme')
with gr.Row(visible = False) as row_devoluciones:
devoluciones_component = gr.outputs.Label(label = 'Categorías Devoluciones')
with gr.Row(visible = False) as row_entrega:
entrega_component = gr.outputs.Label(label = 'Categorías Entrega')
with gr.Row(visible = False) as row_financiamiento:
financiamiento_component = gr.outputs.Label(label = 'Categorías Financiamiento')
with gr.Row(visible = False) as row_otros:
otros_component = gr.outputs.Label(label = 'Categorías Otros')
with gr.Row(visible = False) as row_stock:
stock_component = gr.outputs.Label(label = 'Categorías Stock')
with gr.Row(visible = False) as row_marketing:
marketing_component = gr.outputs.Label(label = 'Categorías Trade Marketing')
with gr.Row(visible = False) as row_ventas:
ventas_component = gr.outputs.Label(label = 'Categorías Ventas')
outputs = [
macro_output, cliente_component, conforme_component, devoluciones_component,
entrega_component, financiamiento_component, otros_component, stock_component,
marketing_component, ventas_component, row_cliente, row_conforme,
row_devoluciones, row_entrega, row_financiamiento, row_otros,
row_stock, row_marketing, row_ventas, ]
button = gr.Button('Submit')
button.click(fn = predict, inputs = text_input, outputs = outputs)
gr.Examples(
examples = [['sale mas a cuenta comprar en los supermercados que a la cervecería'],
['llega las latas abolladas sucias'],
['vendedor no viene presencialmente solo por whatsapp'],
['mejorar la atención de los repartidores porque roban'],
['seria bueno mas promociones y publicidad']],
inputs = text_input)
demo.launch()