Sapir commited on
Commit
b30014f
1 Parent(s): 4f52f00

import fixes.

Browse files
xora/models/autoencoders/video_autoencoder.py CHANGED
@@ -11,7 +11,7 @@ from torch.nn import functional
11
 
12
  from diffusers.utils import logging
13
 
14
- from txt2img.models.layers.nn import Identity
15
  from xora.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd
16
  from xora.models.autoencoders.pixel_norm import PixelNorm
17
  from xora.models.autoencoders.vae import AutoencoderKLWrapper
 
11
 
12
  from diffusers.utils import logging
13
 
14
+ from xora.utils.torch_utils import Identity
15
  from xora.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd
16
  from xora.models.autoencoders.pixel_norm import PixelNorm
17
  from xora.models.autoencoders.vae import AutoencoderKLWrapper
xora/schedulers/rf.py CHANGED
@@ -9,7 +9,7 @@ from diffusers.schedulers.scheduling_utils import SchedulerMixin
9
  from diffusers.utils import BaseOutput
10
  from torch import Tensor
11
 
12
- from txt2img.common.torch_utils import append_dims
13
 
14
 
15
  def simple_diffusion_resolution_dependent_timestep_shift(
 
9
  from diffusers.utils import BaseOutput
10
  from torch import Tensor
11
 
12
+ from xora.utils.torch_utils import append_dims
13
 
14
 
15
  def simple_diffusion_resolution_dependent_timestep_shift(
xora/utils/torch_utils.py CHANGED
@@ -1,4 +1,5 @@
1
  import torch
 
2
 
3
  def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor:
4
  """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
@@ -8,3 +9,13 @@ def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor:
8
  elif dims_to_append == 0:
9
  return x
10
  return x[(...,) + (None,) * dims_to_append]
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ from torch import nn
3
 
4
  def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor:
5
  """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
 
9
  elif dims_to_append == 0:
10
  return x
11
  return x[(...,) + (None,) * dims_to_append]
12
+
13
+ class Identity(nn.Module):
14
+ """A placeholder identity operator that is argument-insensitive."""
15
+
16
+ def __init__(self, *args, **kwargs) -> None: # pylint: disable=unused-argument
17
+ super().__init__()
18
+
19
+ # pylint: disable=unused-argument
20
+ def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
21
+ return x