magi / configuration_magi.py
ragavsachdeva's picture
Upload model
b36f7c2 verified
from transformers import PretrainedConfig, VisionEncoderDecoderConfig
from typing import List
class MagiConfig(PretrainedConfig):
model_type = "magi"
def __init__(
self,
disable_ocr: bool = False,
disable_crop_embeddings: bool = False,
disable_detections: bool = False,
detection_model_config: dict = None,
ocr_model_config: dict = None,
crop_embedding_model_config: dict = None,
detection_image_preprocessing_config: dict = None,
ocr_pretrained_processor_path: str = None,
crop_embedding_image_preprocessing_config: dict = None,
**kwargs,
):
self.disable_ocr = disable_ocr
self.disable_crop_embeddings = disable_crop_embeddings
self.disable_detections = disable_detections
self.detection_model_config = None
self.ocr_model_config = None
self.crop_embedding_model_config = None
if detection_model_config is not None:
self.detection_model_config = PretrainedConfig.from_dict(detection_model_config)
if ocr_model_config is not None:
self.ocr_model_config = VisionEncoderDecoderConfig.from_dict(ocr_model_config)
if crop_embedding_model_config is not None:
self.crop_embedding_model_config = PretrainedConfig.from_dict(crop_embedding_model_config)
self.detection_image_preprocessing_config = detection_image_preprocessing_config
self.ocr_pretrained_processor_path = ocr_pretrained_processor_path
self.crop_embedding_image_preprocessing_config = crop_embedding_image_preprocessing_config
super().__init__(**kwargs)