import torch from transformers import AutoModel, AutoTokenizer from underthesea import word_tokenize import __main__ #phobert = AutoModel.from_pretrained("vinai/phobert-base") tokenizer = AutoTokenizer.from_pretrained("./bert/bert_tokenizer") class PhoBertModel(torch.nn.Module): def __init__(self): super(PhoBertModel, self).__init__() self.bert = phobert self.pre_classifier = torch.nn.Linear(self.bert.config.hidden_size, self.bert.config.hidden_size) self.dropout = torch.nn.Dropout(0.1) self.classifier = torch.nn.Linear(self.bert.config.hidden_size, 6) def forward(self, input_ids, attention_mask, token_type_ids): hidden_state, output_1 = self.bert( input_ids = input_ids, attention_mask=attention_mask, return_dict = False ) pooler = self.pre_classifier(output_1) activation_1 = torch.nn.Tanh()(pooler) drop = self.dropout(activation_1) output_2 = self.classifier(drop) # activation_2 = torch.nn.Tanh()(output_2) output = torch.nn.Sigmoid()(output_2) return output setattr(__main__, "PhoBertModel", PhoBertModel) def getModel(): model = torch.load('./bert/phoBertModel.pth', map_location=torch.device('cpu')) model.eval() return model model = getModel() def tokenize(data): max_length = 200 for line in data: token = tokenizer.encode_plus( line, max_length=200, add_special_tokens=False, pad_to_max_length=True ) ids = torch.tensor([token['input_ids']]) mask = torch.tensor([token['attention_mask']]) token_type_ids = torch.tensor([token['token_type_ids']]) output = { 'ids': ids, 'mask': mask, 'token_type_ids': token_type_ids, } #outputs.append(output) return output def BERT_predict(text): text = [text] token = tokenize(text) ids = token['ids'] mask = token['mask'] token_type_ids = token['token_type_ids'] result = model(ids, mask, token_type_ids) # print(result) return result.tolist()[0] print(BERT_predict("xin chaof")) print(BERT_predict("con chó")) print(BERT_predict("đồ chó")) print(BERT_predict("đồ ngu")) print(BERT_predict("cái lồn")) print(BERT_predict("óc chó")) print(BERT_predict("đồ chó đẻ")) print(BERT_predict("con đĩ"))