import torch import torch.nn as nn from transformers import PreTrainedModel from .configuration_reborn import RebornUASRConfig from typing import Optional, Tuple, Union, List class RebornSegmenter(nn.Module): def __init__(self, config): super().__init__() self.config = config self.conv1 = nn.Conv1d(config.segmenter_input_dim, config.segmenter_hidden_dim, config.segmenter_kernel_size, padding=config.segmenter_kernel_size//2) self.conv2 = nn.Conv1d(config.segmenter_hidden_dim, config.segmenter_hidden_dim, 3, padding=1) self.conv3 = nn.Conv1d(config.segmenter_hidden_dim, 2, 1) self.dropout = nn.Dropout(config.segmenter_dropout) self.relu = nn.ReLU() def forward(self, x): """ Input: x: (B, T, C) padding_mask: (B, T) # 0: not padding; 1: padding Output: boundary: (B, T, 2) # 0: not boundary; 1: boundary """ x = x.transpose(1, 2) x = self.dropout(self.relu(self.conv1(x))) x = self.dropout(self.relu(self.conv2(x))) x = self.conv3(x) x = x.transpose(1, 2) return x def boundary_predict(self, x, padding_mask, deterministic=False): """ Input: x: (B, T, C) padding_mask: (B, T) Output: boundary: (B, T) # 0: not boundary; 1: boundary boundary_logits: (B, T, 2) # 0: not boundary; 1: boundary """ boundary_logits = self.forward(x) if deterministic: boundary = boundary_logits.argmax(-1) boundary[padding_mask] = -1 else: boundary = torch.distributions.Categorical(logits=boundary_logits).sample() boundary[padding_mask] = -1 return boundary, boundary_logits def pre_segment(self, logits, padding_mask, return_boundary=False, deterministic=True): """ Input: logits: (B, T, C) padding_mask: (B, T) Output: new_logits: (B, T', C) new_padding_mask: (B, T') """ bsz, tsz, csz = logits.size() boundary, boundary_logits = self.boundary_predict(logits, padding_mask, deterministic=deterministic) # max boundary number # print("boundary", boundary) # print(torch.sum(boundary==1, dim=1)) new_tsz = int(torch.max(torch.sum(boundary==1, dim=1)).item())+1 # add new_logits = logits.new_zeros(bsz, new_tsz, csz) new_pad = padding_mask.new_zeros(bsz, new_tsz) for b in range(bsz): # merge consecutive segments when meeting a boundary (mean_pool_join) new_idx = 0 count = 0 for t in range(tsz): if padding_mask[b, t] == 1: break if boundary[b, t] == 1: new_logits[b, new_idx] /= count new_idx += 1 count = 0 new_logits[b, new_idx] += logits[b, t] count += 1 if count > 0: # last segment new_logits[b, new_idx] /= count new_idx += 1 count = 0 if new_idx < new_tsz: pad = new_tsz - new_idx new_logits[b, -pad:] = 0 new_pad[b, -pad:] = True if return_boundary: return new_logits, new_pad, boundary, boundary_logits return new_logits, new_pad class RebornGenerator(nn.Module): def __init__(self, config): super().__init__() self.config = config self.output_dim = config.generator_output_dim self.stride = config.generator_stride self.dropout = nn.Dropout(config.generator_dropout) cnn_input_dim = config.generator_input_dim cnn_output_dim = config.generator_output_dim padding = config.generator_kernel // 2 self.proj = nn.Sequential( nn.Conv1d( cnn_input_dim, cnn_output_dim, kernel_size=config.generator_kernel, stride=config.generator_stride, dilation=config.generator_dilation, padding=padding, bias=config.generator_bias, ), ) def forward(self, dense_x, tokens, dense_padding_mask): dense_x = self.dropout(dense_x) # (B, T, C) -> (B, C, T) dense_x = dense_x.transpose(-2, -1) dense_x = self.proj(dense_x) # (B, C, T) -> (B, T, C) dense_x = dense_x.transpose(-2, -1) if self.stride > 1: dense_padding_mask = dense_padding_mask[:, :: self.stride] if dense_padding_mask.size(1) != dense_x.size(1): new_padding = dense_padding_mask.new_zeros(dense_x.shape[:-1]) diff = new_padding.size(1) - dense_padding_mask.size(1) assert ( diff > 0 ), f"{new_padding.shape}, {dense_padding_mask.shape}, {dense_x.shape}, {diff}" if diff > 0: new_padding[:, diff:] = dense_padding_mask else: assert diff < 0 new_padding = dense_padding_mask[:, :diff] dense_padding_mask = new_padding result = {} token_x = None if tokens is not None: token_x = dense_x.new_zeros(tokens.numel(), self.output_dim) token_x.scatter_(1, tokens.view(-1, 1).long(), 1) token_x = token_x.view(tokens.shape + (self.output_dim,)) result["dense_x"] = dense_x result["token_x"] = token_x result["dense_padding_mask"] = dense_padding_mask return result def get_item(tensor): # tpu-comment: making this a no-op for xla devices. if torch.is_tensor(tensor) and tensor.device.type == "xla": return tensor.detach() if hasattr(tensor, "item"): return tensor.item() if hasattr(tensor, "__getitem__"): return tensor[0] return tensor def post_process(sentence: str, symbol: str): if symbol == "sentencepiece": sentence = sentence.replace(" ", "").replace("\u2581", " ").strip() elif symbol == "wordpiece": sentence = sentence.replace(" ", "").replace("_", " ").strip() elif symbol == "letter": sentence = sentence.replace(" ", "").replace("|", " ").strip() elif symbol == "silence": import re sentence = sentence.replace("", "") sentence = re.sub(' +', ' ', sentence).strip() elif symbol == "_EOW": sentence = sentence.replace(" ", "").replace("_EOW", " ").strip() elif symbol in {"subword_nmt", "@@ ", "@@"}: if symbol == "subword_nmt": symbol = "@@ " sentence = (sentence + " ").replace(symbol, "").rstrip() elif symbol == "none": pass elif symbol is not None: raise NotImplementedError(f"Unknown post_process option: {symbol}") return sentence class SimpleTokenizer(object): def __init__(self, phones: List[str], bos="", pad="", eos="", unk="", extra_special_symbols=None, ) -> None: self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos self.symbols = [] self.count = [] self.indices = {} self.bos_index = self.add_symbol(bos) self.pad_index = self.add_symbol(pad) self.eos_index = self.add_symbol(eos) self.unk_index = self.add_symbol(unk) if extra_special_symbols: for s in extra_special_symbols: self.add_symbol(s) self.nspecial = len(self.symbols) for phone in phones: self.add_symbol(phone) self.postprocess_code = "silence" def add_symbol(self, word, n=1, overwrite=False): """Adds a word to the dictionary""" if word in self.indices and not overwrite: idx = self.indices[word] self.count[idx] = self.count[idx] + n return idx else: idx = len(self.symbols) self.indices[word] = idx self.symbols.append(word) self.count.append(n) return idx def __eq__(self, other): return self.indices == other.indices def __getitem__(self, idx): if idx < len(self.symbols): return self.symbols[idx] return self.unk_word def get_count(self, idx): return self.count[idx] def __len__(self): """Returns the number of symbols in the dictionary""" return len(self.symbols) def __contains__(self, sym): return sym in self.indices def index(self, sym): """Returns the index of the specified symbol""" assert isinstance(sym, str) if sym in self.indices: return self.indices[sym] return self.unk_index def string( self, tensor, bpe_symbol=None, escape_unk=False, extra_symbols_to_ignore=None, unk_string=None, include_eos=False, separator=" ", ): """Helper for converting a tensor of token indices to a string. Can optionally remove BPE symbols or escape words. """ if torch.is_tensor(tensor) and tensor.dim() == 2: return "\n".join( self.string( t, bpe_symbol, escape_unk, extra_symbols_to_ignore, include_eos=include_eos, ) for t in tensor ) extra_symbols_to_ignore = set(extra_symbols_to_ignore or []) if not include_eos: extra_symbols_to_ignore.add(self.eos()) def token_string(i): if i == self.unk(): if unk_string is not None: return unk_string else: return self.unk_string(escape_unk) else: return self[i] if hasattr(self, "bos_index"): extra_symbols_to_ignore.add(self.bos()) sent = separator.join( token_string(i) for i in tensor if get_item(i) not in extra_symbols_to_ignore ) return post_process(sent, bpe_symbol) def unk_string(self, escape=False): """Return unknown string, optionally escaped as: <>""" if escape: return "<{}>".format(self.unk_word) else: return self.unk_word def bos(self): """Helper to get index of beginning-of-sentence symbol""" return self.bos_index def pad(self): """Helper to get index of pad symbol""" return self.pad_index def eos(self): """Helper to get index of end-of-sentence symbol""" return self.eos_index def unk(self): """Helper to get index of unk symbol""" return self.unk_index class RebornUASRModel(PreTrainedModel): config_class = RebornUASRConfig def __init__(self, config): super().__init__(config) self.pca = nn.Linear(1024, 512) self.segmenter = RebornSegmenter(config) self.generator = RebornGenerator(config) self.tokenizer = None if len(config.phones) > 0: self.tokenizer = SimpleTokenizer(config.phones) def forward( self, x: Optional[torch.Tensor], # (B, T, C) padding_mask: Optional[torch.Tensor], # (B, T) ): x_reduced = self.pca(x) x_segmented, segmented_padding_mask = self.segmenter.pre_segment(x_reduced, padding_mask, deterministic=True) x_generated = self.generator(x_segmented, None, segmented_padding_mask) return { 'x_reduced': x_reduced, 'x_segmented': x_segmented, 'x_generated': x_generated } def generate(self, x, padding_mask, merge_consecutive=True, remove_silence=True): res = self.forward(x, padding_mask) y_raw_logits = res['x_generated']['dense_x'] y_raw_padding = res['x_generated']['dense_padding_mask'] y_raw_logits[y_raw_padding][..., self.tokenizer.pad_index] = float('inf') preds = y_raw_logits.argmax(-1) hyps = [] postprocess_code = "silence" if remove_silence else "none" for pred in preds: if merge_consecutive: # merge consecutive predictions pred = torch.unique_consecutive(pred) hyp = self.tokenizer.string(pred, bpe_symbol=postprocess_code) hyps.append(hyp) return hyps def main(): model_config = RebornUASRConfig.from_pretrained("/home/andybi7676/Desktop/uasr-rl/reborn_uasr/config.json") print(model_config) model = RebornUASRModel(model_config) print(model.tokenizer.indices) if __name__ == "__main__": main()