|
import gradio as gr |
|
from huggingface_hub import hf_hub_download, snapshot_download |
|
import subprocess |
|
import tempfile |
|
import shutil |
|
import os |
|
import spaces |
|
import importlib |
|
from transformers import T5ForConditionalGeneration, T5Tokenizer |
|
import os |
|
|
|
def download_t5_model(model_id, save_directory): |
|
|
|
if not os.path.exists(save_directory): |
|
os.makedirs(save_directory) |
|
snapshot_download(repo_id="DeepFloyd/t5-v1_1-xxl",local_dir=save_directory, local_dir_use_symlinks=False) |
|
|
|
|
|
model_id = "DeepFloyd/t5-v1_1-xxl" |
|
save_directory = "pretrained_models/t5_ckpts/t5-v1_1-xxl" |
|
|
|
|
|
download_t5_model(model_id, save_directory) |
|
|
|
def download_model(repo_id, model_name): |
|
model_path = hf_hub_download(repo_id=repo_id, filename=model_name) |
|
return model_path |
|
|
|
import glob |
|
|
|
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) |
|
|
|
@spaces.GPU(duration=200) |
|
def run_inference(prompt_text): |
|
repo_id = "hpcai-tech/Open-Sora" |
|
|
|
|
|
model_name = "OpenSora-v1-HQ-16x512x512.pth" |
|
config_mapping = { |
|
"OpenSora-v1-16x256x256.pth": "configs/opensora/inference/16x256x256.py", |
|
"OpenSora-v1-HQ-16x256x256.pth": "configs/opensora/inference/16x512x512.py", |
|
"OpenSora-v1-HQ-16x512x512.pth": "configs/opensora/inference/64x512x512.py" |
|
} |
|
|
|
config_path = config_mapping[model_name] |
|
ckpt_path = download_model(repo_id, model_name) |
|
|
|
|
|
prompt_file = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode='w') |
|
prompt_file.write(prompt_text) |
|
prompt_file.close() |
|
|
|
with open(config_path, 'r') as file: |
|
config_content = file.read() |
|
config_content = config_content.replace('prompt_path = "./assets/texts/t2v_samples.txt"', f'prompt_path = "{prompt_file.name}"') |
|
|
|
with tempfile.NamedTemporaryFile('w', delete=False, suffix='.py') as temp_file: |
|
temp_file.write(config_content) |
|
temp_config_path = temp_file.name |
|
|
|
cmd = [ |
|
"torchrun", "--standalone", "--nproc_per_node", "1", |
|
"scripts/inference.py", temp_config_path, |
|
"--ckpt-path", ckpt_path |
|
] |
|
subprocess.run(cmd) |
|
|
|
save_dir = "./outputs/samples/" |
|
list_of_files = glob.glob(f'{save_dir}/*') |
|
if list_of_files: |
|
latest_file = max(list_of_files, key=os.path.getctime) |
|
return latest_file |
|
else: |
|
print("No files found in the output directory.") |
|
return None |
|
|
|
|
|
os.remove(temp_file.name) |
|
os.remove(prompt_file.name) |
|
|
|
def main(): |
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.HTML( |
|
""" |
|
<h1 style='text-align: center'> |
|
Open-Sora: Democratizing Efficient Video Production for All |
|
</h1> |
|
""" |
|
) |
|
gr.HTML( |
|
""" |
|
<h3 style='text-align: center'> |
|
Follow me for more! |
|
<a href='https://twitter.com/kadirnar_ai' target='_blank'>Twitter</a> | <a href='https://github.com/kadirnar' target='_blank'>Github</a> | <a href='https://www.linkedin.com/in/kadir-nar/' target='_blank'>Linkedin</a> |
|
</h3> |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
prompt_text = gr.Textbox(show_label=False, placeholder="Enter prompt text here", lines=4) |
|
submit_button = gr.Button("Run Inference") |
|
|
|
with gr.Column(): |
|
output_video = gr.Video() |
|
|
|
submit_button.click( |
|
fn=run_inference, |
|
inputs=[prompt_text], |
|
outputs=output_video |
|
) |
|
gr.Examples( |
|
examples=[ |
|
[ |
|
"Animated scene features a close-up of a short fluffy monster kneeling beside a melting red candle. The art style is 3D and realistic, with a focus on lighting and texture. The mood of the painting is one of wonder and curiosity, as the monster gazes at the flame with wide eyes and open mouth. Its pose and expression convey a sense of innocence and playfulness, as if it is exploring the world around it for the first time. The use of warm colors and dramatic lighting further enhances the cozy atmosphere of the image.", |
|
], |
|
], |
|
fn=run_inference, |
|
inputs=[prompt_text,], |
|
outputs=[output_video], |
|
cache_examples=True, |
|
) |
|
|
|
demo.launch(debug=True) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|