""" This script is used to create a Streamlit web application for generating videos using the CogVideoX model. Run the script using Streamlit: $ export OPENAI_API_KEY=your OpenAI Key or ZhiupAI Key $ export OPENAI_BASE_URL=https://open.bigmodel.cn/api/paas/v4/ # using with ZhipuAI, Not using this when using OpenAI $ streamlit run web_demo.py """ import base64 import json import os import time from datetime import datetime from typing import List import imageio import numpy as np import streamlit as st import torch from convert_demo import convert_prompt from diffusers import CogVideoXPipeline model_path: str = "THUDM/CogVideoX-2b" # Load the model at the start @st.cache_resource def load_model(model_path: str, dtype: torch.dtype, device: str) -> CogVideoXPipeline: """ Load the CogVideoX model. Args: - model_path (str): Path to the model. - dtype (torch.dtype): Data type for model. - device (str): Device to load the model on. Returns: - CogVideoXPipeline: Loaded model pipeline. """ return CogVideoXPipeline.from_pretrained(model_path, torch_dtype=dtype).to(device) # Define a function to generate video based on the provided prompt and model path def generate_video( pipe: CogVideoXPipeline, prompt: str, num_inference_steps: int = 50, guidance_scale: float = 6.0, num_videos_per_prompt: int = 1, device: str = "cuda", dtype: torch.dtype = torch.float16, ) -> List[np.ndarray]: """ Generate a video based on the provided prompt and model path. Args: - pipe (CogVideoXPipeline): The pipeline for generating videos. - prompt (str): Text prompt for video generation. - num_inference_steps (int): Number of inference steps. - guidance_scale (float): Guidance scale for generation. - num_videos_per_prompt (int): Number of videos to generate per prompt. - device (str): Device to run the generation on. - dtype (torch.dtype): Data type for the model. Returns: - List[np.ndarray]: Generated video frames. """ prompt_embeds, _ = pipe.encode_prompt( prompt=prompt, negative_prompt=None, do_classifier_free_guidance=True, num_videos_per_prompt=num_videos_per_prompt, max_sequence_length=226, device=device, dtype=dtype, ) # Generate video video = pipe( num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, prompt_embeds=prompt_embeds, negative_prompt_embeds=torch.zeros_like(prompt_embeds), ).frames[0] return video def save_video(video: List[np.ndarray], path: str, fps: int = 8) -> None: """ Save the generated video to a file. Args: - video (List[np.ndarray]): Video frames. - path (str): Path to save the video. - fps (int): Frames per second for the video. """ # Remove the first frame video = video[1:] writer = imageio.get_writer(path, fps=fps, codec="libx264") for frame in video: np_frame = np.array(frame) writer.append_data(np_frame) writer.close() def save_metadata( prompt: str, converted_prompt: str, num_inference_steps: int, guidance_scale: float, num_videos_per_prompt: int, path: str, ) -> None: """ Save metadata to a JSON file. Args: - prompt (str): Original prompt. - converted_prompt (str): Converted prompt. - num_inference_steps (int): Number of inference steps. - guidance_scale (float): Guidance scale. - num_videos_per_prompt (int): Number of videos per prompt. - path (str): Path to save the metadata. """ metadata = { "prompt": prompt, "converted_prompt": converted_prompt, "num_inference_steps": num_inference_steps, "guidance_scale": guidance_scale, "num_videos_per_prompt": num_videos_per_prompt, } with open(path, "w") as f: json.dump(metadata, f, indent=4) def main() -> None: """ Main function to run the Streamlit web application. """ st.set_page_config(page_title="CogVideoX-Demo", page_icon="🎥", layout="wide") st.write("# CogVideoX 🎥") dtype: torch.dtype = torch.float16 device: str = "cuda" global pipe pipe = load_model(model_path, dtype, device) with st.sidebar: st.info("It will take some time to generate a video (~90 seconds per videos in 50 steps).", icon="ℹ️") num_inference_steps: int = st.number_input("Inference Steps", min_value=1, max_value=100, value=50) guidance_scale: float = st.number_input("Guidance Scale", min_value=0.0, max_value=20.0, value=6.0) num_videos_per_prompt: int = st.number_input("Videos per Prompt", min_value=1, max_value=10, value=1) share_links_container = st.empty() prompt: str = st.chat_input("Prompt") if prompt: # Not Necessary, Suggestions with st.spinner("Refining prompts..."): converted_prompt = convert_prompt(prompt=prompt, retry_times=1) if converted_prompt is None: st.error("Failed to Refining the prompt, Using origin one.") st.info(f"**Origin prompt:** \n{prompt} \n \n**Convert prompt:** \n{converted_prompt}") torch.cuda.empty_cache() with st.spinner("Generating Video..."): start_time = time.time() video_paths = [] timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") output_dir = f"./output/{timestamp}" os.makedirs(output_dir, exist_ok=True) metadata_path = os.path.join(output_dir, "config.json") save_metadata( prompt, converted_prompt, num_inference_steps, guidance_scale, num_videos_per_prompt, metadata_path ) for i in range(num_videos_per_prompt): video_path = os.path.join(output_dir, f"output_{i + 1}.mp4") video = generate_video( pipe, converted_prompt or prompt, num_inference_steps, guidance_scale, 1, device, dtype ) save_video(video, video_path, fps=8) video_paths.append(video_path) with open(video_path, "rb") as video_file: video_bytes: bytes = video_file.read() st.video(video_bytes, autoplay=True, loop=True, format="video/mp4") torch.cuda.empty_cache() used_time: float = time.time() - start_time st.success(f"Videos generated in {used_time:.2f} seconds.") # Create download links in the sidebar with share_links_container: st.sidebar.write("### Download Links:") for video_path in video_paths: video_name = os.path.basename(video_path) with open(video_path, "rb") as f: video_bytes: bytes = f.read() b64_video = base64.b64encode(video_bytes).decode() href = f'Download {video_name}' st.sidebar.markdown(href, unsafe_allow_html=True) if __name__ == "__main__": main()