File size: 3,590 Bytes
ebaff66
c811a04
 
 
 
 
cef1afc
 
bebbcd0
 
ebaff66
bebbcd0
 
 
 
 
ebaff66
bebbcd0
 
 
 
 
 
 
 
 
 
cef1afc
bebbcd0
 
 
 
 
cef1afc
bebbcd0
 
cef1afc
ebaff66
bebbcd0
 
 
cef1afc
 
bebbcd0
 
ebaff66
 
bebbcd0
cef1afc
 
 
 
 
 
 
bebbcd0
ebaff66
 
 
bebbcd0
ebaff66
bebbcd0
ebaff66
bebbcd0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ebaff66
bebbcd0
ebaff66
 
 
 
 
 
 
 
 
 
 
 
 
cef1afc
 
 
bebbcd0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import torch
from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder
from xora.models.transformers.transformer3d import Transformer3DModel
from xora.models.transformers.symmetric_patchifier import SymmetricPatchifier
from xora.schedulers.rf import RectifiedFlowScheduler
from xora.pipelines.pipeline_video_pixart_alpha import VideoPixArtAlphaPipeline
from pathlib import Path
from transformers import T5EncoderModel
import safetensors.torch
import json

# Paths for the separate mode directories
separate_dir = Path("/opt/models/xora-txt2video")
unet_dir = separate_dir / 'unet'
vae_dir = separate_dir / 'vae'
scheduler_dir = separate_dir / 'scheduler'

# Load VAE from separate mode
vae_ckpt_path = vae_dir / "diffusion_pytorch_model.safetensors"
vae_config_path = vae_dir / "config.json"
with open(vae_config_path, 'r') as f:
    vae_config = json.load(f)
vae_state_dict = safetensors.torch.load_file(vae_ckpt_path)
vae = CausalVideoAutoencoder.from_pretrained_conf(
    config=vae_config,
    state_dict=vae_state_dict,
    torch_dtype=torch.bfloat16
).cuda()

# Load UNet (Transformer) from separate mode
unet_ckpt_path = unet_dir / "diffusion_pytorch_model.safetensors"
unet_config_path = unet_dir / "config.json"
transformer_config = Transformer3DModel.load_config(unet_config_path)
transformer = Transformer3DModel.from_config(transformer_config)
unet_state_dict = safetensors.torch.load_file(unet_ckpt_path)
transformer.load_state_dict(unet_state_dict, strict=True)
transformer = transformer.cuda()
unet = transformer

# Load Scheduler from separate mode
scheduler_config_path = scheduler_dir / "scheduler_config.json"
scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path)
scheduler = RectifiedFlowScheduler.from_config(scheduler_config)

# Patchifier (remains the same)
patchifier = SymmetricPatchifier(patch_size=1)

# Use submodels for the pipeline
submodel_dict = {
    "unet": unet,
    "transformer": transformer,
    "patchifier": patchifier,
    "scheduler": scheduler,
    "vae": vae,
}
model_name_or_path = "PixArt-alpha/PixArt-XL-2-1024-MS"
pipeline = VideoPixArtAlphaPipeline.from_pretrained(model_name_or_path,
                                                    safety_checker=None,
            revision=None,
            torch_dtype=torch.float32,
            **submodel_dict,
        ).to("cuda")

# Sample input
num_inference_steps = 20
num_images_per_prompt = 2
guidance_scale = 3
height = 512
width = 768
num_frames = 57
frame_rate = 25
sample = {
    "prompt": "A middle-aged man with glasses and a salt-and-pepper beard is driving a car and talking, gesturing with his right hand. "
              "The man is wearing a dark blue zip-up jacket and a light blue collared shirt. He is sitting in the driver's seat of a car with a black interior. The car is moving on a road with trees and bushes on either side. The man has a serious expression on his face and is looking straight ahead.",
    'prompt_attention_mask': None,  # Adjust attention masks as needed
    'negative_prompt': "Ugly deformed",
    'negative_prompt_attention_mask': None
}

# Generate images (video frames)
images = pipeline(
    num_inference_steps=num_inference_steps,
    num_images_per_prompt=num_images_per_prompt,
    guidance_scale=guidance_scale,
    generator=None,
    output_type="pt",
    callback_on_step_end=None,
    height=height,
    width=width,
    num_frames=num_frames,
    frame_rate=frame_rate,
    **sample,
    is_video=True,
    vae_per_channel_normalize=True,
).images

print("Generated images (video frames).")