BK-AI commited on
Commit
3a52501
1 Parent(s): 226545e

initial commit building on PoC

Browse files
app.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ import gradio as gr
5
+ import numpy as np
6
+ import requests
7
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer, TextClassificationPipeline, pipeline
8
+ from langdetect import detect
9
+ from matplotlib import pyplot as plt
10
+ import imageio
11
+
12
+ # Load the model
13
+ model = AutoModelForSequenceClassification.from_pretrained("saved_model")
14
+ tokenizer = AutoTokenizer.from_pretrained("saved_model")
15
+ pipe = TextClassificationPipeline(model=model, tokenizer=tokenizer)
16
+
17
+ # Function called by the UI
18
+ def attribution(text):
19
+
20
+ # Clean the plot
21
+ plt.clf()
22
+
23
+ # Detect the language
24
+ language = detect(text)
25
+
26
+ # Translate the input in german if necessary
27
+ if language == 'fr':
28
+ translator = pipeline("translation", model="Helsinki-NLP/opus-mt-fr-de")
29
+ translatedText = translator(text[0:1000])
30
+ text = translatedText[0]["translation_text"]
31
+ elif language != 'de':
32
+ return "The language is not recognized, it must be either in German or in French.", None
33
+
34
+ # Set the bars of the bar chart
35
+ bars = ""
36
+ if language == 'fr':
37
+ bars = ("DDPS", "DFI", "AS-MPC", "DFJP", "DEFR", "DETEC", "DFAE", "Parl", "ChF", "DFF", "AF", "TF")
38
+ else:
39
+ bars = ("VBS", "EDI", "AB-BA", "EJPD", "WBF", "UVEK", "EDA", "Parl", "BK", "EFD", "BV", "BGer")
40
+
41
+ # Make the prediction with the 1000 first characters
42
+ results = pipe(text[0:1000], return_all_scores=True)
43
+ rates = [row["score"] for row in results[0]]
44
+
45
+ # Bar chart
46
+ y_pos = np.arange(len(bars))
47
+ plt.barh(y_pos, rates)
48
+ plt.yticks(y_pos, bars)
49
+
50
+ # Set the output text
51
+ name = ""
52
+ maxRate = np.max(rates)
53
+ maxIndex = np.argmax(rates)
54
+
55
+ # ML model not sure if highest probability < 60%
56
+ if maxRate < 0.6:
57
+ # de / fr
58
+ if language == 'de':
59
+ name = "Das ML-Modell ist nicht sicher. Das Departement könnte sein : \n\n"
60
+ else:
61
+ name = "Le modèle ML n'est pas sûr. Le département pourrait être : \n\n"
62
+ i = 0
63
+ # Show each department that has a probability > 10%
64
+ while i == 0:
65
+ if rates[maxIndex] >= 0.1:
66
+ name = name + "\t" + str(rates[maxIndex])[2:4] + "%" + "\t\t\t\t\t" + bars[maxIndex] + "\n"
67
+ rates[maxIndex] = 0
68
+ maxIndex = np.argmax(rates)
69
+ else:
70
+ i = 1
71
+ # ML model pretty sure, show only one department
72
+ else:
73
+ name = str(maxRate)[2:4] + "%" + "\t\t\t\t\t\t" + bars[maxIndex]
74
+
75
+ # Save the bar chart as png and load it (enables better display)
76
+ plt.savefig('rates.png')
77
+ im = imageio.imread('rates.png')
78
+
79
+ return name, im
80
+
81
+
82
+ # display the UI
83
+ interface = gr.Interface(fn=attribution,
84
+ inputs=[gr.inputs.Textbox(lines=20, placeholder="Geben Sie bitte den Titel und den Sumbmitted Text des Vorstoss ein.\nVeuillez entrer le titre et le Submitted Text de la requête.")],
85
+ outputs=['text', 'image'])
86
+ interface.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ numpy
2
+ transformers
3
+ langdetect
4
+ matplotlib
5
+ imageio
6
+ torch
7
+ sentencepiece
saved_model/config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f1dd5122dedc8fdf6eb1ec32b25f3246f8c3c64432abfd4d9bad4b626f378fc4
3
+ size 1255
saved_model/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:766bc088608d7eb221c530029f2a704887b4072dab8b79448ec89729aef0bd87
3
+ size 436430445
saved_model/special_tokens_map.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:303df45a03609e4ead04bc3dc1536d0ab19b5358db685b6f3da123d05ec200e3
3
+ size 112
saved_model/tokenizer.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5d6f6affc6b91020cabef56fe9289907e34a89e7f3463a93250c0d94cc61000d
3
+ size 726371
saved_model/tokenizer_config.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5ed472c8edcb18869d09d7bc852465911b105dd301fda14b4283b01577a5ebd7
3
+ size 327
saved_model/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d35049b1861176f257b17db726fad1ace03c2b81216d26e34180592cfe717fa2
3
+ size 3183
saved_model/vocab.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:982f8396ec746db0ed414dcc4789398ab6b365663cada50f776afb905dacbb61
3
+ size 254729