File size: 1,672 Bytes
b36f7c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from transformers import PretrainedConfig, VisionEncoderDecoderConfig
from typing import List


class MagiConfig(PretrainedConfig):
    model_type = "magi"

    def __init__(
        self,
        disable_ocr: bool = False,
        disable_crop_embeddings: bool = False,
        disable_detections: bool = False,
        detection_model_config: dict = None,
        ocr_model_config: dict = None,
        crop_embedding_model_config: dict = None,
        detection_image_preprocessing_config: dict = None,
        ocr_pretrained_processor_path: str = None,
        crop_embedding_image_preprocessing_config: dict = None,
        **kwargs,
    ):
        self.disable_ocr = disable_ocr
        self.disable_crop_embeddings = disable_crop_embeddings
        self.disable_detections = disable_detections
        
        self.detection_model_config = None
        self.ocr_model_config = None
        self.crop_embedding_model_config = None
        if detection_model_config is not None:
            self.detection_model_config = PretrainedConfig.from_dict(detection_model_config)
        if ocr_model_config is not None:
            self.ocr_model_config = VisionEncoderDecoderConfig.from_dict(ocr_model_config)
        if crop_embedding_model_config is not None:
            self.crop_embedding_model_config = PretrainedConfig.from_dict(crop_embedding_model_config)
        
        self.detection_image_preprocessing_config = detection_image_preprocessing_config
        self.ocr_pretrained_processor_path = ocr_pretrained_processor_path
        self.crop_embedding_image_preprocessing_config = crop_embedding_image_preprocessing_config
        super().__init__(**kwargs)