Spaces:
Running
Running
""" | |
Copyright (c) 2022, salesforce.com, inc. | |
All rights reserved. | |
SPDX-License-Identifier: BSD-3-Clause | |
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
""" | |
class Registry: | |
""" | |
注册管理器 | |
""" | |
mapping = { | |
"processor_name_mapping": {}, | |
"task_name_mapping": {}, | |
"state": {}, | |
"paths": {}, | |
} | |
def register_task(cls, name): | |
r"""Register a task to registry with key 'name' | |
Args: | |
name: Key with which the task will be registered. | |
Usage: | |
from lavis.common.registry import registry | |
""" | |
print(f"from speakers.common.registry import registry {name}") | |
def wrap(task_cls): | |
from speakers.tasks.base_task import BaseTask | |
assert issubclass( | |
task_cls, BaseTask | |
), "All tasks must inherit BaseTask class" | |
if name in cls.mapping["task_name_mapping"]: | |
raise KeyError( | |
"Name '{}' already registered for {}.".format( | |
name, cls.mapping["task_name_mapping"][name] | |
) | |
) | |
cls.mapping["task_name_mapping"][name] = task_cls | |
return task_cls | |
return wrap | |
def register_processor(cls, name): | |
r"""Register a processor to registry with key 'name' | |
Args: | |
name: Key with which the task will be registered. | |
Usage: | |
from speakers.common.registry import registry | |
""" | |
print(f"from speakers.common.registry import registry {name}") | |
def wrap(processor_cls): | |
from speakers.processors import BaseProcessor | |
assert issubclass( | |
processor_cls, BaseProcessor | |
), "All processors must inherit BaseProcessor class" | |
if name in cls.mapping["processor_name_mapping"]: | |
raise KeyError( | |
"Name '{}' already registered for {}.".format( | |
name, cls.mapping["processor_name_mapping"][name] | |
) | |
) | |
cls.mapping["processor_name_mapping"][name] = processor_cls | |
return processor_cls | |
return wrap | |
def register_path(cls, name, path): | |
r"""Register a path to registry with key 'name' | |
Args: | |
name: Key with which the path will be registered. | |
path: Key with which the path will be registered. | |
Usage: | |
from lavis.common.registry import registry | |
""" | |
assert isinstance(path, str), "All path must be str." | |
if name in cls.mapping["paths"]: | |
raise KeyError("Name '{}' already registered.".format(name)) | |
cls.mapping["paths"][name] = path | |
def register(cls, name, obj): | |
r"""Register an item to registry with key 'name' | |
Args: | |
name: Key with which the item will be registered. | |
Usage:: | |
from lavis.common.registry import registry | |
registry.register("config", {}) | |
""" | |
path = name.split(".") | |
current = cls.mapping["state"] | |
for part in path[:-1]: | |
if part not in current: | |
current[part] = {} | |
current = current[part] | |
current[path[-1]] = obj | |
print(f" Key with which the item will be registered {current}") | |
def get_processor_class(cls, name): | |
return cls.mapping["processor_name_mapping"].get(name, None) | |
def get_task_class(cls, name): | |
return cls.mapping["task_name_mapping"].get(name, None) | |
def list_processors(cls): | |
return sorted(cls.mapping["processor_name_mapping"].keys()) | |
def list_tasks(cls): | |
return sorted(cls.mapping["task_name_mapping"].keys()) | |
def get_path(cls, name): | |
return cls.mapping["paths"].get(name, None) | |
def get(cls, name, default=None, no_warning=False): | |
r"""Get an item from registry with key 'name' | |
Args: | |
name (string): Key whose value needs to be retrieved. | |
default: If passed and key is not in registry, default value will | |
be returned with a warning. Default: None | |
no_warning (bool): If passed as True, warning when key doesn't exist | |
will not be generated. Useful for MMF's | |
internal operations. Default: False | |
""" | |
original_name = name | |
name = name.split(".") | |
value = cls.mapping["state"] | |
for subname in name: | |
value = value.get(subname, default) | |
if value is default: | |
break | |
if ( | |
"writer" in cls.mapping["state"] | |
and value == default | |
and no_warning is False | |
): | |
cls.mapping["state"]["writer"].warning( | |
"Key {} is not present in registry, returning default value " | |
"of {}".format(original_name, default) | |
) | |
return value | |
def unregister(cls, name): | |
r"""Remove an item from registry with key 'name' | |
Args: | |
name: Key which needs to be removed. | |
Usage:: | |
from mmf.common.registry import registry | |
config = registry.unregister("config") | |
""" | |
return cls.mapping["state"].pop(name, None) | |
registry = Registry() | |