Spaces:
Runtime error
Runtime error
File size: 3,067 Bytes
5c49760 2f4692b 8ec661c 5c49760 8ec661c 2f4692b 8ec661c 5c49760 8ec661c 5c49760 8ec661c 5c49760 8ec661c 5c49760 8ec661c 5c49760 8ec661c 5c49760 8ec661c 5c49760 8ec661c 5c49760 8ec661c 5c49760 8ec661c 5c49760 8ec661c 5c49760 8ec661c 5c49760 8ec661c 5c49760 8ec661c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
#!/usr/bin/env python
# coding: utf-8
import gradio as gr
import numpy as np
import requests
from transformers import AutoModelForSequenceClassification, AutoTokenizer, TextClassificationPipeline, pipeline
from langdetect import detect
from matplotlib import pyplot as plt
import imageio
# Load the model
model = AutoModelForSequenceClassification.from_pretrained("saved_model")
tokenizer = AutoTokenizer.from_pretrained("saved_model")
pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer)
# Function called by the UI
def attribution(text):
# Clean the plot
plt.clf()
# Detect the language
language = detect(text)
# Translate the input in german if necessary
if language == 'fr':
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-fr-de")
translatedText = translator(text[0:1000])
text = translatedText[0]["translation_text"]
elif language != 'de':
return "The language is not recognized, it must be either in German or in French.", None
# Set the bars of the bar chart
bars = ""
if language == 'fr':
bars = ("DDPS", "DFI", "AS-MPC", "DFJP", "DEFR", "DETEC", "DFAE", "Parl", "ChF", "DFF", "AF", "TF")
else:
bars = ("VBS", "EDI", "AB-BA", "EJPD", "WBF", "UVEK", "EDA", "Parl", "BK", "EFD", "BV", "BGer")
# Make the prediction with the 1000 first characters
results = pipe(text[0:1000], return_all_scores=True)
rates = [row["score"] for row in results[0]]
# Bar chart
y_pos = np.arange(len(bars))
plt.barh(y_pos, rates)
plt.yticks(y_pos, bars)
# Set the output text
name = ""
maxRate = np.max(rates)
maxIndex = np.argmax(rates)
# ML model not sure if highest probability < 60%
if maxRate < 0.6:
# de / fr
if language == 'de':
name = "Das ML-Modell ist nicht sicher. Das Departement könnte sein : \n\n"
else:
name = "Le modèle ML n'est pas sûr. Le département pourrait être : \n\n"
i = 0
# Show each department that has a probability > 10%
while i == 0:
if rates[maxIndex] >= 0.1:
name = name + "\t" + str(rates[maxIndex])[2:4] + "%" + "\t\t\t\t\t" + bars[maxIndex] + "\n"
rates[maxIndex] = 0
maxIndex = np.argmax(rates)
else:
i = 1
# ML model pretty sure, show only one department
else:
name = str(maxRate)[2:4] + "%" + "\t\t\t\t\t\t" + bars[maxIndex]
# Save the bar chart as png and load it (enables better display)
plt.savefig('rates.png')
im = imageio.imread('rates.png')
return name, im
# display the UI
interface = gr.Interface(fn=attribution, layout="vertical",
inputs=[gr.inputs.Textbox(lines=20, placeholder="Geben Sie bitte den Titel und den Sumbmitted Text des Vorstoss ein.\nVeuillez entrer le titre et le Submitted Text de la requête.")],
outputs=['text', 'image'])
interface.launch() |