Mehdi Cherti commited on
Commit
6c1d070
1 Parent(s): 023c7dd

support higher res/lower res sampling than training time

Browse files
score_sde/models/ncsnpp_generator_adagn.py CHANGED
@@ -379,7 +379,8 @@ class NCSNpp(nn.Module):
379
  #print(hs[-1].shape, temb.shape, zemb.shape, type(modules[m_idx]))
380
  h = modules[m_idx](hs[-1], temb, zemb)
381
  m_idx += 1
382
- if h.shape[-1] in self.attn_resolutions:
 
383
  if type(modules[m_idx]) in (layers.CondAttnBlock, CrossAndGlobalAttnBlock):
384
  h = modules[m_idx](h, cond, cond_mask)
385
  else:
@@ -415,6 +416,7 @@ class NCSNpp(nn.Module):
415
  h = hs[-1]
416
  h = modules[m_idx](h, temb, zemb)
417
  m_idx += 1
 
418
  if type(modules[m_idx]) in (layers.CondAttnBlock, CrossAndGlobalAttnBlock):
419
  h = modules[m_idx](h, cond, cond_mask)
420
  else:
@@ -431,7 +433,8 @@ class NCSNpp(nn.Module):
431
  h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb, zemb)
432
  m_idx += 1
433
 
434
- if h.shape[-1] in self.attn_resolutions:
 
435
  if type(modules[m_idx]) in (layers.CondAttnBlock, CrossAndGlobalAttnBlock):
436
  h = modules[m_idx](h, cond, cond_mask)
437
  else:
 
379
  #print(hs[-1].shape, temb.shape, zemb.shape, type(modules[m_idx]))
380
  h = modules[m_idx](hs[-1], temb, zemb)
381
  m_idx += 1
382
+ if type(modules[m_idx]) in (layers.CondAttnBlock, CrossAndGlobalAttnBlock, layers.AttnBlock):
383
+ #if h.shape[-1] in self.attn_resolutions:
384
  if type(modules[m_idx]) in (layers.CondAttnBlock, CrossAndGlobalAttnBlock):
385
  h = modules[m_idx](h, cond, cond_mask)
386
  else:
 
416
  h = hs[-1]
417
  h = modules[m_idx](h, temb, zemb)
418
  m_idx += 1
419
+
420
  if type(modules[m_idx]) in (layers.CondAttnBlock, CrossAndGlobalAttnBlock):
421
  h = modules[m_idx](h, cond, cond_mask)
422
  else:
 
433
  h = modules[m_idx](torch.cat([h, hs.pop()], dim=1), temb, zemb)
434
  m_idx += 1
435
 
436
+ #if h.shape[-1] in self.attn_resolutions:
437
+ if type(modules[m_idx]) in (layers.CondAttnBlock, CrossAndGlobalAttnBlock, layers.AttnBlock):
438
  if type(modules[m_idx]) in (layers.CondAttnBlock, CrossAndGlobalAttnBlock):
439
  h = modules[m_idx](h, cond, cond_mask)
440
  else: