File size: 10,571 Bytes
c32f190 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 |
""" huggingface model adapter
Wraps HuggingFace transformers (https://github.com/huggingface/transformers) models for use as a text tower in CLIP model.
"""
import re
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch import TensorType
try:
import transformers
from transformers import AutoModel, AutoModelForMaskedLM, AutoTokenizer, AutoConfig, PretrainedConfig
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, \
BaseModelOutputWithPoolingAndCrossAttentions
except ImportError as e:
transformers = None
class BaseModelOutput:
pass
class PretrainedConfig:
pass
from .hf_configs import arch_dict
# utils
def _camel2snake(s):
return re.sub(r'(?<!^)(?=[A-Z])', '_', s).lower()
# TODO: ?last - for gpt-like models
_POOLERS = {}
def register_pooler(cls):
"""Decorator registering pooler class"""
_POOLERS[_camel2snake(cls.__name__)] = cls
return cls
@register_pooler
class MeanPooler(nn.Module):
"""Mean pooling"""
def forward(self, x:BaseModelOutput, attention_mask:TensorType):
masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
@register_pooler
class MaxPooler(nn.Module):
"""Max pooling"""
def forward(self, x:BaseModelOutput, attention_mask:TensorType):
masked_output = x.last_hidden_state.masked_fill(attention_mask.unsqueeze(-1), -torch.inf)
return masked_output.max(1).values
@register_pooler
class ClsPooler(nn.Module):
"""CLS token pooling"""
def __init__(self, use_pooler_output=True):
super().__init__()
self.cls_token_position = 0
self.use_pooler_output = use_pooler_output
def forward(self, x:BaseModelOutput, attention_mask:TensorType):
if (self.use_pooler_output and
isinstance(x, (BaseModelOutputWithPooling, BaseModelOutputWithPoolingAndCrossAttentions)) and
(x.pooler_output is not None)
):
return x.pooler_output
return x.last_hidden_state[:, self.cls_token_position, :]
class HFTextEncoder(nn.Module):
"""HuggingFace model adapter"""
def __init__(
self,
model_name_or_path: str,
output_dim: int,
tokenizer_name: str = None,
config: PretrainedConfig = None,
pooler_type: str = None,
proj: str = None,
pretrained: bool = True,
masked_language_modeling: bool = False):
super().__init__()
self.output_dim = output_dim
# TODO: find better way to get this information
uses_transformer_pooler = (pooler_type == "cls_pooler")
if transformers is None:
raise RuntimeError("Please `pip install transformers` to use pre-trained HuggingFace models")
if config is None:
self.config = AutoConfig.from_pretrained(model_name_or_path)
if masked_language_modeling:
create_func, model_args = (AutoModelForMaskedLM.from_pretrained, model_name_or_path) if pretrained else (
AutoModelForMaskedLM.from_config, self.config)
else:
create_func, model_args = (AutoModel.from_pretrained, model_name_or_path) if pretrained else (
AutoModel.from_config, self.config)
# TODO: do all model configs have this attribute? PretrainedConfig does so yes??
if hasattr(self.config, "is_encoder_decoder") and self.config.is_encoder_decoder:
self.transformer = create_func(model_args)
self.transformer = self.transformer.encoder
else:
self.transformer = create_func(model_args, add_pooling_layer=uses_transformer_pooler)
else:
self.config = config
if masked_language_modeling:
self.transformer = AutoModelForMaskedLM.from_config(config)
else:
self.transformer = AutoModel.from_config(config)
if pooler_type is None: # get default arch pooler
self.pooler = _POOLERS[(arch_dict[self.config.model_type]["pooler"])]()
else:
self.pooler = _POOLERS[pooler_type]()
d_model = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["width"])
if (d_model == output_dim) and (proj is None): # do we always need a proj?
self.proj = nn.Identity()
elif proj == 'linear':
self.proj = nn.Linear(d_model, output_dim, bias=False)
elif proj == 'mlp':
hidden_size = (d_model + output_dim) // 2
self.proj = nn.Sequential(
nn.Linear(d_model, hidden_size, bias=False),
nn.GELU(),
nn.Linear(hidden_size, output_dim, bias=False),
)
# self.itm_proj = nn.Linear(d_model, 2, bias=False)
# self.mlm_proj = nn.Linear(d_model, self.config.vocab_size), bias=False)
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
# def forward_itm(self, x:TensorType, image_embeds:TensorType) -> TensorType:
# image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(x.device)
# attn_mask = (x != self.config.pad_token_id).long()
# out = self.transformer(
# input_ids=x,
# attention_mask=attn_mask,
# encoder_hidden_states = image_embeds,
# encoder_attention_mask = image_atts,
# )
# pooled_out = self.pooler(out, attn_mask)
# return self.itm_proj(pooled_out)
def mask(self, input_ids, vocab_size, device, targets=None, masked_indices=None, probability_matrix=None):
if masked_indices is None:
masked_indices = torch.bernoulli(probability_matrix).bool()
masked_indices[input_ids == self.tokenizer.pad_token_id] = False
masked_indices[input_ids == self.tokenizer.cls_token_id] = False
if targets is not None:
targets[~masked_indices] = -100 # We only compute loss on masked tokens
# 80% of the time, we replace masked input tokens with tokenizer.mask_token ([MASK])
indices_replaced = torch.bernoulli(torch.full(input_ids.shape, 0.8)).bool() & masked_indices
input_ids[indices_replaced] = self.tokenizer.mask_token_id
# 10% of the time, we replace masked input tokens with random word
indices_random = torch.bernoulli(torch.full(input_ids.shape, 0.5)).bool() & masked_indices & ~indices_replaced
random_words = torch.randint(vocab_size, input_ids.shape, dtype=torch.long).to(device)
input_ids[indices_random] = random_words[indices_random]
# The rest of the time (10% of the time) we keep the masked input tokens unchanged
if targets is not None:
return input_ids, targets
else:
return input_ids
def forward_mlm(self, input_ids, image_embeds, mlm_probability=0.25):
labels = input_ids.clone()
attn_mask = (input_ids != self.config.pad_token_id).long()
image_atts = torch.ones(image_embeds.size()[:-1],dtype=torch.long).to(input_ids.device)
vocab_size = getattr(self.config, arch_dict[self.config.model_type]["config_names"]["vocab_size"])
probability_matrix = torch.full(labels.shape, mlm_probability)
input_ids, labels = self.mask(input_ids, vocab_size, input_ids.device, targets=labels,
probability_matrix = probability_matrix)
mlm_output = self.transformer(input_ids,
attention_mask = attn_mask,
encoder_hidden_states = image_embeds,
encoder_attention_mask = image_atts,
return_dict = True,
labels = labels,
)
return mlm_output.loss
# mlm_output = self.transformer(input_ids,
# attention_mask = attn_mask,
# encoder_hidden_states = image_embeds,
# encoder_attention_mask = image_atts,
# return_dict = True,
# ).last_hidden_state
# logits = self.mlm_proj(mlm_output)
# # logits = logits[:, :-1, :].contiguous().view(-1, vocab_size)
# logits = logits[:, 1:, :].contiguous().view(-1, vocab_size)
# labels = labels[:, 1:].contiguous().view(-1)
# mlm_loss = F.cross_entropy(
# logits,
# labels,
# # label_smoothing=0.1,
# )
# return mlm_loss
def forward(self, x:TensorType) -> TensorType:
attn_mask = (x != self.config.pad_token_id).long()
out = self.transformer(input_ids=x, attention_mask=attn_mask)
pooled_out = self.pooler(out, attn_mask)
return self.proj(pooled_out)
def lock(self, unlocked_layers:int=0, freeze_layer_norm:bool=True):
if not unlocked_layers: # full freezing
for n, p in self.transformer.named_parameters():
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
return
encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
print(f"Unlocking {unlocked_layers}/{len(layer_list) + 1} layers of hf model")
embeddings = getattr(
self.transformer, arch_dict[self.config.model_type]["config_names"]["token_embeddings_attr"])
modules = [embeddings, *layer_list][:-unlocked_layers]
# freeze layers
for module in modules:
for n, p in module.named_parameters():
p.requires_grad = (not freeze_layer_norm) if "LayerNorm" in n.split(".") else False
@torch.jit.ignore
def set_grad_checkpointing(self, enable=True):
self.transformer.gradient_checkpointing_enable()
def get_num_layers(self):
encoder = self.transformer.encoder if hasattr(self.transformer, 'encoder') else self.transformer
layer_list = getattr(encoder, arch_dict[self.config.model_type]["config_names"]["layer_attr"])
return len(layer_list)
def init_parameters(self):
pass
|