|
import torch |
|
|
|
class CoordStage(object): |
|
def __init__(self, n_embed, down_factor): |
|
self.n_embed = n_embed |
|
self.down_factor = down_factor |
|
|
|
def eval(self): |
|
return self |
|
|
|
def encode(self, c): |
|
"""fake vqmodel interface""" |
|
assert 0.0 <= c.min() and c.max() <= 1.0 |
|
b,ch,h,w = c.shape |
|
assert ch == 1 |
|
|
|
c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor, |
|
mode="area") |
|
c = c.clamp(0.0, 1.0) |
|
c = self.n_embed*c |
|
c_quant = c.round() |
|
c_ind = c_quant.to(dtype=torch.long) |
|
|
|
info = None, None, c_ind |
|
return c_quant, None, info |
|
|
|
def decode(self, c): |
|
c = c/self.n_embed |
|
c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor, |
|
mode="nearest") |
|
return c |
|
|