|
""" |
|
interactive_demo.py |
|
|
|
Entry point for all VLM-Bench interactive demos; specify model and get a gradio UI where you can chat with it! |
|
|
|
This file is heavily adapted from the script used to serve models in the LLaVa repo: |
|
https://github.com/haotian-liu/LLaVA/blob/main/llava/serve/model_worker.py. It is |
|
modified to ensure compatibility with our Prismatic models. |
|
""" |
|
import asyncio |
|
import json |
|
import os |
|
import threading |
|
import time |
|
import uuid |
|
from dataclasses import dataclass |
|
from functools import partial |
|
from pathlib import Path |
|
from typing import Union |
|
|
|
import draccus |
|
import requests |
|
import torch |
|
import uvicorn |
|
from accelerate.utils import set_seed |
|
from fastapi import BackgroundTasks, FastAPI, Request |
|
from fastapi.responses import StreamingResponse |
|
from llava.constants import WORKER_HEART_BEAT_INTERVAL |
|
from llava.mm_utils import load_image_from_base64 |
|
from llava.utils import server_error_msg |
|
from torchvision.transforms import Compose |
|
|
|
from vlbench.models import load_vlm |
|
from vlbench.overwatch import initialize_overwatch |
|
from serve import INTERACTION_MODES_MAP, MODEL_ID_TO_NAME |
|
|
|
GB = 1 << 30 |
|
worker_id = str(uuid.uuid4())[:6] |
|
global_counter = 0 |
|
model_semaphore = None |
|
|
|
|
|
def heart_beat_worker(controller): |
|
while True: |
|
time.sleep(WORKER_HEART_BEAT_INTERVAL) |
|
controller.send_heart_beat() |
|
|
|
|
|
class ModelWorker: |
|
def __init__(self, controller_addr, worker_addr, worker_id, no_register, vlm, model_base, model_name): |
|
self.controller_addr = controller_addr |
|
self.worker_addr = worker_addr |
|
self.worker_id = worker_id |
|
self.model_name = model_name |
|
|
|
|
|
self.vlm = vlm |
|
self.tokenizer, self.model, self.image_processor, self.context_len = ( |
|
vlm.tokenizer, |
|
vlm.model, |
|
vlm.image_processor, |
|
vlm.max_length, |
|
) |
|
|
|
if not no_register: |
|
self.register_to_controller() |
|
self.heart_beat_thread = threading.Thread(target=heart_beat_worker, args=(self,)) |
|
self.heart_beat_thread.start() |
|
|
|
def register_to_controller(self): |
|
|
|
|
|
url = self.controller_addr + "/register_worker" |
|
data = {"worker_name": self.worker_addr, "check_heart_beat": True, "worker_status": self.get_status()} |
|
r = requests.post(url, json=data) |
|
assert r.status_code == 200 |
|
|
|
def send_heart_beat(self): |
|
|
|
|
|
|
|
|
|
url = self.controller_addr + "/receive_heart_beat" |
|
|
|
while True: |
|
try: |
|
ret = requests.post( |
|
url, json={"worker_name": self.worker_addr, "queue_length": self.get_queue_length()}, timeout=5 |
|
) |
|
exist = ret.json()["exist"] |
|
break |
|
except requests.exceptions.RequestException: |
|
pass |
|
|
|
time.sleep(5) |
|
|
|
if not exist: |
|
self.register_to_controller() |
|
|
|
def get_queue_length(self): |
|
if model_semaphore is None: |
|
return 0 |
|
else: |
|
return ( |
|
limit_model_concurrency |
|
- model_semaphore._value |
|
+ (len(model_semaphore._waiters) if model_semaphore._waiters is not None else 0) |
|
) |
|
|
|
def get_status(self): |
|
return { |
|
"model_names": [self.model_name], |
|
"speed": 1, |
|
"queue_length": self.get_queue_length(), |
|
} |
|
|
|
@torch.inference_mode() |
|
def generate_stream(self, params): |
|
prompt = params["prompt"] |
|
ori_prompt = prompt |
|
images = params.get("images", None) |
|
|
|
temperature = params.get("temperature", 0.2) |
|
max_new_tokens = params.get("max_new_tokens", 2048) |
|
interaction_mode = INTERACTION_MODES_MAP[params.get("interaction_mode", "Chat")] |
|
|
|
if temperature != 0: |
|
self.vlm.set_generate_kwargs( |
|
{"do_sample": True, "max_new_tokens": max_new_tokens, "temperature": temperature} |
|
) |
|
else: |
|
self.vlm.set_generate_kwargs({"do_sample": False, "max_new_tokens": max_new_tokens}) |
|
|
|
if images is not None and len(images) == 1: |
|
images = [load_image_from_base64(image) for image in images] |
|
else: |
|
raise NotImplementedError("Only supports queries with one image for now") |
|
|
|
if interaction_mode == "chat": |
|
question_prompt = [prompt] |
|
else: |
|
prompt_fn = self.vlm.get_prompt_fn(interaction_mode) |
|
if interaction_mode != "captioning": |
|
question_prompt = [prompt_fn(prompt)] |
|
else: |
|
question_prompt = [prompt_fn()] |
|
|
|
if isinstance(self.image_processor, Compose) or hasattr(self.image_processor, "is_prismatic"): |
|
|
|
pixel_values = self.image_processor(images[0].convert("RGB")) |
|
else: |
|
|
|
pixel_values = self.image_processor(images[0].convert("RGB"), return_tensors="pt")["pixel_values"][0] |
|
|
|
generated_text = self.vlm.generate_answer(torch.unsqueeze(pixel_values.cuda(), 0), question_prompt)[0] |
|
generated_text = generated_text.split("USER")[0].split("ASSISTANT")[0] |
|
yield json.dumps({"text": ori_prompt + generated_text, "error_code": 0}).encode() + b"\0" |
|
|
|
def generate_stream_gate(self, params): |
|
try: |
|
for x in self.generate_stream(params): |
|
yield x |
|
except ValueError as e: |
|
print("Caught ValueError:", e) |
|
ret = { |
|
"text": server_error_msg, |
|
"error_code": 1, |
|
} |
|
yield json.dumps(ret).encode() + b"\0" |
|
except torch.cuda.CudaError as e: |
|
print("Caught torch.cuda.CudaError:", e) |
|
ret = { |
|
"text": server_error_msg, |
|
"error_code": 1, |
|
} |
|
yield json.dumps(ret).encode() + b"\0" |
|
except Exception as e: |
|
print("Caught Unknown Error", e) |
|
ret = { |
|
"text": server_error_msg, |
|
"error_code": 1, |
|
} |
|
yield json.dumps(ret).encode() + b"\0" |
|
|
|
|
|
app = FastAPI() |
|
|
|
|
|
def release_model_semaphore(fn=None): |
|
model_semaphore.release() |
|
if fn is not None: |
|
fn() |
|
|
|
|
|
@app.post("/worker_generate_stream") |
|
async def generate_stream(request: Request): |
|
global model_semaphore, global_counter |
|
global_counter += 1 |
|
params = await request.json() |
|
|
|
if model_semaphore is None: |
|
model_semaphore = asyncio.Semaphore(limit_model_concurrency) |
|
await model_semaphore.acquire() |
|
worker.send_heart_beat() |
|
generator = worker.generate_stream_gate(params) |
|
background_tasks = BackgroundTasks() |
|
background_tasks.add_task(partial(release_model_semaphore, fn=worker.send_heart_beat)) |
|
return StreamingResponse(generator, background=background_tasks) |
|
|
|
|
|
@app.post("/worker_get_status") |
|
async def get_status(request: Request): |
|
return worker.get_status() |
|
|
|
|
|
|
|
overwatch = initialize_overwatch(__name__) |
|
|
|
|
|
@dataclass |
|
class DemoConfig: |
|
|
|
|
|
|
|
model_family: str = "quartz" |
|
model_id: str = "llava-v1.5-7b" |
|
model_dir: Path = ( |
|
"resize-naive-siglip-vit-l-16-384px-no-align-2-epochs+13b+stage-finetune+x7" |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
host: str = "0.0.0.0" |
|
port: int = 40000 |
|
controller_address: str = "http://localhost:10000" |
|
model_base: str = "llava-v15" |
|
limit_model_concurrency: int = 5 |
|
stream_interval: int = 1 |
|
no_register: bool = False |
|
|
|
|
|
device_batch_size: int = 1 |
|
num_workers: int = 2 |
|
|
|
|
|
hf_token: Union[str, Path] = Path(".hf_token") |
|
|
|
|
|
seed: int = 21 |
|
|
|
def __post_init__(self) -> None: |
|
if self.model_family == "quartz": |
|
self.model_name = MODEL_ID_TO_NAME[str(self.model_dir)] |
|
self.run_dir = Path("/mnt/fsx/x-onyx-vlms/runs") / self.model_dir |
|
elif self.model_family in {"instruct-blip", "llava", "llava-v15"}: |
|
self.model_name = MODEL_ID_TO_NAME[self.model_id] |
|
self.run_dir = self.model_dir |
|
else: |
|
raise ValueError(f"Run Directory for `{self.model_family = }` does not exist!") |
|
self.worker_address = f"http://localhost:{self.port}" |
|
|
|
|
|
|
|
|
|
@draccus.wrap() |
|
def interactive_demo(cfg: DemoConfig): |
|
|
|
set_seed(cfg.seed) |
|
|
|
|
|
overwatch.info("Initializing VLM =>> Bundling Models, Image Processors, and Tokenizer") |
|
hf_token = cfg.hf_token.read_text().strip() if isinstance(cfg.hf_token, Path) else os.environ[cfg.hf_token] |
|
vlm = load_vlm(cfg.model_family, cfg.model_id, cfg.run_dir, hf_token=hf_token) |
|
|
|
global worker |
|
global limit_model_concurrency |
|
limit_model_concurrency = cfg.limit_model_concurrency |
|
worker = ModelWorker( |
|
cfg.controller_address, cfg.worker_address, worker_id, cfg.no_register, vlm, cfg.model_base, cfg.model_name |
|
) |
|
uvicorn.run(app, host=cfg.host, port=cfg.port, log_level="info") |
|
|
|
|
|
if __name__ == "__main__": |
|
interactive_demo() |
|
|