import os import sys import torch import base64 import datetime import logging import requests import logging.handlers from PIL import Image from io import BytesIO from transformers import StoppingCriteria from xmodelvlm.constants import LOGDIR, IMAGE_TOKEN_INDEX server_error_msg = "**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**" moderation_msg = "YOUR INPUT VIOLATES OUR CONTENT MODERATION GUIDELINES. PLEASE TRY AGAIN." handler = None def load_image_from_base64(image): return Image.open(BytesIO(base64.b64decode(image))) def expand2square(pil_img, background_color): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result def process_images(images, image_processor, model_cfg): image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None) new_images = [] if image_aspect_ratio == 'pad': for image in images: image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean)) image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] new_images.append(image) else: return image_processor(images, return_tensors='pt')['pixel_values'] if all(x.shape == new_images[0].shape for x in new_images): new_images = torch.stack(new_images, dim=0) return new_images def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] def insert_separator(X, sep): return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] input_ids = [] offset = 0 if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: offset = 1 input_ids.append(prompt_chunks[0][0]) for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): input_ids.extend(x[offset:]) if return_tensors is not None: if return_tensors == 'pt': return torch.tensor(input_ids, dtype=torch.long) raise ValueError(f'Unsupported tensor type: {return_tensors}') return input_ids def get_model_name_from_path(model_path): model_path = model_path.strip("/") model_paths = model_path.split("/") if model_paths[-1].startswith('checkpoint-'): return model_paths[-2] + "_" + model_paths[-1] else: return model_paths[-1] class KeywordsStoppingCriteria(StoppingCriteria): def __init__(self, keywords, tokenizer, input_ids): self.keywords = keywords self.keyword_ids = [] self.max_keyword_len = 0 for keyword in keywords: cur_keyword_ids = tokenizer(keyword).input_ids if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: cur_keyword_ids = cur_keyword_ids[1:] if len(cur_keyword_ids) > self.max_keyword_len: self.max_keyword_len = len(cur_keyword_ids) self.keyword_ids.append(torch.tensor(cur_keyword_ids)) self.tokenizer = tokenizer self.start_len = input_ids.shape[1] def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: assert output_ids.shape[0] == 1, "Only support batch size 1 (yet)" # TODO offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len) self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] for keyword_id in self.keyword_ids: if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all(): return True outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] for keyword in self.keywords: if keyword in outputs: return True return False def build_logger(logger_name, logger_filename): global handler formatter = logging.Formatter( fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S", ) # Set the format of root handlers if not logging.getLogger().handlers: logging.basicConfig(level=logging.INFO) logging.getLogger().handlers[0].setFormatter(formatter) # Redirect stdout and stderr to loggers stdout_logger = logging.getLogger("stdout") stdout_logger.setLevel(logging.INFO) sl = StreamToLogger(stdout_logger, logging.INFO) sys.stdout = sl stderr_logger = logging.getLogger("stderr") stderr_logger.setLevel(logging.ERROR) sl = StreamToLogger(stderr_logger, logging.ERROR) sys.stderr = sl # Get logger logger = logging.getLogger(logger_name) logger.setLevel(logging.INFO) # Add a file handler for all loggers if handler is None: os.makedirs(LOGDIR, exist_ok=True) filename = os.path.join(LOGDIR, logger_filename) handler = logging.handlers.TimedRotatingFileHandler( filename, when='D', utc=True) handler.setFormatter(formatter) for name, item in logging.root.manager.loggerDict.items(): if isinstance(item, logging.Logger): item.addHandler(handler) return logger class StreamToLogger(object): """ Fake file-like stream object that redirects writes to a logger instance. """ def __init__(self, logger, log_level=logging.INFO): self.terminal = sys.stdout self.logger = logger self.log_level = log_level self.linebuf = '' def __getattr__(self, attr): return getattr(self.terminal, attr) def write(self, buf): temp_linebuf = self.linebuf + buf self.linebuf = '' for line in temp_linebuf.splitlines(True): # From the io.TextIOWrapper docs: # On output, if newline is None, any '\n' characters written # are translated to the system default line separator. # By default sys.stdout.write() expects '\n' newlines and then # translates them so this is still cross platform. if line[-1] == '\n': self.logger.log(self.log_level, line.rstrip()) else: self.linebuf += line def flush(self): if self.linebuf != '': self.logger.log(self.log_level, self.linebuf.rstrip()) self.linebuf = '' def disable_torch_init(): """ Disable the redundant torch default initialization to accelerate model creation. """ import torch setattr(torch.nn.Linear, "reset_parameters", lambda self: None) setattr(torch.nn.LayerNorm, "reset_parameters", lambda self: None) def violates_moderation(text): """ Check whether the text violates OpenAI moderation API. """ url = "https://api.openai.com/v1/moderations" headers = {"Content-Type": "application/json", "Authorization": "Bearer " + os.environ["OPENAI_API_KEY"]} text = text.replace("\n", "") data = "{" + '"input": ' + f'"{text}"' + "}" data = data.encode("utf-8") try: ret = requests.post(url, headers=headers, data=data, timeout=5) flagged = ret.json()["results"][0]["flagged"] except requests.exceptions.RequestException as e: flagged = False except KeyError as e: flagged = False return flagged def pretty_print_semaphore(semaphore): if semaphore is None: return "None" return f"Semaphore(value={semaphore._value}, locked={semaphore.locked()})"