Spaces:
Sleeping
Sleeping
import torch | |
class PrefixEncoder(torch.nn.Module): | |
r''' | |
The torch.nn model to encode the prefix | |
Input shape: (batch-size, prefix-length) | |
Output shape: (batch-size, prefix-length, 2*layers*hidden) | |
''' | |
def __init__(self, config): | |
super().__init__() | |
self.prefix_projection = config.prefix_projection | |
if self.prefix_projection: | |
# Use a two-layer MLP to encode the prefix | |
self.embedding = torch.nn.Embedding(config.pre_seq_len, config.hidden_size) | |
self.trans = torch.nn.Sequential( | |
torch.nn.Linear(config.hidden_size, config.prefix_hidden_size), | |
torch.nn.Tanh(), | |
torch.nn.Linear(config.prefix_hidden_size, config.num_hidden_layers * 2 * config.hidden_size) | |
) | |
else: | |
self.embedding = torch.nn.Embedding(config.pre_seq_len, config.num_hidden_layers * 2 * config.hidden_size) | |
def forward(self, prefix: torch.Tensor): | |
device = next(self.embedding.parameters()).device | |
if self.prefix_projection: | |
prefix_tokens = self.embedding(prefix.to(device)) | |
past_key_values = self.trans(prefix_tokens) | |
else: | |
past_key_values = self.embedding(prefix.to(device)) | |
return past_key_values |