Sapir commited on
Commit
d699d2b
·
1 Parent(s): bebbcd0

CausalVideoAutoencoder: made neater load_ckpt.

Browse files
xora/examples/image_to_video.py CHANGED
@@ -19,12 +19,12 @@ vae_ckpt_path = vae_dir / "diffusion_pytorch_model.safetensors"
19
  vae_config_path = vae_dir / "config.json"
20
  with open(vae_config_path, 'r') as f:
21
  vae_config = json.load(f)
 
22
  vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
23
- vae = CausalVideoAutoencoder.from_pretrained_conf(
24
- config=vae_config,
25
  state_dict=vae_state_dict,
26
- torch_dtype=torch.bfloat16
27
- ).cuda()
28
 
29
  # Load UNet (Transformer) from separate mode
30
  unet_ckpt_path = unet_dir / "diffusion_pytorch_model.safetensors"
 
19
  vae_config_path = vae_dir / "config.json"
20
  with open(vae_config_path, 'r') as f:
21
  vae_config = json.load(f)
22
+ vae = CausalVideoAutoencoder.from_config(vae_config)
23
  vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
24
+ vae.load_state_dict(
 
25
  state_dict=vae_state_dict,
26
+ )
27
+ vae = vae.cuda().to(torch.bfloat16)
28
 
29
  # Load UNet (Transformer) from separate mode
30
  unet_ckpt_path = unet_dir / "diffusion_pytorch_model.safetensors"
xora/examples/text_to_video.py CHANGED
@@ -10,7 +10,7 @@ import safetensors.torch
10
  import json
11
 
12
  # Paths for the separate mode directories
13
- separate_dir = Path("/opt/models/xora-txt2video")
14
  unet_dir = separate_dir / 'unet'
15
  vae_dir = separate_dir / 'vae'
16
  scheduler_dir = separate_dir / 'scheduler'
@@ -20,12 +20,12 @@ vae_ckpt_path = vae_dir / "diffusion_pytorch_model.safetensors"
20
  vae_config_path = vae_dir / "config.json"
21
  with open(vae_config_path, 'r') as f:
22
  vae_config = json.load(f)
 
23
  vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
24
- vae = CausalVideoAutoencoder.from_pretrained_conf(
25
- config=vae_config,
26
  state_dict=vae_state_dict,
27
- torch_dtype=torch.bfloat16
28
- ).cuda()
29
 
30
  # Load UNet (Transformer) from separate mode
31
  unet_ckpt_path = unet_dir / "diffusion_pytorch_model.safetensors"
 
10
  import json
11
 
12
  # Paths for the separate mode directories
13
+ separate_dir = Path("/opt/models/xora-img2video")
14
  unet_dir = separate_dir / 'unet'
15
  vae_dir = separate_dir / 'vae'
16
  scheduler_dir = separate_dir / 'scheduler'
 
20
  vae_config_path = vae_dir / "config.json"
21
  with open(vae_config_path, 'r') as f:
22
  vae_config = json.load(f)
23
+ vae = CausalVideoAutoencoder.from_config(vae_config)
24
  vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
25
+ vae.load_state_dict(
 
26
  state_dict=vae_state_dict,
27
+ )
28
+ vae = vae.cuda().to(torch.bfloat16)
29
 
30
  # Load UNet (Transformer) from separate mode
31
  unet_ckpt_path = unet_dir / "diffusion_pytorch_model.safetensors"
xora/models/autoencoders/causal_video_autoencoder.py CHANGED
@@ -41,35 +41,6 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
41
 
42
  return video_vae
43
 
44
- @classmethod
45
- def from_pretrained_conf(cls, config, state_dict, torch_dtype=torch.float32):
46
- video_vae = cls.from_config(config)
47
- video_vae.to(torch_dtype)
48
-
49
- per_channel_statistics_prefix = "per_channel_statistics."
50
- ckpt_state_dict = {
51
- key: value
52
- for key, value in state_dict.items()
53
- if not key.startswith(per_channel_statistics_prefix)
54
- }
55
- video_vae.load_state_dict(ckpt_state_dict)
56
-
57
- data_dict = {
58
- key.removeprefix(per_channel_statistics_prefix): value
59
- for key, value in state_dict.items()
60
- if key.startswith(per_channel_statistics_prefix)
61
- }
62
- if len(data_dict) > 0:
63
- video_vae.register_buffer("std_of_means", data_dict["std-of-means"])
64
- video_vae.register_buffer(
65
- "mean_of_means",
66
- data_dict.get(
67
- "mean-of-means", torch.zeros_like(data_dict["std-of-means"])
68
- ),
69
- )
70
-
71
- return video_vae
72
-
73
  @staticmethod
74
  def from_config(config):
75
  assert config["_class_name"] == "CausalVideoAutoencoder", "config must have _class_name=CausalVideoAutoencoder"
@@ -155,6 +126,13 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
155
  return json.dumps(self.config.__dict__)
156
 
157
  def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
 
 
 
 
 
 
 
158
  model_keys = set(name for name, _ in self.named_parameters())
159
 
160
  key_mapping = {
@@ -162,9 +140,8 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
162
  "downsamplers.0": "downsample",
163
  "upsamplers.0": "upsample",
164
  }
165
-
166
  converted_state_dict = {}
167
- for key, value in state_dict.items():
168
  for k, v in key_mapping.items():
169
  key = key.replace(k, v)
170
 
@@ -176,6 +153,20 @@ class CausalVideoAutoencoder(AutoencoderKLWrapper):
176
 
177
  super().load_state_dict(converted_state_dict, strict=strict)
178
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  def last_layer(self):
180
  if hasattr(self.decoder, "conv_out"):
181
  if isinstance(self.decoder.conv_out, nn.Sequential):
 
41
 
42
  return video_vae
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  @staticmethod
45
  def from_config(config):
46
  assert config["_class_name"] == "CausalVideoAutoencoder", "config must have _class_name=CausalVideoAutoencoder"
 
126
  return json.dumps(self.config.__dict__)
127
 
128
  def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
129
+ per_channel_statistics_prefix = "per_channel_statistics."
130
+ ckpt_state_dict = {
131
+ key: value
132
+ for key, value in state_dict.items()
133
+ if not key.startswith(per_channel_statistics_prefix)
134
+ }
135
+
136
  model_keys = set(name for name, _ in self.named_parameters())
137
 
138
  key_mapping = {
 
140
  "downsamplers.0": "downsample",
141
  "upsamplers.0": "upsample",
142
  }
 
143
  converted_state_dict = {}
144
+ for key, value in ckpt_state_dict.items():
145
  for k, v in key_mapping.items():
146
  key = key.replace(k, v)
147
 
 
153
 
154
  super().load_state_dict(converted_state_dict, strict=strict)
155
 
156
+ data_dict = {
157
+ key.removeprefix(per_channel_statistics_prefix): value
158
+ for key, value in state_dict.items()
159
+ if key.startswith(per_channel_statistics_prefix)
160
+ }
161
+ if len(data_dict) > 0:
162
+ self.register_buffer("std_of_means", data_dict["std-of-means"])
163
+ self.register_buffer(
164
+ "mean_of_means",
165
+ data_dict.get(
166
+ "mean-of-means", torch.zeros_like(data_dict["std-of-means"])
167
+ ),
168
+ )
169
+
170
  def last_layer(self):
171
  if hasattr(self.decoder, "conv_out"):
172
  if isinstance(self.decoder.conv_out, nn.Sequential):