import sys import os import argparse import time import subprocess import gradio_web_server as gws # Execute the pip install command with additional options subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'flash-attn', '--no-build-isolation', '-U']) def start_controller(): print("Starting the controller") controller_command = [ sys.executable, "-m", "llava.serve.controller", "--host", "0.0.0.0", "--port", "10000", ] print(controller_command) return subprocess.Popen(controller_command) def start_worker(model_path: str, model_name: str, bits=16, device=0): print(f"Starting the model worker for the model {model_path}") # model_name = model_path.strip("/").split("/")[-1] device = f"cuda:{device}" if isinstance(device, int) else device assert bits in [4, 8, 16], "It can be only loaded with 16-bit, 8-bit, and 4-bit." if bits != 16: model_name += f"-{bits}bit" worker_command = [ sys.executable, "-m", "llava.serve.model_worker", "--host", "0.0.0.0", "--controller", "http://localhost:10000", "--model-path", model_path, "--model-name", model_name, "--use-flash-attn", '--device', device ] if bits != 16: worker_command += [f"--load-{bits}bit"] print(worker_command) return subprocess.Popen(worker_command) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="0.0.0.0") parser.add_argument("--port", type=int) parser.add_argument("--controller-url", type=str, default="http://localhost:10000") parser.add_argument("--concurrency-count", type=int, default=5) parser.add_argument("--model-list-mode", type=str, default="reload", choices=["once", "reload"]) parser.add_argument("--share", action="store_true") parser.add_argument("--moderate", action="store_true") parser.add_argument("--embed", action="store_true") gws.args = parser.parse_args() gws.models = [] gws.title_markdown += """ ONLY WORKS WITH GPU! Set the environment variable `model` to change the model: ['AIML-TUDA/LlavaGuard-7B'](https://huggingface.co/AIML-TUDA/LlavaGuard-7B), ['AIML-TUDA/LlavaGuard-13B'](https://huggingface.co/AIML-TUDA/LlavaGuard-13B), ['AIML-TUDA/LlavaGuard-34B'](https://huggingface.co/AIML-TUDA/LlavaGuard-34B), """ # set_up_env_and_token(read=True) print(f"args: {gws.args}") # set the huggingface login token controller_proc = start_controller() concurrency_count = int(os.getenv("concurrency_count", 5)) api_key = os.getenv("token") if api_key: cmd = f"huggingface-cli login --token {api_key} --add-to-git-credential" os.system(cmd) else: if '/workspace' not in sys.path: sys.path.append('/workspace') from llavaguard.hf_utils import set_up_env_and_token set_up_env_and_token(read=True, write=False) models = [ 'LukasHug/LlavaGuard-7B-hf', 'LukasHug/LlavaGuard-13B-hf', 'LukasHug/LlavaGuard-34B-hf',] bits = int(os.getenv("bits", 16)) model = os.getenv("model", models[0]) available_devices = os.getenv("CUDA_VISIBLE_DEVICES", "0") model_path, model_name = model, model.split("/")[-1] worker_proc = start_worker(model_path, model_name, bits=bits) # Wait for worker and controller to start time.sleep(10) exit_status = 0 try: demo = gws.build_demo(embed_mode=False, cur_dir='./', concurrency_count=concurrency_count) demo.queue( status_update_rate=10, api_open=False ).launch( server_name=gws.args.host, server_port=gws.args.port, share=gws.args.share ) except Exception as e: print(e) exit_status = 1 finally: worker_proc.kill() controller_proc.kill() sys.exit(exit_status)