File size: 3,359 Bytes
86b1a7e
 
 
 
 
 
 
bebbcd0
 
86b1a7e
bebbcd0
 
 
 
 
86b1a7e
bebbcd0
 
 
 
 
 
 
 
 
 
86b1a7e
bebbcd0
 
 
 
 
86b1a7e
bebbcd0
 
86b1a7e
 
bebbcd0
 
 
86b1a7e
 
bebbcd0
 
86b1a7e
 
bebbcd0
86b1a7e
 
 
 
 
 
 
 
 
bebbcd0
86b1a7e
 
bebbcd0
 
 
 
86b1a7e
bebbcd0
 
 
 
 
 
 
86b1a7e
bebbcd0
86b1a7e
bebbcd0
86b1a7e
bebbcd0
 
86b1a7e
 
bebbcd0
86b1a7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bebbcd0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch
from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
from xora.models.transformers.transformer3d import Transformer3DModel
from xora.models.transformers.symmetric_patchifier import SymmetricPatchifier
from xora.schedulers.rf import RectifiedFlowScheduler
from xora.pipelines.pipeline_video_pixart_alpha import VideoPixArtAlphaPipeline
from pathlib import Path
import safetensors.torch
import json

# Paths for the separate mode directories
separate_dir = Path("/opt/models/xora-img2video")
unet_dir = separate_dir / 'unet'
vae_dir = separate_dir / 'vae'
scheduler_dir = separate_dir / 'scheduler'

# Load VAE from separate mode
vae_ckpt_path = vae_dir / "diffusion_pytorch_model.safetensors"
vae_config_path = vae_dir / "config.json"
with open(vae_config_path, 'r') as f:
    vae_config = json.load(f)
vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
vae = CausalVideoAutoencoder.from_pretrained_conf(
    config=vae_config,
    state_dict=vae_state_dict,
    torch_dtype=torch.bfloat16
).cuda()

# Load UNet (Transformer) from separate mode
unet_ckpt_path = unet_dir / "diffusion_pytorch_model.safetensors"
unet_config_path = unet_dir / "config.json"
transformer_config = Transformer3DModel.load_config(unet_config_path)
transformer = Transformer3DModel.from_config(transformer_config)
unet_state_dict = safetensors.torch.load_file(unet_ckpt_path)
transformer.load_state_dict(unet_state_dict, strict=True)
transformer = transformer.cuda()
unet = transformer

# Load Scheduler from separate mode
scheduler_config_path = scheduler_dir / "scheduler_config.json"
scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path)
scheduler = RectifiedFlowScheduler.from_config(scheduler_config)

# Patchifier (remains the same)
patchifier = SymmetricPatchifier(patch_size=1)

# Use submodels for the pipeline
submodel_dict = {
    "unet": unet,
    "transformer": transformer,
    "patchifier": patchifier,
    "text_encoder": None,
    "scheduler": scheduler,
    "vae": vae,
}

model_name_or_path = "PixArt-alpha/PixArt-XL-2-1024-MS"
pipeline = VideoPixArtAlphaPipeline.from_pretrained(model_name_or_path,
                                                    safety_checker=None,
                                                    revision=None,
                                                    torch_dtype=torch.float32,  # dtype adjusted
                                                    **submodel_dict,
                                                    ).to("cuda")

num_inference_steps = 20
num_images_per_prompt = 2
guidance_scale = 3
height = 512
width = 768
num_frames = 57
frame_rate = 25

# Assuming sample is a dict loaded from a .pt file
sample = torch.load("/opt/sample.pt")
for key, item in sample.items():
    if item is not None:
        sample[key] = item.cuda()

media_items = torch.load("/opt/sample_media.pt")

# Generate images (video frames)
images = pipeline(
    num_inference_steps=num_inference_steps,
    num_images_per_prompt=num_images_per_prompt,
    guidance_scale=guidance_scale,
    generator=None,
    output_type="pt",
    callback_on_step_end=None,
    height=height,
    width=width,
    num_frames=num_frames,
    frame_rate=frame_rate,
    **sample,
    is_video=True,
    vae_per_channel_normalize=True,
).images

print("Generated video frames.")