|
from transformers import PretrainedConfig, VisionEncoderDecoderConfig |
|
from typing import List |
|
|
|
|
|
class MagiConfig(PretrainedConfig): |
|
model_type = "magi" |
|
|
|
def __init__( |
|
self, |
|
disable_ocr: bool = False, |
|
disable_crop_embeddings: bool = False, |
|
disable_detections: bool = False, |
|
detection_model_config: dict = None, |
|
ocr_model_config: dict = None, |
|
crop_embedding_model_config: dict = None, |
|
detection_image_preprocessing_config: dict = None, |
|
ocr_pretrained_processor_path: str = None, |
|
crop_embedding_image_preprocessing_config: dict = None, |
|
**kwargs, |
|
): |
|
self.disable_ocr = disable_ocr |
|
self.disable_crop_embeddings = disable_crop_embeddings |
|
self.disable_detections = disable_detections |
|
|
|
self.detection_model_config = None |
|
self.ocr_model_config = None |
|
self.crop_embedding_model_config = None |
|
if detection_model_config is not None: |
|
self.detection_model_config = PretrainedConfig.from_dict(detection_model_config) |
|
if ocr_model_config is not None: |
|
self.ocr_model_config = VisionEncoderDecoderConfig.from_dict(ocr_model_config) |
|
if crop_embedding_model_config is not None: |
|
self.crop_embedding_model_config = PretrainedConfig.from_dict(crop_embedding_model_config) |
|
|
|
self.detection_image_preprocessing_config = detection_image_preprocessing_config |
|
self.ocr_pretrained_processor_path = ocr_pretrained_processor_path |
|
self.crop_embedding_image_preprocessing_config = crop_embedding_image_preprocessing_config |
|
super().__init__(**kwargs) |
|
|