|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import os |
|
|
|
import imageio |
|
import torch |
|
|
|
from cosmos1.models.autoregressive.inference.world_generation_pipeline import ARVideo2WorldGenerationPipeline |
|
from cosmos1.models.autoregressive.utils.inference import add_common_arguments, load_vision_input, validate_args |
|
from .log import log |
|
from io import read_prompts_from_file |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser(description="Prompted video to world generation demo script") |
|
add_common_arguments(parser) |
|
parser.add_argument( |
|
"--ar_model_dir", |
|
type=str, |
|
default="Cosmos-1.0-Autoregressive-5B-Video2World", |
|
) |
|
parser.add_argument( |
|
"--input_type", |
|
type=str, |
|
default="text_and_video", |
|
choices=["text_and_image", "text_and_video"], |
|
help="Input types", |
|
) |
|
parser.add_argument( |
|
"--prompt", |
|
type=str, |
|
help="Text prompt for generating a single video", |
|
) |
|
parser.add_argument( |
|
"--offload_text_encoder_model", |
|
action="store_true", |
|
help="Offload T5 model after inference", |
|
) |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def main(args): |
|
"""Run prompted video-to-world generation demo. |
|
|
|
This function handles the main video-to-world generation pipeline, including: |
|
- Setting up the random seed for reproducibility |
|
- Initializing the generation pipeline with the provided configuration |
|
- Processing single or multiple prompts/images/videos from input |
|
- Generating videos from prompts and images/videos |
|
- Saving the generated videos and corresponding prompts to disk |
|
|
|
Args: |
|
cfg (argparse.Namespace): Configuration namespace containing: |
|
- Model configuration (checkpoint paths, model settings) |
|
- Generation parameters (temperature, top_p) |
|
- Input/output settings (images/videos, save paths) |
|
- Performance options (model offloading settings) |
|
|
|
The function will save: |
|
- Generated MP4 video files |
|
|
|
If guardrails block the generation, a critical log message is displayed |
|
and the function continues to the next prompt if available. |
|
""" |
|
inference_type = "video2world" |
|
sampling_config = validate_args(args, inference_type) |
|
|
|
|
|
pipeline = ARVideo2WorldGenerationPipeline( |
|
inference_type=inference_type, |
|
checkpoint_dir=args.checkpoint_dir, |
|
checkpoint_name=args.ar_model_dir, |
|
disable_diffusion_decoder=args.disable_diffusion_decoder, |
|
offload_guardrail_models=args.offload_guardrail_models, |
|
offload_diffusion_decoder=args.offload_diffusion_decoder, |
|
offload_network=args.offload_ar_model, |
|
offload_tokenizer=args.offload_tokenizer, |
|
offload_text_encoder_model=args.offload_text_encoder_model, |
|
) |
|
|
|
|
|
input_videos = load_vision_input( |
|
input_type=args.input_type, |
|
batch_input_path=args.batch_input_path, |
|
input_image_or_video_path=args.input_image_or_video_path, |
|
data_resolution=args.data_resolution, |
|
num_input_frames=args.num_input_frames, |
|
) |
|
|
|
if args.batch_input_path: |
|
prompts_list = read_prompts_from_file(args.batch_input_path) |
|
else: |
|
prompts_list = [{"visual_input": args.input_image_or_video_path, "prompt": args.prompt}] |
|
|
|
|
|
for idx, prompt_entry in enumerate(prompts_list): |
|
video_path = prompt_entry["visual_input"] |
|
input_filename = os.path.basename(video_path) |
|
|
|
|
|
if input_filename not in input_videos: |
|
log.critical(f"Input file {input_filename} not found, skipping prompt.") |
|
continue |
|
|
|
inp_vid = input_videos[input_filename] |
|
inp_prompt = prompt_entry["prompt"] |
|
|
|
|
|
log.info(f"Run with input: {prompt_entry}") |
|
out_vid = pipeline.generate( |
|
inp_prompt=inp_prompt, |
|
inp_vid=inp_vid, |
|
num_input_frames=args.num_input_frames, |
|
seed=args.seed, |
|
sampling_config=sampling_config, |
|
) |
|
if out_vid is None: |
|
log.critical("Guardrail blocked video2world generation.") |
|
continue |
|
|
|
|
|
if args.input_image_or_video_path: |
|
out_vid_path = os.path.join(args.video_save_folder, f"{args.video_save_name}.mp4") |
|
else: |
|
out_vid_path = os.path.join(args.video_save_folder, f"{idx}.mp4") |
|
imageio.mimsave(out_vid_path, out_vid, fps=25) |
|
|
|
log.info(f"Saved video to {out_vid_path}") |
|
|
|
|
|
if __name__ == "__main__": |
|
torch._C._jit_set_texpr_fuser_enabled(False) |
|
args = parse_args() |
|
main(args) |
|
|