|
import spaces |
|
import fire |
|
import subprocess |
|
import os |
|
import time |
|
import signal |
|
import subprocess |
|
import atexit |
|
|
|
try: |
|
import flash_attn |
|
except ImportError: |
|
|
|
@spaces.GPU |
|
def install_flash_attn(): |
|
os.system("pip install flash-attn==2.5.9.post1") |
|
|
|
install_flash_attn() |
|
import flash_attn |
|
|
|
|
|
def kill_processes_by_cmd_substring(cmd_substring): |
|
|
|
result = subprocess.run(["ps", "-ef"], stdout=subprocess.PIPE, text=True) |
|
lines = result.stdout.splitlines() |
|
|
|
|
|
for line in lines: |
|
if cmd_substring in line: |
|
|
|
parts = line.split() |
|
pid = int(parts[1]) |
|
print(f"Killing process with PID: {pid}, CMD: {line}") |
|
os.kill(pid, signal.SIGTERM) |
|
|
|
|
|
def main( |
|
python_path="python", |
|
run_controller=True, |
|
run_worker=True, |
|
run_gradio=True, |
|
controller_port=10086, |
|
gradio_port=7860, |
|
worker_names=[ |
|
"OpenGVLab/InternVL2-8B", |
|
], |
|
run_sd_worker=False, |
|
**kwargs, |
|
): |
|
host = "http://0.0.0.0" |
|
controller_process = None |
|
if run_controller: |
|
|
|
cmd_args = [ |
|
f"{python_path}", |
|
"controller.py", |
|
"--host", |
|
"0.0.0.0", |
|
"--port", |
|
f"{controller_port}", |
|
] |
|
kill_processes_by_cmd_substring(" ".join(cmd_args)) |
|
print("Launching controller: ", " ".join(cmd_args)) |
|
controller_process = subprocess.Popen(cmd_args) |
|
atexit.register(controller_process.terminate) |
|
|
|
worker_processes = [] |
|
if run_worker: |
|
worker_port = 10088 |
|
for worker_name in worker_names: |
|
cmd_args = [ |
|
f"{python_path}", |
|
"model_worker.py", |
|
"--port", |
|
f"{worker_port}", |
|
"--controller-url", |
|
f"{host}:{controller_port}", |
|
"--model-path", |
|
f"{worker_name}", |
|
"--load-8bit", |
|
] |
|
kill_processes_by_cmd_substring(" ".join(cmd_args)) |
|
print("Launching worker: ", " ".join(cmd_args)) |
|
worker_process = subprocess.Popen(cmd_args) |
|
worker_processes.append(worker_process) |
|
atexit.register(worker_process.terminate) |
|
worker_port += 1 |
|
|
|
time.sleep(10) |
|
gradio_process = None |
|
if run_gradio: |
|
|
|
cmd_args = [ |
|
f"{python_path}", |
|
"gradio_web_server.py", |
|
"--port", |
|
f"{gradio_port}", |
|
"--controller-url", |
|
f"{host}:{controller_port}", |
|
"--model-list-mode", |
|
"reload", |
|
] |
|
kill_processes_by_cmd_substring(" ".join(cmd_args)) |
|
print("Launching gradio: ", " ".join(cmd_args)) |
|
gradio_process = subprocess.Popen(cmd_args) |
|
atexit.register(gradio_process.terminate) |
|
|
|
sd_worker_process = None |
|
if run_sd_worker: |
|
|
|
cmd_args = [f"{python_path}", "sd_worker.py"] |
|
kill_processes_by_cmd_substring(" ".join(cmd_args)) |
|
print("Launching sd_worker: ", " ".join(cmd_args)) |
|
sd_worker_process = subprocess.Popen(cmd_args) |
|
atexit.register(sd_worker_process.terminate) |
|
|
|
for worker_process in worker_processes: |
|
worker_process.wait() |
|
if controller_process: |
|
controller_process.wait() |
|
if gradio_process: |
|
gradio_process.wait() |
|
if sd_worker_process: |
|
sd_worker_process.wait() |
|
|
|
|
|
if __name__ == "__main__": |
|
fire.Fire(main) |
|
|