Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
from .JPEG_utils import diff_round, quality_to_factor, Quantization | |
from .compression import compress_jpeg | |
from .decompression import decompress_jpeg | |
class DiffJPEG(nn.Module): | |
def __init__(self, differentiable=True, quality=75): | |
"""Initialize the DiffJPEG layer | |
Inputs: | |
height(int): Original image height | |
width(int): Original image width | |
differentiable(bool): If true uses custom differentiable | |
rounding function, if false uses standrard torch.round | |
quality(float): Quality factor for jpeg compression scheme. | |
""" | |
super(DiffJPEG, self).__init__() | |
if differentiable: | |
rounding = diff_round | |
# rounding = Quantization() | |
else: | |
rounding = torch.round | |
factor = quality_to_factor(quality) | |
self.compress = compress_jpeg(rounding=rounding, factor=factor) | |
# self.decompress = decompress_jpeg(height, width, rounding=rounding, | |
# factor=factor) | |
self.decompress = decompress_jpeg(rounding=rounding, factor=factor) | |
def forward(self, x): | |
""" """ | |
org_height = x.shape[2] | |
org_width = x.shape[3] | |
y, cb, cr = self.compress(x) | |
recovered = self.decompress(y, cb, cr, org_height, org_width) | |
return recovered | |