File size: 35,345 Bytes
52f1bcb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
from dataclasses import dataclass
from typing import Dict, Optional, Tuple, Union

import torch
import torch.utils.checkpoint
from torch import nn
from transformers.utils import ModelOutput

from surya.model.table_rec.config import SuryaTableRecDecoderConfig, SuryaTableRecTextEncoderConfig
from transformers import PreTrainedModel
from transformers.activations import ACT2FN
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
from transformers.modeling_outputs import BaseModelOutputWithNoAttention, CausalLMOutput
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS

from surya.settings import settings

_MAX_SQRT_GRADIENT = 1000.0

@dataclass
class TableRecModelOutput(ModelOutput):
    bbox_logits: torch.Tensor
    class_logits: torch.Tensor | None = None
    hidden_states: torch.Tensor | None = None


class SuryaTableRecDecoderRMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.zeros(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float())
        # Llama does x.to(float16) * w whilst SuryaTableRecDecoder is (x * w).to(float16)
        # See https://github.com/huggingface/transformers/pull/29402
        output = output * (1.0 + self.weight.float())
        return output.type_as(x)

    def extra_repr(self):
        return f"{tuple(self.weight.shape)}, eps={self.eps}"


ALL_LAYERNORM_LAYERS.append(SuryaTableRecDecoderRMSNorm)


