Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# | |
# This source code is licensed under the Chameleon License found in the | |
# LICENSE file in the root directory of this source tree. | |
from abc import ABC, abstractmethod | |
import torch | |
class PromptAlignment(ABC): | |
def start_index(self, input_ids: list[list[int]]) -> int: | |
... | |
def prepare_inputs(self, input_ids: list[list[int]]) -> torch.Tensor: | |
... | |
def postprocess_inputs( | |
self, inputs: torch.Tensor, original_inputs: torch.Tensor | |
) -> torch.Tensor: | |
... | |
class AlignPromptRight(PromptAlignment): | |
def __init__(self, pad_id: int): | |
self.pad_id = pad_id | |
def start_index(self, input_ids: list[list[int]]) -> int: | |
return max(len(sublist) for sublist in input_ids) | |
def prepare_inputs(self, input_ids: list[list[int]]) -> torch.LongTensor: | |
max_length = max(len(sublist) for sublist in input_ids) | |
return torch.tensor( | |
[ | |
([self.pad_id] * (max_length - len(sublist))) + sublist | |
for sublist in input_ids | |
], | |
requires_grad=False, | |
) | |
def postprocess_inputs( | |
self, | |
inputs: torch.Tensor, | |
original_inputs: torch.Tensor, | |
) -> torch.Tensor: | |
return inputs | |
class AlignPromptLeft(PromptAlignment): | |
def __init__(self, pad_id: int = -1): | |
self.pad_id = pad_id | |
def start_index(self, input_ids: list[list[int]]) -> int: | |
return min(len(sublist) for sublist in input_ids) | |
def prepare_inputs(self, input_ids: list[list[int]]) -> torch.Tensor: | |
max_length = max(len(sublist) for sublist in input_ids) | |
return torch.tensor( | |
[ | |
sublist + ([self.pad_id] * (max_length - len(sublist))) | |
for sublist in input_ids | |
], | |
requires_grad=False, | |
) | |
def postprocess_inputs( | |
self, | |
inputs: torch.Tensor, | |
original_inputs: torch.Tensor, | |
) -> torch.Tensor: | |
max_init_len = original_inputs.shape[1] | |
if inputs.shape[1] <= max_init_len: | |
original_inputs_limited = original_inputs[:, : inputs.shape[1]] | |
mask = original_inputs_limited != self.pad_id | |
inputs[mask] = original_inputs_limited[mask] | |
return inputs | |