Spaces:
Runtime error
Runtime error
import torch.nn as nn | |
from losses import masked_log_probs | |
from utils import _logits, shift_targets | |
class EditableModel(nn.Module): | |
def __init__(self, model, config, model_constructor): | |
super().__init__() | |
self.model = model | |
self.config = config | |
self.model_constructor = model_constructor | |
def _edit_loss_fn(pred, targ, **kwargs): | |
return masked_log_probs(pred, targ, shift=shift_targets(self.config), **kwargs) | |
self.edit_loss_fn = _edit_loss_fn | |
self.loc_loss_fn = _edit_loss_fn | |
def edit(self, batch, condition=None, detach_history=False): | |
raise NotImplementedError | |
def forward(self, *inputs, **kwargs): | |
return _logits(self.model(*inputs, **kwargs)) | |
def outer_parameters(self, grouped=False): | |
if grouped: | |
return [dict(params=self.parameters(), lr=self.config.lr)] | |
else: | |
return list(self.parameters()) | |
def generate(self, *args, **kwargs): | |
return self.model.generate(*args, **kwargs) | |
def base_loss(self, input_ids, attention_masks, label_ids): | |
pass | |