BKMotionsAI / app.py
BFH's picture
Upload app.py
5c49760
raw
history blame
3.07 kB
#!/usr/bin/env python
# coding: utf-8
import gradio as gr
import numpy as np
import requests
from transformers import AutoModelForSequenceClassification, AutoTokenizer, TextClassificationPipeline, pipeline
from langdetect import detect
from matplotlib import pyplot as plt
import imageio
# Load the model
model = AutoModelForSequenceClassification.from_pretrained("saved_model")
tokenizer = AutoTokenizer.from_pretrained("saved_model")
pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer)
# Function called by the UI
def attribution(text):
# Clean the plot
plt.clf()
# Detect the language
language = detect(text)
# Translate the input in german if necessary
if language == 'fr':
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-fr-de")
translatedText = translator(text[0:1000])
text = translatedText[0]["translation_text"]
elif language != 'de':
return "The language is not recognized, it must be either in German or in French.", None
# Set the bars of the bar chart
bars = ""
if language == 'fr':
bars = ("DDPS", "DFI", "AS-MPC", "DFJP", "DEFR", "DETEC", "DFAE", "Parl", "ChF", "DFF", "AF", "TF")
else:
bars = ("VBS", "EDI", "AB-BA", "EJPD", "WBF", "UVEK", "EDA", "Parl", "BK", "EFD", "BV", "BGer")
# Make the prediction with the 1000 first characters
results = pipe(text[0:1000], return_all_scores=True)
rates = [row["score"] for row in results[0]]
# Bar chart
y_pos = np.arange(len(bars))
plt.barh(y_pos, rates)
plt.yticks(y_pos, bars)
# Set the output text
name = ""
maxRate = np.max(rates)
maxIndex = np.argmax(rates)
# ML model not sure if highest probability < 60%
if maxRate < 0.6:
# de / fr
if language == 'de':
name = "Das ML-Modell ist nicht sicher. Das Departement könnte sein : \n\n"
else:
name = "Le modèle ML n'est pas sûr. Le département pourrait être : \n\n"
i = 0
# Show each department that has a probability > 10%
while i == 0:
if rates[maxIndex] >= 0.1:
name = name + "\t" + str(rates[maxIndex])[2:4] + "%" + "\t\t\t\t\t" + bars[maxIndex] + "\n"
rates[maxIndex] = 0
maxIndex = np.argmax(rates)
else:
i = 1
# ML model pretty sure, show only one department
else:
name = str(maxRate)[2:4] + "%" + "\t\t\t\t\t\t" + bars[maxIndex]
# Save the bar chart as png and load it (enables better display)
plt.savefig('rates.png')
im = imageio.imread('rates.png')
return name, im
# display the UI
interface = gr.Interface(fn=attribution, layout="vertical",
inputs=[gr.inputs.Textbox(lines=20, placeholder="Geben Sie bitte den Titel und den Sumbmitted Text des Vorstoss ein.\nVeuillez entrer le titre et le Submitted Text de la requête.")],
outputs=['text', 'image'])
interface.launch()