File size: 6,027 Bytes
92be70e
 
5da0eba
92be70e
 
 
c09ae13
 
92be70e
5da0eba
 
92be70e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70ebd67
ff69146
92be70e
 
70ebd67
92be70e
 
70ebd67
92be70e
 
5da0eba
70ebd67
 
92be70e
 
 
 
 
 
 
5da0eba
92be70e
 
5da0eba
92be70e
 
 
 
 
 
 
 
 
 
 
 
 
 
b3cae23
92be70e
 
 
 
 
 
 
 
 
8ac7f12
92be70e
1c9f7e0
92be70e
 
20062d7
37fe09a
 
 
1fd8303
 
ff69146
 
1fd8303
ab91900
 
 
 
 
1fd8303
92be70e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import datetime
import gradio as gr
from huggingface_hub import hf_hub_download
import fasttext, torch, clip
from sentence_transformers import SentenceTransformer, util
 
model_en, _ = clip.load("ViT-B/32")
model_multi = SentenceTransformer("sentence-transformers/clip-ViT-B-32-multilingual-v1")

fasttext_model = fasttext.load_model(hf_hub_download("julien-c/fasttext-language-id", "lid.176.bin"))

def prep_examples():
    example_text1 = "Coronavirus disease (COVID-19) is an infectious disease caused by the SARS-CoV-2 virus. Most \
    people who fall sick with COVID-19 will experience mild to moderate symptoms and recover without special treatment. \
    However, some will become seriously ill and require medical attention."
    example_labels1 = "business;;health related;;politics;;climate change"

    example_text2 = "Elephants are"
    example_labels2 = "big;;small;;strong;;fast;;carnivorous"

    example_text3 = "Elephants"
    example_labels3 = "are big;;can be very small;;generally not strong enough;;are faster than you think"

    example_text4 = "Dogs are man's best friend"
    example_labels4 = "positive;;negative;;neutral"

    example_text5 = "Şampiyonlar Ligi’nde 5. hafta oynanan karşılaşmaların ardından sona erdi. Real Madrid, \
    Inter ve Sporting oynadıkları mücadeleler sonrasında Son 16 turuna yükselmeyi başardı. \
    Gecenin dev mücadelesinde ise Manchester City, PSG’yi yenerek liderliği garantiledi."
    example_labels5 = "dünya;;ekonomi;;kültür;;siyaset;;spor;;teknoloji"

    example_text6 = "Letzte Woche gab es einen Selbstmord in einer nahe gelegenen kolonie"
    example_labels6 = "verbrechen;;tragödie;;stehlen"

    example_text7 = "El autor se perfila, a los 50 años de su muerte, como uno de los grandes de su siglo"
    example_labels7 = "cultura;;sociedad;;economia;;salud;;deportes"

    example_text8 = "Россия в среду заявила, что военные учения в аннексированном Москвой Крыму закончились \
    и что солдаты возвращаются в свои гарнизоны, на следующий день после того, как она объявила о первом выводе \
    войск от границ Украины."
    example_labels8 = "новости;;комедия"

    example_text9 = "I quattro registi - Federico Fellini, Pier Paolo Pasolini, Bernardo Bertolucci e Vittorio De Sica - \
    hanno utilizzato stili di ripresa diversi, ma hanno fortemente influenzato le giovani generazioni di registi."
    example_labels9 = "cinema;;politica;;cibo"

    example_text10 = "Ja, vi elsker dette landet,\
    som det stiger frem,\
    furet, værbitt over vannet,\
    med de tusen hjem.\
    Og som fedres kamp har hevet\
    det av nød til seir"
    example_labels10 = "helse;;sport;;religion;;mat;;patriotisme og nasjonalisme"

    example_text11 = "Amar sonar bangla ami tomay bhalobasi"
    example_labels11 = "bhalo;;kharap"

    examples = [
        [example_text1, example_labels1],
        [example_text2, example_labels2],
        [example_text3, example_labels3],
        [example_text4, example_labels4],
        [example_text5, example_labels5],
        [example_text6, example_labels6],
        [example_text7, example_labels7],
        [example_text8, example_labels8],
        [example_text9, example_labels9],
        [example_text10, example_labels10],
        [example_text11, example_labels11]]

    return examples

def detect_lang(text):
    DetectorFactory.seed = 0
    seq_lang = 'en'
    
    text = text.replace('\n', ' ')

    try:
        seq_lang = fasttext_model.predict(text, k=1)[0][0].split("__label__")[1]
    except:
        print("Language detection failed!",
              "Date:{}, Sequence: {}".format(
			  str(datetime.datetime.now()),
			  text))

    return seq_lang

def sequence_to_classify(text, labels):
	lang = detect_lang(text)
	if lang == 'en':
		model = model_en
		hypothesis_template = "{}"
	else:
		model = model_multi
		hypothesis_template = "{}"
	
	labels = [hypothesis_template.format(label) for label in labels.split(";;")]
	
	if str(type(model)) == "<class 'clip.model.CLIP'>":
		text_tokens = clip.tokenize(text)
		text_features = model.encode_text(text_tokens)
		
		label_tokens = clip.tokenize(labels)
		labels_features = model.encode_text(label_tokens)
	else:    
		text_features = torch.tensor(model.encode(text))
		labels_features = torch.tensor(self.model.encode(labels))
		
	sim_scores = util.cos_sim(text_features, labels_features)
	preds = []
	for textlet, sim_score in zip([text], sim_scores):
		out = []
		pred = {}
		for raw_score in sim_score:
			out.append(raw_score.item() * 100)
		probs = torch.tensor([out])
		probs = probs.softmax(dim=-1).cpu().numpy()
		scores = list(probs.flatten())

		sorted_sl = sorted(zip(scores, labels), key=lambda t:t[0], reverse=True)  

		pred["text"] = textlet
		pred["scores"], pred["labels"] = zip(*sorted_sl)
		preds.append(pred)
	
	if len(preds) == 1:
	  preds = preds[0]

	predicted_labels = list(preds['labels'])
	predicted_scores = list(preds['scores'])
	print(predicted_labels)
	print(predicted_scores)
	output = {idx: float(predicted_scores.pop(0)) for idx in predicted_labels}
	print("Date:{}, Sequence:{}, Labels: {}".format(
		str(datetime.datetime.now()),
		text,
		predicted_labels))

	return output

iface = gr.Interface(
    title="Alternate Zero-shot Multi-label Multilingual NLP Classifier",
    description="Work in progress.",
    fn=sequence_to_classify,
    inputs=[gr.inputs.Textbox(lines=10,
        label="Please enter the text you would like to classify...",
        placeholder="Text here..."),
        gr.inputs.Textbox(lines=2,
        label="Please enter the candidate labels (separated by 2 consecutive semicolons)...",
        placeholder="Labels here separated by ;;")],
    outputs=gr.outputs.Label(num_top_classes=5),
    #interpretation="default",
    examples=prep_examples())

iface.launch()