import contextlib import tempfile import torch import transformers as tr from relik.common.utils import is_package_available # check if ORT is available if is_package_available("onnxruntime"): from optimum.onnxruntime import ( ORTModel, ORTModelForCustomTasks, ORTModelForSequenceClassification, ORTOptimizer, ) from optimum.onnxruntime.configuration import AutoOptimizationConfig # from relik.retriever.pytorch_modules import PRECISION_MAP def get_autocast_context( device: str | torch.device, precision: str ) -> contextlib.AbstractContextManager: # fucking autocast only wants pure strings like 'cpu' or 'cuda' # we need to convert the model device to that device_type_for_autocast = str(device).split(":")[0] from relik.retriever.pytorch_modules import PRECISION_MAP # autocast doesn't work with CPU and stuff different from bfloat16 autocast_manager = ( contextlib.nullcontext() if device_type_for_autocast in ["cpu", "mps"] and PRECISION_MAP[precision] != torch.bfloat16 else ( torch.autocast( device_type=device_type_for_autocast, dtype=PRECISION_MAP[precision], ) ) ) return autocast_manager # def load_ort_optimized_hf_model( # hf_model: tr.PreTrainedModel, # provider: str = "CPUExecutionProvider", # ort_model_type: callable = "ORTModelForCustomTasks", # ) -> ORTModel: # """ # Load an optimized ONNX Runtime HF model. # # Args: # hf_model (`tr.PreTrainedModel`): # The HF model to optimize. # provider (`str`, optional): # The ONNX Runtime provider to use. Defaults to "CPUExecutionProvider". # # Returns: # `ORTModel`: The optimized HF model. # """ # if isinstance(hf_model, ORTModel): # return hf_model # temp_dir = tempfile.mkdtemp() # hf_model.save_pretrained(temp_dir) # ort_model = ort_model_type.from_pretrained( # temp_dir, export=True, provider=provider, use_io_binding=True # ) # if is_package_available("onnxruntime"): # optimizer = ORTOptimizer.from_pretrained(ort_model) # optimization_config = AutoOptimizationConfig.O4() # optimizer.optimize(save_dir=temp_dir, optimization_config=optimization_config) # ort_model = ort_model_type.from_pretrained( # temp_dir, # export=True, # provider=provider, # use_io_binding=bool(provider == "CUDAExecutionProvider"), # ) # return ort_model # else: # raise ValueError("onnxruntime is not installed. Please install Ray with `pip install relik[serve]`.")