mattb512 commited on
Commit
62f32cb
2 Parent(s): 54e6841 c10578e

Merge pull request #2 from TRI-ML/master

Browse files
interactive_demo.py CHANGED
@@ -47,20 +47,12 @@ def heart_beat_worker(controller):
47
 
48
 
49
  class ModelWorker:
50
- def __init__(self, controller_addr, worker_addr, worker_id, no_register, vlm, model_base, model_name):
51
  self.controller_addr = controller_addr
52
  self.worker_addr = worker_addr
53
  self.worker_id = worker_id
54
  self.model_name = model_name
55
-
56
- # logger.info(f"Loading the model {self.model_name} on worker {worker_id} ...")
57
  self.vlm = vlm
58
- self.tokenizer, self.model, self.image_processor, self.context_len = (
59
- vlm.tokenizer,
60
- vlm.model,
61
- vlm.image_processor,
62
- vlm.max_length,
63
- )
64
 
65
  if not no_register:
66
  self.register_to_controller()
@@ -68,18 +60,12 @@ class ModelWorker:
68
  self.heart_beat_thread.start()
69
 
70
  def register_to_controller(self):
71
- # logger.info("Register to controller")
72
-
73
  url = self.controller_addr + "/register_worker"
74
  data = {"worker_name": self.worker_addr, "check_heart_beat": True, "worker_status": self.get_status()}
75
  r = requests.post(url, json=data)
76
  assert r.status_code == 200
77
 
78
  def send_heart_beat(self):
79
- # logger.info(f"Send heart beat. Models: {[self.model_name]}. "
80
- # f"Semaphore: {pretty_print_semaphore(model_semaphore)}. "
81
- # f"global_counter: {global_counter}")
82
-
83
  url = self.controller_addr + "/receive_heart_beat"
84
 
85
  while True:
@@ -91,7 +77,6 @@ class ModelWorker:
91
  break
92
  except requests.exceptions.RequestException:
93
  pass
94
- # logger.error(f"heart beat error: {e}")
95
  time.sleep(5)
96
 
97
  if not exist:
@@ -145,12 +130,12 @@ class ModelWorker:
145
  else:
146
  question_prompt = [prompt_fn()]
147
 
148
- if isinstance(self.image_processor, Compose) or hasattr(self.image_processor, "is_prismatic"):
149
  # This is a standard `torchvision.transforms` object or custom PrismaticVLM wrapper
150
- pixel_values = self.image_processor(images[0].convert("RGB"))
151
  else:
152
  # Assume `image_transform` is a HF ImageProcessor...
153
- pixel_values = self.image_processor(images[0].convert("RGB"), return_tensors="pt")["pixel_values"][0]
154
 
155
  if type(pixel_values) is dict:
156
  for k in pixel_values.keys():
@@ -227,31 +212,29 @@ overwatch = initialize_overwatch(__name__)
227
  class DemoConfig:
228
  # fmt: off
229
 
230
- # === Model Parameters =>> Quartz ===
231
- model_family: str = "quartz" # Model family to load from in < `quartz` | `llava-v15` | ... >
232
- model_id: str = "llava-v1.5-7b" # Model ID to load and run (instance of `model_family`)
233
- model_dir: Path = ( # Path to model checkpoint to load --> should be self-contained
234
- "resize-naive-siglip-vit-l-16-384px-no-align-2-epochs+13b+stage-finetune+x7"
235
- )
236
 
237
  # === Model Parameters =>> Official LLaVa ===
238
  # model_family: str = "llava-v15"
239
  # model_id: str = "llava-v1.5-13b"
240
  # model_dir: Path = "liuhaotian/llava-v1.5-13b"
241
 
 
 
 
 
 
242
  # Model Worker Parameters
243
  host: str = "0.0.0.0"
244
  port: int = 40000
245
  controller_address: str = "http://localhost:10000"
246
- model_base: str = "llava-v15"
247
  limit_model_concurrency: int = 5
248
  stream_interval: int = 1
249
  no_register: bool = False
250
 
251
- # Inference Parameters
252
- device_batch_size: int = 1 # Device Batch Size set to 1 until LLaVa/HF LLaMa fixes bugs!
253
- num_workers: int = 2 # Number of Dataloader Workers (on each process)
254
-
255
  # HF Hub Credentials (for LLaMa-2)
256
  hf_token: Union[str, Path] = Path(".hf_token") # Environment variable or Path to HF Token
257
 
@@ -259,14 +242,8 @@ class DemoConfig:
259
  seed: int = 21 # Random Seed (for reproducibility)
260
 
261
  def __post_init__(self) -> None:
262
- if self.model_family == "quartz":
263
- self.model_name = MODEL_ID_TO_NAME[str(self.model_dir)]
264
- self.run_dir = Path("/mnt/fsx/x-onyx-vlms/runs") / self.model_dir
265
- elif self.model_family in {"instruct-blip", "llava", "llava-v15"}:
266
- self.model_name = MODEL_ID_TO_NAME[self.model_id]
267
- self.run_dir = self.model_dir
268
- else:
269
- raise ValueError(f"Run Directory for `{self.model_family = }` does not exist!")
270
  self.worker_address = f"http://localhost:{self.port}"
