import os import torch import spaces import gradio as gr from PIL import Image from transformers.utils import move_cache from huggingface_hub import snapshot_download from transformers import AutoModelForCausalLM, AutoTokenizer # https://huggingface.co/THUDM/cogvlm2-llama3-chat-19B MODEL_PATH = "THUDM/cogvlm2-llama3-chat-19B" # https://huggingface.co/THUDM/cogvlm2-llama3-chat-19B-int4 # MODEL_PATH = "THUDM/cogvlm2-llama3-chat-19B-int4" ### DOWNLOAD ### os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1' MODEL_PATH = snapshot_download(MODEL_PATH) move_cache() DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' TORCH_TYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8 else torch.float16 ## MODEL ## tokenizer = AutoTokenizer.from_pretrained( MODEL_PATH, trust_remote_code=True ) ## TOKENIZER ## model = AutoModelForCausalLM.from_pretrained( MODEL_PATH, torch_dtype=TORCH_TYPE, trust_remote_code=True, ).to(DEVICE).eval() text_only_template = """USER: {} ASSISTANT:""" @spaces.GPU def generate_caption(image, prompt): print(DEVICE) # Process the image and the prompt # image = Image.open(image_path).convert('RGB') image = image.convert('RGB') query = "USER: %s ASSISTANT:" % prompt input_by_model = model.build_conversation_input_ids( tokenizer, query=query, history=[], images=[image], template_version='chat' ) inputs = { 'input_ids': input_by_model['input_ids'].unsqueeze(0).to(DEVICE), 'token_type_ids': input_by_model['token_type_ids'].unsqueeze(0).to(DEVICE), 'attention_mask': input_by_model['attention_mask'].unsqueeze(0).to(DEVICE), 'images': [[input_by_model['images'][0].to(DEVICE).to(TORCH_TYPE)]] if image is not None else None, } gen_kwargs = { "max_new_tokens": 2048, "pad_token_id": 128002, } with torch.no_grad(): outputs = model.generate(**inputs, **gen_kwargs) outputs = outputs[:, inputs['input_ids'].shape[1]:] response = tokenizer.decode(outputs[0]) response = response.split("<|end_of_text|>")[0] print("\nCogVLM2:", response) return response ## make predictions via api ## # https://www.gradio.app/guides/getting-started-with-the-python-client#connecting-a-general-gradio-app demo = gr.Interface( fn=generate_caption, inputs=[gr.Image(type="pil", label="Upload Image"), gr.Textbox(label="Prompt", value="Describe the image in great detail")], outputs=gr.Textbox(label="Generated Caption") ) # Launch the interface demo.launch(share=True)