ChineseBERT-for-csc / csc_model.py
iioSnail's picture
Upload 14 files
ae509ea
raw history blame
No virus
23.1 kB
import json
import os
import shutil
import time
from pathlib import Path
from typing import List
import numpy as np
import torch
from huggingface_hub import hf_hub_download
from huggingface_hub.file_download import http_user_agent
from torch import nn
from torch.nn import functional as F
from transformers import BertPreTrainedModel, BertModel
from transformers.modeling_outputs import MaskedLMOutput, BaseModelOutputWithPooling
from transformers.models.bert.modeling_bert import BertEncoder, BertPooler, BertLMPredictionHead
cache_path = Path(os.path.abspath(__file__)).parent
def download_file(filename: str, path: Path):
if os.path.exists(cache_path / filename):
return
if os.path.exists(path / filename):
shutil.copyfile(path / filename, cache_path / filename)
return
hf_hub_download(
"iioSnail/ChineseBERT-for-csc",
filename,
local_dir=cache_path,
user_agent=http_user_agent(),
)
time.sleep(0.2)
class ChineseBertForCSC(BertPreTrainedModel):
def __init__(self, config):
super(ChineseBertForCSC, self).__init__(config)
self.model = Dynamic_GlyceBertForMultiTask(config)
self.tokenizer = None
def forward(self, **kwargs):
return self.model(**kwargs)
def set_tokenizer(self, tokenizer):
self.tokenizer = tokenizer
def _predict(self, sentence):
if self.tokenizer is None:
return "Please init tokenizer by `set_tokenizer(tokenizer)` before predict."
inputs = self.tokenizer([sentence], return_tensors='pt')
output_hidden = self.model(**inputs).logits
return self.tokenizer.convert_ids_to_tokens(output_hidden.argmax(-1)[0, 1:-1])
def predict(self, sentence, window=1):
_src_tokens = list(sentence)
src_tokens = list(sentence)
pred_tokens = self._predict(sentence)
for _ in range(window):
record_index = []
for i, (a, b) in enumerate(zip(src_tokens, pred_tokens)):
if a != b:
record_index.append(i)
src_tokens = pred_tokens
pred_tokens = self._predict(''.join(pred_tokens))
for i, (a, b) in enumerate(zip(src_tokens, pred_tokens)):
# 若这个token被修改了,且在窗口范围内,则什么都不做。
if a != b and any([abs(i - x) <= 1 for x in record_index]):
pass
else:
pred_tokens[i] = src_tokens[i]
return ''.join(pred_tokens)
#################################ChineseBERT Source Code##############################################
class Dynamic_GlyceBertForMultiTask(BertPreTrainedModel):
def __init__(self, config):
super(Dynamic_GlyceBertForMultiTask, self).__init__(config)
self.bert = GlyceBertModel(config)
self.cls = MultiTaskHeads(config)
def get_output_embeddings(self):
return self.cls.predictions.decoder
def forward(
self,
input_ids=None,
pinyin_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
**kwargs
):
assert kwargs == {}, f"Unexpected keyword arguments: {list(kwargs.keys())}."
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs_x = self.bert(
input_ids,
pinyin_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
encoded_x = outputs_x[0]
prediction_scores = self.cls(encoded_x)
return MaskedLMOutput(
logits=prediction_scores,
hidden_states=outputs_x.hidden_states,
attentions=outputs_x.attentions,
)
class GlyceBertModel(BertModel):
r"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**last_hidden_state**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, hidden_size)``
Sequence of hidden-states at the output of the last layer of the models.
**pooler_output**: ``torch.FloatTensor`` of shape ``(batch_size, hidden_size)``
Last layer hidden-state of the first token of the sequence (classification token)
further processed by a Linear layer and a Tanh activation function. The Linear
layer weights are trained from the next sentence prediction (classification)
objective during Bert pretraining. This output is usually *not* a good summary
of the semantic content of the input, you're often better with averaging or pooling
the sequence of hidden-states for the whole input sequence.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the models at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
models = BertModel.from_pretrained('bert-base-uncased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0) # Batch size 1
outputs = models(input_ids)
last_hidden_states = outputs[0] # The last hidden-state is the first element of the output tuple
"""
def __init__(self, config):
super(GlyceBertModel, self).__init__(config)
self.config = config
self.embeddings = FusionBertEmbeddings(config)
self.encoder = BertEncoder(config)
self.pooler = BertPooler(config)
self.init_weights()
def forward(
self,
input_ids=None,
pinyin_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
r"""
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
if the models is configured as a decoder.
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask
is used in the cross-attention if the models is configured as a decoder.
Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
embedding_output = self.embeddings(
input_ids=input_ids, pinyin_ids=pinyin_ids, position_ids=position_ids, token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds
)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
def forward_with_embedding(
self,
input_ids=None,
pinyin_ids=None,
attention_mask=None,
token_type_ids=None,
position_ids=None,
head_mask=None,
inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
embedding=None
):
r"""
encoder_hidden_states (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
if the models is configured as a decoder.
encoder_attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on the padding token indices of the encoder input. This mask
is used in the cross-attention if the models is configured as a decoder.
Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
"""
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
input_shape = input_ids.size()
elif inputs_embeds is not None:
input_shape = inputs_embeds.size()[:-1]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")
device = input_ids.device if input_ids is not None else inputs_embeds.device
if attention_mask is None:
attention_mask = torch.ones(input_shape, device=device)
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
# We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
# ourselves in which case we just need to make it broadcastable to all heads.
extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape, device)
# If a 2D or 3D attention mask is provided for the cross-attention
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
if self.config.is_decoder and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_extended_attention_mask = None
# Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N
# input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
# and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
assert embedding is not None
embedding_output = embedding
encoder_outputs = self.encoder(
embedding_output,
attention_mask=extended_attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_extended_attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
sequence_output = encoder_outputs[0]
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
if not return_dict:
return (sequence_output, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
class MultiTaskHeads(nn.Module):
def __init__(self, config):
super().__init__()
self.predictions = BertLMPredictionHead(config)
def forward(self, sequence_output):
prediction_scores = self.predictions(sequence_output)
return prediction_scores
class FusionBertEmbeddings(nn.Module):
"""
Construct the embeddings from word, position, glyph, pinyin and token_type embeddings.
"""
def __init__(self, config):
super(FusionBertEmbeddings, self).__init__()
self.path = Path(config._name_or_path)
config_path = cache_path / 'config'
if not os.path.exists(config_path):
os.makedirs(config_path)
font_files = []
download_file("config/STFANGSO.TTF24.npy", self.path)
download_file("config/STXINGKA.TTF24.npy", self.path)
download_file("config/方正古隶繁体.ttf24.npy", self.path)
for file in os.listdir(config_path):
if file.endswith(".npy"):
font_files.append(config_path / file)
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=0)
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
self.pinyin_embeddings = PinyinEmbedding(embedding_size=128, pinyin_out_dim=config.hidden_size, config=config)
self.glyph_embeddings = GlyphEmbedding(font_npy_files=font_files)
# self.LayerNorm is not snake-cased to stick with TensorFlow models variable name and be able to load
# any TensorFlow checkpoint file
self.glyph_map = nn.Linear(1728, config.hidden_size)
self.map_fc = nn.Linear(config.hidden_size * 3, config.hidden_size)
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
self.dropout = nn.Dropout(config.hidden_dropout_prob)
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)))
def forward(self, input_ids=None, pinyin_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1]
if position_ids is None:
position_ids = self.position_ids[:, :seq_length]
if token_type_ids is None:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
# get char embedding, pinyin embedding and glyph embedding
word_embeddings = inputs_embeds # [bs,l,hidden_size]
pinyin_embeddings = self.pinyin_embeddings(pinyin_ids) # [bs,l,hidden_size]
glyph_embeddings = self.glyph_map(self.glyph_embeddings(input_ids)) # [bs,l,hidden_size]
# fusion layer
concat_embeddings = torch.cat((word_embeddings, pinyin_embeddings, glyph_embeddings), 2)
inputs_embeds = self.map_fc(concat_embeddings)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + position_embeddings + token_type_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class PinyinEmbedding(nn.Module):
def __init__(self, embedding_size: int, pinyin_out_dim: int, config):
"""
Pinyin Embedding Module
Args:
embedding_size: the size of each embedding vector
pinyin_out_dim: kernel number of conv
"""
super(PinyinEmbedding, self).__init__()
download_file("config/pinyin_map.json", Path(config._name_or_path))
with open(cache_path / 'config' / 'pinyin_map.json') as fin:
pinyin_dict = json.load(fin)
self.pinyin_out_dim = pinyin_out_dim
self.embedding = nn.Embedding(len(pinyin_dict['idx2char']), embedding_size)
self.conv = nn.Conv1d(in_channels=embedding_size, out_channels=self.pinyin_out_dim, kernel_size=2,
stride=1, padding=0)
def forward(self, pinyin_ids):
"""
Args:
pinyin_ids: (bs*sentence_length*pinyin_locs)
Returns:
pinyin_embed: (bs,sentence_length,pinyin_out_dim)
"""
# input pinyin ids for 1-D conv
embed = self.embedding(pinyin_ids) # [bs,sentence_length,pinyin_locs,embed_size]
bs, sentence_length, pinyin_locs, embed_size = embed.shape
view_embed = embed.view(-1, pinyin_locs, embed_size) # [(bs*sentence_length),pinyin_locs,embed_size]
input_embed = view_embed.permute(0, 2, 1) # [(bs*sentence_length), embed_size, pinyin_locs]
# conv + max_pooling
pinyin_conv = self.conv(input_embed) # [(bs*sentence_length),pinyin_out_dim,H]
pinyin_embed = F.max_pool1d(pinyin_conv, pinyin_conv.shape[-1]) # [(bs*sentence_length),pinyin_out_dim,1]
return pinyin_embed.view(bs, sentence_length, self.pinyin_out_dim) # [bs,sentence_length,pinyin_out_dim]
class GlyphEmbedding(nn.Module):
"""Glyph2Image Embedding"""
def __init__(self, font_npy_files: List[str]):
super(GlyphEmbedding, self).__init__()
font_arrays = [
np.load(np_file).astype(np.float32) for np_file in font_npy_files
]
self.vocab_size = font_arrays[0].shape[0]
self.font_num = len(font_arrays)
self.font_size = font_arrays[0].shape[-1]
# N, C, H, W
font_array = np.stack(font_arrays, axis=1)
self.embedding = nn.Embedding(
num_embeddings=self.vocab_size,
embedding_dim=self.font_size ** 2 * self.font_num,
_weight=torch.from_numpy(font_array.reshape([self.vocab_size, -1]))
)
def forward(self, input_ids):
"""
get glyph images for batch inputs
Args:
input_ids: [batch, sentence_length]
Returns:
images: [batch, sentence_length, self.font_num*self.font_size*self.font_size]
"""
# return self.embedding(input_ids).view([-1, self.font_num, self.font_size, self.font_size])
return self.embedding(input_ids)