|
import torch |
|
import torch.nn as nn |
|
from transformers import CLIPTextModel, CLIPTokenizer |
|
|
|
|
|
class FrozenCLIPEmbedder(nn.Module): |
|
"""Uses the CLIP transformer encoder for text (from huggingface)""" |
|
|
|
def __init__( |
|
self, |
|
version="openai/clip-vit-large-patch14", |
|
device="cuda", |
|
max_length=77, |
|
freeze=True, |
|
): |
|
super().__init__() |
|
self.tokenizer = CLIPTokenizer.from_pretrained(version) |
|
self.transformer = CLIPTextModel.from_pretrained(version).to(device) |
|
self.device = device |
|
self.hidden_size = self.transformer.config.hidden_size |
|
self.max_length = max_length |
|
if freeze: |
|
self.freeze() |
|
|
|
def freeze(self): |
|
self.transformer = self.transformer.eval() |
|
for param in self.parameters(): |
|
param.requires_grad = False |
|
|
|
def forward(self, text): |
|
batch_encoding = self.tokenizer( |
|
text, |
|
truncation=True, |
|
max_length=self.max_length, |
|
return_overflowing_tokens=False, |
|
padding="max_length", |
|
return_tensors="pt", |
|
).to(self.device) |
|
|
|
outputs = self.transformer(**batch_encoding) |
|
|
|
attn_bias = batch_encoding["attention_mask"].to(outputs["last_hidden_state"].dtype) |
|
attn_bias[attn_bias == 0] = -float("inf") |
|
attn_bias[attn_bias == 1] = 0.0 |
|
outputs["attn_bias"] = attn_bias |
|
return outputs |
|
|
|
@torch.no_grad() |
|
def encode(self, text): |
|
return self(text) |
|
|