class SuryaTableRecDecoderRotaryEmbedding(nn.Module):
    def __init__(self, dim, base=10000, device=None):
        super().__init__()
        self.dim = dim
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
        self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)

    @torch.no_grad()
    # Copied from transformers.models.gemma.modeling_gemma.GemmaRotaryEmbedding.forward with Gemma->SuryaTableRecDecoder
    def forward(self, x, position_ids, seq_len=None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        self.inv_freq.to(x.device)
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
        position_ids_expanded = position_ids[:, None, :].float()

        freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
        emb = torch.cat((freqs, freqs), dim=-1)
        cos = emb.cos()
        sin = emb.sin()
        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


# Copied from transformers.models.llama.modeling_llama.rotate_half
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


# Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


# Copied from transformers.models.llama.modeling_llama.repeat_kv
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


class SuryaTableRecDecoderSdpaCrossAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper
    Modified for GQA
    """

    def __init__(self, config: SuryaTableRecDecoderConfig):
        super().__init__()
        self.config = config
        self.attention_dropout = config.attention_dropout
        self.hidden_size = config.hidden_size
        self.num_attention_heads = config.num_attention_heads
        self.head_dim = config.head_dim
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads

        self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=config.attention_bias)
        self.k_proj = nn.Linear(self.config.encoder_hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.v_proj = nn.Linear(self.config.encoder_hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=True)
        self.rotary_emb = SuryaTableRecDecoderRotaryEmbedding(
            self.head_dim,
            base=config.rope_theta,
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        encoder_hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        encoder_attention_mask: Optional[torch.Tensor] = None,
        use_cache: bool = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        # Encoder attention mask currently ignored

        bsz, q_len, _ = hidden_states.size()
        _, v_len, _ = encoder_hidden_states.size()

        query_states = self.q_proj(hidden_states)
        query_states = query_states.view(bsz, q_len, self.num_attention_heads, self.head_dim).transpose(1, 2)

        if self.key_states is None:
            key_states = self.k_proj(encoder_hidden_states)
            value_states = self.v_proj(encoder_hidden_states)
            key_states = key_states.view(bsz, v_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
            value_states = value_states.view(bsz, v_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
            if use_cache:
                self._update_cache(key_states, value_states)
        else:
            key_states = self.key_states
            value_states = self.value_states

        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        attn_output = torch.nn.functional.scaled_dot_product_attention(
            query_states.contiguous(),
            key_states.contiguous(),
            value_states.contiguous(),
            attn_mask=None,
            dropout_p=self.attention_dropout if self.training else 0.0,
            scale=self.head_dim**-0.5,
        )

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(bsz, q_len, self.hidden_size)
        attn_output = self.o_proj(attn_output)
        return attn_output

    def _setup_cache(self, batch_size, device, dtype=None):
        # Setup initial caches
        self.value_states = None
        self.key_states = None

    @torch.no_grad()
    def _update_cache(self, key_states, value_states, **cache_kwargs):
        self.value_states = value_states
        self.key_states = key_states


class SuryaTableRecDecoderSdpaAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, config: SuryaTableRecDecoderConfig):
        super().__init__()
        self.config = config
        self.attention_dropout = config.attention_dropout
        self.hidden_size = config.hidden_size
        self.num_attention_heads = config.num_attention_heads
        self.head_dim = config.head_dim
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_attention_heads // self.num_key_value_heads

        self.q_proj = nn.Linear(self.hidden_size, self.num_attention_heads * self.head_dim, bias=config.attention_bias)
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
        self.o_proj = nn.Linear(self.num_attention_heads * self.head_dim, self.hidden_size, bias=True)
        self.rotary_emb = SuryaTableRecDecoderRotaryEmbedding(
            self.head_dim,
            base=config.rope_theta,
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        cache_position: Optional[torch.LongTensor] = None,
        use_cache: bool = False,
        window_attn: bool = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        # Final is bsz, num_attention_heads, seq_len, head_dim
        query_states = query_states.view(bsz, q_len, self.num_attention_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        cos, sin = self.rotary_emb(value_states, position_ids, seq_len=None)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if use_cache and hasattr(self, "key_states"):
            cache_kwargs = {"cache_position": cache_position, "window_attn": window_attn}
            key_states, value_states = self._update_cache(key_states, value_states, **cache_kwargs)

        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        causal_mask = attention_mask
        if attention_mask is not None:
            # Mask is batch, head, seq_len, kv_len
            causal_mask = causal_mask[:, :, :, :key_states.shape[-2]]
            current_cache_position = cache_position[-1].item() if cache_position is not None else None
            if current_cache_position and settings.RECOGNITION_STATIC_CACHE:
                # Mask out future cache positions
                position_mask = torch.ones_like(causal_mask, dtype=torch.bool, device=causal_mask.device)
                position_mask[:, :, :, :current_cache_position + 1] = False
                causal_mask = torch.where(position_mask, torch.finfo(causal_mask.dtype).min, causal_mask)

        attn_output = torch.nn.functional.scaled_dot_product_attention(
            query_states.contiguous(),
            key_states.contiguous(),
            value_states.contiguous(),
            attn_mask=causal_mask,
            dropout_p=self.attention_dropout if self.training else 0.0,
            scale=self.head_dim**-0.5,
        )

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(bsz, q_len, self.hidden_size)
        attn_output = self.o_proj(attn_output)
        return attn_output

    def _setup_cache(self, batch_size, device, dtype=None):
        if dtype is None and self.config.torch_dtype is not None:
            dtype = self.config.torch_dtype
        dtype = dtype if dtype is not None else torch.float32

        # Setup initial caches
        self.value_states = None
        self.key_states = None

        if settings.RECOGNITION_STATIC_CACHE:
            cache_shape = (batch_size, self.num_key_value_heads, settings.RECOGNITION_MAX_TOKENS, self.head_dim)
            self.value_states = torch.zeros(cache_shape, dtype=dtype, device=device)
            self.key_states = torch.zeros(cache_shape, dtype=dtype, device=device)

    def _update_static_cache(self, key_states, value_states, **cache_kwargs):
        cache_position = cache_kwargs.get("cache_position")
        k_out, v_out = self.key_states.to(key_states.device), self.value_states.to(value_states.device)

        k_out[:, :, cache_position] = key_states.to(k_out.dtype)
        v_out[:, :, cache_position] = value_states.to(v_out.dtype)

        self.key_states, self.value_states = k_out, v_out
        return k_out, v_out

    def _update_dynamic_cache(self, key_states, value_states, **cache_kwargs):
        k_out = key_states
        if self.key_states is not None:
            k_out = torch.cat([self.key_states, key_states], dim=2)

        v_out = value_states
        if self.value_states is not None:
            v_out = torch.cat([self.value_states, value_states], dim=2)

        self.key_states, self.value_states = k_out, v_out
        return k_out, v_out

    @torch.no_grad()
    def _update_cache(self, key_states, value_states, **cache_kwargs):
        if settings.RECOGNITION_STATIC_CACHE:
            return self._update_static_cache(key_states, value_states, **cache_kwargs)

        return self._update_dynamic_cache(key_states, value_states, **cache_kwargs)


class SuryaTableRecDecoderMlp(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
        if config.hidden_activation is None:
            config.hidden_activation = "gelu_pytorch_tanh"
        hidden_activation = config.hidden_activation
        self.act_fn = ACT2FN[hidden_activation]

    def forward(self, x):
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))


class SuryaTableRecDecoderLayer(nn.Module):
    def __init__(self, config, layer_idx):
        super().__init__()
        super().__init__()
        self.cross_pre_norm = SuryaTableRecDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.temporal_pre_norm = SuryaTableRecDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

        self.temporal_block = None
        if layer_idx in config.self_attn_layers:
            self.temporal_block = SuryaTableRecDecoderSdpaAttention(config)

        self.cross_attn_block = None
        if layer_idx in config.cross_attn_layers:
            self.cross_attn_block = SuryaTableRecDecoderSdpaCrossAttention(config)

        self.window_attn = layer_idx not in config.global_attn_layers
        self.channel_pre_norm = SuryaTableRecDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.mlp_block = SuryaTableRecDecoderMlp(config)

    def forward(
        self,
        activations: torch.Tensor,
        position_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        encoder_hidden_states: torch.Tensor = None,
        encoder_attention_mask: torch.Tensor = None,
        cache_position: torch.Tensor = None,
        use_cache: bool = None,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
        raw_activations = activations

        if self.cross_attn_block is not None:
            # Do cross-attention on encoder outputs
            cross_attn_inputs = self.cross_pre_norm(activations)
            cross_attn_path = self.cross_attn_block(
                cross_attn_inputs, encoder_hidden_states, attention_mask, encoder_attention_mask, use_cache=use_cache
            )
            cross_attn_output = cross_attn_path + raw_activations
        else:
            cross_attn_output = raw_activations

        if self.temporal_block is not None:
            inputs_normalized = self.temporal_pre_norm(cross_attn_output)  # RMSNorm introduces slight slight differences
            hidden_states = self.temporal_block(
                inputs_normalized, position_ids, attention_mask, cache_position=cache_position, use_cache=use_cache, window_attn=self.window_attn
            )

            residual = hidden_states + raw_activations
        else:
            residual = cross_attn_output

        hidden_states = self.channel_pre_norm(residual)
        hidden_states = self.mlp_block(hidden_states)

        hidden_states = hidden_states + residual
        return hidden_states


class SuryaTableRecDecoderPreTrainedModel(PreTrainedModel):
    config_class = SuryaTableRecDecoderConfig
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _no_split_modules = ["SuryaTableRecDecoderLayer"]
    _skip_keys_device_placement = ["cache"]
    _supports_flash_attn_2 = False
    _supports_sdpa = False  # we can't compare with eager for now
    _supports_cache_class = True
    _supports_quantized_cache = True

    def _init_weights(self, module):
        if isinstance(module, SuryaTableRecDecoderSdpaAttention):
            torch.nn.init.normal_(module.q_proj.weight, mean=0.0, std=self.config.init_std)
            torch.nn.init.normal_(module.k_proj.weight, mean=0.0, std=self.config.init_std)
            torch.nn.init.normal_(module.v_proj.weight, mean=0.0, std=self.config.init_std)

            torch.nn.init.normal_(module.o_proj.weight, mean=0.0, std=self.config.init_std)
        elif isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.init_std)
            if getattr(module, "bias", None) is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=self.config.init_std)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()

    def _setup_cache(self, config, batch, device, dtype):
        layers = getattr(self, "model", self).layers
        for layer in layers:
            if layer.temporal_block:
                layer.temporal_block._setup_cache(batch, device, dtype)
            if layer.cross_attn_block:
                layer.cross_attn_block._setup_cache(batch, device, dtype)

    def reset_cache(self, batch, device, dtype):
        pass

    def _tie_weights(self):
        pass

    def tie_weights(self):
        pass


class LabelEmbedding(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.vocab_size = config.vocab_size
        self.x1_embed = nn.Embedding(config.max_width, config.hidden_size)
        self.y1_embed = nn.Embedding(config.max_height, config.hidden_size)
        self.x2_embed = nn.Embedding(config.max_width, config.hidden_size)
        self.y2_embed = nn.Embedding(config.max_height, config.hidden_size)
        self.w_embed = nn.Embedding(config.max_width, config.hidden_size)
        self.h_embed = nn.Embedding(config.max_height, config.hidden_size)
        self.cx_embed = nn.Embedding(config.max_width, config.hidden_size)
        self.cy_embed = nn.Embedding(config.max_height, config.hidden_size)
        self.class_embed = nn.Embedding(config.max_classes, config.hidden_size)
        self.max_width = config.max_width
        self.max_height = config.max_height
        self.max_classes = config.max_classes

    def forward(self, labels: torch.LongTensor, input_box_counts: torch.LongTensor):
        cx, cy, w, h, class_ = labels.to(torch.long).unbind(dim=-1)
        # Shape is (batch_size, num_boxes/seq len, d_model)
        x1 = (cx - w // 2).long()
        y1 = (cy - h // 2).long()
        x2 = (cx + w // 2).long()
        y2 = (cy + h // 2).long()
        x1 = torch.clamp(x1, 0, self.max_width - 1)
        y1 = torch.clamp(y1, 0, self.max_height - 1)
        x2 = torch.clamp(x2, 0, self.max_width - 1)
        y2 = torch.clamp(y2, 0, self.max_height - 1)

        class_ = torch.clamp(class_, 0, self.max_classes - 1).long()

        w = torch.clamp(w, 0, self.max_width - 1).long()
        h = torch.clamp(h, 0, self.max_height - 1).long()
        cx = torch.clamp(cx, 0, self.max_width - 1).long()
        cy = torch.clamp(cy, 0, self.max_height - 1).long()

        coord_embeds = self.x1_embed(x1) + self.y1_embed(y1) + self.x2_embed(x2) + self.y2_embed(y2)
        class_embeds = self.class_embed(class_)
        embedded = coord_embeds + self.w_embed(w) + self.h_embed(h) + self.cx_embed(cx) + self.cy_embed(cy) + class_embeds

        return embedded


class BboxEmbedding(nn.Module):
    def __init__(self, config, embed_positions=False):
        super().__init__()
        self.x1_embed = nn.Embedding(config.max_width, config.hidden_size)
        self.y1_embed = nn.Embedding(config.max_height, config.hidden_size)
        self.x2_embed = nn.Embedding(config.max_width, config.hidden_size)
        self.y2_embed = nn.Embedding(config.max_height, config.hidden_size)
        self.w_embed = nn.Embedding(config.max_width, config.hidden_size)
        self.h_embed = nn.Embedding(config.max_height, config.hidden_size)
        self.cx_embed = nn.Embedding(config.max_width, config.hidden_size)
        self.cy_embed = nn.Embedding(config.max_height, config.hidden_size)
        self.box_pos_embed = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.max_width = config.max_width
        self.max_height = config.max_height
        self.embed_positions = embed_positions

    def forward(self, boxes: torch.LongTensor, input_box_counts: torch.LongTensor):
        x1, y1, x2, y2 = boxes.unbind(dim=-1)
        x1 = torch.clamp(x1, 0, self.max_width - 1).long()
        y1 = torch.clamp(y1, 0, self.max_height - 1).long()
        x2 = torch.clamp(x2, 0, self.max_width - 1).long()
        y2 = torch.clamp(y2, 0, self.max_height - 1).long()

        # Shape is (batch_size, num_boxes/seq len, d_model)
        w = x2 - x1
        h = y2 - y1
        # Center x and y in torch long tensors
        cx = (x1 + x2) / 2
        cy = (y1 + y2) / 2
        cx = cx.long()
        cy = cy.long()

        w = torch.clamp(w, 0, self.max_width - 1).long()
        h = torch.clamp(h, 0, self.max_height - 1).long()
        cx = torch.clamp(cx, 0, self.max_width - 1).long()
        cy = torch.clamp(cy, 0, self.max_height - 1).long()

        coord_embeds = self.x1_embed(x1) + self.y1_embed(y1) + self.x2_embed(x2) + self.y2_embed(y2)
        embedded = coord_embeds + self.w_embed(w) + self.h_embed(h) + self.cx_embed(cx) + self.cy_embed(cy)

        # Add in positional embeddings for the boxes and labels
        if self.embed_positions:
            for j in range(embedded.shape[0]):
                box_start = input_box_counts[j, 0]
                box_end = input_box_counts[j, 1] - 1 # Skip the sep token
                box_count = box_end - box_start
                embedded[j, box_start:box_end] = embedded[j, box_start:box_end] + self.box_pos_embed.weight[:box_count]

        return embedded


class SuryaTableRecDecoderModel(SuryaTableRecDecoderPreTrainedModel):
    """
    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`SuryaTableRecDecoderDecoderLayer`]

    Args:
        config: SuryaTableRecDecoderConfig
    """

    def __init__(self, config: SuryaTableRecDecoderConfig, embed_labels=False, embed_positions=True):
        super().__init__(config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size
        self.causal = config.causal

        if embed_labels:
            self.embed_tokens = LabelEmbedding(config)
        else:
            self.embed_tokens = BboxEmbedding(config, embed_positions=embed_positions)

        self.layers = nn.ModuleList(
            [SuryaTableRecDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )
        self.final_norm = SuryaTableRecDecoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.gradient_checkpointing = False

        self.register_buffer(
            "normalizer", torch.tensor(self.config.hidden_size**0.5, dtype=torch.float32), persistent=False
        )
        # Initialize weights and apply final processing
        self.post_init()

    # Copied from transformers.models.llama.modeling_llama.LlamaModel.get_input_embeddings
    def get_input_embeddings(self):
        return self.embed_tokens

    # Copied from transformers.models.llama.modeling_llama.LlamaModel.set_input_embeddings
    def set_input_embeddings(self, value):
        self.embed_tokens = value

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        input_boxes_counts: torch.LongTensor = None,
        position_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        cache_position: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        prefill: bool = False
    ) -> Union[Tuple, BaseModelOutputWithNoAttention]:
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if self.gradient_checkpointing and self.training and use_cache:
            use_cache = False

        inputs_embeds = self.embed_tokens(input_ids, input_boxes_counts)
        hidden_states = inputs_embeds

        if use_cache and prefill:
            self._setup_cache(self.config, hidden_states.shape[0], hidden_states.device, hidden_states.dtype)

        if cache_position is None:
            cache_position = torch.arange(hidden_states.shape[1], device=hidden_states.device)
        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position)

        all_hidden_states = () if output_hidden_states else None
        for i, residual_block in enumerate(self.layers):
            if output_hidden_states:
                all_hidden_states += (hidden_states,)
            if self.gradient_checkpointing and self.training:
                hidden_states = self._gradient_checkpointing_func(
                    residual_block.__call__, hidden_states, position_ids, causal_mask, encoder_hidden_states, encoder_attention_mask, cache_position, use_cache
                )
            else:
                hidden_states = residual_block(hidden_states, position_ids, causal_mask, encoder_hidden_states, encoder_attention_mask, cache_position, use_cache)

        hidden_states = self.final_norm(hidden_states)

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        if not return_dict:
            return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)

        return BaseModelOutputWithNoAttention(
            last_hidden_state=hidden_states,
            hidden_states=all_hidden_states,
        )

    # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
    # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
    # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
    # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
    # Ignore copy
    def _update_causal_mask(self, attention_mask, input_tensor, cache_position):
        if not self.causal:
            return None

        dtype, device = input_tensor.dtype, input_tensor.device
        min_dtype = torch.finfo(dtype).min
        sequence_length = input_tensor.shape[1]
        target_length = max(settings.TABLE_REC_MAX_BOXES, sequence_length)

        diagonal = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
        causal_mask = diagonal
        if sequence_length != 1:
            # Select the upper triangular part of the matrix, but unmask current token (the diagonal)
            # triu will be the min_dtype, everything else is 0 (attended to)
            causal_mask = torch.triu(diagonal, diagonal=1)

        causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
        causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
        if attention_mask is not None:
            causal_mask = causal_mask.clone()  # copy to contiguous memory for in-place edit
            if attention_mask.dim() == 2:
                # Mask positions in the causal mask that are masked in the attention mask
                mask_length = attention_mask.shape[-1]
                padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0)
                causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype)

        if attention_mask is not None and attention_mask.device.type == "cuda":
            # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
            # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
            # Details: https://github.com/pytorch/pytorch/issues/110213
            causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)

        return causal_mask


class SuryaTableRecDecoder(SuryaTableRecDecoderPreTrainedModel):
    _tied_weights_keys = None

    def __init__(self, config, **kwargs):
        super().__init__(config)
        self.model = SuryaTableRecDecoderModel(config, embed_labels=True, embed_positions=False)
        self.vocab_size = config.vocab_size

        self.bbox_head = nn.Linear(config.hidden_size, config.max_width * 4, bias=False)
        self.class_head = nn.Linear(config.hidden_size, config.max_classes, bias=False)
        self.max_width = config.max_width

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        self.model.embed_tokens = value

    def get_output_embeddings(self):
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    def set_decoder(self, decoder):
        self.model = decoder

    def get_decoder(self):
        return self.model

    # Ignore copy
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        cache_position: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        prefill: bool = False,
        **kwargs
    ) -> Union[Tuple, TableRecModelOutput]:
        outputs = self.model(
            input_ids=input_ids,
            cache_position=cache_position,
            attention_mask=attention_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            use_cache=use_cache,
            output_hidden_states=True,
            return_dict=True,
            prefill=prefill,
        )

        hidden_states = outputs[0]
        bbox_logits = self.bbox_head(hidden_states)
        class_logits = self.class_head(hidden_states)
        bsz, seq_len = class_logits.shape[:2]
        bbox_logits = bbox_logits.view(bsz, seq_len, 4, self.max_width)

        return TableRecModelOutput(
            bbox_logits=bbox_logits,
            class_logits=class_logits,
            hidden_states=hidden_states,
        )
@dataclass
class TextEncoderOutput(CausalLMOutput):
    hidden_states: torch.FloatTensor = None


class SuryaTableRecTextEncoder(SuryaTableRecDecoderPreTrainedModel):
    _tied_weights_keys = None
    config_class = SuryaTableRecTextEncoderConfig

    def __init__(self, config, **kwargs):
        super().__init__(config)
        self.model = SuryaTableRecDecoderModel(config, embed_labels=False, embed_positions=True)
        self.vocab_size = config.vocab_size

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        self.model.embed_tokens = value

    def set_decoder(self, decoder):
        self.model = decoder

    def get_decoder(self):
        return self.model

    # Ignore copy
    def forward(
        self,
        input_boxes: Optional[torch.LongTensor] = None,
        input_boxes_counts: Optional[torch.LongTensor] = None,
        cache_position: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        **kwargs
    ) -> Union[Tuple, CausalLMOutput]:
        outputs = self.model(
            input_ids=input_boxes,
            input_boxes_counts=input_boxes_counts,
            cache_position=cache_position,
            attention_mask=attention_mask,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            use_cache=use_cache,
            output_hidden_states=True,
            return_dict=True,
        )

        return TextEncoderOutput(
            hidden_states=outputs.last_hidden_state,
        )