Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,088 Bytes
c5b3aef c98b207 5cd56f1 c98b207 be961e6 c98b207 09399fd c98b207 09399fd c98b207 cbb017b c98b207 09399fd c98b207 5adecab 2692054 c98b207 2692054 c98b207 3193581 09399fd c98b207 3193581 de27ed6 c98b207 2692054 6904764 09399fd 3193581 6904764 999df98 c98b207 3090ac5 c98b207 09399fd c98b207 6904764 09399fd c98b207 09399fd be961e6 09399fd c98b207 cf7a112 e7455bb 60e7596 5531c0c 60e7596 e7455bb c98b207 cff68a5 c98b207 0d0766f c98b207 5531c0c c98b207 bb45d22 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
import subprocess
subprocess.run(
'pip install flash-attn --no-build-isolation',
env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
shell=True
)
from threading import Thread
import torch
from PIL import Image
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoProcessor, TextIteratorStreamer
import os
import time
from huggingface_hub import hf_hub_download
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
HF_TOKEN = os.environ.get("HF_TOKEN", None)
MODEL_ID = os.environ.get("MODEL_ID")
MODEL_NAME = MODEL_ID.split("/")[-1]
TITLE = "<h1><center>VL-Chatbox</center></h1>"
DESCRIPTION = "<h3><center>MODEL: " + MODEL_NAME + "</center></h3>"
CSS = """
.duplicate-button {
margin: auto !important;
color: white !important;
background: black !important;
border-radius: 100vh !important;
}
"""
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
trust_remote_code=True
).to(0)
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
eos_token_id=processor.tokenizer.eos_token_id
@spaces.GPU(queue=False)
def stream_chat(message, history: list, temperature: float, max_new_tokens: int):
print(message)
conversation = []
for prompt, answer in history:
conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
if message["files"]:
image = Image.open(message["files"][-1])
conversation.append({"role": "user", "content": f"<|image_1|>\n{message['text']}"})
else:
if len(history) == 0:
raise gr.Error("Please upload an image first.")
image = None
conversation.append({"role": "user", "content": message['text']})
print(conversation)
inputs = processor.tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
inputs_ids = processor(inputs, image, return_tensors="pt").to(0)
streamer = TextIteratorStreamer(processor, **{"skip_special_tokens": True, "skip_prompt": True, 'clean_up_tokenization_spaces':False,})
generate_kwargs = dict(
streamer=streamer,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=True,
eos_token_id=eos_token_id,
)
if temperature == 0:
generate_kwargs["do_sample"] = False
generate_kwargs = {**inputs_ids, **generate_kwargs}
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
yield buffer
chatbot = gr.Chatbot(height=450)
chat_input = gr.MultimodalTextbox(
interactive=True,
file_types=["image"],
placeholder="Enter message or upload file...",
show_label=False,
)
EXAMPLES = [
{"text": "What is on the desk?", "files": ["./laptop.jpg"]},
{"text": "Where it is?", "files": ["./hotel.jpg"]},
{"text": "Can yo describe this image?", "files": ["./spacecat.png"]}
]
with gr.Blocks(css=CSS) as demo:
gr.HTML(TITLE)
gr.HTML(DESCRIPTION)
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
gr.ChatInterface(
fn=stream_chat,
multimodal=True,
examples=EXAMPLES,
textbox=chat_input,
chatbot=chatbot,
fill_height=True,
#additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
additional_inputs=[
gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.8,
label="Temperature",
render=False,
),
gr.Slider(
minimum=128,
maximum=4096,
step=1,
value=1024,
label="Max new tokens",
render=False,
),
],
)
if __name__ == "__main__":
demo.queue(api_open=False).launch(show_api=False, share=False) |