CarlosMalaga's picture
Upload 201 files
2f044c1 verified
raw
history blame
2.73 kB
import contextlib
import tempfile
import torch
import transformers as tr
from relik.common.utils import is_package_available
# check if ORT is available
if is_package_available("onnxruntime"):
from optimum.onnxruntime import (
ORTModel,
ORTModelForCustomTasks,
ORTModelForSequenceClassification,
ORTOptimizer,
)
from optimum.onnxruntime.configuration import AutoOptimizationConfig
# from relik.retriever.pytorch_modules import PRECISION_MAP
def get_autocast_context(
device: str | torch.device, precision: str
) -> contextlib.AbstractContextManager:
# fucking autocast only wants pure strings like 'cpu' or 'cuda'
# we need to convert the model device to that
device_type_for_autocast = str(device).split(":")[0]
from relik.retriever.pytorch_modules import PRECISION_MAP
# autocast doesn't work with CPU and stuff different from bfloat16
autocast_manager = (
contextlib.nullcontext()
if device_type_for_autocast in ["cpu", "mps"]
and PRECISION_MAP[precision] != torch.bfloat16
else (
torch.autocast(
device_type=device_type_for_autocast,
dtype=PRECISION_MAP[precision],
)
)
)
return autocast_manager
# def load_ort_optimized_hf_model(
# hf_model: tr.PreTrainedModel,
# provider: str = "CPUExecutionProvider",
# ort_model_type: callable = "ORTModelForCustomTasks",
# ) -> ORTModel:
# """
# Load an optimized ONNX Runtime HF model.
#
# Args:
# hf_model (`tr.PreTrainedModel`):
# The HF model to optimize.
# provider (`str`, optional):
# The ONNX Runtime provider to use. Defaults to "CPUExecutionProvider".
#
# Returns:
# `ORTModel`: The optimized HF model.
# """
# if isinstance(hf_model, ORTModel):
# return hf_model
# temp_dir = tempfile.mkdtemp()
# hf_model.save_pretrained(temp_dir)
# ort_model = ort_model_type.from_pretrained(
# temp_dir, export=True, provider=provider, use_io_binding=True
# )
# if is_package_available("onnxruntime"):
# optimizer = ORTOptimizer.from_pretrained(ort_model)
# optimization_config = AutoOptimizationConfig.O4()
# optimizer.optimize(save_dir=temp_dir, optimization_config=optimization_config)
# ort_model = ort_model_type.from_pretrained(
# temp_dir,
# export=True,
# provider=provider,
# use_io_binding=bool(provider == "CUDAExecutionProvider"),
# )
# return ort_model
# else:
# raise ValueError("onnxruntime is not installed. Please install Ray with `pip install relik[serve]`.")