Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,265 Bytes
28c256d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Union
import numpy as np
def quantize(arr: np.ndarray,
min_val: Union[int, float],
max_val: Union[int, float],
levels: int,
dtype=np.int64) -> tuple:
"""Quantize an array of (-inf, inf) to [0, levels-1].
Args:
arr (ndarray): Input array.
min_val (int or float): Minimum value to be clipped.
max_val (int or float): Maximum value to be clipped.
levels (int): Quantization levels.
dtype (np.type): The type of the quantized array.
Returns:
tuple: Quantized array.
"""
if not (isinstance(levels, int) and levels > 1):
raise ValueError(
f'levels must be a positive integer, but got {levels}')
if min_val >= max_val:
raise ValueError(
f'min_val ({min_val}) must be smaller than max_val ({max_val})')
arr = np.clip(arr, min_val, max_val) - min_val
quantized_arr = np.minimum(
np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1)
return quantized_arr
def dequantize(arr: np.ndarray,
min_val: Union[int, float],
max_val: Union[int, float],
levels: int,
dtype=np.float64) -> tuple:
"""Dequantize an array.
Args:
arr (ndarray): Input array.
min_val (int or float): Minimum value to be clipped.
max_val (int or float): Maximum value to be clipped.
levels (int): Quantization levels.
dtype (np.type): The type of the dequantized array.
Returns:
tuple: Dequantized array.
"""
if not (isinstance(levels, int) and levels > 1):
raise ValueError(
f'levels must be a positive integer, but got {levels}')
if min_val >= max_val:
raise ValueError(
f'min_val ({min_val}) must be smaller than max_val ({max_val})')
dequantized_arr = (arr + 0.5).astype(dtype) * (max_val -
min_val) / levels + min_val
return dequantized_arr
|