import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer, StoppingCriteria from threading import Thread import re import time from PIL import Image import torch import spaces import subprocess subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) torch.set_default_device('cuda') tokenizer = AutoTokenizer.from_pretrained( 'qnguyen3/nanoLLaVA', trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( 'qnguyen3/nanoLLaVA', torch_dtype=torch.float16, device_map='auto', trust_remote_code=True) class KeywordsStoppingCriteria(StoppingCriteria): def __init__(self, keywords, tokenizer, input_ids): self.keywords = keywords self.keyword_ids = [] self.max_keyword_len = 0 for keyword in keywords: cur_keyword_ids = tokenizer(keyword).input_ids if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: cur_keyword_ids = cur_keyword_ids[1:] if len(cur_keyword_ids) > self.max_keyword_len: self.max_keyword_len = len(cur_keyword_ids) self.keyword_ids.append(torch.tensor(cur_keyword_ids)) self.tokenizer = tokenizer self.start_len = input_ids.shape[1] def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len) self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] for keyword_id in self.keyword_ids: truncated_output_ids = output_ids[0, -keyword_id.shape[0]:] if torch.equal(truncated_output_ids, keyword_id): return True outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] for keyword in self.keywords: if keyword in outputs: return True return False def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: outputs = [] for i in range(output_ids.shape[0]): outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores)) return all(outputs) @spaces.GPU def bot_streaming(message, history): messages = [] if message["files"]: image = message["files"][-1]["path"] else: for i, hist in enumerate(history): if type(hist[0])==tuple: image = hist[0][0] image_turn = i if len(history) > 0 and image is not None: messages.append({"role": "user", "content": f'\n{history[1][0]}'}) messages.append({"role": "assistant", "content": history[1][1] }) for human, assistant in history[2:]: messages.append({"role": "user", "content": human }) messages.append({"role": "assistant", "content": assistant }) messages.append({"role": "user", "content": message['text']}) elif len(history) > 0 and image is None: for human, assistant in history: messages.append({"role": "user", "content": human }) messages.append({"role": "assistant", "content": assistant }) messages.append({"role": "user", "content": message['text']}) elif len(history) == 0 and image is not None: messages.append({"role": "user", "content": f"\n{message['text']}"}) elif len(history) == 0 and image is None: messages.append({"role": "user", "content": message['text'] }) # if image is None: # gr.Error("You need to upload an image for LLaVA to work.") image = Image.open(image).convert("RGB") text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True) text_chunks = [tokenizer(chunk).input_ids for chunk in text.split('')] input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1], dtype=torch.long).unsqueeze(0) stop_str = '<|im_end|>' keywords = [stop_str] stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) image_tensor = model.process_images([image], model.config).to(dtype=model.dtype) generation_kwargs = dict(input_ids=input_ids, images=image_tensor, streamer=streamer, max_new_tokens=100, stopping_criteria=[stopping_criteria]) generated_text = "" thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() text_prompt =f"<|im_start|>user\n{message['text']}<|im_end|>" buffer = "" for new_text in streamer: buffer += new_text generated_text_without_prompt = buffer[len(text_prompt):] time.sleep(0.04) yield generated_text_without_prompt demo = gr.ChatInterface(fn=bot_streaming, title="LLaVA NeXT", examples=[{"text": "What is on the flower?", "files":["./bee.jpg"]}, {"text": "How to make this pastry?", "files":["./baklava.png"]}], description="Try [LLaVA NeXT](https://huggingface.co/docs/transformers/main/en/model_doc/llava_next) in this demo (more specifically, the [Mistral-7B variant](https://huggingface.co/llava-hf/llava-v1.6-mistral-7b-hf)). Upload an image and start chatting about it, or simply try one of the examples below. If you don't upload an image, you will receive an error.", stop_btn="Stop Generation", multimodal=True) demo.launch(debug=True)