# Copyright (C) 2024 Charles O. Goddard # # This software is free software: you can redistribute it and/or # modify it under the terms of the GNU Lesser General Public License as # published by the Free Software Foundation, either version 3 of the # License, or (at your option) any later version. # # This software is distributed in the hope that it will be useful, but # WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU # Lesser General Public License for more details. # # You should have received a copy of the GNU Lesser General Public License # along with this program. If not, see http://www.gnu.org/licenses/. from abc import ABC, abstractmethod from typing import Dict, Optional, Sequence import safetensors import torch from mergekit.io.lazy_unpickle import DeferredLoad, TorchArchiveReader, torch_lazy_load class TensorLoader(ABC): """Base class for (potentially lazy) tensor loaders.""" @abstractmethod def get_tensor(self, key: str) -> torch.Tensor: ... @abstractmethod def keys(self) -> Sequence[str]: ... @classmethod def get( cls, shard_path: str, use_lazy_unpickle: bool = False, device: Optional[str] = None, ) -> "TensorLoader": if shard_path.lower().endswith(".safetensors"): # not a subclass of TensorLoader, but exposes same api return safetensors.safe_open( shard_path, framework="pt", device=device or "cpu" ) elif use_lazy_unpickle: return LazyPickleLoader(shard_path, device=device) return DumbPytorchLoader(shard_path, device=device) class LazyPickleLoader(TensorLoader): """Loader for pytorch files using a custom unpickler and vigorous monkeypatching.""" zip_reader: TorchArchiveReader index: Dict[str, DeferredLoad] device: Optional[str] = None def __init__(self, path: str, device: Optional[str] = None): self.zip_reader = TorchArchiveReader(path) self.device = device with torch_lazy_load(): self.index = torch.load(path) def get_tensor(self, key: str) -> torch.Tensor: if key not in self.index: raise KeyError(key) return self.index[key].execute(self.zip_reader, map_location=self.device) def keys(self) -> Sequence[str]: return self.index.keys() class DumbPytorchLoader(TensorLoader): """Naive `torch.load` shard loading.""" tensors: Dict[str, torch.Tensor] def __init__(self, path: str, device: Optional[str] = None): self.tensors = torch.load(path, map_location=device, weights_only=True) def get_tensor(self, key: str) -> torch.Tensor: return self.tensors[key] def keys(self) -> Sequence[str]: return self.tensors.keys()