|
import os |
|
|
|
import torch |
|
from torch import Tensor, nn |
|
from transformers import ( |
|
CLIPTextModel, |
|
CLIPTokenizer, |
|
T5EncoderModel, |
|
T5Tokenizer, |
|
__version__, |
|
) |
|
from transformers.utils.quantization_config import QuantoConfig, BitsAndBytesConfig |
|
|
|
CACHE_DIR = os.environ.get("HF_HOME", "~/.cache/huggingface") |
|
|
|
|
|
def auto_quantization_config( |
|
quantization_dtype: str, |
|
) -> QuantoConfig | BitsAndBytesConfig: |
|
if quantization_dtype == "qfloat8": |
|
return QuantoConfig(weights="float8") |
|
elif quantization_dtype == "qint4": |
|
return BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_compute_dtype=torch.bfloat16, |
|
bnb_4bit_quant_type="nf4", |
|
) |
|
elif quantization_dtype == "qint8": |
|
return BitsAndBytesConfig(load_in_8bit=True, llm_int8_has_fp16_weight=False) |
|
elif quantization_dtype == "qint2": |
|
return QuantoConfig(weights="int2") |
|
elif quantization_dtype is None or quantization_dtype == "bfloat16": |
|
return None |
|
else: |
|
raise ValueError(f"Unsupported quantization dtype: {quantization_dtype}") |
|
|
|
|
|
class HFEmbedder(nn.Module): |
|
def __init__( |
|
self, |
|
version: str, |
|
max_length: int, |
|
device: torch.device | int, |
|
quantization_dtype: str | None = None, |
|
offloading_device: torch.device | int | None = torch.device("cpu"), |
|
is_clip: bool = False, |
|
**hf_kwargs, |
|
): |
|
super().__init__() |
|
self.offloading_device = ( |
|
offloading_device |
|
if isinstance(offloading_device, torch.device) |
|
else torch.device(offloading_device) |
|
) |
|
self.device = ( |
|
device if isinstance(device, torch.device) else torch.device(device) |
|
) |
|
self.is_clip = version.startswith("openai") or is_clip |
|
self.max_length = max_length |
|
self.output_key = "pooler_output" if self.is_clip else "last_hidden_state" |
|
|
|
auto_quant_config = ( |
|
auto_quantization_config(quantization_dtype) |
|
if quantization_dtype is not None |
|
and quantization_dtype != "bfloat16" |
|
and quantization_dtype != "float16" |
|
else None |
|
) |
|
|
|
|
|
if isinstance(auto_quant_config, BitsAndBytesConfig): |
|
hf_kwargs["device_map"] = {"": self.device.index} |
|
if auto_quant_config is not None: |
|
hf_kwargs["quantization_config"] = auto_quant_config |
|
|
|
if self.is_clip: |
|
self.tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained( |
|
version, max_length=max_length |
|
) |
|
|
|
self.hf_module: CLIPTextModel = CLIPTextModel.from_pretrained( |
|
version, |
|
**hf_kwargs, |
|
) |
|
|
|
else: |
|
self.tokenizer: T5Tokenizer = T5Tokenizer.from_pretrained( |
|
version, max_length=max_length |
|
) |
|
self.hf_module: T5EncoderModel = T5EncoderModel.from_pretrained( |
|
version, |
|
**hf_kwargs, |
|
) |
|
|
|
def offload(self): |
|
self.hf_module.to(device=self.offloading_device) |
|
torch.cuda.empty_cache() |
|
|
|
def cuda(self): |
|
self.hf_module.to(device=self.device) |
|
|
|
def forward(self, text: list[str]) -> Tensor: |
|
batch_encoding = self.tokenizer( |
|
text, |
|
truncation=True, |
|
max_length=self.max_length, |
|
return_length=False, |
|
return_overflowing_tokens=False, |
|
padding="max_length", |
|
return_tensors="pt", |
|
) |
|
outputs = self.hf_module( |
|
input_ids=batch_encoding["input_ids"].to(self.hf_module.device), |
|
attention_mask=None, |
|
output_hidden_states=False, |
|
) |
|
return outputs[self.output_key] |
|
|
|
|
|
if __name__ == "__main__": |
|
model = HFEmbedder( |
|
"city96/t5-v1_1-xxl-encoder-bf16", |
|
max_length=512, |
|
device=0, |
|
quantization_dtype="qfloat8", |
|
) |
|
o = model(["hello"]) |
|
print(o) |
|
|