File size: 975 Bytes
cd16641 ce179d5 cd16641 7ff77f3 cd16641 7ff77f3 ce179d5 7ff77f3 ce179d5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
import os
from typing import Optional
from transformers.modeling_utils import PretrainedConfig
class CaSEDConfig(PretrainedConfig):
"""Configuration class for CaSED.
Args:
index_name (str): Name of the index. Defaults to "cc12m".
alpha (float): Weight of the vision loss. Defaults to 0.5.
retrieval_num_results (int): Number of results to return. Defaults to 10.
cache_dir (str): Path to cache directory. Defaults to "~/.cache/cased".
"""
model_type = "cased"
is_composition = True
def __init__(
self,
index_name: str = "cc12m",
alpha: float = 0.5,
retrieval_num_results: int = 10,
cache_dir: Optional[str] = None,
**kwargs,
):
super().__init__(**kwargs)
self.index_name = index_name
self.alpha = alpha
self.retrieval_num_results = retrieval_num_results
self.cache_dir = cache_dir or os.path.expanduser("~/.cache/cased")
|