|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
import dgl |
|
from dgl.nn.pytorch import GatedGraphConv |
|
|
|
def sequence_mask(lengths, maxlen, dtype=torch.bool): |
|
if maxlen is None: |
|
maxlen = lengths.max() |
|
mask = ~(torch.ones((len(lengths), maxlen)).to(lengths.device).cumsum(dim=1).t() > lengths).t() |
|
mask.type(dtype) |
|
return mask |
|
|
|
|
|
def group_hidden_by_segs(h, seg_ids, max_len): |
|
""" |
|
:param h: [B, T, H] |
|
:param seg_ids: [B, T] |
|
:return: h_ph: [B, T_ph, H] |
|
""" |
|
B, T, H = h.shape |
|
h_gby_segs = h.new_zeros([B, max_len + 1, H]).scatter_add_(1, seg_ids[:, :, None].repeat([1, 1, H]), h) |
|
all_ones = h.new_ones(h.shape[:2]) |
|
cnt_gby_segs = h.new_zeros([B, max_len + 1]).scatter_add_(1, seg_ids, all_ones).contiguous() |
|
h_gby_segs = h_gby_segs[:, 1:] |
|
cnt_gby_segs = cnt_gby_segs[:, 1:] |
|
h_gby_segs = h_gby_segs / torch.clamp(cnt_gby_segs[:, :, None], min=1) |
|
|
|
return h_gby_segs |
|
|
|
class GraphAuxEnc(nn.Module): |
|
def __init__(self, in_dim, hid_dim, out_dim, n_iterations=5, n_edge_types=6): |
|
super(GraphAuxEnc, self).__init__() |
|
self.in_dim = in_dim |
|
self.hid_dim = hid_dim |
|
self.out_dim = out_dim |
|
self.skip_connect = True |
|
self.dropout_after_gae = False |
|
|
|
self.ggc_1 = GatedGraphConv(in_feats=in_dim, out_feats=hid_dim |
|
, n_steps=n_iterations, n_etypes=n_edge_types) |
|
self.ggc_2 = GatedGraphConv(in_feats=hid_dim, out_feats=out_dim |
|
, n_steps=n_iterations, n_etypes=n_edge_types) |
|
self.dropout = nn.Dropout(p=0.5) |
|
|
|
@staticmethod |
|
def ph_encoding_to_word_encoding(ph_encoding, ph2word, word_len): |
|
""" |
|
ph_encoding: [batch, t_p, hid] |
|
ph2word: tensor [batch, t_w] |
|
word_len: tensor [batch] |
|
""" |
|
word_encoding_for_graph, batch_word_encoding, has_word_row_idx = GraphAuxEnc._process_ph_to_word_encoding( |
|
ph_encoding, |
|
ph2word, |
|
word_len) |
|
|
|
return batch_word_encoding, word_encoding_for_graph |
|
|
|
def pad_word_encoding_to_phoneme(self, word_encoding, ph2word, t_p): |
|
return self._postprocess_word2ph(word_encoding, ph2word, t_p) |
|
|
|
@staticmethod |
|
def _process_ph_to_word_encoding(ph_encoding, ph2word, word_len=None): |
|
""" |
|
ph_encoding: [batch, t_p, hid] |
|
ph2word: tensor [batch, t_w] |
|
word_len: tensor [batch] |
|
""" |
|
word_len = word_len.reshape([-1,]) |
|
max_len = max(word_len) |
|
num_nodes = sum(word_len) |
|
|
|
batch_word_encoding = group_hidden_by_segs(ph_encoding, ph2word, max_len) |
|
bs, t_p, hid = batch_word_encoding.shape |
|
has_word_mask = sequence_mask(word_len, max_len) |
|
word_encoding = batch_word_encoding.reshape([bs * t_p, hid]) |
|
has_word_row_idx = has_word_mask.reshape([-1]) |
|
word_encoding = word_encoding[has_word_row_idx] |
|
assert word_encoding.shape[0] == num_nodes |
|
return word_encoding, batch_word_encoding, has_word_row_idx |
|
|
|
@staticmethod |
|
def _postprocess_word2ph(word_encoding, ph2word, t_p): |
|
word_encoding = F.pad(word_encoding,[0,0,1,0]) |
|
ph2word_ = ph2word[:, :, None].repeat([1, 1, word_encoding.shape[-1]]) |
|
out = torch.gather(word_encoding, 1, ph2word_) |
|
return out |
|
|
|
@staticmethod |
|
def _repeat_one_sequence(x, d, T): |
|
"""Repeat each frame according to duration.""" |
|
if d.sum() == 0: |
|
d = d.fill_(1) |
|
hid = x.shape[-1] |
|
expanded_lst = [x_.repeat(int(d_), 1) for x_, d_ in zip(x, d) if d_ != 0] |
|
expanded = torch.cat(expanded_lst, dim=0) |
|
if T > expanded.shape[0]: |
|
expanded = torch.cat([expanded, torch.zeros([T - expanded.shape[0], hid]).to(expanded.device)], dim=0) |
|
return expanded |
|
|
|
def word_forward(self, graph_lst, word_encoding, etypes_lst): |
|
""" |
|
word encoding in, word encoding out. |
|
""" |
|
batched_graph = dgl.batch(graph_lst) |
|
inp = word_encoding |
|
batched_etypes = torch.cat(etypes_lst) |
|
assert batched_graph.num_nodes() == inp.shape[0] |
|
|
|
gcc1_out = self.ggc_1(batched_graph, inp, batched_etypes) |
|
if self.dropout_after_gae: |
|
gcc1_out = self.dropout(gcc1_out) |
|
gcc2_out = self.ggc_2(batched_graph, gcc1_out, batched_etypes) |
|
if self.dropout_after_gae: |
|
gcc2_out = self.ggc_2(batched_graph, gcc2_out, batched_etypes) |
|
if self.skip_connect: |
|
assert self.in_dim == self.hid_dim and self.hid_dim == self.out_dim |
|
gcc2_out = inp + gcc1_out + gcc2_out |
|
|
|
word_len = torch.tensor([g.num_nodes() for g in graph_lst]).reshape([-1]) |
|
max_len = max(word_len) |
|
has_word_mask = sequence_mask(word_len, max_len) |
|
has_word_row_idx = has_word_mask.reshape([-1]) |
|
bs = len(graph_lst) |
|
t_w = max([g.num_nodes() for g in graph_lst]) |
|
hid = word_encoding.shape[-1] |
|
output = torch.zeros([bs * t_w, hid]).to(gcc2_out.device) |
|
output[has_word_row_idx] = gcc2_out |
|
output = output.reshape([bs, t_w, hid]) |
|
word_level_output = output |
|
return torch.transpose(word_level_output, 1, 2) |
|
|
|
def forward(self, graph_lst, ph_encoding, ph2word, etypes_lst, return_word_encoding=False): |
|
""" |
|
graph_lst: [list of dgl_graph] |
|
ph_encoding: [batch, hid, t_p] |
|
ph2word: [list of list[1,2,2,2,3,3,3]] |
|
etypes_lst: [list of etypes]; etypes: torch.LongTensor |
|
""" |
|
t_p = ph_encoding.shape[-1] |
|
ph_encoding = ph_encoding.transpose(1,2) |
|
word_len = torch.tensor([g.num_nodes() for g in graph_lst]).reshape([-1]) |
|
batched_graph = dgl.batch(graph_lst) |
|
inp, batched_word_encoding, has_word_row_idx = self._process_ph_to_word_encoding(ph_encoding, ph2word, |
|
word_len=word_len) |
|
bs, t_w, hid = batched_word_encoding.shape |
|
batched_etypes = torch.cat(etypes_lst) |
|
gcc1_out = self.ggc_1(batched_graph, inp, batched_etypes) |
|
gcc2_out = self.ggc_2(batched_graph, gcc1_out, batched_etypes) |
|
|
|
gcc2_out = inp + gcc1_out + gcc2_out |
|
|
|
output = torch.zeros([bs * t_w, hid]).to(gcc2_out.device) |
|
output[has_word_row_idx] = gcc2_out |
|
output = output.reshape([bs, t_w, hid]) |
|
word_level_output = output |
|
output = self._postprocess_word2ph(word_level_output, ph2word, t_p) |
|
output = torch.transpose(output, 1, 2) |
|
|
|
if return_word_encoding: |
|
return output, torch.transpose(word_level_output, 1, 2) |
|
else: |
|
return output |
|
|
|
if __name__ == '__main__': |
|
|
|
from modules.syntaspeech.syntactic_graph_buider import Sentence2GraphParser, plot_dgl_sentence_graph |
|
parser = Sentence2GraphParser("en") |
|
|
|
|
|
text1 = "To be or not to be , that 's a question ." |
|
text2 = "I love you . You love me . Mixue ice-scream and tea ." |
|
graph1, etypes1 = parser.parse(text1) |
|
graph2, etypes2 = parser.parse(text2) |
|
batched_text = "<BOS> " + text1 + " <EOS>" + " " + "<BOS> " + text2 + " <EOS>" |
|
batched_nodes = [graph1.num_nodes(), graph2.num_nodes()] |
|
plot_dgl_sentence_graph(dgl.batch([graph1, graph2]), {i: w for i, w in enumerate(batched_text.split(" "))}) |
|
etypes_lst = [etypes1, etypes2] |
|
|
|
|
|
in_feats = 4 |
|
out_feats = 4 |
|
enc = GraphAuxEnc(in_dim=in_feats, hid_dim=in_feats, out_dim=out_feats) |
|
ph2word = torch.tensor([ |
|
[1, 2, 3, 3, 3, 4, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 0], |
|
[1, 2, 3, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16] |
|
]) |
|
inp = torch.randn([2, in_feats, 17]) |
|
graph_lst = [graph1, graph2] |
|
out = enc(graph_lst, inp, ph2word, etypes_lst) |
|
print(out.shape) |
|
|