Spaces:
Running
on
Zero
Running
on
Zero
# -*- coding: utf-8 -*- | |
# Copyright (c) Alibaba, Inc. and its affiliates. | |
from inspect import isfunction | |
import torch | |
from torch.nn.utils.rnn import pad_sequence | |
from scepter.modules.utils.distribute import we | |
def exists(x): | |
return x is not None | |
def default(val, d): | |
if exists(val): | |
return val | |
return d() if isfunction(d) else d | |
def disabled_train(self, mode=True): | |
"""Overwrite model.train with this function to make sure train/eval mode | |
does not change anymore.""" | |
return self | |
def transfer_size(para_num): | |
if para_num > 1000 * 1000 * 1000 * 1000: | |
bill = para_num / (1000 * 1000 * 1000 * 1000) | |
return '{:.2f}T'.format(bill) | |
elif para_num > 1000 * 1000 * 1000: | |
gyte = para_num / (1000 * 1000 * 1000) | |
return '{:.2f}B'.format(gyte) | |
elif para_num > (1000 * 1000): | |
meta = para_num / (1000 * 1000) | |
return '{:.2f}M'.format(meta) | |
elif para_num > 1000: | |
kelo = para_num / 1000 | |
return '{:.2f}K'.format(kelo) | |
else: | |
return para_num | |
def count_params(model): | |
total_params = sum(p.numel() for p in model.parameters()) | |
return transfer_size(total_params) | |
def expand_dims_like(x, y): | |
while x.dim() != y.dim(): | |
x = x.unsqueeze(-1) | |
return x | |
def unpack_tensor_into_imagelist(image_tensor, shapes): | |
image_list = [] | |
for img, shape in zip(image_tensor, shapes): | |
h, w = shape[0], shape[1] | |
image_list.append(img[:, :h * w].view(1, -1, h, w)) | |
return image_list | |
def find_example(tensor_list, image_list): | |
for i in tensor_list: | |
if isinstance(i, torch.Tensor): | |
return torch.zeros_like(i) | |
for i in image_list: | |
if isinstance(i, torch.Tensor): | |
_, c, h, w = i.size() | |
return torch.zeros_like(i.view(c, h * w).transpose(1, 0)) | |
return None | |
def pack_imagelist_into_tensor_v2(image_list): | |
# allow None | |
example = None | |
image_tensor, shapes = [], [] | |
for img in image_list: | |
if img is None: | |
example = find_example(image_tensor, | |
image_list) if example is None else example | |
image_tensor.append(example) | |
shapes.append(None) | |
continue | |
_, c, h, w = img.size() | |
image_tensor.append(img.view(c, h * w).transpose(1, 0)) # h*w, c | |
shapes.append((h, w)) | |
image_tensor = pad_sequence(image_tensor, | |
batch_first=True).permute(0, 2, 1) # b, c, l | |
return image_tensor, shapes | |
def to_device(inputs, strict=True): | |
if inputs is None: | |
return None | |
if strict: | |
assert all(isinstance(i, torch.Tensor) for i in inputs) | |
return [i.to(we.device_id) if i is not None else None for i in inputs] | |
def check_list_of_list(ll): | |
return isinstance(ll, list) and all(isinstance(i, list) for i in ll) | |