271
 
272
  # fmt: on
@@ -286,7 +263,7 @@ def interactive_demo(cfg: DemoConfig):
286
  global limit_model_concurrency
287
  limit_model_concurrency = cfg.limit_model_concurrency
288
  worker = ModelWorker(
289
- cfg.controller_address, cfg.worker_address, worker_id, cfg.no_register, vlm, cfg.model_base, cfg.model_name
290
  )
291
  uvicorn.run(app, host=cfg.host, port=cfg.port, log_level="info")
292
 
 
47
 
48
 
49
  class ModelWorker:
50
+ def __init__(self, controller_addr, worker_addr, worker_id, no_register, vlm, model_name):
51
  self.controller_addr = controller_addr
52
  self.worker_addr = worker_addr
53
  self.worker_id = worker_id
54
  self.model_name = model_name
 
 
55
  self.vlm = vlm
 
 
 
 
 
 
56
 
57
  if not no_register:
58
  self.register_to_controller()
 
60
  self.heart_beat_thread.start()
61
 
62
  def register_to_controller(self):
 
 
63
  url = self.controller_addr + "/register_worker"
64
  data = {"worker_name": self.worker_addr, "check_heart_beat": True, "worker_status": self.get_status()}
65
  r = requests.post(url, json=data)
66
  assert r.status_code == 200
67
 
68
  def send_heart_beat(self):
 
 
 
 
69
  url = self.controller_addr + "/receive_heart_beat"
70
 
71
  while True:
 
77
  break
78
  except requests.exceptions.RequestException:
79
  pass
 
80
  time.sleep(5)
81
 
82
  if not exist:
 
130
  else:
131
  question_prompt = [prompt_fn()]
132
 
133
+ if isinstance(self.vlm.image_processor, Compose) or hasattr(self.vlm.image_processor, "is_prismatic"):
134
  # This is a standard `torchvision.transforms` object or custom PrismaticVLM wrapper
135
+ pixel_values = self.vlm.image_processor(images[0].convert("RGB"))
136
  else:
137
  # Assume `image_transform` is a HF ImageProcessor...
138
+ pixel_values = self.vlm.image_processor(images[0].convert("RGB"), return_tensors="pt")["pixel_values"][0]
139
 
140
  if type(pixel_values) is dict:
141
  for k in pixel_values.keys():
 
212
  class DemoConfig:
213
  # fmt: off
214
 
215
+ # === Model Parameters =>> Prismatic ===
216
+ model_family: str = "prismatic" # Model family to load from in < `prismatic` | `llava-v15` | ... >
217
+ model_id: str = "prism-dinosiglip+7b" # Model ID to load and run (instance of `model_family`)
218
+ model_dir: str = None # Can optionally supply model_dir instead of model_id
 
 
219
 
220
  # === Model Parameters =>> Official LLaVa ===
221
  # model_family: str = "llava-v15"
222
  # model_id: str = "llava-v1.5-13b"
223
  # model_dir: Path = "liuhaotian/llava-v1.5-13b"
224
 
225
+ # === Model Parameters =>> Official InstructBLIP ===
226
+ # model_family: str = "instruct-blip"
227
+ # model_id: str = "instructblip-vicuna-7b"
228
+ # model_dir: Path = "Salesforce/instructblip-vicuna-7b"
229
+
230
  # Model Worker Parameters
231
  host: str = "0.0.0.0"
232
  port: int = 40000
233
  controller_address: str = "http://localhost:10000"
 
234
  limit_model_concurrency: int = 5
235
  stream_interval: int = 1
236
  no_register: bool = False
237
 
 
 
 
 
238
  # HF Hub Credentials (for LLaMa-2)
239
  hf_token: Union[str, Path] = Path(".hf_token") # Environment variable or Path to HF Token
240
 
 
242
  seed: int = 21 # Random Seed (for reproducibility)
243
 
244
  def __post_init__(self) -> None:
245
+ self.run_dir = self.model_dir
246
+ self.model_name = MODEL_ID_TO_NAME[str(self.model_id)]
 
 
 
 
 
 
247
  self.worker_address = f"http://localhost:{self.port}"
248
 
249
  # fmt: on
 
263
  global limit_model_concurrency
264
  limit_model_concurrency = cfg.limit_model_concurrency
265
  worker = ModelWorker(
266
+ cfg.controller_address, cfg.worker_address, worker_id, cfg.no_register, vlm, cfg.model_name
267
  )
268
  uvicorn.run(app, host=cfg.host, port=cfg.port, log_level="info")
269
 
