File size: 902 Bytes
1ba3df3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch
import numpy as np

def magic_image_handler(img):
    if isinstance(img, torch.Tensor):
        img = img.detach().cpu().numpy()
    if img.ndim == 3:
        img = img.transpose((1, 2, 0))
    elif img.ndim == 2:
        img = np.repeat(img[..., np.newaxis], 3, axis=2)
    elif img.ndim == 4:
        img = img[:4]  # first 4 batch
        img = np.concatenate(img, axis=-1)
        img = img.transpose((1, 2, 0))
    elif img.ndim == 5:
        img = img[:4]  # first 4 batch
        img = np.concatenate(img, axis=-2)
        img = np.concatenate(img, axis=-1)
        img = img.transpose((1, 2, 0))
    else:
        raise ValueError(f'img ndim is {img.ndim}, should be 2~4')
    if img.shape[-1] != 1 or img.shape[-1] != 3:
        img = np.expand_dims(np.concatenate([img[..., i] for i in range(img.shape[-1])], axis=0), -1)
    img = np.clip(img, a_min=0, a_max=255)
    return img