File size: 3,046 Bytes
2f4692b
8ec661c
 
 
 
 
 
2f4692b
8ec661c
2f4692b
 
 
 
8ec661c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import gradio as gr
import numpy as np
# import requests
from transformers import AutoModelForSequenceClassification, AutoTokenizer, TextClassificationPipeline, pipeline
from langdetect import detect
from matplotlib import pyplot as plt
import imageio

"""
def greet(name):
    return "Hello " + name + "!!"

iface = gr.Interface(fn=greet, inputs="text", outputs="text")
iface.launch()
"""
# Load the model
model = AutoModelForSequenceClassification.from_pretrained("saved_model")
tokenizer = AutoTokenizer.from_pretrained("saved_model")
pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer)


# Function called by the UI
def attribution(text):
    # Clean the plot
    plt.clf()

    # Detect the language
    language = detect(text)

    # Translate the input in german if necessary
    if language == 'fr':
        translator = pipeline("translation", model="Helsinki-NLP/opus-mt-fr-de")
        translatedText = translator(text)
        text = translatedText[0]["translation_text"]

    # Set the bars of the bar chart
    bars = ""
    if language == 'fr':
        bars = ("DDPS", "DFI", "AS-MPC", "DFJP", "DEFR", "DETEC", "DFAE", "Parl", "ChF", "DFF", "AF", "TF")
    else:
        bars = ("VBS", "EDI", "AB-BA", "EJPD", "WBF", "UVEK", "EDA", "Parl", "BK", "EFD", "BV", "BGer")

    # Make the prediction with the 512 first characters
    results = pipe(text[0:511], return_all_scores=True)
    rates = [row["score"] for row in results[0]]

    # Bar chart
    y_pos = np.arange(len(bars))
    plt.barh(y_pos, rates)
    plt.yticks(y_pos, bars)

    # Set the output text
    name = ""
    maxRate = np.max(rates)
    maxIndex = np.argmax(rates)

    # ML model not sure if highest probability < 60%
    if maxRate < 0.6:
        # de / fr
        if language == 'de':
            name = "Das ML-Modell ist nicht sicher. Das Departement könnte sein : \n\n"
        else:
            name = "Le modèle ML n'est pas sûr. Le département pourrait être : \n\n"
        i = 0
        # Show each department that has a probability > 10%
        while i == 0:
            if rates[maxIndex] >= 0.1:
                name = name + "\t" + str(rates[maxIndex])[2:4] + "%" + "\t\t\t\t\t" + bars[maxIndex] + "\n"
                rates[maxIndex] = 0
                maxIndex = np.argmax(rates)
            else:
                i = 1
    # ML model pretty sure, show only one department
    else:
        name = str(maxRate)[2:4] + "%" + "\t\t\t\t\t\t" + bars[maxIndex]

    # Save the bar chart as png and load it (enables better display)
    plt.savefig('rates.png')
    im = imageio.imread('rates.png')

    return name, im


# display the UI
interface = gr.Interface(fn=attribution, layout="vertical",
                         inputs=[gr.inputs.Textbox(lines=20,
                                                   placeholder="Geben Sie bitte den Titel und den Sumbmitted Text des Vorstoss ein.\nVeuillez entrer le titre et le Submitted Text de la requête.")],
                         outputs=['text', 'image'])
interface.launch()