|
import torch
|
|
|
|
def tensor_to_size(source, dest_size):
|
|
if isinstance(dest_size, torch.Tensor):
|
|
dest_size = dest_size.shape[0]
|
|
source_size = source.shape[0]
|
|
|
|
if source_size < dest_size:
|
|
shape = [dest_size - source_size] + [1]*(source.dim()-1)
|
|
source = torch.cat((source, source[-1:].repeat(shape)), dim=0)
|
|
elif source_size > dest_size:
|
|
source = source[:dest_size]
|
|
|
|
return source
|
|
|
|
def tensor_to_image(tensor):
|
|
image = tensor.mul(255).clamp(0, 255).byte().cpu()
|
|
image = image[..., [2, 1, 0]].numpy()
|
|
return image
|
|
|
|
def image_to_tensor(image):
|
|
tensor = torch.clamp(torch.from_numpy(image).float() / 255., 0, 1)
|
|
tensor = tensor[..., [2, 1, 0]]
|
|
return tensor
|
|
|