import os | |
from typing import Optional | |
from transformers.modeling_utils import PretrainedConfig | |
class CaSEDConfig(PretrainedConfig): | |
"""Configuration class for CaSED. | |
Args: | |
index_name (str): Name of the index. Defaults to "cc12m". | |
alpha (float): Weight of the vision loss. Defaults to 0.5. | |
retrieval_num_results (int): Number of results to return. Defaults to 10. | |
cache_dir (str): Path to cache directory. Defaults to "~/.cache/cased". | |
""" | |
model_type = "cased" | |
is_composition = True | |
def __init__( | |
self, | |
index_name: str = "cc12m", | |
alpha: float = 0.5, | |
retrieval_num_results: int = 10, | |
cache_dir: Optional[str] = None, | |
**kwargs, | |
): | |
super().__init__(**kwargs) | |
self.index_name = index_name | |
self.alpha = alpha | |
self.retrieval_num_results = retrieval_num_results | |
self.cache_dir = cache_dir or os.path.expanduser("~/.cache/cased") | |