import torch | |
from transformers import PreTrainedModel | |
from ..utils import torch_gc | |
class CPUTextEncoderWrapper(PreTrainedModel): | |
def __init__(self, text_encoder, torch_dtype): | |
super().__init__(text_encoder.config) | |
self.config = text_encoder.config | |
self._device = text_encoder.device | |
# cpu not support float16 | |
self.text_encoder = text_encoder.to(torch.device("cpu"), non_blocking=True) | |
self.text_encoder = self.text_encoder.to(torch.float32, non_blocking=True) | |
self.torch_dtype = torch_dtype | |
del text_encoder | |
torch_gc() | |
def __call__(self, x, **kwargs): | |
input_device = x.device | |
original_output = self.text_encoder(x.to(self.text_encoder.device), **kwargs) | |
for k, v in original_output.items(): | |
if isinstance(v, tuple): | |
original_output[k] = [ | |
v[i].to(input_device).to(self.torch_dtype) for i in range(len(v)) | |
] | |
else: | |
original_output[k] = v.to(input_device).to(self.torch_dtype) | |
return original_output | |
def dtype(self): | |
return self.torch_dtype | |
def device(self) -> torch.device: | |
""" | |
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same | |
device). | |
""" | |
return self._device |