liuhaotian commited on
Commit
14f034e
1 Parent(s): 3e4d21c
Files changed (2) hide show
  1. app.py +93 -0
  2. requirements.txt +2 -0
app.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import argparse
4
+ import time
5
+ import subprocess
6
+
7
+ import llava.serve.gradio_web_server as gws
8
+
9
+
10
+ def start_controller():
11
+ print("Starting the controller")
12
+ controller_command = [
13
+ "python",
14
+ "-m",
15
+ "llava.serve.controller",
16
+ "--host",
17
+ "0.0.0.0",
18
+ "--port",
19
+ "10000",
20
+ ]
21
+ return subprocess.Popen(controller_command)
22
+
23
+
24
+ def start_worker(model_path: str, bits=16):
25
+ print(f"Starting the model worker for the model {model_path}")
26
+ model_name = model_path.strip("/").split("/")[-1]
27
+ assert bits in [4, 8, 16], "It can be only loaded with 16-bit, 8-bit, and 4-bit."
28
+ if bits != 16:
29
+ model_name += f"-{bits}bit"
30
+ worker_command = [
31
+ "python",
32
+ "-m",
33
+ "llava.serve.model_worker",
34
+ "--host",
35
+ "0.0.0.0",
36
+ "--controller",
37
+ "http://localhost:10000",
38
+ "--model-path",
39
+ model_path,
40
+ "--model-name",
41
+ model_name,
42
+ ]
43
+ if bits != 16:
44
+ worker_command += [f"--load-{bits}bit"]
45
+ return subprocess.Popen(worker_command)
46
+
47
+
48
+ if __name__ == "__main__":
49
+ parser = argparse.ArgumentParser()
50
+ parser.add_argument("--host", type=str, default="0.0.0.0")
51
+ parser.add_argument("--port", type=int)
52
+ parser.add_argument("--controller-url", type=str, default="http://localhost:10000")
53
+ parser.add_argument("--concurrency-count", type=int, default=5)
54
+ parser.add_argument("--model-list-mode", type=str, default="reload", choices=["once", "reload"])
55
+ parser.add_argument("--share", action="store_true")
56
+ parser.add_argument("--moderate", action="store_true")
57
+ parser.add_argument("--embed", action="store_true")
58
+ gws.args = parser.parse_args()
59
+ gws.models = []
60
+
61
+ print(f"args: {gws.args}")
62
+
63
+ model_path = "liuhaotian/llava-v1.6-mistral-7b"
64
+ bits = int(os.getenv("bits", 4))
65
+ concurrency_count = int(os.getenv("concurrency_count", 5))
66
+
67
+ controller_proc = start_controller()
68
+ worker_proc = start_worker(model_path, bits=bits)
69
+
70
+ # Wait for worker and controller to start
71
+ time.sleep(10)
72
+
73
+ exit_status = 0
74
+ try:
75
+ demo = gws.build_demo(embed_mode=False, cur_dir='./')
76
+ demo.queue(
77
+ concurrency_count=concurrency_count,
78
+ status_update_rate=10,
79
+ api_open=False
80
+ ).launch(
81
+ server_name=gws.args.host,
82
+ server_port=gws.args.port,
83
+ share=gws.args.share
84
+ )
85
+
86
+ except Exception as e:
87
+ print(e)
88
+ exit_status = 1
89
+ finally:
90
+ worker_proc.kill()
91
+ controller_proc.kill()
92
+
93
+ sys.exit(exit_status)
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ llava-torch==1.2.1.post1
2
+ protobuf==4.23.3