|
import gradio as gr |
|
import subprocess |
|
import tempfile |
|
import shutil |
|
|
|
def run_inference(config_path, ckpt_path, prompt_path): |
|
with open(config_path, 'r') as file: |
|
config_content = file.read() |
|
config_content = config_content.replace('prompt_path = "./assets/texts/t2v_samples.txt"', f'prompt_path = "{prompt_path}"') |
|
|
|
with tempfile.NamedTemporaryFile('w', delete=False) as temp_file: |
|
temp_file.write(config_content) |
|
temp_config_path = temp_file.name |
|
|
|
cmd = [ |
|
"torchrun", "--standalone", "--nproc_per_node", "1", |
|
"scripts/inference.py", temp_config_path, |
|
"--ckpt-path", ckpt_path |
|
] |
|
result = subprocess.run(cmd, capture_output=True, text=True) |
|
|
|
shutil.rmtree(temp_config_path) |
|
|
|
if result.returncode == 0: |
|
return "Inference completed successfully.", result.stdout |
|
else: |
|
return "Error occurred:", result.stderr |
|
|
|
def main(): |
|
gr.Interface( |
|
fn=run_inference, |
|
inputs=[ |
|
gr.Textbox(label="Configuration Path"), |
|
gr.Dropdown(choices=["./path/to/model1.ckpt", "./path/to/model2.ckpt", "./path/to/model3.ckpt"], label="Checkpoint Path"), |
|
gr.Textbox(label="Prompt Path") |
|
], |
|
outputs=[ |
|
gr.Text(label="Status"), |
|
gr.Text(label="Output") |
|
], |
|
title="Open-Sora Inference", |
|
description="Run Open-Sora Inference with Custom Parameters" |
|
).launch() |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|