Jiangxz01's picture
Upload 56 files
52f1bcb verified
raw
history blame
35.3 kB
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from torch import nn
from transformers.utils import ModelOutput
from surya.model.table_rec.config import SuryaTableRecDecoderConfig, SuryaTableRecTextEncoderConfig
from transformers import PreTrainedModel
from transformers.activations import ACT2FN
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.modeling_outputs import BaseModelOutputWithNoAttention, CausalLMOutput
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from surya.settings import settings
_MAX_SQRT_GRADIENT = 1000.0
@dataclass
class TableRecModelOutput(ModelOutput):
bbox_logits: torch.Tensor
class_logits: torch.Tensor | None = None
hidden_states: torch.Tensor | None = None
class SuryaTableRecDecoderRMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.zeros(dim))
def _norm(self, x):
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x):
output = self._norm(x.float())
# Llama does x.to(float16) * w whilst SuryaTableRecDecoder is (x * w).to(float16)
# See https://github.com/huggingface/transformers/pull/29402
output = output * (1.0 + self.weight.float())
return output.type_as(x)
def extra_repr(self):
return f"{tuple(self.weight.shape)}, eps={self.eps}"
ALL_LAYERNORM_LAYERS.append(SuryaTableRecDecoderRMSNorm)
class SuryaTableRecDecoderRotaryEmbedding(nn.Module):
def __init__(self, dim, base=10000, device=None):
super().__init__()
self.dim = dim
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
@torch.no_grad()
# Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding.forward with Gemma->SuryaTableRecDecoder
def forward(self, x, position_ids, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
self.inv_freq.to(x.device)
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
"""Applies Rotary Position Embedding to the query and key tensors.
Args:
q (`torch.Tensor`): The query tensor.
k (`torch.Tensor`): The key tensor.
cos (`torch.Tensor`): The cosine part of the rotary embedding.
sin (`torch.Tensor`): The sine part of the rotary embedding.
unsqueeze_dim (`int`, *optional*, defaults to 1):
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
Returns:
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
"""
cos = cos.unsqueeze(unsqueeze_dim)
sin = sin.unsqueeze(unsqueeze_dim)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
"""
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
if n_rep == 1:
return hidden_states
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
class SuryaTableRecDecoderSdpaCrossAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper
Modified for GQA
"""
def __init__(self, config: SuryaTableRecDecoderConfig):
super().__init__()
self.config = config
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.head_dim = config.head_dim
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(self.config.encoder_hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(self.config.encoder_hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=True)
self.rotary_emb = SuryaTableRecDecoderRotaryEmbedding(
self.head_dim,
base=config.rope_theta,
)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# Encoder attention mask currently ignored
bsz, q_len, _ = hidden_states.size()
_, v_len, _ = encoder_hidden_states.size()
query_states = self.q_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_attention_heads, self.head_dim).transpose(1, 2)
if self.key_states is None:
key_states = self.k_proj(encoder_hidden_states)
value_states = self.v_proj(encoder_hidden_states)
key_states = key_states.view(bsz, v_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, v_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
if use_cache:
self._update_cache(key_states, value_states)
else:
key_states = self.key_states
value_states = self.value_states
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states.contiguous(),
key_states.contiguous(),
value_states.contiguous(),
attn_mask=None,
dropout_p=self.attention_dropout if self.training else 0.0,
scale=self.head_dim**-0.5,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output
def _setup_cache(self, batch_size, device, dtype=None):
# Setup initial caches
self.value_states = None
self.key_states = None
@torch.no_grad()
def _update_cache(self, key_states, value_states, **cache_kwargs):
self.value_states = value_states
self.key_states = key_states
class SuryaTableRecDecoderSdpaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: SuryaTableRecDecoderConfig):
super().__init__()
self.config = config
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_attention_heads = config.num_attention_heads
self.head_dim = config.head_dim
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads
self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=True)
self.rotary_emb = SuryaTableRecDecoderRotaryEmbedding(
self.head_dim,
base=config.rope_theta,
)
def forward(
self,
hidden_states: torch.Tensor,
position_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
cache_position: Optional[torch.LongTensor] = None,
use_cache: bool = False,
window_attn: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
# Final is bsz, num_attention_heads, seq_len, head_dim
query_states = query_states.view(bsz, q_len, self.num_attention_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
if use_cache and hasattr(self, "key_states"):
cache_kwargs = {"cache_position": cache_position, "window_attn": window_attn}
key_states, value_states = self._update_cache(key_states, value_states, **cache_kwargs)
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
causal_mask = attention_mask
if attention_mask is not None:
# Mask is batch, head, seq_len, kv_len
causal_mask = causal_mask[:, :, :, :key_states.shape[-2]]
current_cache_position = cache_position[-1].item() if cache_position is not None else None
if current_cache_position and settings.RECOGNITION_STATIC_CACHE:
# Mask out future cache positions
position_mask = torch.ones_like(causal_mask, dtype=torch.bool, device=causal_mask.device)
position_mask[:, :, :, :current_cache_position + 1] = False
causal_mask = torch.where(position_mask, torch.finfo(causal_mask.dtype).min, causal_mask)
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states.contiguous(),
key_states.contiguous(),
value_states.contiguous(),
attn_mask=causal_mask,
dropout_p=self.attention_dropout if self.training else 0.0,
scale=self.head_dim**-0.5,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output
def _setup_cache(self, batch_size, device, dtype=None):
if dtype is None and self.config.torch_dtype is not None:
dtype = self.config.torch_dtype
dtype = dtype if dtype is not None else torch.float32
# Setup initial caches
self.value_states = None
self.key_states = None
if settings.RECOGNITION_STATIC_CACHE:
cache_shape = (batch_size, self.num_key_value_heads, settings.RECOGNITION_MAX_TOKENS, self.head_dim)
self.value_states = torch.zeros(cache_shape, dtype=dtype, device=device)
self.key_states = torch.zeros(cache_shape, dtype=dtype, device=device)
def _update_static_cache(self, key_states, value_states, **cache_kwargs):
cache_position = cache_kwargs.get("cache_position")
k_out, v_out = self.key_states.to(key_states.device), self.value_states.to(value_states.device)
k_out[:, :, cache_position] = key_states.to(k_out.dtype)
v_out[:, :, cache_position] = value_states.to(v_out.dtype)
self.key_states, self.value_states = k_out, v_out
return k_out, v_out
def _update_dynamic_cache(self, key_states, value_states, **cache_kwargs):
k_out = key_states
if self.key_states is not None:
k_out = torch.cat([self.key_states, key_states], dim=2)
v_out = value_states
if self.value_states is not None:
v_out = torch.cat([self.value_states, value_states], dim=2)
self.key_states, self.value_states = k_out, v_out
return k_out, v_out
@torch.no_grad()
def _update_cache(self, key_states, value_states, **cache_kwargs):
if settings.RECOGNITION_STATIC_CACHE:
return self._update_static_cache(key_states, value_states, **cache_kwargs)
return self._update_dynamic_cache(key_states, value_states, **cache_kwargs)
class SuryaTableRecDecoderMlp(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
if config.hidden_activation is None:
config.hidden_activation = "gelu_pytorch_tanh"
hidden_activation = config.hidden_activation
self.act_fn = ACT2FN[hidden_activation]
def forward(self, x):
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
class SuryaTableRecDecoderLayer(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
super().__init__()
self.cross_pre_norm = SuryaTableRecDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.temporal_pre_norm = SuryaTableRecDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.temporal_block = None
if layer_idx in config.self_attn_layers:
self.temporal_block = SuryaTableRecDecoderSdpaAttention(config)
self.cross_attn_block = None
if layer_idx in config.cross_attn_layers:
self.cross_attn_block = SuryaTableRecDecoderSdpaCrossAttention(config)
self.window_attn = layer_idx not in config.global_attn_layers
self.channel_pre_norm = SuryaTableRecDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.mlp_block = SuryaTableRecDecoderMlp(config)
def forward(
self,
activations: torch.Tensor,
position_ids: torch.Tensor,
attention_mask: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
encoder_attention_mask: torch.Tensor = None,
cache_position: torch.Tensor = None,
use_cache: bool = None,
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
raw_activations = activations
if self.cross_attn_block is not None:
# Do cross-attention on encoder outputs
cross_attn_inputs = self.cross_pre_norm(activations)
cross_attn_path = self.cross_attn_block(
cross_attn_inputs, encoder_hidden_states, attention_mask, encoder_attention_mask, use_cache=use_cache
)
cross_attn_output = cross_attn_path + raw_activations
else:
cross_attn_output = raw_activations
if self.temporal_block is not None:
inputs_normalized = self.temporal_pre_norm(cross_attn_output) # RMSNorm introduces slight slight differences
hidden_states = self.temporal_block(
inputs_normalized, position_ids, attention_mask, cache_position=cache_position, use_cache=use_cache, window_attn=self.window_attn
)
residual = hidden_states + raw_activations
else:
residual = cross_attn_output
hidden_states = self.channel_pre_norm(residual)
hidden_states = self.mlp_block(hidden_states)
hidden_states = hidden_states + residual
return hidden_states
class SuryaTableRecDecoderPreTrainedModel(PreTrainedModel):
config_class = SuryaTableRecDecoderConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["SuryaTableRecDecoderLayer"]
_skip_keys_device_placement = ["cache"]
_supports_flash_attn_2 = False
_supports_sdpa = False # we can't compare with eager for now
_supports_cache_class = True
_supports_quantized_cache = True
def _init_weights(self, module):
if isinstance(module, SuryaTableRecDecoderSdpaAttention):
torch.nn.init.normal_(module.q_proj.weight, mean=0.0, std=self.config.init_std)
torch.nn.init.normal_(module.k_proj.weight, mean=0.0, std=self.config.init_std)
torch.nn.init.normal_(module.v_proj.weight, mean=0.0, std=self.config.init_std)
torch.nn.init.normal_(module.o_proj.weight, mean=0.0, std=self.config.init_std)
elif isinstance(module, nn.Linear):
torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std)
if getattr(module, "bias", None) is not None:
torch.nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=self.config.init_std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def _setup_cache(self, config, batch, device, dtype):
layers = getattr(self, "model", self).layers
for layer in layers:
if layer.temporal_block:
layer.temporal_block._setup_cache(batch, device, dtype)
if layer.cross_attn_block:
layer.cross_attn_block._setup_cache(batch, device, dtype)
def reset_cache(self, batch, device, dtype):
pass
def _tie_weights(self):
pass
def tie_weights(self):
pass
class LabelEmbedding(nn.Module):
def __init__(self, config):
super().__init__()
self.vocab_size = config.vocab_size
self.x1_embed = nn.Embedding(config.max_width, config.hidden_size)
self.y1_embed = nn.Embedding(config.max_height, config.hidden_size)
self.x2_embed = nn.Embedding(config.max_width, config.hidden_size)
self.y2_embed = nn.Embedding(config.max_height, config.hidden_size)
self.w_embed = nn.Embedding(config.max_width, config.hidden_size)
self.h_embed = nn.Embedding(config.max_height, config.hidden_size)
self.cx_embed = nn.Embedding(config.max_width, config.hidden_size)
self.cy_embed = nn.Embedding(config.max_height, config.hidden_size)
self.class_embed = nn.Embedding(config.max_classes, config.hidden_size)
self.max_width = config.max_width
self.max_height = config.max_height
self.max_classes = config.max_classes
def forward(self, labels: torch.LongTensor, input_box_counts: torch.LongTensor):
cx, cy, w, h, class_ = labels.to(torch.long).unbind(dim=-1)
# Shape is (batch_size, num_boxes/seq len, d_model)
x1 = (cx - w // 2).long()
y1 = (cy - h // 2).long()
x2 = (cx + w // 2).long()
y2 = (cy + h // 2).long()
x1 = torch.clamp(x1, 0, self.max_width - 1)
y1 = torch.clamp(y1, 0, self.max_height - 1)
x2 = torch.clamp(x2, 0, self.max_width - 1)
y2 = torch.clamp(y2, 0, self.max_height - 1)
class_ = torch.clamp(class_, 0, self.max_classes - 1).long()
w = torch.clamp(w, 0, self.max_width - 1).long()
h = torch.clamp(h, 0, self.max_height - 1).long()
cx = torch.clamp(cx, 0, self.max_width - 1).long()
cy = torch.clamp(cy, 0, self.max_height - 1).long()
coord_embeds = self.x1_embed(x1) + self.y1_embed(y1) + self.x2_embed(x2) + self.y2_embed(y2)
class_embeds = self.class_embed(class_)
embedded = coord_embeds + self.w_embed(w) + self.h_embed(h) + self.cx_embed(cx) + self.cy_embed(cy) + class_embeds
return embedded
class BboxEmbedding(nn.Module):
def __init__(self, config, embed_positions=False):
super().__init__()
self.x1_embed = nn.Embedding(config.max_width, config.hidden_size)
self.y1_embed = nn.Embedding(config.max_height, config.hidden_size)
self.x2_embed = nn.Embedding(config.max_width, config.hidden_size)
self.y2_embed = nn.Embedding(config.max_height, config.hidden_size)
self.w_embed = nn.Embedding(config.max_width, config.hidden_size)
self.h_embed = nn.Embedding(config.max_height, config.hidden_size)
self.cx_embed = nn.Embedding(config.max_width, config.hidden_size)
self.cy_embed = nn.Embedding(config.max_height, config.hidden_size)
self.box_pos_embed = nn.Embedding(config.max_position_embeddings, config.hidden_size)
self.max_width = config.max_width
self.max_height = config.max_height
self.embed_positions = embed_positions
def forward(self, boxes: torch.LongTensor, input_box_counts: torch.LongTensor):
x1, y1, x2, y2 = boxes.unbind(dim=-1)
x1 = torch.clamp(x1, 0, self.max_width - 1).long()
y1 = torch.clamp(y1, 0, self.max_height - 1).long()
x2 = torch.clamp(x2, 0, self.max_width - 1).long()
y2 = torch.clamp(y2, 0, self.max_height - 1).long()
# Shape is (batch_size, num_boxes/seq len, d_model)
w = x2 - x1
h = y2 - y1
# Center x and y in torch long tensors
cx = (x1 + x2) / 2
cy = (y1 + y2) / 2
cx = cx.long()
cy = cy.long()
w = torch.clamp(w, 0, self.max_width - 1).long()
h = torch.clamp(h, 0, self.max_height - 1).long()
cx = torch.clamp(cx, 0, self.max_width - 1).long()
cy = torch.clamp(cy, 0, self.max_height - 1).long()
coord_embeds = self.x1_embed(x1) + self.y1_embed(y1) + self.x2_embed(x2) + self.y2_embed(y2)
embedded = coord_embeds + self.w_embed(w) + self.h_embed(h) + self.cx_embed(cx) + self.cy_embed(cy)
# Add in positional embeddings for the boxes and labels
if self.embed_positions:
for j in range(embedded.shape[0]):
box_start = input_box_counts[j, 0]
box_end = input_box_counts[j, 1] - 1 # Skip the sep token
box_count = box_end - box_start
embedded[j, box_start:box_end] = embedded[j, box_start:box_end] + self.box_pos_embed.weight[:box_count]
return embedded
class SuryaTableRecDecoderModel(SuryaTableRecDecoderPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`SuryaTableRecDecoderDecoderLayer`]
Args:
config: SuryaTableRecDecoderConfig
"""
def __init__(self, config: SuryaTableRecDecoderConfig, embed_labels=False, embed_positions=True):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.causal = config.causal
if embed_labels:
self.embed_tokens = LabelEmbedding(config)
else:
self.embed_tokens = BboxEmbedding(config, embed_positions=embed_positions)
self.layers = nn.ModuleList(
[SuryaTableRecDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
self.final_norm = SuryaTableRecDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
self.register_buffer(
"normalizer", torch.tensor(self.config.hidden_size**0.5, dtype=torch.float32), persistent=False
)
# Initialize weights and apply final processing
self.post_init()
# Copied from transformers.models.llama.modeling_llama.LlamaModel.get_input_embeddings
def get_input_embeddings(self):
return self.embed_tokens
# Copied from transformers.models.llama.modeling_llama.LlamaModel.set_input_embeddings
def set_input_embeddings(self, value):
self.embed_tokens = value
def forward(
self,
input_ids: torch.LongTensor = None,
input_boxes_counts: torch.LongTensor = None,
position_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
prefill: bool = False
) -> Union[Tuple, BaseModelOutputWithNoAttention]:
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if self.gradient_checkpointing and self.training and use_cache:
use_cache = False
inputs_embeds = self.embed_tokens(input_ids, input_boxes_counts)
hidden_states = inputs_embeds
if use_cache and prefill:
self._setup_cache(self.config, hidden_states.shape[0], hidden_states.device, hidden_states.dtype)
if cache_position is None:
cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device)
if position_ids is None:
position_ids = cache_position.unsqueeze(0)
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)
all_hidden_states = () if output_hidden_states else None
for i, residual_block in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)
if self.gradient_checkpointing and self.training:
hidden_states = self._gradient_checkpointing_func(
residual_block.__call__, hidden_states, position_ids, causal_mask, encoder_hidden_states, encoder_attention_mask, cache_position, use_cache
)
else:
hidden_states = residual_block(hidden_states, position_ids, causal_mask, encoder_hidden_states, encoder_attention_mask, cache_position, use_cache)
hidden_states = self.final_norm(hidden_states)
# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)
if not return_dict:
return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
return BaseModelOutputWithNoAttention(
last_hidden_state=hidden_states,
hidden_states=all_hidden_states,
)
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
# Ignore copy
def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
if not self.causal:
return None
dtype, device = input_tensor.dtype, input_tensor.device
min_dtype = torch.finfo(dtype).min
sequence_length = input_tensor.shape[1]
target_length = max(settings.TABLE_REC_MAX_BOXES, sequence_length)
diagonal = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
causal_mask = diagonal
if sequence_length != 1:
# Select the upper triangular part of the matrix, but unmask current token (the diagonal)
# triu will be the min_dtype, everything else is 0 (attended to)
causal_mask = torch.triu(diagonal, diagonal=1)
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
if attention_mask is not None:
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
if attention_mask.dim() == 2:
# Mask positions in the causal mask that are masked in the attention mask
mask_length = attention_mask.shape[-1]
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)
if attention_mask is not None and attention_mask.device.type == "cuda":
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
# Details: https://github.com/pytorch/pytorch/issues/110213
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
return causal_mask
class SuryaTableRecDecoder(SuryaTableRecDecoderPreTrainedModel):
_tied_weights_keys = None
def __init__(self, config, **kwargs):
super().__init__(config)
self.model = SuryaTableRecDecoderModel(config, embed_labels=True, embed_positions=False)
self.vocab_size = config.vocab_size
self.bbox_head = nn.Linear(config.hidden_size, config.max_width * 4, bias=False)
self.class_head = nn.Linear(config.hidden_size, config.max_classes, bias=False)
self.max_width = config.max_width
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
# Ignore copy
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
prefill: bool = False,
**kwargs
) -> Union[Tuple, TableRecModelOutput]:
outputs = self.model(
input_ids=input_ids,
cache_position=cache_position,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_hidden_states=True,
return_dict=True,
prefill=prefill,
)
hidden_states = outputs[0]
bbox_logits = self.bbox_head(hidden_states)
class_logits = self.class_head(hidden_states)
bsz, seq_len = class_logits.shape[:2]
bbox_logits = bbox_logits.view(bsz, seq_len, 4, self.max_width)
return TableRecModelOutput(
bbox_logits=bbox_logits,
class_logits=class_logits,
hidden_states=hidden_states,
)
@dataclass
class TextEncoderOutput(CausalLMOutput):
hidden_states: torch.FloatTensor = None
class SuryaTableRecTextEncoder(SuryaTableRecDecoderPreTrainedModel):
_tied_weights_keys = None
config_class = SuryaTableRecTextEncoderConfig
def __init__(self, config, **kwargs):
super().__init__(config)
self.model = SuryaTableRecDecoderModel(config, embed_labels=False, embed_positions=True)
self.vocab_size = config.vocab_size
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
# Ignore copy
def forward(
self,
input_boxes: Optional[torch.LongTensor] = None,
input_boxes_counts: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
**kwargs
) -> Union[Tuple, CausalLMOutput]:
outputs = self.model(
input_ids=input_boxes,
input_boxes_counts=input_boxes_counts,
cache_position=cache_position,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache,
output_hidden_states=True,
return_dict=True,
)
return TextEncoderOutput(
hidden_states=outputs.last_hidden_state,
)