import hydra import pyrootutils import torch import re import time from omegaconf import OmegaConf from flask import Flask, request from typing import Optional import transformers from dataclasses import dataclass, field import io import base64 from PIL import Image import numpy as np import cv2 from diffusers import AutoencoderKL, UNet2DConditionModel, EulerDiscreteScheduler pyrootutils.setup_root(__file__, indicator=".project-root", pythonpath=True) from src.data.any_res import process_anyres_image BOI_TOKEN = '' BOP_TOKEN = '' EOI_TOKEN = '' EOP_TOKEN = '' IMG_TOKEN = '' IMG_FLAG = '' num_img_in_tokens = 64 num_img_out_tokens = 64 resolution_grids = ['1x1', '1x2', '1x3', '1x4', '1x5', '1x6', '1x10', '2x1', '3x1', '4x1', '5x1', '6x1', '10x1', '2x2', '2x3', '3x2', '2x4', '4x2'] base_resolution = 448 app = Flask(__name__) def decode_image(encoded_image: str) -> Image: decoded_bytes = base64.b64decode(encoded_image.encode('utf-8')) buffer = io.BytesIO(decoded_bytes) image = Image.open(buffer) return image def encode_image(image: Image.Image, format: str = 'PNG') -> str: with io.BytesIO() as buffer: image.save(buffer, format=format) encoded_image = base64.b64encode(buffer.getvalue()).decode('utf-8') return encoded_image @dataclass class Arguments: image_transform: Optional[str] = field(default=None, metadata={"help": "config path of image transform"}) tokenizer: Optional[str] = field(default=None, metadata={"help": "config path of tokenizer used to initialize tokenizer"}) llm: Optional[str] = field(default=None, metadata={"help": "config path of llm"}) visual_encoder: Optional[str] = field(default=None, metadata={"help": "config path of visual encoder"}) sd_adapter: Optional[str] = field(default=None, metadata={"help": "config path of sd adapter"}) agent: Optional[str] = field(default=None, metadata={"help": "config path of agent model"}) diffusion_path: Optional[str] = field(default=None, metadata={"help": "diffusion model path"}) has_bbox: Optional[bool] = field(default=False, metadata={"help": "visualize the box"}) port: Optional[str] = field(default=80, metadata={"help": "network port"}) llm_device: Optional[str] = field(default='cuda:0', metadata={"help": "llm device"}) vit_sd_device: Optional[str] = field(default='cuda:0', metadata={"help": "sd and vit device"}) dtype: Optional[str] = field(default='fp16', metadata={"help": "mix percision"}) multi_resolution: Optional[bool] = field(default=False, metadata={"help": "multi resolution"}) parser = transformers.HfArgumentParser(Arguments) args, = parser.parse_args_into_dataclasses() def extract_box(output_str): boxes = re.findall('(.*?)', output_str) if len(boxes) >0: bboxes = [[int(num) for num in re.findall('', box)] for box in boxes] else: bboxes = None return bboxes def visualize_bbox(image, bboxes): img_width, img_height = image.size image = np.array(image) image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) for bbox in bboxes: x_center, y_center, box_width, box_height = bbox x_center = x_center / 224 * img_width y_center = y_center / 224 * img_height box_width = box_width /224 * img_width box_height = box_height / 224 * img_height x1 = int(x_center - box_width / 2) y1 = int(y_center - box_height / 2) x2 = int(x_center + box_width / 2) y2 = int(y_center + box_height / 2) cv2.rectangle(image, (x1, y1), (x2, y2), (0, 255, 0), 4) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) image = Image.fromarray(image) return image class LLMService: def __init__(self, args) -> None: self.llm_device = args.llm_device self.vit_sd_device = args.vit_sd_device dtype = args.dtype if dtype == 'fp16': self.dtype = torch.float16 elif dtype == 'bf16': self.dtype = torch.bfloat16 else: raise ValueError image_transform_cfg = OmegaConf.load(args.image_transform) self.image_transform = hydra.utils.instantiate(image_transform_cfg) tokenizer_cfg = OmegaConf.load(args.tokenizer) self.tokenizer = hydra.utils.instantiate(tokenizer_cfg) visual_encoder_cfg = OmegaConf.load(args.visual_encoder) self.visual_encoder = hydra.utils.instantiate(visual_encoder_cfg) self.visual_encoder.eval().to(self.vit_sd_device, dtype=self.dtype) print('Init visual encoder done') llm_cfg = OmegaConf.load(args.llm) llm = hydra.utils.instantiate(llm_cfg, torch_dtype=self.dtype) print('Init llm done.') agent_cfg = OmegaConf.load(args.agent) self.agent = hydra.utils.instantiate(agent_cfg, llm=llm) self.agent.eval().to(self.llm_device, dtype=self.dtype) print('Init agent mdoel Done') noise_scheduler = EulerDiscreteScheduler.from_pretrained(args.diffusion_path, subfolder="scheduler") vae = AutoencoderKL.from_pretrained(args.diffusion_path, subfolder="vae").to(self.vit_sd_device, dtype=self.dtype) unet = UNet2DConditionModel.from_pretrained(args.diffusion_path, subfolder="unet").to(dtype=self.dtype) sd_adapter_cfg = OmegaConf.load(args.sd_adapter) self.sd_adapter = hydra.utils.instantiate(sd_adapter_cfg, unet=unet).eval().to(dtype=self.dtype) self.sd_adapter.init_pipe(vae=vae, scheduler=noise_scheduler, visual_encoder=self.visual_encoder.to("cpu"), image_transform=self.image_transform, discrete_model=None, dtype=self.dtype, device="cpu") print('Init sd adapter pipe done.') self.visual_encoder.to(self.vit_sd_device, dtype=self.dtype) self.boi_token_id = self.tokenizer.encode(BOI_TOKEN, add_special_tokens=False)[0] self.eoi_token_id = self.tokenizer.encode(EOI_TOKEN, add_special_tokens=False)[0] self.bop_token_id = self.tokenizer.encode(BOP_TOKEN, add_special_tokens=False)[0] self.eop_token_id = self.tokenizer.encode(EOP_TOKEN, add_special_tokens=False)[0] self.multi_resolution = args.multi_resolution if self.multi_resolution: self.base_resolution = base_resolution grid_pinpoints = [] for scale in resolution_grids: s1, s2 = scale.split('x') grid_pinpoints.append([int(s1)*base_resolution, int(s2)*base_resolution]) self.grid_pinpoints = grid_pinpoints service = LLMService(args) @app.route('/generate', methods=['GET', 'POST']) def generate(): with torch.no_grad(): request_info = request.get_json() text_list = request_info['text'].split(IMG_FLAG) image_list = request_info['images'] max_new_tokens = request_info.get('max_new_tokens', 256) top_p = 0.5 force_boi = request_info.get('force_boi', False) force_bbox = request_info.get('force_bbox', False) assert len(text_list) == len(image_list) + 1 image_tokens = BOI_TOKEN + ''.join([IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)]) + EOI_TOKEN input_images = [] if len(image_list) > 0: image_tensor_list = [] embeds_cmp_mask = [] embeds_gen_mask = [] if service.multi_resolution: patch_pos = [] image_patch_length = [] image_size_list = [] for idx, image_item in enumerate(image_list): if isinstance(image_item, str): image = decode_image(image_item) print('after decode image size:', image.size) input_images.append(image) if service.multi_resolution: image_size_list.append(image.size) print('image size:', image.size) image_tensor, patch_pos_tensor = process_anyres_image(image, service.image_transform, service.grid_pinpoints, service.base_resolution) image_tensor_list.append(image_tensor) patch_pos.append(patch_pos_tensor) image_patch_length.append(image_tensor.shape[0]) print('image_patch_length', image_patch_length) embeds_cmp_mask.extend([True]*image_tensor.shape[0]) embeds_gen_mask.extend([False]*image_tensor.shape[0]) else: image_tensor = service.image_transform(image) image_tensor_list.append(image_tensor) embeds_cmp_mask.append(True) embeds_gen_mask.append(False) else: raise ValueError if service.multi_resolution: pixel_values = torch.cat(image_tensor_list).to(service.vit_sd_device, dtype=service.dtype) patch_position = torch.cat(patch_pos, dim=0) image_tokens_list = [] for patch_length in image_patch_length: image_tokens = '' for _ in range(patch_length-1): image_tokens += BOP_TOKEN + ''.join(IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)) + EOP_TOKEN image_tokens += BOI_TOKEN + ''.join(IMG_TOKEN.format(int(item)) for item in range(num_img_in_tokens)) + EOI_TOKEN image_tokens_list.append(image_tokens) else: pixel_values = torch.stack(image_tensor_list).to(service.vit_sd_device, dtype=service.dtype) image_embeds = service.visual_encoder(pixel_values) image_embeds = image_embeds.to(service.llm_device) embeds_cmp_mask = torch.tensor(embeds_cmp_mask, dtype=torch.bool).to(service.llm_device) embeds_gen_mask = torch.tensor(embeds_gen_mask, dtype=torch.bool).to(service.llm_device) else: image_embeds = None patch_position = 0 embeds_cmp_mask = None embeds_gen_mask = None if service.multi_resolution: input_text = '' for i, c in enumerate(text_list[:-1]): input_text += c + image_tokens_list[i] input_text += text_list[-1] else: input_text = image_tokens.join(text_list) if force_boi: input_text = input_text + BOI_TOKEN if force_bbox: input_text = input_text + '[[ ' print('input_text:', input_text) input_ids = service.tokenizer.encode(input_text, add_special_tokens=False) input_ids = [service.tokenizer.bos_token_id] + input_ids input_ids = torch.tensor(input_ids).to(service.llm_device, dtype=torch.long) ids_cmp_mask = torch.zeros_like(input_ids, dtype=torch.bool).to(service.llm_device) ids_gen_mask = torch.zeros_like(input_ids, dtype=torch.bool).to(service.llm_device) if service.multi_resolution: boi_indices = torch.where(torch.logical_or(input_ids == service.boi_token_id, input_ids == service.bop_token_id))[0].tolist() eoi_indices = torch.where(torch.logical_or(input_ids == service.eoi_token_id, input_ids == service.eop_token_id))[0].tolist() else: boi_indices = torch.where(input_ids == service.boi_token_id)[0].tolist() eoi_indices = torch.where(input_ids == service.eoi_token_id)[0].tolist() for boi_idx, eoi_idx in zip(boi_indices, eoi_indices): ids_cmp_mask[boi_idx + 1:eoi_idx] = True input_ids = input_ids.unsqueeze(0) ids_cmp_mask = ids_cmp_mask.unsqueeze(0) ids_gen_mask = ids_gen_mask.unsqueeze(0) error_msg = [] if service.multi_resolution: output = service.agent.generate( tokenizer=service.tokenizer, input_ids=input_ids, image_embeds=image_embeds, patch_positions=patch_position, embeds_cmp_mask=embeds_cmp_mask, ids_cmp_mask=ids_cmp_mask, num_img_gen_tokens=num_img_out_tokens, max_new_tokens=max_new_tokens, dtype=service.dtype, device=service.llm_device, top_p=top_p, ) else: output = service.agent.generate( tokenizer=service.tokenizer, input_ids=input_ids, image_embeds=image_embeds, embeds_cmp_mask=embeds_cmp_mask, ids_cmp_mask=ids_cmp_mask, num_img_gen_tokens=num_img_out_tokens, max_new_tokens=max_new_tokens, dtype=service.dtype, device=service.llm_device, top_p=top_p, ) gen_imgs_base64_list = [] generated_text = output['text'] generated_text = generated_text.replace(EOI_TOKEN, IMG_FLAG).replace(service.tokenizer.eos_token, '') if output['has_img_output']: print('loading visual encoder and llm to CPU, and sd to GPU') a = time.time() service.agent = service.agent.to("cpu") service.sd_adapter = service.sd_adapter.to(service.vit_sd_device, dtype=service.dtype) print("Loading finished: ", time.time() - a) img_gen_feat = output['img_gen_feat'].to(service.vit_sd_device, dtype=service.dtype) for img_idx in range(output['num_gen_imgs']): img_feat = img_gen_feat[img_idx:img_idx + 1] generated_image = service.sd_adapter.generate(image_embeds=img_feat, num_inference_steps=50)[0] image_base64 = encode_image(generated_image) gen_imgs_base64_list.append(image_base64) print('loading visual encoder and llm to GPU, and sd to CPU') a = time.time() service.sd_adapter = service.sd_adapter.to("cpu") service.visual_encoder = service.visual_encoder.to(service.vit_sd_device, dtype=service.dtype) service.agent = service.agent.to(service.vit_sd_device, dtype=service.dtype) print("Loading finished: ", time.time() - a) if args.has_bbox: bboxes = extract_box(generated_text) if bboxes is not None and len(input_images) > 0: image_viz = visualize_bbox(input_images[0], bboxes) image_base64 = encode_image(image_viz) gen_imgs_base64_list.append(image_base64) generated_text = re.sub(r'\[\[ .*?.*?\]\]', 'the green bounding box', generated_text) generated_text += IMG_FLAG print(input_text + generated_text) return {'text': generated_text, 'images': gen_imgs_base64_list, 'error_msg': error_msg} if __name__ == '__main__': app.run(host='0.0.0.0', port=args.port)