Spaces:
Runtime error
Runtime error
#!/usr/bin/env python | |
# coding: utf-8 | |
import gradio as gr | |
import numpy as np | |
import requests | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer, TextClassificationPipeline, pipeline | |
from langdetect import detect | |
from matplotlib import pyplot as plt | |
import imageio | |
# Load the model | |
model = AutoModelForSequenceClassification.from_pretrained("saved_model") | |
tokenizer = AutoTokenizer.from_pretrained("saved_model") | |
pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer) | |
# Function called by the UI | |
def attribution(text): | |
# Clean the plot | |
plt.clf() | |
# Detect the language | |
language = detect(text) | |
# Translate the input in german if necessary | |
if language == 'fr': | |
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-fr-de") | |
translatedText = translator(text[0:1000]) | |
text = translatedText[0]["translation_text"] | |
elif language != 'de': | |
return "The language is not recognized, it must be either in German or in French.", None | |
# Set the bars of the bar chart | |
bars = "" | |
if language == 'fr': | |
bars = ("DDPS", "DFI", "AS-MPC", "DFJP", "DEFR", "DETEC", "DFAE", "Parl", "ChF", "DFF", "AF", "TF") | |
else: | |
bars = ("VBS", "EDI", "AB-BA", "EJPD", "WBF", "UVEK", "EDA", "Parl", "BK", "EFD", "BV", "BGer") | |
# Make the prediction with the 1000 first characters | |
results = pipe(text[0:1000], return_all_scores=True) | |
rates = [row["score"] for row in results[0]] | |
# Bar chart | |
y_pos = np.arange(len(bars)) | |
plt.barh(y_pos, rates) | |
plt.yticks(y_pos, bars) | |
# Set the output text | |
name = "" | |
maxRate = np.max(rates) | |
maxIndex = np.argmax(rates) | |
# ML model not sure if highest probability < 60% | |
if maxRate < 0.6: | |
# de / fr | |
if language == 'de': | |
name = "Das ML-Modell ist nicht sicher. Das Departement könnte sein : \n\n" | |
else: | |
name = "Le modèle ML n'est pas sûr. Le département pourrait être : \n\n" | |
i = 0 | |
# Show each department that has a probability > 10% | |
while i == 0: | |
if rates[maxIndex] >= 0.1: | |
name = name + "\t" + str(rates[maxIndex])[2:4] + "%" + "\t\t\t\t\t" + bars[maxIndex] + "\n" | |
rates[maxIndex] = 0 | |
maxIndex = np.argmax(rates) | |
else: | |
i = 1 | |
# ML model pretty sure, show only one department | |
else: | |
name = str(maxRate)[2:4] + "%" + "\t\t\t\t\t\t" + bars[maxIndex] | |
# Save the bar chart as png and load it (enables better display) | |
plt.savefig('rates.png') | |
im = imageio.imread('rates.png') | |
return name, im | |
# display the UI | |
interface = gr.Interface(fn=attribution, layout="vertical", | |
inputs=[gr.inputs.Textbox(lines=20, placeholder="Geben Sie bitte den Titel und den Sumbmitted Text des Vorstoss ein.\nVeuillez entrer le titre et le Submitted Text de la requête.")], | |
outputs=['text', 'image']) | |
interface.launch() |