File size: 2,885 Bytes
2a00960
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
from inspect import isfunction

import torch
from torch.nn.utils.rnn import pad_sequence

from scepter.modules.utils.distribute import we


def exists(x):
    return x is not None


def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d


def disabled_train(self, mode=True):
    """Overwrite model.train with this function to make sure train/eval mode
    does not change anymore."""
    return self


def transfer_size(para_num):
    if para_num > 1000 * 1000 * 1000 * 1000:
        bill = para_num / (1000 * 1000 * 1000 * 1000)
        return '{:.2f}T'.format(bill)
    elif para_num > 1000 * 1000 * 1000:
        gyte = para_num / (1000 * 1000 * 1000)
        return '{:.2f}B'.format(gyte)
    elif para_num > (1000 * 1000):
        meta = para_num / (1000 * 1000)
        return '{:.2f}M'.format(meta)
    elif para_num > 1000:
        kelo = para_num / 1000
        return '{:.2f}K'.format(kelo)
    else:
        return para_num


def count_params(model):
    total_params = sum(p.numel() for p in model.parameters())
    return transfer_size(total_params)


def expand_dims_like(x, y):
    while x.dim() != y.dim():
        x = x.unsqueeze(-1)
    return x


def unpack_tensor_into_imagelist(image_tensor, shapes):
    image_list = []
    for img, shape in zip(image_tensor, shapes):
        h, w = shape[0], shape[1]
        image_list.append(img[:, :h * w].view(1, -1, h, w))

    return image_list


def find_example(tensor_list, image_list):
    for i in tensor_list:
        if isinstance(i, torch.Tensor):
            return torch.zeros_like(i)
    for i in image_list:
        if isinstance(i, torch.Tensor):
            _, c, h, w = i.size()
            return torch.zeros_like(i.view(c, h * w).transpose(1, 0))
    return None


def pack_imagelist_into_tensor_v2(image_list):
    # allow None
    example = None
    image_tensor, shapes = [], []
    for img in image_list:
        if img is None:
            example = find_example(image_tensor,
                                   image_list) if example is None else example
            image_tensor.append(example)
            shapes.append(None)
            continue
        _, c, h, w = img.size()
        image_tensor.append(img.view(c, h * w).transpose(1, 0))  # h*w, c
        shapes.append((h, w))

    image_tensor = pad_sequence(image_tensor,
                                batch_first=True).permute(0, 2, 1)  # b, c, l
    return image_tensor, shapes


def to_device(inputs, strict=True):
    if inputs is None:
        return None
    if strict:
        assert all(isinstance(i, torch.Tensor) for i in inputs)
    return [i.to(we.device_id) if i is not None else None for i in inputs]


def check_list_of_list(ll):
    return isinstance(ll, list) and all(isinstance(i, list) for i in ll)