Spaces:
Runtime error
Runtime error
import gradio as gr | |
from huggingface_hub import hf_hub_download, snapshot_download | |
import subprocess | |
import tempfile | |
import shutil | |
import os | |
import spaces | |
import importlib | |
from transformers import T5ForConditionalGeneration, T5Tokenizer | |
import os | |
def check_and_install(package_name): | |
if importlib.util.find_spec(package_name) is None: | |
print(f"{package_name} not installed, installing...") | |
subprocess.run( | |
f'pip install {package_name} --no-build-isolation', | |
env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, | |
shell=True | |
) | |
else: | |
print(f"{package_name} is already installed.") | |
check_and_install('flash_attn') | |
def download_t5_model(model_id, save_directory): | |
# Modelin tokenizer'ını ve modeli indir | |
if not os.path.exists(save_directory): | |
os.makedirs(save_directory) | |
snapshot_download(repo_id="DeepFloyd/t5-v1_1-xxl",local_dir=save_directory, local_dir_use_symlinks=False) | |
# Model ID ve kaydedilecek dizin | |
model_id = "DeepFloyd/t5-v1_1-xxl" | |
save_directory = "pretrained_models/t5_ckpts/t5-v1_1-xxl" | |
# Modeli indir | |
download_t5_model(model_id, save_directory) | |
def download_model(repo_id, model_name): | |
model_path = hf_hub_download(repo_id=repo_id, filename=model_name) | |
return model_path | |
import glob | |
def run_inference(model_name, prompt_text): | |
repo_id = "hpcai-tech/Open-Sora" | |
# Map model names to their respective configuration files | |
config_mapping = { | |
"OpenSora-v1-16x256x256.pth": "configs/opensora/inference/16x256x256.py", | |
"OpenSora-v1-HQ-16x256x256.pth": "configs/opensora/inference/16x512x512.py", | |
"OpenSora-v1-HQ-16x512x512.pth": "configs/opensora/inference/64x512x512.py" | |
} | |
config_path = config_mapping[model_name] | |
ckpt_path = download_model(repo_id, model_name) | |
# Save prompt_text to a temporary text file | |
prompt_file = tempfile.NamedTemporaryFile(delete=False, suffix=".txt", mode='w') | |
prompt_file.write(prompt_text) | |
prompt_file.close() | |
with open(config_path, 'r') as file: | |
config_content = file.read() | |
config_content = config_content.replace('prompt_path = "./assets/texts/t2v_samples.txt"', f'prompt_path = "{prompt_file.name}"') | |
with tempfile.NamedTemporaryFile('w', delete=False, suffix='.py') as temp_file: | |
temp_file.write(config_content) | |
temp_config_path = temp_file.name | |
cmd = [ | |
"torchrun", "--standalone", "--nproc_per_node", "1", | |
"scripts/inference.py", temp_config_path, | |
"--ckpt-path", ckpt_path | |
] | |
subprocess.run(cmd) | |
save_dir = "./outputs/samples/" # Örneğin, inference.py tarafından kullanılan kayıt dizini | |
list_of_files = glob.glob(f'{save_dir}/*') | |
if list_of_files: | |
latest_file = max(list_of_files, key=os.path.getctime) | |
return latest_file | |
else: | |
print("No files found in the output directory.") | |
return None | |
# Clean up the temporary files | |
os.remove(temp_file.name) | |
os.remove(prompt_file.name) | |
def main(): | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
gr.HTML( | |
""" | |
<h1 style='text-align: center'> | |
Open-Sora: Democratizing Efficient Video Production for All | |
</h1> | |
""" | |
) | |
gr.HTML( | |
""" | |
<h3 style='text-align: center'> | |
Follow me for more! | |
<a href='https://twitter.com/kadirnar_ai' target='_blank'>Twitter</a> | <a href='https://github.com/kadirnar' target='_blank'>Github</a> | <a href='https://www.linkedin.com/in/kadir-nar/' target='_blank'>Linkedin</a> | |
</h3> | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
model_dropdown = gr.Dropdown( | |
choices=[ | |
"OpenSora-v1-16x256x256.pth", | |
"OpenSora-v1-HQ-16x256x256.pth", | |
"OpenSora-v1-HQ-16x512x512.pth" | |
], | |
value="OpenSora-v1-16x256x256.pth" | |
) | |
prompt_text = gr.Textbox(show_label=False, placeholder="Enter prompt text here", lines=4) | |
submit_button = gr.Button("Run Inference") | |
with gr.Column(): | |
output_video = gr.Video() | |
submit_button.click( | |
fn=run_inference, | |
inputs=[model_dropdown, prompt_text], | |
outputs=output_video | |
) | |
demo.launch() | |
if __name__ == "__main__": | |
main() | |