Spaces:
Runtime error
Runtime error
import torch | |
class CoordStage(object): | |
def __init__(self, n_embed, down_factor): | |
self.n_embed = n_embed | |
self.down_factor = down_factor | |
def eval(self): | |
return self | |
def encode(self, c): | |
"""fake vqmodel interface""" | |
assert 0.0 <= c.min() and c.max() <= 1.0 | |
b,ch,h,w = c.shape | |
assert ch == 1 | |
c = torch.nn.functional.interpolate(c, scale_factor=1/self.down_factor, | |
mode="area") | |
c = c.clamp(0.0, 1.0) | |
c = self.n_embed*c | |
c_quant = c.round() | |
c_ind = c_quant.to(dtype=torch.long) | |
info = None, None, c_ind | |
return c_quant, None, info | |
def decode(self, c): | |
c = c/self.n_embed | |
c = torch.nn.functional.interpolate(c, scale_factor=self.down_factor, | |
mode="nearest") | |
return c | |