ACE-Chat / modules /model /utils /basic_utils.py
pan-yl's picture
update file
2a00960
raw
history blame
2.89 kB
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
from inspect import isfunction
import torch
from torch.nn.utils.rnn import pad_sequence
from scepter.modules.utils.distribute import we
def exists(x):
return x is not None
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def disabled_train(self, mode=True):
"""Overwrite model.train with this function to make sure train/eval mode
does not change anymore."""
return self
def transfer_size(para_num):
if para_num > 1000 * 1000 * 1000 * 1000:
bill = para_num / (1000 * 1000 * 1000 * 1000)
return '{:.2f}T'.format(bill)
elif para_num > 1000 * 1000 * 1000:
gyte = para_num / (1000 * 1000 * 1000)
return '{:.2f}B'.format(gyte)
elif para_num > (1000 * 1000):
meta = para_num / (1000 * 1000)
return '{:.2f}M'.format(meta)
elif para_num > 1000:
kelo = para_num / 1000
return '{:.2f}K'.format(kelo)
else:
return para_num
def count_params(model):
total_params = sum(p.numel() for p in model.parameters())
return transfer_size(total_params)
def expand_dims_like(x, y):
while x.dim() != y.dim():
x = x.unsqueeze(-1)
return x
def unpack_tensor_into_imagelist(image_tensor, shapes):
image_list = []
for img, shape in zip(image_tensor, shapes):
h, w = shape[0], shape[1]
image_list.append(img[:, :h * w].view(1, -1, h, w))
return image_list
def find_example(tensor_list, image_list):
for i in tensor_list:
if isinstance(i, torch.Tensor):
return torch.zeros_like(i)
for i in image_list:
if isinstance(i, torch.Tensor):
_, c, h, w = i.size()
return torch.zeros_like(i.view(c, h * w).transpose(1, 0))
return None
def pack_imagelist_into_tensor_v2(image_list):
# allow None
example = None
image_tensor, shapes = [], []
for img in image_list:
if img is None:
example = find_example(image_tensor,
image_list) if example is None else example
image_tensor.append(example)
shapes.append(None)
continue
_, c, h, w = img.size()
image_tensor.append(img.view(c, h * w).transpose(1, 0)) # h*w, c
shapes.append((h, w))
image_tensor = pad_sequence(image_tensor,
batch_first=True).permute(0, 2, 1) # b, c, l
return image_tensor, shapes
def to_device(inputs, strict=True):
if inputs is None:
return None
if strict:
assert all(isinstance(i, torch.Tensor) for i in inputs)
return [i.to(we.device_id) if i is not None else None for i in inputs]
def check_list_of_list(ll):
return isinstance(ll, list) and all(isinstance(i, list) for i in ll)