DoBaumann commited on
Commit
8ec661c
·
1 Parent(s): 2f4692b

bert code into app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -1
app.py CHANGED
@@ -1,7 +1,89 @@
1
  import gradio as gr
 
 
 
 
 
 
2
 
 
3
  def greet(name):
4
  return "Hello " + name + "!!"
5
 
6
  iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import numpy as np
3
+ # import requests
4
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, TextClassificationPipeline, pipeline
5
+ from langdetect import detect
6
+ from matplotlib import pyplot as plt
7
+ import imageio
8
 
9
+ """
10
  def greet(name):
11
  return "Hello " + name + "!!"
12
 
13
  iface = gr.Interface(fn=greet, inputs="text", outputs="text")
14
+ iface.launch()
15
+ """
16
+ # Load the model
17
+ model = AutoModelForSequenceClassification.from_pretrained("saved_model")
18
+ tokenizer = AutoTokenizer.from_pretrained("saved_model")
19
+ pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer)
20
+
21
+
22
+ # Function called by the UI
23
+ def attribution(text):
24
+ # Clean the plot
25
+ plt.clf()
26
+
27
+ # Detect the language
28
+ language = detect(text)
29
+
30
+ # Translate the input in german if necessary
31
+ if language == 'fr':
32
+ translator = pipeline("translation", model="Helsinki-NLP/opus-mt-fr-de")
33
+ translatedText = translator(text)
34
+ text = translatedText[0]["translation_text"]
35
+
36
+ # Set the bars of the bar chart
37
+ bars = ""
38
+ if language == 'fr':
39
+ bars = ("DDPS", "DFI", "AS-MPC", "DFJP", "DEFR", "DETEC", "DFAE", "Parl", "ChF", "DFF", "AF", "TF")
40
+ else:
41
+ bars = ("VBS", "EDI", "AB-BA", "EJPD", "WBF", "UVEK", "EDA", "Parl", "BK", "EFD", "BV", "BGer")
42
+
43
+ # Make the prediction with the 512 first characters
44
+ results = pipe(text[0:511], return_all_scores=True)
45
+ rates = [row["score"] for row in results[0]]
46
+
47
+ # Bar chart
48
+ y_pos = np.arange(len(bars))
49
+ plt.barh(y_pos, rates)
50
+ plt.yticks(y_pos, bars)
51
+
52
+ # Set the output text
53
+ name = ""
54
+ maxRate = np.max(rates)
55
+ maxIndex = np.argmax(rates)
56
+
57
+ # ML model not sure if highest probability < 60%
58
+ if maxRate < 0.6:
59
+ # de / fr
60
+ if language == 'de':
61
+ name = "Das ML-Modell ist nicht sicher. Das Departement könnte sein : \n\n"
62
+ else:
63
+ name = "Le modèle ML n'est pas sûr. Le département pourrait être : \n\n"
64
+ i = 0
65
+ # Show each department that has a probability > 10%
66
+ while i == 0:
67
+ if rates[maxIndex] >= 0.1:
68
+ name = name + "\t" + str(rates[maxIndex])[2:4] + "%" + "\t\t\t\t\t" + bars[maxIndex] + "\n"
69
+ rates[maxIndex] = 0
70
+ maxIndex = np.argmax(rates)
71
+ else:
72
+ i = 1
73
+ # ML model pretty sure, show only one department
74
+ else:
75
+ name = str(maxRate)[2:4] + "%" + "\t\t\t\t\t\t" + bars[maxIndex]
76
+
77
+ # Save the bar chart as png and load it (enables better display)
78
+ plt.savefig('rates.png')
79
+ im = imageio.imread('rates.png')
80
+
81
+ return name, im
82
+
83
+
84
+ # display the UI
85
+ interface = gr.Interface(fn=attribution, layout="vertical",
86
+ inputs=[gr.inputs.Textbox(lines=20,
87
+ placeholder="Geben Sie bitte den Titel und den Sumbmitted Text des Vorstoss ein.\nVeuillez entrer le titre et le Submitted Text de la requête.")],
88
+ outputs=['text', 'image'])
89
+ interface.launch()