|
import torch |
|
|
|
|
|
def forward(model_name, model, input_ids, past, device='cpu'): |
|
if "gpt2" in model_name or "ctrl" in model_name: |
|
if past is not None: |
|
return model(input_ids[:, -1], past=past) |
|
return model(input_ids) |
|
elif "xlnet" in model_name: |
|
input_ids = torch.cat(( |
|
input_ids, |
|
torch.zeros((input_ids.shape[0], 1), dtype=torch.long, device=device) |
|
), dim=1) |
|
|
|
perm_mask = torch.zeros( |
|
(input_ids.shape[0], input_ids.shape[1], input_ids.shape[1]), |
|
dtype=torch.float, |
|
device=device |
|
) |
|
perm_mask[:, :, -1] = 1.0 |
|
|
|
target_mapping = torch.zeros( |
|
(input_ids.shape[0], 1, input_ids.shape[1]), |
|
dtype=torch.float, |
|
device=device) |
|
target_mapping[:, 0, -1] = 1.0 |
|
|
|
return model(input_ids, perm_mask=perm_mask, target_mapping=target_mapping) |
|
elif "transfo-xl" in model_name: |
|
return model(input_ids, mems=past) |
|
else: |
|
return model(input_ids) |
|
|
|
|
|
def create_context(model_name, tokenizer, initial_text="", padding_text=None, max_tokens=512): |
|
if not len(initial_text) and "gpt2" in model_name: |
|
initial_text = "<|endoftext|>" |
|
if 'xlnet' in model_name or "transfo-xl" in model_name: |
|
initial_text = padding_text + initial_text |
|
|
|
if 'transfo-xl' in model_name: |
|
max_tokens = int(max_tokens / 2) |
|
|
|
context_tokens = tokenizer.encode(initial_text)[-max_tokens:] |
|
|
|
if "gpt2" in model_name: |
|
eot_token = tokenizer.encoder["<|endoftext|>"] |
|
if len(context_tokens) == 0: |
|
context_tokens = [tokenizer.encoder["<|endoftext|>"]] |
|
elif "xlnet" in model_name: |
|
eot_token = tokenizer.convert_tokens_to_ids('<eop>') |
|
else: |
|
eot_token = None |
|
dot_token = tokenizer.encode(".")[-1] |
|
|
|
return context_tokens, eot_token, dot_token |
|
|
|
|