LlavaGuard / app.py
LukasHug's picture
Upload 87 files
cd043e1 verified
raw
history blame
4.09 kB
import sys
import os
import argparse
import time
import subprocess
import gradio_web_server as gws
# Execute the pip install command with additional options
# subprocess.check_call([sys.executable, '-m', 'pip', 'install', 'flash-attn', '--no-build-isolation', '-U'])
def start_controller():
print("Starting the controller")
controller_command = [
sys.executable,
"-m",
"llava.serve.controller",
"--host",
"0.0.0.0",
"--port",
"10000",
]
print(controller_command)
return subprocess.Popen(controller_command)
def start_worker(model_path: str, model_name: str, bits=16, device=0):
print(f"Starting the model worker for the model {model_path}")
# model_name = model_path.strip("/").split("/")[-1]
device = f"cuda:{device}" if isinstance(device, int) else device
assert bits in [4, 8, 16], "It can be only loaded with 16-bit, 8-bit, and 4-bit."
if bits != 16:
model_name += f"-{bits}bit"
worker_command = [
sys.executable,
"-m",
"llava.serve.model_worker",
"--host",
"0.0.0.0",
"--controller",
"http://localhost:10000",
"--model-path",
model_path,
"--model-name",
model_name,
"--use-flash-attn",
'--device',
device
]
if bits != 16:
worker_command += [f"--load-{bits}bit"]
print(worker_command)
return subprocess.Popen(worker_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int)
parser.add_argument("--controller-url", type=str, default="http://localhost:10000")
parser.add_argument("--concurrency-count", type=int, default=5)
parser.add_argument("--model-list-mode", type=str, default="reload", choices=["once", "reload"])
parser.add_argument("--share", action="store_true")
parser.add_argument("--moderate", action="store_true")
parser.add_argument("--embed", action="store_true")
gws.args = parser.parse_args()
gws.models = []
gws.title_markdown += """
ONLY WORKS WITH GPU!
Set the environment variable `model` to change the model:
['AIML-TUDA/LlavaGuard-7B'](https://huggingface.co/AIML-TUDA/LlavaGuard-7B),
['AIML-TUDA/LlavaGuard-13B'](https://huggingface.co/AIML-TUDA/LlavaGuard-13B),
['AIML-TUDA/LlavaGuard-34B'](https://huggingface.co/AIML-TUDA/LlavaGuard-34B),
"""
# set_up_env_and_token(read=True)
print(f"args: {gws.args}")
# set the huggingface login token
controller_proc = start_controller()
concurrency_count = int(os.getenv("concurrency_count", 5))
api_key = os.getenv("token")
if api_key:
cmd = f"huggingface-cli login --token {api_key} --add-to-git-credential"
os.system(cmd)
else:
if '/workspace' not in sys.path:
sys.path.append('/workspace')
from llavaguard.hf_utils import set_up_env_and_token
set_up_env_and_token(read=True, write=False)
models = [
'LukasHug/LlavaGuard-7B-hf',
'LukasHug/LlavaGuard-13B-hf',
'LukasHug/LlavaGuard-34B-hf',]
bits = int(os.getenv("bits", 16))
model = os.getenv("model", models[0])
available_devices = os.getenv("CUDA_VISIBLE_DEVICES", "0")
model_path, model_name = model, model.split("/")[-1]
# worker_proc = start_worker(model_path, model_name, bits=bits)
# Wait for worker and controller to start
time.sleep(10)
exit_status = 0
try:
demo = gws.build_demo(embed_mode=False, cur_dir='./', concurrency_count=concurrency_count)
demo.queue(
status_update_rate=10,
api_open=False
).launch(
server_name=gws.args.host,
server_port=gws.args.port,
share=gws.args.share
)
except Exception as e:
print(e)
exit_status = 1
finally:
worker_proc.kill()
controller_proc.kill()
sys.exit(exit_status)