File size: 4,688 Bytes
584d3dc
 
 
 
 
 
 
5d3f566
8abf40a
7a7564e
5d3f566
8abf40a
584d3dc
 
 
 
 
 
 
 
 
 
 
 
 
8096a84
584d3dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8096a84
584d3dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c4d3f75
 
584d3dc
 
 
bc714a8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import gradio as gr
import torch
import json
from nltk.corpus import wordnet
from transformers import AutoConfig, AutoTokenizer
from models import BERTLstmCRF
from huggingface_hub import hf_hub_download
import os
import nltk

os.system("python -m nltk.downloader all")

checkpoint = "gundruke/bert-lstm-crf-absa"
config = AutoConfig.from_pretrained(checkpoint)
id2label = config.id2label

tokenizer = AutoTokenizer.from_pretrained("gundruke/bert-lstm-crf-absa")
model = BERTLstmCRF(config)


repo = "gundruke/bert-lstm-crf-absa"
filename = "pytorch_model.bin"
model.load_state_dict(torch.load(hf_hub_download(repo_id=repo, filename=filename), 
                                 map_location=torch.device('cpu')))

dictionary_file_path = hf_hub_download(repo_id=repo, filename="dictionary.json")

def tokenize_text(text):
    tokens = tokenizer.tokenize(text)
    tokenized_text = tokenizer(text)
    
    return tokens, tokenized_text


def convert_to_multilabel(label_list): 
    multilabel = []
    if "B-POS" in label_list or "I-POS" in label_list:
        multilabel.append("Positive")
    if "B-NEG" in label_list or "I-NEG" in label_list:
        multilabel.append("Negative")
    if "B-NEU" in label_list or "I-NEU" in label_list:
        multilabel.append("Neutral")

    return " and ".join(multilabel)


def classify_word(word, dictionary):
    synsets = wordnet.synsets(word)
    if synsets:
        hypernyms = synsets[0].hypernyms()  # Get the hypernym of the first synset
        if hypernyms:
            nltk_result = hypernyms[0].lemmas()[0].name()
        else:
            nltk_result = "Unknown"
    else:
        nltk_result = "Unknown"
    
    if word in dictionary:
        result = dictionary[word]
    elif nltk_result in ['atmosphere', 'drinks', 'food', 'price', 'service']:
        result = nltk_result
    else:
        result = 'other'

    return result, nltk_result


def get_outputs(tokenized_text):
    input_ids = tokenized_text["input_ids"]
    token_type_ids = tokenized_text["token_type_ids"]
    attention_mask = tokenized_text["attention_mask"]
    
    inputs = {
        'input_ids': torch.tensor([input_ids]),
        'token_type_ids': torch.tensor([token_type_ids]),
        'attention_mask': torch.tensor([attention_mask])
    }
    
    with torch.no_grad():
        outputs = model(**inputs)
        
    labels = [id2label.get(i) for i in torch.flatten(outputs[1]).tolist()][1:-1]
    
    return labels


def join_wordpieces(tokens, labels):
    joined_tokens =  []
    
    for token, label in zip(tokens, labels):
        if label == "O":
            label = None
        if token.startswith("##"):
            last_token = joined_tokens[-1][0]
            joined_tokens[-1] = (last_token+token[2:], label)
        else:
            joined_tokens.append((token, label))

    return joined_tokens


def get_category(word, dict_file):
    with open(dict_file, "r") as file:
        dictionary = json.load(file)
    
    r, n = classify_word(word, dictionary)
    
    return r


def text_analysis(text):
    tokens, tokenized_text = tokenize_text(text)
    labels = get_outputs(tokenized_text)
    multilabel = convert_to_multilabel(labels)
    
    token_tuple = join_wordpieces(tokens, labels)    
    tokenized_text["tokens"] = tokens
    
    categories = []
    for tok in token_tuple:
        if tok[1]:
            categories.append((tok[0], get_category(tok[0], dictionary_file_path)))
        else:
            categories.append((tok[0], None))
    
    
    
    
    return token_tuple, multilabel, categories


theme = gr.themes.Base()
with gr.Blocks(theme=theme) as demo:
    with gr.Column():
        input_textbox = gr.Textbox(placeholder="Enter sentence here...")
        btn = gr.Button("Submit", variant="primary")

        btn.click(fn=text_analysis,
                  inputs=input_textbox,
                    outputs=[gr.HighlightedText(label="Token labels"),
                             gr.Label(label="Multilabel classification"),
                             gr.HighlightedText(label="Category")],
                    queue=False)
    
    with gr.Column():
        examples=[
            ["I've been coming here as a child and always come back for the taste."],
            ["The tea is great and all the sweets are homemade."],
            ["Strong build which really adds to its durability but poor battery life."],
            ["We loved the recommendation for the wine, and I think the eggplant parmigiana appetizer should become an entree."],
            ["chicken pasta was tasty, wine was super nice but waiter was rude."]
            ]
        gr.Examples(examples, input_textbox)

demo.launch(debug=True)