from transformers import PretrainedConfig from surya.settings import settings BOX_DIM = 1024 SPECIAL_TOKENS = 7 MAX_ROWS = 384 class SuryaTableRecConfig(PretrainedConfig): model_type = "vision-encoder-decoder" is_composition = True def __init__(self, **kwargs): super().__init__(**kwargs) encoder_config = kwargs.pop("encoder") decoder_config = kwargs.pop("decoder") text_enc_config = kwargs.pop("text_encoder") self.encoder = encoder_config self.decoder = decoder_config self.text_encoder = text_enc_config self.is_encoder_decoder = True if isinstance(decoder_config, dict): self.decoder_start_token_id = decoder_config["bos_token_id"] self.pad_token_id = decoder_config["pad_token_id"] self.eos_token_id = decoder_config["eos_token_id"] else: self.decoder_start_token_id = decoder_config.bos_token_id self.pad_token_id = decoder_config.pad_token_id self.eos_token_id = decoder_config.eos_token_id class DonutSwinTableRecConfig(PretrainedConfig): model_type = "donut-swin" attribute_map = { "num_attention_heads": "num_heads", "num_hidden_layers": "num_layers", } def __init__( self, image_size=(settings.TABLE_REC_IMAGE_SIZE["width"], settings.TABLE_REC_IMAGE_SIZE["height"]), patch_size=4, num_channels=3, embed_dim=128, depths=[2, 2, 14, 2], num_heads=[4, 8, 16, 32], num_kv_heads=[4, 8, 16, 32], window_size=8, mlp_ratio=4.0, qkv_bias=True, hidden_dropout_prob=0.0, attention_probs_dropout_prob=0.0, drop_path_rate=0.1, hidden_act="gelu", use_absolute_embeddings=True, initializer_range=0.02, layer_norm_eps=1e-5, encoder_length=1024, **kwargs, ): super().__init__(**kwargs) self.image_size = image_size self.patch_size = patch_size self.num_channels = num_channels self.embed_dim = embed_dim self.depths = depths self.num_layers = len(depths) self.num_heads = num_heads self.num_kv_heads = num_kv_heads self.window_size = window_size self.mlp_ratio = mlp_ratio self.qkv_bias = qkv_bias self.hidden_dropout_prob = hidden_dropout_prob self.attention_probs_dropout_prob = attention_probs_dropout_prob self.drop_path_rate = drop_path_rate self.hidden_act = hidden_act self.use_absolute_embeddings = use_absolute_embeddings self.layer_norm_eps = layer_norm_eps self.initializer_range = initializer_range # we set the hidden_size attribute in order to make Swin work with VisionEncoderDecoderModel # this indicates the channel dimension after the last stage of the model self.hidden_size = int(embed_dim * 2 ** (len(depths) - 1)) self.encoder_length = encoder_length class SuryaTableRecDecoderConfig(PretrainedConfig): model_type = "surya_tablerec" def __init__( self, num_hidden_layers=3, vocab_size=settings.TABLE_REC_MAX_ROWS + SPECIAL_TOKENS, hidden_size=512, intermediate_size=4 * 512, encoder_hidden_size=1024, num_attention_heads=8, lru_width=None, attention_window_size=16, conv1d_width=4, logits_soft_cap=30.0, rms_norm_eps=1e-6, use_cache=True, pad_token_id=0, eos_token_id=1, bos_token_id=2, hidden_activation="gelu_pytorch_tanh", rope_theta=10000.0, block_types=("attention",), cross_attn_layers=(0, 1, 2, 3), encoder_cross_attn_layers=(0, 1, 2, 3), self_attn_layers=(0, 1, 2, 3), global_attn_layers=(0, 1, 2, 3), attention_dropout=0.0, num_key_value_heads=4, attention_bias=False, w_init_variance_scale=0.01, init_std=0.02, tie_word_embeddings=False, aux_heads=0, # How many n-token-ahead heads to add causal=True, max_classes=2 + SPECIAL_TOKENS, max_width=1024 + SPECIAL_TOKENS, max_height=1024 + SPECIAL_TOKENS, out_box_size=1024, **kwargs, ): self.num_hidden_layers = num_hidden_layers self.vocab_size = vocab_size self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_attention_heads = num_attention_heads self.lru_width = lru_width if lru_width is not None else hidden_size self.attention_window_size = attention_window_size self.conv1d_width = conv1d_width self.logits_soft_cap = logits_soft_cap self.rms_norm_eps = rms_norm_eps self.use_cache = use_cache self.rope_theta = rope_theta self.block_types = list(block_types) self.hidden_activation = hidden_activation self.head_dim = self.hidden_size // self.num_attention_heads self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads if self.num_key_value_heads > self.num_attention_heads: raise ValueError("The number of `num_key_value_heads` must be smaller than `num_attention_heads`") self.cross_attn_layers = cross_attn_layers self.self_attn_layers = self_attn_layers self.global_attn_layers = global_attn_layers self.attention_dropout = attention_dropout self.attention_bias = attention_bias self.w_init_variance_scale = w_init_variance_scale self.final_w_init_variance_scale = 2.0 / self.num_hidden_layers self.init_std = init_std self.tie_word_embeddings = tie_word_embeddings self.aux_heads = aux_heads self.encoder_hidden_size=encoder_hidden_size self.causal = causal self.encoder_cross_attn_layers = encoder_cross_attn_layers self.max_classes = max_classes self.max_width = max_width self.max_height = max_height self.out_box_size = out_box_size super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs, ) @property def layers_block_type(self): return (self.block_types * 100)[: self.num_hidden_layers] class SuryaTableRecTextEncoderConfig(PretrainedConfig): model_type = "surya_tablerec" def __init__( self, num_hidden_layers=4, vocab_size=settings.TABLE_REC_MAX_ROWS + SPECIAL_TOKENS, hidden_size=1024, intermediate_size=4 * 1024, encoder_hidden_size=1024, num_attention_heads=16, lru_width=None, attention_window_size=16, conv1d_width=4, logits_soft_cap=30.0, rms_norm_eps=1e-6, use_cache=True, pad_token_id=0, eos_token_id=1, bos_token_id=2, hidden_activation="gelu_pytorch_tanh", rope_theta=10000.0, block_types=("attention",), cross_attn_layers=(0, 1, 2, 3, 4, 5), self_attn_layers=(0, 1, 2, 3, 4, 5), global_attn_layers=(0, 1, 2, 3, 4, 5), attention_dropout=0.0, num_key_value_heads=16, attention_bias=False, w_init_variance_scale=0.01, init_std=0.02, tie_word_embeddings=False, causal=False, max_width=BOX_DIM + SPECIAL_TOKENS, max_height=BOX_DIM + SPECIAL_TOKENS, max_position_embeddings=1024, **kwargs, ): self.num_hidden_layers = num_hidden_layers self.vocab_size = vocab_size self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.num_attention_heads = num_attention_heads self.lru_width = lru_width if lru_width is not None else hidden_size self.attention_window_size = attention_window_size self.conv1d_width = conv1d_width self.logits_soft_cap = logits_soft_cap self.rms_norm_eps = rms_norm_eps self.use_cache = use_cache self.rope_theta = rope_theta self.block_types = list(block_types) self.hidden_activation = hidden_activation self.head_dim = self.hidden_size // self.num_attention_heads self.num_key_value_heads = num_key_value_heads if num_key_value_heads is not None else num_attention_heads if self.num_key_value_heads > self.num_attention_heads: raise ValueError("The number of `num_key_value_heads` must be smaller than `num_attention_heads`") self.cross_attn_layers = cross_attn_layers self.self_attn_layers = self_attn_layers self.global_attn_layers = global_attn_layers self.attention_dropout = attention_dropout self.attention_bias = attention_bias self.w_init_variance_scale = w_init_variance_scale self.final_w_init_variance_scale = 2.0 / self.num_hidden_layers self.init_std = init_std self.tie_word_embeddings = tie_word_embeddings self.encoder_hidden_size = encoder_hidden_size self.causal = causal self.max_width = max_width self.max_height = max_height self.max_position_embeddings = max_position_embeddings super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs, ) @property def layers_block_type(self): return (self.block_types * 100)[: self.num_hidden_layers]