gundruke commited on
Commit
584d3dc
1 Parent(s): dc9f186

added app file

Browse files
app.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import json
4
+ from nltk.corpus import wordnet
5
+ from transformers import AutoConfig, AutoTokenizer
6
+ from models import BERTLstmCRF
7
+ from huggingface_hub import hf_hub_download
8
+
9
+ checkpoint = "gundruke/bert-lstm-crf-absa"
10
+ config = AutoConfig.from_pretrained(checkpoint)
11
+ id2label = config.id2label
12
+
13
+ tokenizer = AutoTokenizer.from_pretrained("gundruke/bert-lstm-crf-absa")
14
+ model = BERTLstmCRF(config)
15
+
16
+
17
+ repo = "gundruke/bert-lstm-crf-absa"
18
+ filename = "pytorch_model.bin"
19
+ model.load_state_dict(torch.load(hf_hub_download(repo_id=repo, filename=filename),
20
+ map_location=torch.device('cpu')))
21
+
22
+
23
+ def tokenize_text(text):
24
+ tokens = tokenizer.tokenize(text)
25
+ tokenized_text = tokenizer(text)
26
+
27
+ return tokens, tokenized_text
28
+
29
+
30
+ def convert_to_multilabel(label_list):
31
+ multilabel = []
32
+ if "B-POS" in label_list or "I-POS" in label_list:
33
+ multilabel.append("Positive")
34
+ if "B-NEG" in label_list or "I-NEG" in label_list:
35
+ multilabel.append("Negative")
36
+ if "B-NEU" in label_list or "I-NEU" in label_list:
37
+ multilabel.append("Neutral")
38
+
39
+ return " and ".join(multilabel)
40
+
41
+
42
+ def classify_word(word, dictionary):
43
+ synsets = wordnet.synsets(word)
44
+ if synsets:
45
+ hypernyms = synsets[0].hypernyms() # Get the hypernym of the first synset
46
+ if hypernyms:
47
+ nltk_result = hypernyms[0].lemmas()[0].name()
48
+ else:
49
+ nltk_result = "Unknown"
50
+ else:
51
+ nltk_result = "Unknown"
52
+
53
+ if word in dictionary:
54
+ result = dictionary[word]
55
+ elif nltk_result in ['atmosphere', 'drinks', 'food', 'price', 'service']:
56
+ result = nltk_result
57
+ else:
58
+ result = 'other'
59
+
60
+ return result, nltk_result
61
+
62
+
63
+ def get_outputs(tokenized_text):
64
+ input_ids = tokenized_text["input_ids"]
65
+ token_type_ids = tokenized_text["token_type_ids"]
66
+ attention_mask = tokenized_text["attention_mask"]
67
+
68
+ inputs = {
69
+ 'input_ids': torch.tensor([input_ids]),
70
+ 'token_type_ids': torch.tensor([token_type_ids]),
71
+ 'attention_mask': torch.tensor([attention_mask])
72
+ }
73
+
74
+ with torch.no_grad():
75
+ outputs = model(**inputs)
76
+
77
+ labels = [id2label.get(i) for i in torch.flatten(outputs[1]).tolist()][1:-1]
78
+
79
+ return labels
80
+
81
+
82
+ def join_wordpieces(tokens, labels):
83
+ joined_tokens = []
84
+
85
+ for token, label in zip(tokens, labels):
86
+ if label == "O":
87
+ label = None
88
+ if token.startswith("##"):
89
+ last_token = joined_tokens[-1][0]
90
+ joined_tokens[-1] = (last_token+token[2:], label)
91
+ else:
92
+ joined_tokens.append((token, label))
93
+
94
+ return joined_tokens
95
+
96
+
97
+ def get_category(word, dict_file):
98
+ with open(dict_file, "r") as file:
99
+ dictionary = json.load(file)
100
+
101
+ r, n = classify_word(word, dictionary)
102
+
103
+ return r
104
+
105
+
106
+ def text_analysis(text):
107
+ tokens, tokenized_text = tokenize_text(text)
108
+ labels = get_outputs(tokenized_text)
109
+ multilabel = convert_to_multilabel(labels)
110
+
111
+ token_tuple = join_wordpieces(tokens, labels)
112
+ tokenized_text["tokens"] = tokens
113
+
114
+ categories = []
115
+ for tok in token_tuple:
116
+ if tok[1]:
117
+ categories.append((tok[0], get_category(tok[0], "dictionary.json")))
118
+ else:
119
+ categories.append((tok[0], None))
120
+
121
+
122
+
123
+
124
+ return token_tuple, multilabel, categories
125
+
126
+
127
+ theme = gr.themes.Base()
128
+ with gr.Blocks(theme=theme) as demo:
129
+ with gr.Column():
130
+ input_textbox = gr.Textbox(placeholder="Enter sentence here...")
131
+ btn = gr.Button("Submit", variant="primary")
132
+
133
+ btn.click(fn=text_analysis,
134
+ inputs=input_textbox,
135
+ outputs=[gr.HighlightedText(label="Token labels"),
136
+ gr.Label(label="Multilabel classification"),
137
+ gr.HighlightedText(label="Category")],
138
+ queue=False)
139
+
140
+ with gr.Column():
141
+ examples=[
142
+ ["I've been coming here as a child and always come back for the taste."],
143
+ ["The tea is great and all the sweets are homemade."],
144
+ ["Strong build which really adds to its durability but poor battery life."],
145
+ ["We loved the recommendation for the wine, and I think the eggplant parmigiana appetizer should become an entree."]
146
+ ]
147
+ gr.Examples(examples, input_textbox)
148
+
149
+ demo.launch()
dictionary.json ADDED
@@ -0,0 +1,899 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "afternoon": "other",
3
+ "alfredo": "food",
4
+ "alternatives": "other",
5
+ "amazin": "other",
6
+ "ambiance": "atmosphere",
7
+ "ambience": "atmosphere",
8
+ "anchovy": "food",
9
+ "and": "other",
10
+ "apetizers": "food",
11
+ "appetizer": "food",
12
+ "appetizers": "food",
13
+ "apple": "food",
14
+ "area": "other",
15
+ "argentine": "food",
16
+ "array": "other",
17
+ "artifical": "other",
18
+ "asian": "food",
19
+ "asparagus": "food",
20
+ "assortment": "food",
21
+ "atmoshere": "atmosphere",
22
+ "atmosphere": "atmosphere",
23
+ "attitude": "service",
24
+ "avocado": "food",
25
+ "back": "other",
26
+ "baclava": "food",
27
+ "baked": "food",
28
+ "ball": "other",
29
+ "banana": "food",
30
+ "bar": "atmosphere",
31
+ "barley": "food",
32
+ "bartender": "service",
33
+ "bartenders": "service",
34
+ "base": "food",
35
+ "bathroom": "service",
36
+ "bbq": "food",
37
+ "beans": "food",
38
+ "beef": "food",
39
+ "beer": "drinks",
40
+ "beers": "drinks",
41
+ "beets": "food",
42
+ "benedict": "food",
43
+ "bi": "other",
44
+ "big": "other",
45
+ "bill": "price",
46
+ "billed": "price",
47
+ "bistro": "atmosphere",
48
+ "black": "food",
49
+ "blended": "food",
50
+ "blue": "food",
51
+ "blueberry": "food",
52
+ "booth": "atmosphere",
53
+ "bottle": "drinks",
54
+ "boutique": "other",
55
+ "braised": "food",
56
+ "branzini": "food",
57
+ "bread": "food",
58
+ "breads": "food",
59
+ "breakfast": "food",
60
+ "brisket": "food",
61
+ "brulee": "food",
62
+ "brunch": "food",
63
+ "buffalo": "food",
64
+ "burger": "food",
65
+ "burgers": "food",
66
+ "burrito": "food",
67
+ "butter": "food",
68
+ "by": "other",
69
+ "caeser": "food",
70
+ "cajun": "food",
71
+ "cake": "food",
72
+ "cakes": "food",
73
+ "calamari": "food",
74
+ "calf": "food",
75
+ "canai": "food",
76
+ "candlelight": "atmosphere",
77
+ "carinthia": "food",
78
+ "carrots": "food",
79
+ "cart": "other",
80
+ "casseroles": "food",
81
+ "casual": "atmosphere",
82
+ "catfish": "food",
83
+ "caviar": "food",
84
+ "chair": "atmosphere",
85
+ "chairs": "atmosphere",
86
+ "cheese": "food",
87
+ "cheeses": "food",
88
+ "chef": "service",
89
+ "cherry": "food",
90
+ "chick": "food",
91
+ "chicken": "food",
92
+ "chickens": "food",
93
+ "chickpea": "food",
94
+ "chickpeas": "food",
95
+ "chili": "food",
96
+ "chillis": "food",
97
+ "chinese": "food",
98
+ "chocolate": "food",
99
+ "choices": "other",
100
+ "chops": "food",
101
+ "chorizo": "food",
102
+ "churrasco": "food",
103
+ "cinna": "food",
104
+ "classics": "food",
105
+ "clientele": "other",
106
+ "cobb": "food",
107
+ "cocktail": "drinks",
108
+ "cocoa": "food",
109
+ "coconut": "food",
110
+ "cod": "food",
111
+ "codfish": "food",
112
+ "coffee": "drinks",
113
+ "cold": "food",
114
+ "concoctions": "drinks",
115
+ "confitte": "food",
116
+ "cooked": "food",
117
+ "cookie": "food",
118
+ "cookies": "food",
119
+ "corn": "food",
120
+ "corner": "other",
121
+ "cosi": "other",
122
+ "cost": "price",
123
+ "counter": "other",
124
+ "courses": "food",
125
+ "crab": "food",
126
+ "crabcakes": "food",
127
+ "cranberry": "food",
128
+ "creamy": "food",
129
+ "creme": "food",
130
+ "creole": "food",
131
+ "crepes": "food",
132
+ "crust": "food",
133
+ "crusted": "food",
134
+ "cuccumber": "food",
135
+ "cuisine": "food",
136
+ "curried": "food",
137
+ "curry": "food",
138
+ "dabs": "food",
139
+ "dance": "atmosphere",
140
+ "day": "other",
141
+ "de": "other",
142
+ "deco": "atmosphere",
143
+ "decor": "atmosphere",
144
+ "decoration": "atmosphere",
145
+ "delights": "food",
146
+ "delivary": "service",
147
+ "delivers": "service",
148
+ "delivery": "service",
149
+ "delmonico": "food",
150
+ "deserts": "food",
151
+ "design": "atmosphere",
152
+ "dessert": "food",
153
+ "desserts": "food",
154
+ "dill": "food",
155
+ "dine": "service",
156
+ "dining": "atmosphere",
157
+ "dinner": "food",
158
+ "dinners": "food",
159
+ "dip": "food",
160
+ "dipping": "food",
161
+ "disco": "atmosphere",
162
+ "dish": "food",
163
+ "dishes": "food",
164
+ "dishs": "food",
165
+ "display": "other",
166
+ "dog": "food",
167
+ "dogs": "food",
168
+ "donut": "food",
169
+ "downstairs": "other",
170
+ "dressed": "food",
171
+ "dressing": "food",
172
+ "drink": "drinks",
173
+ "drinks": "drinks",
174
+ "duck": "food",
175
+ "dumplings": "food",
176
+ "duo": "other",
177
+ "eastern": "food",
178
+ "eating": "other",
179
+ "egg": "food",
180
+ "eggplant": "food",
181
+ "eggs": "food",
182
+ "emiliana": "food",
183
+ "empenadas": "food",
184
+ "english": "food",
185
+ "entertainment": "atmosphere",
186
+ "entree": "food",
187
+ "entrees": "food",
188
+ "erbazzone": "food",
189
+ "escargot": "food",
190
+ "experience": "atmosphere",
191
+ "fajita": "food",
192
+ "falafal": "food",
193
+ "falafel": "food",
194
+ "famous": "other",
195
+ "fare": "food",
196
+ "female": "other",
197
+ "fennel": "food",
198
+ "fettuccine": "food",
199
+ "fettucino": "food",
200
+ "filet": "food",
201
+ "fish": "food",
202
+ "fixe": "price",
203
+ "flan": "food",
204
+ "flavor": "food",
205
+ "flavored": "food",
206
+ "flavors": "food",
207
+ "floor": "atmosphere",
208
+ "focacchia": "food",
209
+ "foie": "food",
210
+ "folding": "other",
211
+ "food": "food",
212
+ "foods": "food",
213
+ "fooood": "food",
214
+ "for": "other",
215
+ "fork": "other",
216
+ "fortune": "other",
217
+ "french": "food",
218
+ "fresh": "food",
219
+ "fried": "food",
220
+ "fries": "food",
221
+ "frosty": "food",
222
+ "fruit": "food",
223
+ "fusion": "food",
224
+ "garden": "atmosphere",
225
+ "garlic": "food",
226
+ "gelato": "food",
227
+ "ginger": "food",
228
+ "glass": "other",
229
+ "gnocchi": "food",
230
+ "goat": "food",
231
+ "gorgonzola": "food",
232
+ "gosht": "food",
233
+ "grand": "other",
234
+ "gras": "food",
235
+ "gratin": "food",
236
+ "gratuity": "price",
237
+ "greek": "food",
238
+ "green": "food",
239
+ "greens": "food",
240
+ "grill": "food",
241
+ "grilled": "food",
242
+ "ground": "food",
243
+ "guacamole": "food",
244
+ "ham": "food",
245
+ "hamburger": "food",
246
+ "happy": "other",
247
+ "hibiscus": "food",
248
+ "hint": "other",
249
+ "homemade": "food",
250
+ "honey": "food",
251
+ "hong": "food",
252
+ "host": "service",
253
+ "hostess": "service",
254
+ "hot": "food",
255
+ "hotdogs": "food",
256
+ "hour": "other",
257
+ "humus": "food",
258
+ "ice": "food",
259
+ "iced": "drinks",
260
+ "in": "other",
261
+ "indian": "food",
262
+ "ingredients": "food",
263
+ "interior": "atmosphere",
264
+ "italian": "food",
265
+ "items": "other",
266
+ "jap": "food",
267
+ "japanese": "food",
268
+ "jazz": "atmosphere",
269
+ "jerusalem": "food",
270
+ "juice": "drinks",
271
+ "juices": "drinks",
272
+ "kalmata": "food",
273
+ "kebabs": "food",
274
+ "kickers": "food",
275
+ "kimono": "other",
276
+ "king": "food",
277
+ "kitchen": "service",
278
+ "knots": "food",
279
+ "kompot": "drinks",
280
+ "kong": "other",
281
+ "korean": "food",
282
+ "lamb": "food",
283
+ "large": "other",
284
+ "lasagna": "food",
285
+ "latkes": "food",
286
+ "latte": "drinks",
287
+ "leaves": "food",
288
+ "lemon": "food",
289
+ "lemonade": "drinks",
290
+ "lettuce": "food",
291
+ "li": "other",
292
+ "life": "other",
293
+ "light": "atmosphere",
294
+ "lime": "food",
295
+ "linguini": "food",
296
+ "lobster": "food",
297
+ "location": "atmosphere",
298
+ "lomo": "food",
299
+ "long": "other",
300
+ "lovely": "atmosphere",
301
+ "low": "other",
302
+ "lunch": "food",
303
+ "lychee": "food",
304
+ "madison": "food",
305
+ "main": "food",
306
+ "make": "other",
307
+ "maki": "food",
308
+ "mango": "food",
309
+ "margherita": "food",
310
+ "margarita": "drinks",
311
+ "martini": "drinks",
312
+ "martinis": "drinks",
313
+ "masala": "food",
314
+ "mashed": "food",
315
+ "massaman": "food",
316
+ "matzo": "food",
317
+ "meal": "food",
318
+ "meat": "food",
319
+ "meatballs": "food",
320
+ "mediterranean": "food",
321
+ "melon": "food",
322
+ "menu": "food",
323
+ "meringue": "food",
324
+ "met": "other",
325
+ "microbrews": "drinks",
326
+ "midtown": "other",
327
+ "milk": "food",
328
+ "mimosa": "drinks",
329
+ "minestrone": "food",
330
+ "mixed": "food",
331
+ "mojito": "drinks",
332
+ "monkfish": "food",
333
+ "more": "other",
334
+ "mousse": "food",
335
+ "muffin": "food",
336
+ "muffins": "food",
337
+ "mushroom": "food",
338
+ "mushrooms": "food",
339
+ "music": "atmosphere",
340
+ "musical": "atmosphere",
341
+ "mustard": "food",
342
+ "nasi": "food",
343
+ "natural": "atmosphere",
344
+ "noodles": "food",
345
+ "north": "other",
346
+ "nova": "food",
347
+ "oatmeal": "food",
348
+ "oil": "food",
349
+ "olives": "food",
350
+ "omelette": "food",
351
+ "onion": "food",
352
+ "open": "other",
353
+ "opener": "other",
354
+ "option": "other",
355
+ "options": "other",
356
+ "orange": "food",
357
+ "organic": "food",
358
+ "out": "other",
359
+ "outside": "other",
360
+ "over": "other",
361
+ "paella": "food",
362
+ "pan": "food",
363
+ "pancake": "food",
364
+ "pancakes": "food",
365
+ "parfait": "food",
366
+ "pasta": "food",
367
+ "pastries": "food",
368
+ "patties": "food",
369
+ "peanut": "food",
370
+ "pear": "food",
371
+ "pears": "food",
372
+ "pecan": "food",
373
+ "peking": "food",
374
+ "pepperoni": "food",
375
+ "persian": "food",
376
+ "pesto": "food",
377
+ "phad": "food",
378
+ "philly": "food",
379
+ "pho": "food",
380
+ "pia": "food",
381
+ "pie": "food",
382
+ "pierogies": "food",
383
+ "pierogi": "food",
384
+ "pies": "food",
385
+ "pigeon": "food",
386
+ "pita": "food",
387
+ "pizza": "food",
388
+ "place": "atmosphere",
389
+ "platters": "food",
390
+ "plate": "food",
391
+ "plates": "food",
392
+ "pocket": "food",
393
+ "pomegranate": "food",
394
+ "pop": "food",
395
+ "pops": "food",
396
+ "popular": "other",
397
+ "porc": "food",
398
+ "pork": "food",
399
+ "pot": "food",
400
+ "potato": "food",
401
+ "potatoes": "food",
402
+ "prawns": "food",
403
+ "prix": "price",
404
+ "prosciutto": "food",
405
+ "prosecco": "drinks",
406
+ "protein": "food",
407
+ "pub": "atmosphere",
408
+ "puff": "food",
409
+ "puffs": "food",
410
+ "pumpkin": "food",
411
+ "quail": "food",
412
+ "quartino": "drinks",
413
+ "quick": "service",
414
+ "quiche": "food",
415
+ "quinoa": "food",
416
+ "rack": "food",
417
+ "radish": "food",
418
+ "ramp": "food",
419
+ "ramyeon": "food",
420
+ "ravioli": "food",
421
+ "raw": "food",
422
+ "razor": "food",
423
+ "red": "food",
424
+ "refreshing": "atmosphere",
425
+ "restaurant": "other",
426
+ "restauraunt": "other",
427
+ "restaurantthe": "other",
428
+ "restaurants": "other",
429
+ "resturant": "other",
430
+ "roast": "food",
431
+ "roasted": "food",
432
+ "roll": "food",
433
+ "rolls": "food",
434
+ "romaine": "food",
435
+ "room": "atmosphere",
436
+ "root": "food",
437
+ "rose": "drinks",
438
+ "rotisserie": "food",
439
+ "rueben": "food",
440
+ "rum": "drinks",
441
+ "rump": "food",
442
+ "saganaki": "food",
443
+ "salad": "food",
444
+ "salads": "food",
445
+ "salami": "food",
446
+ "salmon": "food",
447
+ "sandwich": "food",
448
+ "sandwiches": "food",
449
+ "sangria": "drinks",
450
+ "sauce": "food",
451
+ "sauces": "food",
452
+ "sausage": "food",
453
+ "savory": "food",
454
+ "scallops": "food",
455
+ "schnitzel": "food",
456
+ "seasonal": "food",
457
+ "seaweed": "food",
458
+ "selection": "food",
459
+ "service": "service",
460
+ "services": "service",
461
+ "set": "other",
462
+ "shake": "drinks",
463
+ "shakes": "drinks",
464
+ "shakshuka": "food",
465
+ "shawarma": "food",
466
+ "shellfish": "food",
467
+ "sherry": "drinks",
468
+ "shiitake": "food",
469
+ "short": "other",
470
+ "shot": "drinks",
471
+ "shots": "drinks",
472
+ "shrimp": "food",
473
+ "side": "food",
474
+ "sides": "food",
475
+ "siu": "food",
476
+ "sliced": "food",
477
+ "sliders": "food",
478
+ "smoked": "food",
479
+ "smoothie": "drinks",
480
+ "smoothies": "drinks",
481
+ "soba": "food",
482
+ "soft": "other",
483
+ "soju": "drinks",
484
+ "soup": "food",
485
+ "soups": "food",
486
+ "south": "other",
487
+ "southern": "food",
488
+ "soya": "food",
489
+ "spanish": "food",
490
+ "sparkling": "drinks",
491
+ "special": "other",
492
+ "specials": "other",
493
+ "spice": "food",
494
+ "spicy": "food",
495
+ "spinach": "food",
496
+ "spoons": "other",
497
+ "spritz": "drinks",
498
+ "squash": "food",
499
+ "squeezed": "other",
500
+ "sriracha": "food",
501
+ "st": "other",
502
+ "stadium": "atmosphere",
503
+ "steak": "food",
504
+ "steaks": "food",
505
+ "stew": "food",
506
+ "sticks": "food",
507
+ "stir": "food",
508
+ "stix": "food",
509
+ "stone": "other",
510
+ "strawberry": "food",
511
+ "strudel": "food",
512
+ "style": "other",
513
+ "sugarcane": "food",
514
+ "sugarfish": "food",
515
+ "sukiyaki": "food",
516
+ "sundae": "food",
517
+ "sundays": "food",
518
+ "super": "other",
519
+ "sushi": "food",
520
+ "swedish": "food",
521
+ "sweet": "food",
522
+ "sweetbread": "food",
523
+ "sweetbreads": "food",
524
+ "sweets": "food",
525
+ "swiss": "food",
526
+ "swordfish": "food",
527
+ "szechuan": "food",
528
+ "table": "other",
529
+ "taco": "food",
530
+ "tacos": "food",
531
+ "tahini": "food",
532
+ "takeout": "service",
533
+ "tapas": "food",
534
+ "tart": "food",
535
+ "tartare": "food",
536
+ "tartufo": "food",
537
+ "tea": "drinks",
538
+ "teapot": "drinks",
539
+ "teas": "drinks",
540
+ "tempura": "food",
541
+ "tenderloin": "food",
542
+ "teriyaki": "food",
543
+ "thai": "food",
544
+ "thali": "food",
545
+ "the": "other",
546
+ "thee": "other",
547
+ "then": "other",
548
+ "thin": "other",
549
+ "thursday": "other",
550
+ "tikka": "food",
551
+ "to": "other",
552
+ "toast": "food",
553
+ "toasts": "food",
554
+ "tofu": "food",
555
+ "toffee": "food",
556
+ "tom": "food",
557
+ "tomyum": "food",
558
+ "tongue": "food",
559
+ "tonkatsu": "food",
560
+ "tony": "food",
561
+ "top": "other",
562
+ "topping": "food",
563
+ "toppings": "food",
564
+ "toro": "food",
565
+ "torte": "food",
566
+ "tortilla": "food",
567
+ "tortillas": "food",
568
+ "tortoise": "food",
569
+ "truffle": "food",
570
+ "truffles": "food",
571
+ "tuna": "food",
572
+ "turkey": "food",
573
+ "turkish": "food",
574
+ "turmeric": "food",
575
+ "turnip": "food",
576
+ "tuscan": "food",
577
+ "twist": "other",
578
+ "udon": "food",
579
+ "umami": "food",
580
+ "unagi": "food",
581
+ "union": "food",
582
+ "up": "other",
583
+ "upbeat": "atmosphere",
584
+ "upside": "other",
585
+ "urchin": "food",
586
+ "us": "other",
587
+ "uzbek": "food",
588
+ "vadai": "food",
589
+ "veal": "food",
590
+ "vegan": "food",
591
+ "vegetable": "food",
592
+ "vegetables": "food",
593
+ "vegetarian": "food",
594
+ "venison": "food",
595
+ "vermicelli": "food",
596
+ "vermouth": "drinks",
597
+ "vietnamese": "food",
598
+ "vindaloo": "food",
599
+ "vinegar": "food",
600
+ "vodka": "drinks",
601
+ "vol": "other",
602
+ "waffle": "food",
603
+ "waffles": "food",
604
+ "wagyu": "food",
605
+ "warm": "atmosphere",
606
+ "wasabi": "food",
607
+ "water": "drinks",
608
+ "watermelon": "food",
609
+ "wednesday": "other",
610
+ "weekend": "other",
611
+ "weekends": "other",
612
+ "weight": "other",
613
+ "wheat": "food",
614
+ "whiskey": "drinks",
615
+ "white": "food",
616
+ "whole": "food",
617
+ "wine": "drinks",
618
+ "wines": "drinks",
619
+ "wing": "food",
620
+ "wings": "food",
621
+ "winter": "other",
622
+ "with": "other",
623
+ "wok": "food",
624
+ "wonton": "food",
625
+ "wrap": "food",
626
+ "wraps": "food",
627
+ "xiao": "food",
628
+ "yakitori": "food",
629
+ "yam": "food",
630
+ "yellow": "food",
631
+ "yogurt": "food",
632
+ "york": "other",
633
+ "yorkshire": "food",
634
+ "yuzu": "food",
635
+ "zealand": "other",
636
+ "zucchini": "food",
637
+ "fontina": "food",
638
+ "staples": "other",
639
+ "ceasar": "food",
640
+ "octopus": "food",
641
+ "dough": "food",
642
+ "candle": "atmosphere",
643
+ "ricotta": "food",
644
+ "tac": "food",
645
+ "scoop": "food",
646
+ "employees": "service",
647
+ "sea": "food",
648
+ "tramezzini": "food",
649
+ "appreciated": "other",
650
+ "collapse": "other",
651
+ "negimaki": "food",
652
+ "napoleon": "food",
653
+ "beverage": "drinks",
654
+ "tip": "service",
655
+ "pleasure": "other",
656
+ "dhosas": "food",
657
+ "parmasean": "food",
658
+ "broiled": "food",
659
+ "stuff": "food",
660
+ "earthy": "other",
661
+ "frites": "food",
662
+ "hawaiian": "food",
663
+ "tamales": "food",
664
+ "cluding": "other",
665
+ "order": "service",
666
+ "sardines": "food",
667
+ "skin": "food",
668
+ "cigar": "atmosphere",
669
+ "district": "other",
670
+ "joint": "other",
671
+ "pinot": "drinks",
672
+ "barbecued": "food",
673
+ "dim": "atmosphere",
674
+ "polenta": "food",
675
+ "eateries": "other",
676
+ "terrine": "food",
677
+ "slice": "food",
678
+ "busboy": "service",
679
+ "scallop": "food",
680
+ "lobby": "atmosphere",
681
+ "seviche": "food",
682
+ "mirrors": "atmosphere",
683
+ "bakery": "food",
684
+ "rasberry": "food",
685
+ "frozen": "food",
686
+ "serving": "service",
687
+ "brasserie": "food",
688
+ "role": "other",
689
+ "category": "other",
690
+ "balls": "food",
691
+ "pepper": "food",
692
+ "range": "other",
693
+ "course": "food",
694
+ "lentil": "food",
695
+ "beverages": "drinks",
696
+ "noise": "atmosphere",
697
+ "app": "other",
698
+ "nachos": "food",
699
+ "seasoning": "food",
700
+ "kobe": "food",
701
+ "bagels": "food",
702
+ "varieties": "other",
703
+ "suace": "food",
704
+ "appropriately": "other",
705
+ "appys": "food",
706
+ "abijah": "other",
707
+ "exotic": "food",
708
+ "dj": "atmosphere",
709
+ "olive": "food",
710
+ "citrus": "food",
711
+ "country": "other",
712
+ "establishment": "other",
713
+ "decker": "other",
714
+ "banquet": "other",
715
+ "chard": "food",
716
+ "smoky": "other",
717
+ "sandwhich": "food",
718
+ "soupy": "food",
719
+ "walls": "atmosphere",
720
+ "california": "other",
721
+ "edamame": "food",
722
+ "resting": "other",
723
+ "tomatoes": "food",
724
+ "cooks": "service",
725
+ "chai": "drinks",
726
+ "glasses": "other",
727
+ "onions": "food",
728
+ "gold": "other",
729
+ "tortelli": "food",
730
+ "bloom": "other",
731
+ "closed": "other",
732
+ "bowl": "food",
733
+ "nostalgia": "other",
734
+ "scrambled": "food",
735
+ "time": "other",
736
+ "triple": "other",
737
+ "papaya": "food",
738
+ "busboys": "service",
739
+ "single": "other",
740
+ "msg": "food",
741
+ "concoction": "food",
742
+ "calves": "food",
743
+ "captain": "service",
744
+ "mint": "food",
745
+ "detail": "other",
746
+ "champagne": "drinks",
747
+ "chop": "food",
748
+ "bacon": "food",
749
+ "cooking": "other",
750
+ "dulce": "food",
751
+ "un": "other",
752
+ "candles": "atmosphere",
753
+ "door": "other",
754
+ "drunken": "food",
755
+ "guac": "food",
756
+ "eel": "food",
757
+ "slicked": "other",
758
+ "dumpling": "food",
759
+ "broth": "food",
760
+ "otoro": "food",
761
+ "secret": "other",
762
+ "2": "other",
763
+ "moules": "food",
764
+ "cup": "drinks",
765
+ "full": "other",
766
+ "yams": "food",
767
+ "answer": "other",
768
+ "skewers": "food",
769
+ "lox": "food",
770
+ "halibut": "food",
771
+ "enough": "other",
772
+ "sheet": "other",
773
+ "mexican": "food",
774
+ "fry": "food",
775
+ "bathrooms": "other",
776
+ "game": "food",
777
+ "darling": "other",
778
+ "jelly": "food",
779
+ "local": "other",
780
+ "piano": "atmosphere",
781
+ "cosmos": "drinks",
782
+ "filo": "food",
783
+ "polish": "food",
784
+ "seating": "atmosphere",
785
+ "ceviche": "food",
786
+ "cannoli": "food",
787
+ "versatile": "other",
788
+ "brick": "other",
789
+ "decore": "other",
790
+ "oven": "other",
791
+ "toe": "other",
792
+ "deluxe": "other",
793
+ "clientelle": "other",
794
+ "terrace": "atmosphere",
795
+ "salsa": "food",
796
+ "hummus": "food",
797
+ "attitudes": "service",
798
+ "color": "other",
799
+ "leche": "food",
800
+ "beats": "other",
801
+ "furnishings": "atmosphere",
802
+ "spread": "food",
803
+ "peppers": "food",
804
+ "coat": "other",
805
+ "whisper": "atmosphere",
806
+ "chole": "food",
807
+ "presentation": "other",
808
+ "deep": "other",
809
+ "desert": "food",
810
+ "cream": "food",
811
+ "buffet": "food",
812
+ "frisee": "food",
813
+ "speck": "food",
814
+ "diners": "service",
815
+ "individual": "other",
816
+ "front": "other",
817
+ "environment": "atmosphere",
818
+ "beet": "food",
819
+ "spring": "other",
820
+ "marina": "other",
821
+ "marinated": "food",
822
+ "tabs": "other",
823
+ "sardinian": "food",
824
+ "check": "other",
825
+ "squid": "food",
826
+ "bass": "food",
827
+ "clams": "food",
828
+ "beginning": "other",
829
+ "sinatra": "other",
830
+ "diner": "food",
831
+ "tic": "other",
832
+ "backyard": "other",
833
+ "tomato": "food",
834
+ "steamed": "food",
835
+ "per": "other",
836
+ "breast": "food",
837
+ "chips": "food",
838
+ "brown": "other",
839
+ "sommelier": "service",
840
+ "servings": "food",
841
+ "pineapple": "food",
842
+ "shirted": "other",
843
+ "oysters": "food",
844
+ "salt": "food",
845
+ "fragrant": "other",
846
+ "dhal": "food",
847
+ "pleasures": "other",
848
+ "seat": "service",
849
+ "appetites": "food",
850
+ "stained": "other",
851
+ "samosas": "food",
852
+ "ceiling": "atmosphere",
853
+ "escabeche": "food",
854
+ "crowds": "atmosphere",
855
+ "club": "other",
856
+ "bruschetta": "food",
857
+ "family": "other",
858
+ "poached": "food",
859
+ "crew": "service",
860
+ "temperature": "other",
861
+ "influence": "other",
862
+ "plantains": "food",
863
+ "suggestion": "other",
864
+ "pico": "food",
865
+ "bagel": "food",
866
+ "melt": "food",
867
+ "bountiful": "other",
868
+ "drop": "other",
869
+ "maitre": "service",
870
+ "artworks": "atmosphere",
871
+ "sicilian": "food",
872
+ "alternative": "other",
873
+ "spot": "other",
874
+ "kaiseki": "food",
875
+ "pompous": "other",
876
+ "comfort": "other",
877
+ "american": "food",
878
+ "tap": "drinks",
879
+ "ribbon": "other",
880
+ "guy": "other",
881
+ "customer": "service",
882
+ "kumquat": "food",
883
+ "mix": "other",
884
+ "brioche": "food",
885
+ "souffle": "food",
886
+ "knife": "other",
887
+ "soda": "drinks",
888
+ "nelson": "other",
889
+ "faced": "other",
890
+ "sum": "other",
891
+ "crowd": "atmosphere",
892
+ "summer": "other",
893
+ "holiday": "other",
894
+ "freaking": "other",
895
+ "waiter": "service",
896
+ "waitress": "service",
897
+ "manager": "service",
898
+ "servers": "service"
899
+ }
models/BERT_LSTM_CRF.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #from transformers import BertPreTrainedModel, BertForSequenceClassification, BertModel
2
+ from transformers import AutoModel, PreTrainedModel
3
+ from transformers.modeling_outputs import TokenClassifierOutput
4
+ from torch import nn
5
+ from torch.nn import CrossEntropyLoss
6
+ import torch
7
+ from .layers import CRF
8
+ from itertools import islice
9
+
10
+ NUM_PER_LAYER = 16
11
+
12
+ class BERTLstmCRF(PreTrainedModel):
13
+ _keys_to_ignore_on_load_unexpected = [r"pooler"]
14
+
15
+ def __init__(self, config):
16
+ super().__init__(config)
17
+ print(config)
18
+ self.num_labels = config.num_labels
19
+ self.bert = AutoModel.from_pretrained(config._name_or_path, config=config, add_pooling_layer=False)
20
+ classifier_dropout = (config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob)
21
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
22
+ self.bilstm = nn.LSTM(config.hidden_size, (config.hidden_size) // 2, batch_first=True, bidirectional=True)
23
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
24
+ self.crf = CRF(num_tags=config.num_labels, batch_first=True)
25
+
26
+ if self.config.freeze == True:
27
+ self.manage_freezing()
28
+
29
+ #self.bert.init_weights() # load pretrained weights
30
+
31
+ def manage_freezing(self):
32
+ for _, param in self.bert.embeddings.named_parameters():
33
+ param.requires_grad = False
34
+
35
+ num_encoders_to_freeze = self.config.num_frozen_encoder
36
+ if num_encoders_to_freeze > 0:
37
+ for _, param in islice(self.bert.encoder.named_parameters(), num_encoders_to_freeze*NUM_PER_LAYER):
38
+ param.requires_grad = False
39
+
40
+
41
+ def forward(self,
42
+ input_ids=None,
43
+ attention_mask=None,
44
+ token_type_ids=None,
45
+ position_ids=None,
46
+ head_mask=None,
47
+ inputs_embeds=None,
48
+ labels=None,
49
+ output_attentions=None,
50
+ output_hidden_states=None,
51
+ return_dict=None
52
+ ):
53
+ # Default `model.config.use_return_dict´ is `True´
54
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
55
+
56
+ outputs = self.bert(input_ids,
57
+ attention_mask=attention_mask,
58
+ token_type_ids=token_type_ids,
59
+ position_ids=position_ids,
60
+ head_mask=head_mask,
61
+ inputs_embeds=inputs_embeds,
62
+ output_attentions=output_attentions,
63
+ output_hidden_states=output_hidden_states,
64
+ return_dict=return_dict)
65
+
66
+ sequence_output = outputs[0]
67
+ sequence_output = self.dropout(sequence_output)
68
+ lstm_output, hc = self.bilstm(sequence_output)
69
+ logits = self.classifier(lstm_output)
70
+
71
+ loss = None
72
+ if labels is not None:
73
+ # During train/test as we don't pass labels during inference
74
+ loss = -1 * self.crf(logits, labels)
75
+
76
+ tags = torch.Tensor(self.crf.decode(logits))
77
+
78
+ return loss, tags
models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .BERT_LSTM_CRF import BERTLstmCRF
models/layers/CRF.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Taken from https://github.com/kmkurn/pytorch-crf/blob/master/torchcrf/__init__.py and fixed got uint8 warning
2
+ __version__ = '0.7.2'
3
+
4
+ from typing import List, Optional
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+
9
+ LARGE_NEGATIVE_NUMBER = -1e9
10
+
11
+ class CRF(nn.Module):
12
+ """Conditional random field.
13
+ This module implements a conditional random field [LMP01]_. The forward computation
14
+ of this class computes the log likelihood of the given sequence of tags and
15
+ emission score tensor. This class also has `~CRF.decode` method which finds
16
+ the best tag sequence given an emission score tensor using `Viterbi algorithm`_.
17
+ Args:
18
+ num_tags: Number of tags.
19
+ batch_first: Whether the first dimension corresponds to the size of a minibatch.
20
+ Attributes:
21
+ start_transitions (`~torch.nn.Parameter`): Start transition score tensor of size
22
+ ``(num_tags,)``.
23
+ end_transitions (`~torch.nn.Parameter`): End transition score tensor of size
24
+ ``(num_tags,)``.
25
+ transitions (`~torch.nn.Parameter`): Transition score tensor of size
26
+ ``(num_tags, num_tags)``.
27
+ .. [LMP01] Lafferty, J., McCallum, A., Pereira, F. (2001).
28
+ "Conditional random fields: Probabilistic models for segmenting and
29
+ labeling sequence data". *Proc. 18th International Conf. on Machine
30
+ Learning*. Morgan Kaufmann. pp. 282–289.
31
+ .. _Viterbi algorithm: https://en.wikipedia.org/wiki/Viterbi_algorithm
32
+ """
33
+
34
+ def __init__(self, num_tags: int, batch_first: bool = False) -> None:
35
+ if num_tags <= 0:
36
+ raise ValueError(f'invalid number of tags: {num_tags}')
37
+ super().__init__()
38
+ self.num_tags = num_tags
39
+ self.batch_first = batch_first
40
+ self.start_transitions = nn.Parameter(torch.empty(num_tags))
41
+ self.end_transitions = nn.Parameter(torch.empty(num_tags))
42
+ self.transitions = nn.Parameter(torch.empty(num_tags, num_tags))
43
+
44
+ self.reset_parameters()
45
+ self.mask_impossible_transitions()
46
+
47
+ def reset_parameters(self) -> None:
48
+ """Initialize the transition parameters.
49
+ The parameters will be initialized randomly from a uniform distribution
50
+ between -0.1 and 0.1.
51
+ """
52
+ nn.init.uniform_(self.start_transitions, -0.1, 0.1)
53
+ nn.init.uniform_(self.end_transitions, -0.1, 0.1)
54
+ nn.init.uniform_(self.transitions, -0.1, 0.1)
55
+
56
+ def mask_impossible_transitions(self) -> None:
57
+ """Set the value of impossible transitions to LARGE_NEGATIVE_NUMBER
58
+ - start transition value of I-X
59
+ - transition score of O -> I
60
+ """
61
+ with torch.no_grad():
62
+ self.start_transitions[2] = LARGE_NEGATIVE_NUMBER
63
+ self.start_transitions[4] = LARGE_NEGATIVE_NUMBER
64
+ self.start_transitions[6] = LARGE_NEGATIVE_NUMBER
65
+
66
+ self.transitions[0][2] = LARGE_NEGATIVE_NUMBER
67
+ self.transitions[0][4] = LARGE_NEGATIVE_NUMBER
68
+ self.transitions[0][6] = LARGE_NEGATIVE_NUMBER
69
+ self.transitions[1][4] = LARGE_NEGATIVE_NUMBER
70
+ self.transitions[1][6] = LARGE_NEGATIVE_NUMBER
71
+ self.transitions[2][4] = LARGE_NEGATIVE_NUMBER
72
+ self.transitions[2][6] = LARGE_NEGATIVE_NUMBER
73
+ self.transitions[3][2] = LARGE_NEGATIVE_NUMBER
74
+ self.transitions[3][6] = LARGE_NEGATIVE_NUMBER
75
+ self.transitions[4][2] = LARGE_NEGATIVE_NUMBER
76
+ self.transitions[4][6] = LARGE_NEGATIVE_NUMBER
77
+ self.transitions[5][2] = LARGE_NEGATIVE_NUMBER
78
+ self.transitions[5][4] = LARGE_NEGATIVE_NUMBER
79
+ self.transitions[6][2] = LARGE_NEGATIVE_NUMBER
80
+ self.transitions[6][4] = LARGE_NEGATIVE_NUMBER
81
+
82
+ def __repr__(self) -> str:
83
+ return f'{self.__class__.__name__}(num_tags={self.num_tags})'
84
+
85
+ def forward(
86
+ self,
87
+ emissions: torch.Tensor,
88
+ tags: torch.LongTensor,
89
+ mask: Optional[torch.ByteTensor] = None,
90
+ reduction: str = 'sum',
91
+ ) -> torch.Tensor:
92
+ """Compute the conditional log likelihood of a sequence of tags given emission scores.
93
+ Args:
94
+ emissions (`~torch.Tensor`): Emission score tensor of size
95
+ ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``,
96
+ ``(batch_size, seq_length, num_tags)`` otherwise.
97
+ tags (`~torch.LongTensor`): Sequence of tags tensor of size
98
+ ``(seq_length, batch_size)`` if ``batch_first`` is ``False``,
99
+ ``(batch_size, seq_length)`` otherwise.
100
+ mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)``
101
+ if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise.
102
+ reduction: Specifies the reduction to apply to the output:
103
+ ``none|sum|mean|token_mean``. ``none``: no reduction will be applied.
104
+ ``sum``: the output will be summed over batches. ``mean``: the output will be
105
+ averaged over batches. ``token_mean``: the output will be averaged over tokens.
106
+ Returns:
107
+ `~torch.Tensor`: The log likelihood. This will have size ``(batch_size,)`` if
108
+ reduction is ``none``, ``()`` otherwise.
109
+ """
110
+ #self.mask_impossible_transitions()
111
+ self._validate(emissions, tags=tags, mask=mask)
112
+ if reduction not in ('none', 'sum', 'mean', 'token_mean'):
113
+ raise ValueError(f'invalid reduction: {reduction}')
114
+ if mask is None:
115
+ mask = torch.ones_like(tags, dtype=torch.uint8)
116
+
117
+ if self.batch_first:
118
+ emissions = emissions.transpose(0, 1)
119
+ tags = tags.transpose(0, 1)
120
+ mask = mask.transpose(0, 1)
121
+
122
+ # shape: (batch_size,)
123
+ numerator = self._compute_score(emissions, tags, mask)
124
+ # shape: (batch_size,)
125
+ denominator = self._compute_normalizer(emissions, mask)
126
+ # shape: (batch_size,)
127
+ llh = numerator - denominator
128
+
129
+ if reduction == 'none':
130
+ return llh
131
+ if reduction == 'sum':
132
+ return llh.sum()
133
+ if reduction == 'mean':
134
+ return llh.mean()
135
+ assert reduction == 'token_mean'
136
+ return llh.sum() / mask.type_as(emissions).sum()
137
+
138
+ def decode(self, emissions: torch.Tensor,
139
+ mask: Optional[torch.ByteTensor] = None) -> List[List[int]]:
140
+ """Find the most likely tag sequence using Viterbi algorithm.
141
+ Args:
142
+ emissions (`~torch.Tensor`): Emission score tensor of size
143
+ ``(seq_length, batch_size, num_tags)`` if ``batch_first`` is ``False``,
144
+ ``(batch_size, seq_length, num_tags)`` otherwise.
145
+ mask (`~torch.ByteTensor`): Mask tensor of size ``(seq_length, batch_size)``
146
+ if ``batch_first`` is ``False``, ``(batch_size, seq_length)`` otherwise.
147
+ Returns:
148
+ List of list containing the best tag sequence for each batch.
149
+ """
150
+ self._validate(emissions, mask=mask)
151
+ if mask is None:
152
+ mask = emissions.new_ones(emissions.shape[:2], dtype=torch.uint8)
153
+
154
+ if self.batch_first:
155
+ emissions = emissions.transpose(0, 1)
156
+ mask = mask.transpose(0, 1)
157
+
158
+ return self._viterbi_decode(emissions, mask)
159
+
160
+ def _validate(
161
+ self,
162
+ emissions: torch.Tensor,
163
+ tags: Optional[torch.LongTensor] = None,
164
+ mask: Optional[torch.ByteTensor] = None) -> None:
165
+ if emissions.dim() != 3:
166
+ raise ValueError(f'emissions must have dimension of 3, got {emissions.dim()}')
167
+ if emissions.size(2) != self.num_tags:
168
+ raise ValueError(
169
+ f'expected last dimension of emissions is {self.num_tags}, '
170
+ f'got {emissions.size(2)}')
171
+
172
+ if tags is not None:
173
+ if emissions.shape[:2] != tags.shape:
174
+ raise ValueError(
175
+ 'the first two dimensions of emissions and tags must match, '
176
+ f'got {tuple(emissions.shape[:2])} and {tuple(tags.shape)}')
177
+
178
+ if mask is not None:
179
+ if emissions.shape[:2] != mask.shape:
180
+ raise ValueError(
181
+ 'the first two dimensions of emissions and mask must match, '
182
+ f'got {tuple(emissions.shape[:2])} and {tuple(mask.shape)}')
183
+ no_empty_seq = not self.batch_first and mask[0].all()
184
+ no_empty_seq_bf = self.batch_first and mask[:, 0].all()
185
+ if not no_empty_seq and not no_empty_seq_bf:
186
+ raise ValueError('mask of the first timestep must all be on')
187
+
188
+ def _compute_score(
189
+ self, emissions: torch.Tensor, tags: torch.LongTensor,
190
+ mask: torch.ByteTensor) -> torch.Tensor:
191
+ # emissions: (seq_length, batch_size, num_tags)
192
+ # tags: (seq_length, batch_size)
193
+ # mask: (seq_length, batch_size)
194
+ assert emissions.dim() == 3 and tags.dim() == 2
195
+ assert emissions.shape[:2] == tags.shape
196
+ assert emissions.size(2) == self.num_tags
197
+ assert mask.shape == tags.shape
198
+ assert mask[0].all()
199
+
200
+ seq_length, batch_size = tags.shape
201
+ mask = mask.type_as(emissions)
202
+
203
+ # Start transition score and first emission
204
+ # shape: (batch_size,)
205
+ score = self.start_transitions[tags[0]]
206
+ score += emissions[0, torch.arange(batch_size), tags[0]]
207
+
208
+ for i in range(1, seq_length):
209
+ # Transition score to next tag, only added if next timestep is valid (mask == 1)
210
+ # shape: (batch_size,)
211
+ score += self.transitions[tags[i - 1], tags[i]] * mask[i]
212
+
213
+ # Emission score for next tag, only added if next timestep is valid (mask == 1)
214
+ # shape: (batch_size,)
215
+ score += emissions[i, torch.arange(batch_size), tags[i]] * mask[i]
216
+
217
+ # End transition score
218
+ # shape: (batch_size,)
219
+ seq_ends = mask.long().sum(dim=0) - 1
220
+ # shape: (batch_size,)
221
+ last_tags = tags[seq_ends, torch.arange(batch_size)]
222
+ # shape: (batch_size,)
223
+ score += self.end_transitions[last_tags]
224
+
225
+ return score
226
+
227
+ def _compute_normalizer(
228
+ self, emissions: torch.Tensor, mask: torch.ByteTensor) -> torch.Tensor:
229
+ # emissions: (seq_length, batch_size, num_tags)
230
+ # mask: (seq_length, batch_size)
231
+ assert emissions.dim() == 3 and mask.dim() == 2
232
+ assert emissions.shape[:2] == mask.shape
233
+ assert emissions.size(2) == self.num_tags
234
+ assert mask[0].all()
235
+
236
+ seq_length = emissions.size(0)
237
+
238
+ # Start transition score and first emission; score has size of
239
+ # (batch_size, num_tags) where for each batch, the j-th column stores
240
+ # the score that the first timestep has tag j
241
+ # shape: (batch_size, num_tags)
242
+ score = self.start_transitions + emissions[0]
243
+
244
+ for i in range(1, seq_length):
245
+ # Broadcast score for every possible next tag
246
+ # shape: (batch_size, num_tags, 1)
247
+ broadcast_score = score.unsqueeze(2)
248
+
249
+ # Broadcast emission score for every possible current tag
250
+ # shape: (batch_size, 1, num_tags)
251
+ broadcast_emissions = emissions[i].unsqueeze(1)
252
+
253
+ # Compute the score tensor of size (batch_size, num_tags, num_tags) where
254
+ # for each sample, entry at row i and column j stores the sum of scores of all
255
+ # possible tag sequences so far that end with transitioning from tag i to tag j
256
+ # and emitting
257
+ # shape: (batch_size, num_tags, num_tags)
258
+ next_score = broadcast_score + self.transitions + broadcast_emissions
259
+
260
+ # Sum over all possible current tags, but we're in score space, so a sum
261
+ # becomes a log-sum-exp: for each sample, entry i stores the sum of scores of
262
+ # all possible tag sequences so far, that end in tag i
263
+ # shape: (batch_size, num_tags)
264
+ next_score = torch.logsumexp(next_score, dim=1)
265
+
266
+ # Set score to the next score if this timestep is valid (mask == 1)
267
+ # shape: (batch_size, num_tags)
268
+ score = torch.where(mask[i].unsqueeze(1).bool(), next_score, score)
269
+
270
+ # End transition score
271
+ # shape: (batch_size, num_tags)
272
+ score += self.end_transitions
273
+
274
+ # Sum (log-sum-exp) over all possible tags
275
+ # shape: (batch_size,)
276
+ return torch.logsumexp(score, dim=1)
277
+
278
+ def _viterbi_decode(self, emissions: torch.FloatTensor,
279
+ mask: torch.ByteTensor) -> List[List[int]]:
280
+ # emissions: (seq_length, batch_size, num_tags)
281
+ # mask: (seq_length, batch_size)
282
+ assert emissions.dim() == 3 and mask.dim() == 2
283
+ assert emissions.shape[:2] == mask.shape
284
+ assert emissions.size(2) == self.num_tags
285
+ assert mask[0].all()
286
+
287
+ seq_length, batch_size = mask.shape
288
+
289
+ # Start transition and first emission
290
+ # shape: (batch_size, num_tags)
291
+ score = self.start_transitions + emissions[0]
292
+ history = []
293
+
294
+ # score is a tensor of size (batch_size, num_tags) where for every batch,
295
+ # value at column j stores the score of the best tag sequence so far that ends
296
+ # with tag j
297
+ # history saves where the best tags candidate transitioned from; this is used
298
+ # when we trace back the best tag sequence
299
+
300
+ # Viterbi algorithm recursive case: we compute the score of the best tag sequence
301
+ # for every possible next tag
302
+ for i in range(1, seq_length):
303
+ # Broadcast viterbi score for every possible next tag
304
+ # shape: (batch_size, num_tags, 1)
305
+ broadcast_score = score.unsqueeze(2)
306
+
307
+ # Broadcast emission score for every possible current tag
308
+ # shape: (batch_size, 1, num_tags)
309
+ broadcast_emission = emissions[i].unsqueeze(1)
310
+
311
+ # Compute the score tensor of size (batch_size, num_tags, num_tags) where
312
+ # for each sample, entry at row i and column j stores the score of the best
313
+ # tag sequence so far that ends with transitioning from tag i to tag j and emitting
314
+ # shape: (batch_size, num_tags, num_tags)
315
+ next_score = broadcast_score + self.transitions + broadcast_emission
316
+
317
+ # Find the maximum score over all possible current tag
318
+ # shape: (batch_size, num_tags)
319
+ next_score, indices = next_score.max(dim=1)
320
+
321
+ # Set score to the next score if this timestep is valid (mask == 1)
322
+ # and save the index that produces the next score
323
+ # shape: (batch_size, num_tags)
324
+ score = torch.where(mask[i].unsqueeze(1).bool(), next_score, score)
325
+ history.append(indices)
326
+
327
+ # End transition score
328
+ # shape: (batch_size, num_tags)
329
+ score += self.end_transitions
330
+
331
+ # Now, compute the best path for each sample
332
+
333
+ # shape: (batch_size,)
334
+ seq_ends = mask.long().sum(dim=0) - 1
335
+ best_tags_list = []
336
+
337
+ for idx in range(batch_size):
338
+ # Find the tag which maximizes the score at the last timestep; this is our best tag
339
+ # for the last timestep
340
+ _, best_last_tag = score[idx].max(dim=0)
341
+ best_tags = [best_last_tag.item()]
342
+
343
+ # We trace back where the best last tag comes from, append that to our best tag
344
+ # sequence, and trace it back again, and so on
345
+ for hist in reversed(history[:seq_ends[idx]]):
346
+ best_last_tag = hist[idx][best_tags[-1]]
347
+ best_tags.append(best_last_tag.item())
348
+
349
+ # Reverse the order because we start from the last timestep
350
+ best_tags.reverse()
351
+ best_tags_list.append(best_tags)
352
+
353
+ return best_tags_list
models/layers/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .CRF import CRF
models/layers/__pycache__/CRF.cpython-310.pyc ADDED
Binary file (9.37 kB). View file
 
models/layers/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (175 Bytes). View file