import torch from xora.models.autoencoders.causal_video_autoencoder import CausalVideoAutoencoder from xora.models.transformers.transformer3d import Transformer3DModel from xora.models.transformers.symmetric_patchifier import SymmetricPatchifier from xora.schedulers.rf import RectifiedFlowScheduler from xora.pipelines.pipeline_video_pixart_alpha import VideoPixArtAlphaPipeline from pathlib import Path from transformers import T5EncoderModel model_name_or_path = "PixArt-alpha/PixArt-XL-2-1024-MS" vae_local_path = Path("/opt/models/checkpoints/vae_training/causal_vvae_32x32x8_420m_cont_32/step_2296000") dtype = torch.float32 vae = CausalVideoAutoencoder.from_pretrained( pretrained_model_name_or_path=vae_local_path, revision=False, torch_dtype=torch.bfloat16, load_in_8bit=False, ).cuda() transformer_config_path = Path("/opt/txt2img/txt2img/config/transformer3d/xora_v1.2-L.json") transformer_config = Transformer3DModel.load_config(transformer_config_path) transformer = Transformer3DModel.from_config(transformer_config) transformer_local_path = Path("/opt/models/logs/v1.2-vae-mf-medHR-mr-cvae-nl/ckpt/01760000/model.pt") transformer_ckpt_state_dict = torch.load(transformer_local_path) transformer.load_state_dict(transformer_ckpt_state_dict, True) transformer = transformer.cuda() unet = transformer scheduler_config_path = Path("/opt/txt2img/txt2img/config/scheduler/RF_SD3_shifted.json") scheduler_config = RectifiedFlowScheduler.load_config(scheduler_config_path) scheduler = RectifiedFlowScheduler.from_config(scheduler_config) patchifier = SymmetricPatchifier(patch_size=1) # text_encoder = T5EncoderModel.from_pretrained("t5-v1_1-xxl") submodel_dict = { "unet": unet, "transformer": transformer, "patchifier": patchifier, "text_encoder": None, "scheduler": scheduler, "vae": vae, } pipeline = VideoPixArtAlphaPipeline.from_pretrained(model_name_or_path, safety_checker=None, revision=None, torch_dtype=dtype, **submodel_dict, ) num_inference_steps=20 num_images_per_prompt=2 guidance_scale=3 height=512 width=768 num_frames=57 frame_rate=25 # sample = { # "prompt": "A cat", # (B, L, E) # 'prompt_attention_mask': None, # (B , L) # 'negative_prompt': "Ugly deformed", # 'negative_prompt_attention_mask': None # (B , L) # } sample = torch.load("/opt/sample.pt") for _, item in sample.items(): if item is not None: item = item.cuda() images = pipeline( num_inference_steps=num_inference_steps, num_images_per_prompt=num_images_per_prompt, guidance_scale=guidance_scale, generator=None, output_type="pt", callback_on_step_end=None, height=height, width=width, num_frames=num_frames, frame_rate=frame_rate, **sample, is_video=True, vae_per_channel_normalize=True, ).images print()