File size: 1,371 Bytes
a50b54b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 |
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# NVIDIA CORPORATION and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION is strictly prohibited.
from argparse import Namespace
from typing import Dict, Any
import torch
from .adaptor_generic import GenericAdaptor, AdaptorBase
dict_t = Dict[str, Any]
state_t = Dict[str, torch.Tensor]
class AdaptorRegistry:
def __init__(self):
self._registry = {}
def register_adaptor(self, name):
def decorator(factory_function):
if name in self._registry:
raise ValueError(f"Model '{name}' already registered")
self._registry[name] = factory_function
return factory_function
return decorator
def create_adaptor(self, name, main_config: Namespace, adaptor_config: dict_t, state: state_t) -> AdaptorBase:
if name not in self._registry:
return GenericAdaptor(main_config, adaptor_config, state)
return self._registry[name](main_config, adaptor_config, state)
# Creating an instance of the registry
adaptor_registry = AdaptorRegistry()
|