serve/__init__.py CHANGED
@@ -5,31 +5,24 @@ from collections import OrderedDict
5
  MODEL_ID_TO_NAME = OrderedDict(
6
  [
7
  (
8
- "llava-lvis4v-lrv+redux-lvis4v-lrv-resize-naive-dinosiglip-vit-so-14-384px-no-align+13b+stage-finetune+x7",
9
- "PrismaticVLM 13B - Chat",
10
- ),
11
- (
12
- "llava-lvis4v-lrv+redux-lvis4v-lrv-resize-naive-dinosiglip-vit-so-14-384px-no-align+7b+stage-finetune+x7",
13
- "PrismaticVLM 7B - Chat",
14
- ),
15
- (
16
- "llava-lvis4v-lrv+redux-lvis4v-lrv-resize-naive-dinosiglip-vit-so-14-384px-no-align-llama2pure+13b+stage-finetune+x7",
17
  "PrismaticVLM 13B",
18
  ),
19
  (
20
- "llava-lvis4v-lrv+redux-lvis4v-lrv-resize-naive-dinosiglip-vit-so-14-384px-no-align-llama2pure+7b+stage-finetune+x7",
21
  "PrismaticVLM 7B",
22
  ),
23
  (
24
- "redux-resize-naive-dinosiglip-vit-so-14-384px-no-align-llama2pure+13b+stage-finetune+x7",
25
  "PrismaticVLM 13B (Controlled)",
26
  ),
27
  (
28
- "redux-resize-naive-dinosiglip-vit-so-14-384px-no-align-llama2pure+7b+stage-finetune+x7",
29
  "PrismaticVLM 7B (Controlled)",
30
  ),
31
- ("llava-v1.5-13b", "LLaVA 1.5: 13B"),
32
- ("llava-v1.5-7b", "LLaVA 1.5: 7B"),
 
33
  ]
34
  )
35
 
 
5
  MODEL_ID_TO_NAME = OrderedDict(
6
  [
7
  (
8
+ "prism-dinosiglip+13b",
 
 
 
 
 
 
 
 
9
  "PrismaticVLM 13B",
10
  ),
11
  (
12
+ "prism-dinosiglip+7b",
13
  "PrismaticVLM 7B",
14
  ),
15
  (
16
+ "prism-dinosiglip-controlled+13b",
17
  "PrismaticVLM 13B (Controlled)",
18
  ),
19
  (
20
+ "prism-dinosiglip-controlled+7b",
21
  "PrismaticVLM 7B (Controlled)",
22
  ),
23
+ ("llava-v1.5-13b", "LLaVA 1.5 13B"),
24
+ ("llava-v1.5-7b", "LLaVA 1.5 7B"),
25
+ ("instructblip-vicuna-7b", "InstructBLIP 7B"),
26
  ]
27
  )
28
 
serve/gradio_web_server.py CHANGED
@@ -93,24 +93,6 @@ def vote_last_response(state, vote_type, model_selector, request: gr.Request):
93
  fout.write(json.dumps(data) + "\n")
94
 
95
 
96
- # def upvote_last_response(state, model_selector, request: gr.Request):
97
- # logger.info(f"upvote. ip: {request.client.host}")
98
- # vote_last_response(state, "upvote", model_selector, request)
99
- # return ("",) + (disable_btn,) * 3
100
-
101
-
102
- # def downvote_last_response(state, model_selector, request: gr.Request):
103
- # logger.info(f"downvote. ip: {request.client.host}")
104
- # vote_last_response(state, "downvote", model_selector, request)
105
- # return ("",) + (disable_btn,) * 3
106
-
107
-
108
- # def flag_last_response(state, model_selector, request: gr.Request):
109
- # logger.info(f"flag. ip: {request.client.host}")
110
- # vote_last_response(state, "flag", model_selector, request)
111
- # return ("",) + (disable_btn,) * 3
112
-
113
-
114
  def regenerate(state, image_process_mode, request: gr.Request):
115
  logger.info(f"regenerate. ip: {request.client.host}")
116
  state.messages[-1][-1] = None
@@ -388,15 +370,6 @@ def build_demo(embed_mode):
388
 
389
  # Register listeners
390
  btn_list = [regenerate_btn, clear_btn]
391
- # upvote_btn.click(
392
- # upvote_last_response, [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn], queue=False
393
- # )
394
- # downvote_btn.click(
395
- # downvote_last_response, [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn], queue=False
396
- # )
397
- # flag_btn.click(
398
- # flag_last_response, [state, model_selector], [textbox, upvote_btn, downvote_btn, flag_btn], queue=False
399
- # )
400
 
401
  regenerate_btn.click(
402
  regenerate, [state, image_process_mode], [state, chatbot, textbox, imagebox, *btn_list], queue=False
 
93
  fout.write(json.dumps(data) + "\n")
94
 
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  def regenerate(state, image_process_mode, request: gr.Request):
97
  logger.info(f"regenerate. ip: {request.client.host}")
98
  state.messages[-1][-1] = None
 
370
 
371
  # Register listeners
372
  btn_list = [regenerate_btn, clear_btn]
 
 
 
 
 
 
 
 
 
373
 
374
  regenerate_btn.click(
375
  regenerate, [state, image_process_mode], [state, chatbot, textbox, imagebox, *btn_list], queue=False