Spaces:
Running
Running
from dataclasses import dataclass | |
from typing import Dict, Optional, Tuple, Union | |
import torch | |
import torch.utils.checkpoint | |
from torch import nn | |
from transformers.utils import ModelOutput | |
from surya.model.table_rec.config import SuryaTableRecDecoderConfig, SuryaTableRecTextEncoderConfig | |
from transformers import PreTrainedModel | |
from transformers.activations import ACT2FN | |
from transformers.modeling_attn_mask_utils import AttentionMaskConverter | |
from transformers.modeling_outputs import BaseModelOutputWithNoAttention, CausalLMOutput | |
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS | |
from surya.settings import settings | |
_MAX_SQRT_GRADIENT = 1000.0 | |
class TableRecModelOutput(ModelOutput): | |
bbox_logits: torch.Tensor | |
class_logits: torch.Tensor | None = None | |
hidden_states: torch.Tensor | None = None | |
class SuryaTableRecDecoderRMSNorm(nn.Module): | |
def __init__(self, dim: int, eps: float = 1e-6): | |
super().__init__() | |
self.eps = eps | |
self.weight = nn.Parameter(torch.zeros(dim)) | |
def _norm(self, x): | |
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) | |
def forward(self, x): | |
output = self._norm(x.float()) | |
# Llama does x.to(float16) * w whilst SuryaTableRecDecoder is (x * w).to(float16) | |
# See https://github.com/huggingface/transformers/pull/29402 | |
output = output * (1.0 + self.weight.float()) | |
return output.type_as(x) | |
def extra_repr(self): | |
return f"{tuple(self.weight.shape)}, eps={self.eps}" | |
ALL_LAYERNORM_LAYERS.append(SuryaTableRecDecoderRMSNorm) | |
class SuryaTableRecDecoderRotaryEmbedding(nn.Module): | |
def __init__(self, dim, base=10000, device=None): | |
super().__init__() | |
self.dim = dim | |
self.base = base | |
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim)) | |
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False) | |
# Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding.forward with Gemma->SuryaTableRecDecoder | |
def forward(self, x, position_ids, seq_len=None): | |
# x: [bs, num_attention_heads, seq_len, head_size] | |
self.inv_freq.to(x.device) | |
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) | |
position_ids_expanded = position_ids[:, None, :].float() | |
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) | |
emb = torch.cat((freqs, freqs), dim=-1) | |
cos = emb.cos() | |
sin = emb.sin() | |
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) | |
# Copied from transformers.models.llama.modeling_llama.rotate_half | |
def rotate_half(x): | |
"""Rotates half the hidden dims of the input.""" | |
x1 = x[..., : x.shape[-1] // 2] | |
x2 = x[..., x.shape[-1] // 2 :] | |
return torch.cat((-x2, x1), dim=-1) | |
# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb | |
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1): | |
"""Applies Rotary Position Embedding to the query and key tensors. | |
Args: | |
q (`torch.Tensor`): The query tensor. | |
k (`torch.Tensor`): The key tensor. | |
cos (`torch.Tensor`): The cosine part of the rotary embedding. | |
sin (`torch.Tensor`): The sine part of the rotary embedding. | |
unsqueeze_dim (`int`, *optional*, defaults to 1): | |
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and | |
sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note | |
that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and | |
k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes | |
cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have | |
the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. | |
Returns: | |
`tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. | |
""" | |
cos = cos.unsqueeze(unsqueeze_dim) | |
sin = sin.unsqueeze(unsqueeze_dim) | |
q_embed = (q * cos) + (rotate_half(q) * sin) | |
k_embed = (k * cos) + (rotate_half(k) * sin) | |
return q_embed, k_embed | |
# Copied from transformers.models.llama.modeling_llama.repeat_kv | |
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: | |
""" | |
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, | |
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) | |
""" | |
batch, num_key_value_heads, slen, head_dim = hidden_states.shape | |
if n_rep == 1: | |
return hidden_states | |
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) | |
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) | |
class SuryaTableRecDecoderSdpaCrossAttention(nn.Module): | |
"""Multi-headed attention from 'Attention Is All You Need' paper | |
Modified for GQA | |
""" | |
def __init__(self, config: SuryaTableRecDecoderConfig): | |
super().__init__() | |
self.config = config | |
self.attention_dropout = config.attention_dropout | |
self.hidden_size = config.hidden_size | |
self.num_attention_heads = config.num_attention_heads | |
self.head_dim = config.head_dim | |
self.num_key_value_heads = config.num_key_value_heads | |
self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads | |
self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=config.attention_bias) | |
self.k_proj = nn.Linear(self.config.encoder_hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) | |
self.v_proj = nn.Linear(self.config.encoder_hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) | |
self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=True) | |
self.rotary_emb = SuryaTableRecDecoderRotaryEmbedding( | |
self.head_dim, | |
base=config.rope_theta, | |
) | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
encoder_hidden_states: torch.Tensor, | |
attention_mask: Optional[torch.Tensor] = None, | |
encoder_attention_mask: Optional[torch.Tensor] = None, | |
use_cache: bool = False, | |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | |
# Encoder attention mask currently ignored | |
bsz, q_len, _ = hidden_states.size() | |
_, v_len, _ = encoder_hidden_states.size() | |
query_states = self.q_proj(hidden_states) | |
query_states = query_states.view(bsz, q_len, self.num_attention_heads, self.head_dim).transpose(1, 2) | |
if self.key_states is None: | |
key_states = self.k_proj(encoder_hidden_states) | |
value_states = self.v_proj(encoder_hidden_states) | |
key_states = key_states.view(bsz, v_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | |
value_states = value_states.view(bsz, v_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | |
if use_cache: | |
self._update_cache(key_states, value_states) | |
else: | |
key_states = self.key_states | |
value_states = self.value_states | |
key_states = repeat_kv(key_states, self.num_key_value_groups) | |
value_states = repeat_kv(value_states, self.num_key_value_groups) | |
attn_output = torch.nn.functional.scaled_dot_product_attention( | |
query_states.contiguous(), | |
key_states.contiguous(), | |
value_states.contiguous(), | |
attn_mask=None, | |
dropout_p=self.attention_dropout if self.training else 0.0, | |
scale=self.head_dim**-0.5, | |
) | |
attn_output = attn_output.transpose(1, 2).contiguous() | |
attn_output = attn_output.view(bsz, q_len, self.hidden_size) | |
attn_output = self.o_proj(attn_output) | |
return attn_output | |
def _setup_cache(self, batch_size, device, dtype=None): | |
# Setup initial caches | |
self.value_states = None | |
self.key_states = None | |
def _update_cache(self, key_states, value_states, **cache_kwargs): | |
self.value_states = value_states | |
self.key_states = key_states | |
class SuryaTableRecDecoderSdpaAttention(nn.Module): | |
"""Multi-headed attention from 'Attention Is All You Need' paper""" | |
def __init__(self, config: SuryaTableRecDecoderConfig): | |
super().__init__() | |
self.config = config | |
self.attention_dropout = config.attention_dropout | |
self.hidden_size = config.hidden_size | |
self.num_attention_heads = config.num_attention_heads | |
self.head_dim = config.head_dim | |
self.num_key_value_heads = config.num_key_value_heads | |
self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads | |
self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=config.attention_bias) | |
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) | |
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias) | |
self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=True) | |
self.rotary_emb = SuryaTableRecDecoderRotaryEmbedding( | |
self.head_dim, | |
base=config.rope_theta, | |
) | |
def forward( | |
self, | |
hidden_states: torch.Tensor, | |
position_ids: Optional[torch.LongTensor] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
cache_position: Optional[torch.LongTensor] = None, | |
use_cache: bool = False, | |
window_attn: bool = False, | |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | |
bsz, q_len, _ = hidden_states.size() | |
query_states = self.q_proj(hidden_states) | |
key_states = self.k_proj(hidden_states) | |
value_states = self.v_proj(hidden_states) | |
# Final is bsz, num_attention_heads, seq_len, head_dim | |
query_states = query_states.view(bsz, q_len, self.num_attention_heads, self.head_dim).transpose(1, 2) | |
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | |
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | |
cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None) | |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) | |
if use_cache and hasattr(self, "key_states"): | |
cache_kwargs = {"cache_position": cache_position, "window_attn": window_attn} | |
key_states, value_states = self._update_cache(key_states, value_states, **cache_kwargs) | |
key_states = repeat_kv(key_states, self.num_key_value_groups) | |
value_states = repeat_kv(value_states, self.num_key_value_groups) | |
causal_mask = attention_mask | |
if attention_mask is not None: | |
# Mask is batch, head, seq_len, kv_len | |
causal_mask = causal_mask[:, :, :, :key_states.shape[-2]] | |
current_cache_position = cache_position[-1].item() if cache_position is not None else None | |
if current_cache_position and settings.RECOGNITION_STATIC_CACHE: | |
# Mask out future cache positions | |
position_mask = torch.ones_like(causal_mask, dtype=torch.bool, device=causal_mask.device) | |
position_mask[:, :, :, :current_cache_position + 1] = False | |
causal_mask = torch.where(position_mask, torch.finfo(causal_mask.dtype).min, causal_mask) | |
attn_output = torch.nn.functional.scaled_dot_product_attention( | |
query_states.contiguous(), | |
key_states.contiguous(), | |
value_states.contiguous(), | |
attn_mask=causal_mask, | |
dropout_p=self.attention_dropout if self.training else 0.0, | |
scale=self.head_dim**-0.5, | |
) | |
attn_output = attn_output.transpose(1, 2).contiguous() | |
attn_output = attn_output.view(bsz, q_len, self.hidden_size) | |
attn_output = self.o_proj(attn_output) | |
return attn_output | |
def _setup_cache(self, batch_size, device, dtype=None): | |
if dtype is None and self.config.torch_dtype is not None: | |
dtype = self.config.torch_dtype | |
dtype = dtype if dtype is not None else torch.float32 | |
# Setup initial caches | |
self.value_states = None | |
self.key_states = None | |
if settings.RECOGNITION_STATIC_CACHE: | |
cache_shape = (batch_size, self.num_key_value_heads, settings.RECOGNITION_MAX_TOKENS, self.head_dim) | |
self.value_states = torch.zeros(cache_shape, dtype=dtype, device=device) | |
self.key_states = torch.zeros(cache_shape, dtype=dtype, device=device) | |
def _update_static_cache(self, key_states, value_states, **cache_kwargs): | |
cache_position = cache_kwargs.get("cache_position") | |
k_out, v_out = self.key_states.to(key_states.device), self.value_states.to(value_states.device) | |
k_out[:, :, cache_position] = key_states.to(k_out.dtype) | |
v_out[:, :, cache_position] = value_states.to(v_out.dtype) | |
self.key_states, self.value_states = k_out, v_out | |
return k_out, v_out | |
def _update_dynamic_cache(self, key_states, value_states, **cache_kwargs): | |
k_out = key_states | |
if self.key_states is not None: | |
k_out = torch.cat([self.key_states, key_states], dim=2) | |
v_out = value_states | |
if self.value_states is not None: | |
v_out = torch.cat([self.value_states, value_states], dim=2) | |
self.key_states, self.value_states = k_out, v_out | |
return k_out, v_out | |
def _update_cache(self, key_states, value_states, **cache_kwargs): | |
if settings.RECOGNITION_STATIC_CACHE: | |
return self._update_static_cache(key_states, value_states, **cache_kwargs) | |
return self._update_dynamic_cache(key_states, value_states, **cache_kwargs) | |
class SuryaTableRecDecoderMlp(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.config = config | |
self.hidden_size = config.hidden_size | |
self.intermediate_size = config.intermediate_size | |
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) | |
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) | |
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) | |
if config.hidden_activation is None: | |
config.hidden_activation = "gelu_pytorch_tanh" | |
hidden_activation = config.hidden_activation | |
self.act_fn = ACT2FN[hidden_activation] | |
def forward(self, x): | |
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) | |
class SuryaTableRecDecoderLayer(nn.Module): | |
def __init__(self, config, layer_idx): | |
super().__init__() | |
super().__init__() | |
self.cross_pre_norm = SuryaTableRecDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
self.temporal_pre_norm = SuryaTableRecDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
self.temporal_block = None | |
if layer_idx in config.self_attn_layers: | |
self.temporal_block = SuryaTableRecDecoderSdpaAttention(config) | |
self.cross_attn_block = None | |
if layer_idx in config.cross_attn_layers: | |
self.cross_attn_block = SuryaTableRecDecoderSdpaCrossAttention(config) | |
self.window_attn = layer_idx not in config.global_attn_layers | |
self.channel_pre_norm = SuryaTableRecDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
self.mlp_block = SuryaTableRecDecoderMlp(config) | |
def forward( | |
self, | |
activations: torch.Tensor, | |
position_ids: torch.Tensor, | |
attention_mask: torch.Tensor, | |
encoder_hidden_states: torch.Tensor = None, | |
encoder_attention_mask: torch.Tensor = None, | |
cache_position: torch.Tensor = None, | |
use_cache: bool = None, | |
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: | |
raw_activations = activations | |
if self.cross_attn_block is not None: | |
# Do cross-attention on encoder outputs | |
cross_attn_inputs = self.cross_pre_norm(activations) | |
cross_attn_path = self.cross_attn_block( | |
cross_attn_inputs, encoder_hidden_states, attention_mask, encoder_attention_mask, use_cache=use_cache | |
) | |
cross_attn_output = cross_attn_path + raw_activations | |
else: | |
cross_attn_output = raw_activations | |
if self.temporal_block is not None: | |
inputs_normalized = self.temporal_pre_norm(cross_attn_output) # RMSNorm introduces slight slight differences | |
hidden_states = self.temporal_block( | |
inputs_normalized, position_ids, attention_mask, cache_position=cache_position, use_cache=use_cache, window_attn=self.window_attn | |
) | |
residual = hidden_states + raw_activations | |
else: | |
residual = cross_attn_output | |
hidden_states = self.channel_pre_norm(residual) | |
hidden_states = self.mlp_block(hidden_states) | |
hidden_states = hidden_states + residual | |
return hidden_states | |
class SuryaTableRecDecoderPreTrainedModel(PreTrainedModel): | |
config_class = SuryaTableRecDecoderConfig | |
base_model_prefix = "model" | |
supports_gradient_checkpointing = True | |
_no_split_modules = ["SuryaTableRecDecoderLayer"] | |
_skip_keys_device_placement = ["cache"] | |
_supports_flash_attn_2 = False | |
_supports_sdpa = False # we can't compare with eager for now | |
_supports_cache_class = True | |
_supports_quantized_cache = True | |
def _init_weights(self, module): | |
if isinstance(module, SuryaTableRecDecoderSdpaAttention): | |
torch.nn.init.normal_(module.q_proj.weight, mean=0.0, std=self.config.init_std) | |
torch.nn.init.normal_(module.k_proj.weight, mean=0.0, std=self.config.init_std) | |
torch.nn.init.normal_(module.v_proj.weight, mean=0.0, std=self.config.init_std) | |
torch.nn.init.normal_(module.o_proj.weight, mean=0.0, std=self.config.init_std) | |
elif isinstance(module, nn.Linear): | |
torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std) | |
if getattr(module, "bias", None) is not None: | |
torch.nn.init.zeros_(module.bias) | |
elif isinstance(module, nn.Embedding): | |
module.weight.data.normal_(mean=0.0, std=self.config.init_std) | |
if module.padding_idx is not None: | |
module.weight.data[module.padding_idx].zero_() | |
def _setup_cache(self, config, batch, device, dtype): | |
layers = getattr(self, "model", self).layers | |
for layer in layers: | |
if layer.temporal_block: | |
layer.temporal_block._setup_cache(batch, device, dtype) | |
if layer.cross_attn_block: | |
layer.cross_attn_block._setup_cache(batch, device, dtype) | |
def reset_cache(self, batch, device, dtype): | |
pass | |
def _tie_weights(self): | |
pass | |
def tie_weights(self): | |
pass | |
class LabelEmbedding(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.vocab_size = config.vocab_size | |
self.x1_embed = nn.Embedding(config.max_width, config.hidden_size) | |
self.y1_embed = nn.Embedding(config.max_height, config.hidden_size) | |
self.x2_embed = nn.Embedding(config.max_width, config.hidden_size) | |
self.y2_embed = nn.Embedding(config.max_height, config.hidden_size) | |
self.w_embed = nn.Embedding(config.max_width, config.hidden_size) | |
self.h_embed = nn.Embedding(config.max_height, config.hidden_size) | |
self.cx_embed = nn.Embedding(config.max_width, config.hidden_size) | |
self.cy_embed = nn.Embedding(config.max_height, config.hidden_size) | |
self.class_embed = nn.Embedding(config.max_classes, config.hidden_size) | |
self.max_width = config.max_width | |
self.max_height = config.max_height | |
self.max_classes = config.max_classes | |
def forward(self, labels: torch.LongTensor, input_box_counts: torch.LongTensor): | |
cx, cy, w, h, class_ = labels.to(torch.long).unbind(dim=-1) | |
# Shape is (batch_size, num_boxes/seq len, d_model) | |
x1 = (cx - w // 2).long() | |
y1 = (cy - h // 2).long() | |
x2 = (cx + w // 2).long() | |
y2 = (cy + h // 2).long() | |
x1 = torch.clamp(x1, 0, self.max_width - 1) | |
y1 = torch.clamp(y1, 0, self.max_height - 1) | |
x2 = torch.clamp(x2, 0, self.max_width - 1) | |
y2 = torch.clamp(y2, 0, self.max_height - 1) | |
class_ = torch.clamp(class_, 0, self.max_classes - 1).long() | |
w = torch.clamp(w, 0, self.max_width - 1).long() | |
h = torch.clamp(h, 0, self.max_height - 1).long() | |
cx = torch.clamp(cx, 0, self.max_width - 1).long() | |
cy = torch.clamp(cy, 0, self.max_height - 1).long() | |
coord_embeds = self.x1_embed(x1) + self.y1_embed(y1) + self.x2_embed(x2) + self.y2_embed(y2) | |
class_embeds = self.class_embed(class_) | |
embedded = coord_embeds + self.w_embed(w) + self.h_embed(h) + self.cx_embed(cx) + self.cy_embed(cy) + class_embeds | |
return embedded | |
class BboxEmbedding(nn.Module): | |
def __init__(self, config, embed_positions=False): | |
super().__init__() | |
self.x1_embed = nn.Embedding(config.max_width, config.hidden_size) | |
self.y1_embed = nn.Embedding(config.max_height, config.hidden_size) | |
self.x2_embed = nn.Embedding(config.max_width, config.hidden_size) | |
self.y2_embed = nn.Embedding(config.max_height, config.hidden_size) | |
self.w_embed = nn.Embedding(config.max_width, config.hidden_size) | |
self.h_embed = nn.Embedding(config.max_height, config.hidden_size) | |
self.cx_embed = nn.Embedding(config.max_width, config.hidden_size) | |
self.cy_embed = nn.Embedding(config.max_height, config.hidden_size) | |
self.box_pos_embed = nn.Embedding(config.max_position_embeddings, config.hidden_size) | |
self.max_width = config.max_width | |
self.max_height = config.max_height | |
self.embed_positions = embed_positions | |
def forward(self, boxes: torch.LongTensor, input_box_counts: torch.LongTensor): | |
x1, y1, x2, y2 = boxes.unbind(dim=-1) | |
x1 = torch.clamp(x1, 0, self.max_width - 1).long() | |
y1 = torch.clamp(y1, 0, self.max_height - 1).long() | |
x2 = torch.clamp(x2, 0, self.max_width - 1).long() | |
y2 = torch.clamp(y2, 0, self.max_height - 1).long() | |
# Shape is (batch_size, num_boxes/seq len, d_model) | |
w = x2 - x1 | |
h = y2 - y1 | |
# Center x and y in torch long tensors | |
cx = (x1 + x2) / 2 | |
cy = (y1 + y2) / 2 | |
cx = cx.long() | |
cy = cy.long() | |
w = torch.clamp(w, 0, self.max_width - 1).long() | |
h = torch.clamp(h, 0, self.max_height - 1).long() | |
cx = torch.clamp(cx, 0, self.max_width - 1).long() | |
cy = torch.clamp(cy, 0, self.max_height - 1).long() | |
coord_embeds = self.x1_embed(x1) + self.y1_embed(y1) + self.x2_embed(x2) + self.y2_embed(y2) | |
embedded = coord_embeds + self.w_embed(w) + self.h_embed(h) + self.cx_embed(cx) + self.cy_embed(cy) | |
# Add in positional embeddings for the boxes and labels | |
if self.embed_positions: | |
for j in range(embedded.shape[0]): | |
box_start = input_box_counts[j, 0] | |
box_end = input_box_counts[j, 1] - 1 # Skip the sep token | |
box_count = box_end - box_start | |
embedded[j, box_start:box_end] = embedded[j, box_start:box_end] + self.box_pos_embed.weight[:box_count] | |
return embedded | |
class SuryaTableRecDecoderModel(SuryaTableRecDecoderPreTrainedModel): | |
""" | |
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`SuryaTableRecDecoderDecoderLayer`] | |
Args: | |
config: SuryaTableRecDecoderConfig | |
""" | |
def __init__(self, config: SuryaTableRecDecoderConfig, embed_labels=False, embed_positions=True): | |
super().__init__(config) | |
self.padding_idx = config.pad_token_id | |
self.vocab_size = config.vocab_size | |
self.causal = config.causal | |
if embed_labels: | |
self.embed_tokens = LabelEmbedding(config) | |
else: | |
self.embed_tokens = BboxEmbedding(config, embed_positions=embed_positions) | |
self.layers = nn.ModuleList( | |
[SuryaTableRecDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] | |
) | |
self.final_norm = SuryaTableRecDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
self.gradient_checkpointing = False | |
self.register_buffer( | |
"normalizer", torch.tensor(self.config.hidden_size**0.5, dtype=torch.float32), persistent=False | |
) | |
# Initialize weights and apply final processing | |
self.post_init() | |
# Copied from transformers.models.llama.modeling_llama.LlamaModel.get_input_embeddings | |
def get_input_embeddings(self): | |
return self.embed_tokens | |
# Copied from transformers.models.llama.modeling_llama.LlamaModel.set_input_embeddings | |
def set_input_embeddings(self, value): | |
self.embed_tokens = value | |
def forward( | |
self, | |
input_ids: torch.LongTensor = None, | |
input_boxes_counts: torch.LongTensor = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
encoder_attention_mask: Optional[torch.FloatTensor] = None, | |
cache_position: Optional[torch.LongTensor] = None, | |
use_cache: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
prefill: bool = False | |
) -> Union[Tuple, BaseModelOutputWithNoAttention]: | |
use_cache = use_cache if use_cache is not None else self.config.use_cache | |
return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
if self.gradient_checkpointing and self.training and use_cache: | |
use_cache = False | |
inputs_embeds = self.embed_tokens(input_ids, input_boxes_counts) | |
hidden_states = inputs_embeds | |
if use_cache and prefill: | |
self._setup_cache(self.config, hidden_states.shape[0], hidden_states.device, hidden_states.dtype) | |
if cache_position is None: | |
cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device) | |
if position_ids is None: | |
position_ids = cache_position.unsqueeze(0) | |
causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) | |
all_hidden_states = () if output_hidden_states else None | |
for i, residual_block in enumerate(self.layers): | |
if output_hidden_states: | |
all_hidden_states += (hidden_states,) | |
if self.gradient_checkpointing and self.training: | |
hidden_states = self._gradient_checkpointing_func( | |
residual_block.__call__, hidden_states, position_ids, causal_mask, encoder_hidden_states, encoder_attention_mask, cache_position, use_cache | |
) | |
else: | |
hidden_states = residual_block(hidden_states, position_ids, causal_mask, encoder_hidden_states, encoder_attention_mask, cache_position, use_cache) | |
hidden_states = self.final_norm(hidden_states) | |
# add hidden states from the last decoder layer | |
if output_hidden_states: | |
all_hidden_states += (hidden_states,) | |
if not return_dict: | |
return tuple(v for v in [hidden_states, all_hidden_states] if v is not None) | |
return BaseModelOutputWithNoAttention( | |
last_hidden_state=hidden_states, | |
hidden_states=all_hidden_states, | |
) | |
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static | |
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. | |
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using | |
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 | |
# Ignore copy | |
def _update_causal_mask(self, attention_mask, input_tensor, cache_position): | |
if not self.causal: | |
return None | |
dtype, device = input_tensor.dtype, input_tensor.device | |
min_dtype = torch.finfo(dtype).min | |
sequence_length = input_tensor.shape[1] | |
target_length = max(settings.TABLE_REC_MAX_BOXES, sequence_length) | |
diagonal = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) | |
causal_mask = diagonal | |
if sequence_length != 1: | |
# Select the upper triangular part of the matrix, but unmask current token (the diagonal) | |
# triu will be the min_dtype, everything else is 0 (attended to) | |
causal_mask = torch.triu(diagonal, diagonal=1) | |
causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) | |
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) | |
if attention_mask is not None: | |
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit | |
if attention_mask.dim() == 2: | |
# Mask positions in the causal mask that are masked in the attention mask | |
mask_length = attention_mask.shape[-1] | |
padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) | |
causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) | |
if attention_mask is not None and attention_mask.device.type == "cuda": | |
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when | |
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. | |
# Details: https://github.com/pytorch/pytorch/issues/110213 | |
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) | |
return causal_mask | |
class SuryaTableRecDecoder(SuryaTableRecDecoderPreTrainedModel): | |
_tied_weights_keys = None | |
def __init__(self, config, **kwargs): | |
super().__init__(config) | |
self.model = SuryaTableRecDecoderModel(config, embed_labels=True, embed_positions=False) | |
self.vocab_size = config.vocab_size | |
self.bbox_head = nn.Linear(config.hidden_size, config.max_width * 4, bias=False) | |
self.class_head = nn.Linear(config.hidden_size, config.max_classes, bias=False) | |
self.max_width = config.max_width | |
# Initialize weights and apply final processing | |
self.post_init() | |
def get_input_embeddings(self): | |
return self.model.embed_tokens | |
def set_input_embeddings(self, value): | |
self.model.embed_tokens = value | |
def get_output_embeddings(self): | |
return self.lm_head | |
def set_output_embeddings(self, new_embeddings): | |
self.lm_head = new_embeddings | |
def set_decoder(self, decoder): | |
self.model = decoder | |
def get_decoder(self): | |
return self.model | |
# Ignore copy | |
def forward( | |
self, | |
input_ids: Optional[torch.LongTensor] = None, | |
cache_position: Optional[torch.LongTensor] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
encoder_attention_mask: Optional[torch.FloatTensor] = None, | |
use_cache: Optional[bool] = None, | |
prefill: bool = False, | |
**kwargs | |
) -> Union[Tuple, TableRecModelOutput]: | |
outputs = self.model( | |
input_ids=input_ids, | |
cache_position=cache_position, | |
attention_mask=attention_mask, | |
encoder_hidden_states=encoder_hidden_states, | |
encoder_attention_mask=encoder_attention_mask, | |
use_cache=use_cache, | |
output_hidden_states=True, | |
return_dict=True, | |
prefill=prefill, | |
) | |
hidden_states = outputs[0] | |
bbox_logits = self.bbox_head(hidden_states) | |
class_logits = self.class_head(hidden_states) | |
bsz, seq_len = class_logits.shape[:2] | |
bbox_logits = bbox_logits.view(bsz, seq_len, 4, self.max_width) | |
return TableRecModelOutput( | |
bbox_logits=bbox_logits, | |
class_logits=class_logits, | |
hidden_states=hidden_states, | |
) | |
class TextEncoderOutput(CausalLMOutput): | |
hidden_states: torch.FloatTensor = None | |
class SuryaTableRecTextEncoder(SuryaTableRecDecoderPreTrainedModel): | |
_tied_weights_keys = None | |
config_class = SuryaTableRecTextEncoderConfig | |
def __init__(self, config, **kwargs): | |
super().__init__(config) | |
self.model = SuryaTableRecDecoderModel(config, embed_labels=False, embed_positions=True) | |
self.vocab_size = config.vocab_size | |
# Initialize weights and apply final processing | |
self.post_init() | |
def get_input_embeddings(self): | |
return self.model.embed_tokens | |
def set_input_embeddings(self, value): | |
self.model.embed_tokens = value | |
def set_decoder(self, decoder): | |
self.model = decoder | |
def get_decoder(self): | |
return self.model | |
# Ignore copy | |
def forward( | |
self, | |
input_boxes: Optional[torch.LongTensor] = None, | |
input_boxes_counts: Optional[torch.LongTensor] = None, | |
cache_position: Optional[torch.LongTensor] = None, | |
attention_mask: Optional[torch.Tensor] = None, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
encoder_attention_mask: Optional[torch.FloatTensor] = None, | |
use_cache: Optional[bool] = None, | |
**kwargs | |
) -> Union[Tuple, CausalLMOutput]: | |
outputs = self.model( | |
input_ids=input_boxes, | |
input_boxes_counts=input_boxes_counts, | |
cache_position=cache_position, | |
attention_mask=attention_mask, | |
encoder_hidden_states=encoder_hidden_states, | |
encoder_attention_mask=encoder_attention_mask, | |
use_cache=use_cache, | |
output_hidden_states=True, | |
return_dict=True, | |
) | |
return TextEncoderOutput( | |
hidden_states=outputs.last_hidden_state, | |
) |