from threading import Thread import torch from PIL import Image import gradio as gr import spaces from transformers import AutoModelForCausalLM, AutoTokenizer import os import time from huggingface_hub import hf_hub_download os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" HF_TOKEN = os.environ.get("HF_TOKEN", None) MODEL_ID = os.environ.get("MODEL_ID") MODEL_NAME = MODEL_ID.split("/")[-1] TITLE = "

VL-Chatbox

" DESCRIPTION = "

MODEL LOADED: " + MODEL_NAME + "

" DEFAULT_SYSTEM = "You named Chatbox. You are a good assitant." CSS = """ .duplicate-button { margin: auto !important; color: white !important; background: black !important; border-radius: 100vh !important; } """ filenames = [ "config.json", "generation_config.json", "model-00001-of-00004.safetensors", "model-00002-of-00004.safetensors", "model-00003-of-00004.safetensors", "model-00004-of-00004.safetensors", "model.safetensors.index.json", "special_tokens_map.json", "tokenizer.json", "tokenizer_config.json" ] for filename in filenames: downloaded_model_path = hf_hub_download( repo_id=MODEL_ID, filename=filename, local_dir="./model/" ) for items in os.listdir("./model"): print(items) # def no_logger(): # logging.config.dictConfig({ # 'version': 1, # 'disable_existing_loggers': True, # }) model = AutoModelForCausalLM.from_pretrained( "./model/", torch_dtype=torch.float16, low_cpu_mem_usage=True, trust_remote_code=True ).to(0) tokenizer = AutoTokenizer.from_pretrained("./model/",trust_remote_code=True) vision_tower = model.get_vision_tower() vision_tower.load_model() vision_tower.to(device="cuda", dtype=torch.float16) image_processor = vision_tower.image_processor tokenizer.pad_token = tokenizer.eos_token # Define terminators (if applicable, adjust as needed) terminators = [ tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>") ] @spaces.GPU def stream_chat(message, history: list, system: str, temperature: float, max_new_tokens: int): print(message) conversation = [{"role": "system", "content": system or DEFAULT_SYSTEM}] for prompt, answer in history: conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}]) conversation.append({"role": "user", "content": message['text']}) if message["files"]: image = Image.open(message["files"][0]).convert('RGB') # Process the conversation text inputs = model.build_conversation_input_ids( tokenizer, query=message['text'], image=image, image_processor=image_processor, ) input_ids = inputs["input_ids"].to(device='cuda', non_blocking=True) images = inputs["image"].to(dtype=torch.float16, device='cuda', non_blocking=True) else: input_ids = tokenizer.apply_chat_template( conversation, add_generation_prompt=True, return_tensors="pt" ).to(model.device) images = None generate_kwargs = dict( input_ids=input_ids, max_new_tokens=max_new_tokens, temperature=temperature, do_sample=True, num_beams=1, eos_token_id=terminators, images=images ) if temperature == 0: generate_kwargs["do_sample"] = False output_ids=model.generate(**generate_kwargs) input_token_len = input_ids.shape[1] outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0] outputs = outputs.strip() for i in range(len(outputs)): time.sleep(0.05) yield outputs[: i + 1] chatbot = gr.Chatbot(height=450) chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False) with gr.Blocks(css=CSS) as demo: gr.HTML(TITLE) gr.HTML(DESCRIPTION) gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button") gr.ChatInterface( fn=stream_chat, multimodal=True, chatbot=chatbot, textbox=chat_input, fill_height=True, additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False), additional_inputs=[ gr.Text( value="", label="System", render=False, ), gr.Slider( minimum=0, maximum=1, step=0.1, value=0.8, label="Temperature", render=False, ), gr.Slider( minimum=128, maximum=4096, step=1, value=1024, label="Max new tokens", render=False, ), ], ) if __name__ == "__main__": demo.queue(api_open=False).launch(show_api=False, share=False)