|
import logging |
|
import math |
|
import os |
|
from contextlib import contextmanager |
|
from transformers import StoppingCriteria, StoppingCriteriaList |
|
|
|
import timm.models.hub as timm_hub |
|
import torch |
|
import torch.distributed as dist |
|
import torch.nn as nn |
|
|
|
|
|
def is_dist_avail_and_initialized(): |
|
if not dist.is_available(): |
|
return False |
|
if not dist.is_initialized(): |
|
return False |
|
return True |
|
|
|
|
|
def get_rank(): |
|
if not is_dist_avail_and_initialized(): |
|
return 0 |
|
return dist.get_rank() |
|
|
|
|
|
def is_main_process(): |
|
return get_rank() == 0 |
|
|
|
|
|
def download_cached_file(url, check_hash=True, progress=False): |
|
""" |
|
Download a file from a URL and cache it locally. If the file already exists, it is not downloaded again. |
|
If distributed, only the main process downloads the file, and the other processes wait for the file to be downloaded. |
|
""" |
|
|
|
def get_cached_file_path(): |
|
|
|
parts = torch.hub.urlparse(url) |
|
filename = os.path.basename(parts.path) |
|
cached_file = os.path.join(timm_hub.get_cache_dir(), filename) |
|
|
|
return cached_file |
|
|
|
if is_main_process(): |
|
timm_hub.download_cached_file(url, check_hash, progress) |
|
|
|
if is_dist_avail_and_initialized(): |
|
dist.barrier() |
|
|
|
return get_cached_file_path() |
|
|
|
|
|
@contextmanager |
|
def all_logging_disabled(highest_level=logging.CRITICAL): |
|
""" |
|
A context manager that will prevent any logging messages |
|
triggered during the body from being processed. |
|
:param highest_level: the maximum logging level in use. |
|
This would only need to be changed if a custom level greater than CRITICAL |
|
is defined. |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
previous_level = logging.root.manager.disable |
|
|
|
logging.disable(highest_level) |
|
|
|
try: |
|
yield |
|
finally: |
|
logging.disable(previous_level) |
|
|
|
|
|
class LoRALinear(nn.Linear): |
|
def __init__( |
|
self, |
|
in_features: int, |
|
out_features: int, |
|
bias: bool = True, |
|
device=None, |
|
dtype=None, |
|
lora_r=8, |
|
lora_alpha=16, |
|
lora_dropout=0.05, |
|
**kwargs |
|
) -> None: |
|
super().__init__(in_features, out_features, bias, device, dtype) |
|
self.lora_r = lora_r |
|
self.lora_alpha = lora_alpha |
|
if lora_dropout > 0.0: |
|
self.lora_dropout = nn.Dropout(p=lora_dropout) |
|
else: |
|
self.lora_dropout = lambda x: x |
|
self.lora_scaling = self.lora_alpha / self.lora_r |
|
|
|
self.lora_A = nn.Linear( |
|
in_features, self.lora_r, bias=False, device=device, dtype=dtype |
|
) |
|
self.lora_B = nn.Linear( |
|
self.lora_r, out_features, bias=False, device=device, dtype=dtype |
|
) |
|
|
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
if hasattr(self, "lora_A"): |
|
|
|
nn.init.kaiming_uniform_(self.lora_A.weight, a=math.sqrt(5)) |
|
nn.init.zeros_(self.lora_B.weight) |
|
|
|
def forward(self, x): |
|
orig_type = x.dtype |
|
res = super().forward(x) |
|
x = x.float() |
|
res += self.lora_B(self.lora_A(self.lora_dropout(x))) * self.lora_scaling |
|
return res.to(orig_type) |
|
|
|
|
|
class StoppingCriteriaSub(StoppingCriteria): |
|
def __init__(self, stops=[], encounters=1): |
|
super().__init__() |
|
self.stops = stops |
|
|
|
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): |
|
for stop in self.stops: |
|
if torch.all((stop == input_ids[:, -len(stop) :])).item(): |
|
return True |
|
|
|
return False |
|
|