File size: 1,090 Bytes
36a67ca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
import torch
import torch.nn as nn
from .attentions import MultiHeadAttention
class VAEMemoryBank(nn.Module):
def __init__(
self,
bank_size=1000,
n_hidden_dims=512,
n_attn_heads=2,
init_values=None,
output_channels=192,
):
super().__init__()
self.bank_size = bank_size
self.n_hidden_dims = n_hidden_dims
self.n_attn_heads = n_attn_heads
self.encoder = MultiHeadAttention(
channels=n_hidden_dims,
out_channels=n_hidden_dims,
n_heads=n_attn_heads,
)
self.memory_bank = nn.Parameter(torch.randn(n_hidden_dims, bank_size))
self.proj = nn.Conv1d(n_hidden_dims, output_channels, 1)
if init_values is not None:
with torch.no_grad():
self.memory_bank.copy_(init_values)
def forward(self, z: torch.Tensor):
b, _, _ = z.shape
ret = self.encoder(
z, self.memory_bank.unsqueeze(0).repeat(b, 1, 1), attn_mask=None
)
ret = self.proj(ret)
return ret
|