File size: 3,755 Bytes
8b33d6d f289b70 c09265b 966d74c 8b33d6d 966d74c f289b70 8b33d6d f289b70 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import spaces
import fire
import subprocess
import os
import time
import signal
import subprocess
import atexit
try:
import flash_attn
except ImportError:
@spaces.GPU
def install_flash_attn():
os.system("pip install flash-attn==2.5.9.post1")
install_flash_attn()
import flash_attn
def kill_processes_by_cmd_substring(cmd_substring):
# execute `ps -ef` and obtain its output
result = subprocess.run(["ps", "-ef"], stdout=subprocess.PIPE, text=True)
lines = result.stdout.splitlines()
# visit each line
for line in lines:
if cmd_substring in line:
# extract PID
parts = line.split()
pid = int(parts[1])
print(f"Killing process with PID: {pid}, CMD: {line}")
os.kill(pid, signal.SIGTERM)
def main(
python_path="python",
run_controller=True,
run_worker=True,
run_gradio=True,
controller_port=10086,
gradio_port=7860,
worker_names=[
"OpenGVLab/InternVL2-8B",
],
run_sd_worker=False,
**kwargs,
):
host = "http://0.0.0.0"
controller_process = None
if run_controller:
# python controller.py --host 0.0.0.0 --port 10086
cmd_args = [
f"{python_path}",
"controller.py",
"--host",
"0.0.0.0",
"--port",
f"{controller_port}",
]
kill_processes_by_cmd_substring(" ".join(cmd_args))
print("Launching controller: ", " ".join(cmd_args))
controller_process = subprocess.Popen(cmd_args)
atexit.register(controller_process.terminate)
worker_processes = []
if run_worker:
worker_port = 10088
for worker_name in worker_names:
cmd_args = [
f"{python_path}",
"model_worker.py",
"--port",
f"{worker_port}",
"--controller-url",
f"{host}:{controller_port}",
"--model-path",
f"{worker_name}",
"--load-8bit",
]
kill_processes_by_cmd_substring(" ".join(cmd_args))
print("Launching worker: ", " ".join(cmd_args))
worker_process = subprocess.Popen(cmd_args)
worker_processes.append(worker_process)
atexit.register(worker_process.terminate)
worker_port += 1
time.sleep(10)
gradio_process = None
if run_gradio:
# python gradio_web_server.py --port 10088 --controller-url http://0.0.0.0:10086
cmd_args = [
f"{python_path}",
"gradio_web_server.py",
"--port",
f"{gradio_port}",
"--controller-url",
f"{host}:{controller_port}",
"--model-list-mode",
"reload",
]
kill_processes_by_cmd_substring(" ".join(cmd_args))
print("Launching gradio: ", " ".join(cmd_args))
gradio_process = subprocess.Popen(cmd_args)
atexit.register(gradio_process.terminate)
sd_worker_process = None
if run_sd_worker:
# python model_worker.py --port 10088 --controller-address http://
cmd_args = [f"{python_path}", "sd_worker.py"]
kill_processes_by_cmd_substring(" ".join(cmd_args))
print("Launching sd_worker: ", " ".join(cmd_args))
sd_worker_process = subprocess.Popen(cmd_args)
atexit.register(sd_worker_process.terminate)
for worker_process in worker_processes:
worker_process.wait()
if controller_process:
controller_process.wait()
if gradio_process:
gradio_process.wait()
if sd_worker_process:
sd_worker_process.wait()
if __name__ == "__main__":
fire.Fire(main)
|