DCWIR-Demo / textattack /models /wrappers /pytorch_model_wrapper.py
PFEemp2024's picture
add necessary file
63775f2
raw
history blame contribute delete
No virus
3.73 kB
"""
PyTorch Model Wrapper
--------------------------
"""
import torch
from torch.nn import CrossEntropyLoss
import textattack
from .model_wrapper import ModelWrapper
torch.cuda.empty_cache()
class PyTorchModelWrapper(ModelWrapper):
"""Loads a PyTorch model (`nn.Module`) and tokenizer.
Args:
model (torch.nn.Module): PyTorch model
tokenizer: tokenizer whose output can be packed as a tensor and passed to the model.
No type requirement, but most have `tokenizer` method that accepts list of strings.
"""
def __init__(self, model, tokenizer):
if not isinstance(model, torch.nn.Module):
raise TypeError(
f"PyTorch model must be torch.nn.Module, got type {type(model)}"
)
self.model = model
self.tokenizer = tokenizer
def to(self, device):
self.model.to(device)
def __call__(self, text_input_list, batch_size=32):
model_device = next(self.model.parameters()).device
ids = self.tokenizer(text_input_list)
ids = torch.tensor(ids).to(model_device)
with torch.no_grad():
outputs = textattack.shared.utils.batch_model_predict(
self.model, ids, batch_size=batch_size
)
return outputs
def get_grad(self, text_input, loss_fn=CrossEntropyLoss()):
"""Get gradient of loss with respect to input tokens.
Args:
text_input (str): input string
loss_fn (torch.nn.Module): loss function. Default is `torch.nn.CrossEntropyLoss`
Returns:
Dict of ids, tokens, and gradient as numpy array.
"""
if not hasattr(self.model, "get_input_embeddings"):
raise AttributeError(
f"{type(self.model)} must have method `get_input_embeddings` that returns `torch.nn.Embedding` object that represents input embedding layer"
)
if not isinstance(loss_fn, torch.nn.Module):
raise ValueError("Loss function must be of type `torch.nn.Module`.")
self.model.train()
embedding_layer = self.model.get_input_embeddings()
original_state = embedding_layer.weight.requires_grad
embedding_layer.weight.requires_grad = True
emb_grads = []
def grad_hook(module, grad_in, grad_out):
emb_grads.append(grad_out[0])
emb_hook = embedding_layer.register_backward_hook(grad_hook)
self.model.zero_grad()
model_device = next(self.model.parameters()).device
ids = self.tokenizer([text_input])
ids = torch.tensor(ids).to(model_device)
predictions = self.model(ids)
output = predictions.argmax(dim=1)
loss = loss_fn(predictions, output)
loss.backward()
# grad w.r.t to word embeddings
# Fix for Issue #601
# Check if gradient has shape [max_sequence,1,_] ( when model input in transpose of input sequence)
if emb_grads[0].shape[1] == 1:
grad = torch.transpose(emb_grads[0], 0, 1)[0].cpu().numpy()
else:
# gradient has shape [1,max_sequence,_]
grad = emb_grads[0][0].cpu().numpy()
embedding_layer.weight.requires_grad = original_state
emb_hook.remove()
self.model.eval()
output = {"ids": ids[0].tolist(), "gradient": grad}
return output
def _tokenize(self, inputs):
"""Helper method that for `tokenize`
Args:
inputs (list[str]): list of input strings
Returns:
tokens (list[list[str]]): List of list of tokens as strings
"""
return [self.tokenizer.convert_ids_to_tokens(self.tokenizer(x)) for x in inputs]