Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,199 Bytes
c5b3aef c98b207 5cd56f1 c98b207 be961e6 c98b207 09399fd c98b207 09399fd c98b207 cbb017b c98b207 09399fd c98b207 5adecab c98b207 5cd56f1 c98b207 09399fd c98b207 09399fd c98b207 5cd56f1 c98b207 09399fd d18aa37 5cd56f1 c98b207 09399fd c98b207 09399fd c98b207 09399fd be961e6 09399fd c98b207 cf7a112 e7455bb cf7a112 e7455bb cf7a112 c98b207 e7455bb c98b207 e7455bb c98b207 0d0766f c98b207 bb45d22 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
import subprocess
subprocess.run(
'pip install flash-attn --no-build-isolation',
env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"},
shell=True
)
from threading import Thread
import torch
from PIL import Image
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoProcessor, TextIteratorStreamer
import os
import time
from huggingface_hub import hf_hub_download
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
HF_TOKEN = os.environ.get("HF_TOKEN", None)
MODEL_ID = os.environ.get("MODEL_ID")
MODEL_NAME = MODEL_ID.split("/")[-1]
TITLE = "<h1><center>VL-Chatbox</center></h1>"
DESCRIPTION = "<h3><center>MODEL: " + MODEL_NAME + "</center></h3>"
DEFAULT_SYSTEM = "You named Chatbox. You are a good assitant."
CSS = """
.duplicate-button {
margin: auto !important;
color: white !important;
background: black !important;
border-radius: 100vh !important;
}
"""
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
trust_remote_code=True
).to(0)
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
eos_token_id=processor.tokenizer.eos_token_id
@spaces.GPU(queue=False)
def stream_chat(message, history: list, system: str, temperature: float, max_new_tokens: int):
print(message)
conversation = [{"role": "user", "content": system or DEFAULT_SYSTEM}]
for prompt, answer in history:
conversation.extend([{"role": "user", "content": f"<|image_1|>\n{prompt}"}, {"role": "assistant", "content": answer}])
conversation.append({"role": "user", "content": message['text']})
if message["files"]:
image = Image.open(message["files"][-1]).convert('RGB')
else:
image = None
prompt = processor.tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
inputs = processor(prompt, images=image, return_tensors="pt").to(0)
generate_kwargs = dict(
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=True,
eos_token_id=eos_token_id,
)
if temperature == 0:
generate_kwargs["do_sample"] = False
generate_kwargs = {**inputs, **generate_kwargs}
streamer = TextIteratorStreamer(processor, **{"skip_special_tokens": True, "skip_prompt": True, 'clean_up_tokenization_spaces':False,})
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
yield buffer
chatbot = gr.Chatbot(height=450)
chat_input = gr.MultimodalTextbox(
interactive=True,
file_types=["image"],
placeholder="Enter message or upload file...",
show_label=False,
)
examples = gr.Examples(
examples=[{"text": "What is on the desk?", "files": ["./laptop.jpg"]},
{"text": "Where it is?", "files": ["./hotel.jpg"]},
{"text": "Can yo describe this image?", "files": ["./spacecat.png"]}],
inputs=chat_input,
)
with gr.Blocks(css=CSS) as demo:
gr.HTML(TITLE)
gr.HTML(DESCRIPTION)
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
gr.ChatInterface(
fn=stream_chat,
multimodal=True,
examples=examples,
textbox=chat_input,
chatbot=chatbot,
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
additional_inputs=[
gr.Text(
value="",
label="System",
render=False,
),
gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.8,
label="Temperature",
render=False,
),
gr.Slider(
minimum=128,
maximum=4096,
step=1,
value=1024,
label="Max new tokens",
render=False,
),
],
)
if __name__ == "__main__":
demo.queue(api_open=False).launch(show_api=False, share=False) |