mkshing's picture
change height
a778f68
raw
history blame
8.27 kB
import gradio as gr
from gradio.utils import async_lambda
import spaces
import time
import subprocess
import torch
from models.mllava import (
MLlavaProcessor,
LlavaForConditionalGeneration,
prepare_inputs,
)
from models.conversation import Conversation, SeparatorStyle
from transformers import TextIteratorStreamer
from transformers.utils import is_flash_attn_2_available
from threading import Thread
device = "cuda" if torch.cuda.is_available() else "cpu"
IMAGE_TOKEN = "<image>"
generation_kwargs = {
"max_new_tokens": 128,
"num_beams": 1,
"do_sample": False,
"no_repeat_ngram_size": 3,
}
if device == "cpu":
processor = None
model = None
else:
if not is_flash_attn_2_available():
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
processor = MLlavaProcessor.from_pretrained("TIGER-Lab/Mantis-8B-siglip-llama3")
processor.tokenizer.pad_token = processor.tokenizer.eos_token
model = LlavaForConditionalGeneration.from_pretrained(
"SakanaAI/Llama-3-EvoVLM-JP-v2",
torch_dtype=torch.float16,
attn_implementation="flash_attention_2",
device_map=device,
).eval()
# Set the system prompt
conv_template = Conversation(
system="<|start_header_id|>system<|end_header_id|>\n\nあなたは誠実で優秀な日本人のアシスタントです。特に指示が無い場合は、常に日本語で回答してください。",
roles=("user", "assistant"),
messages=(),
offset=0,
sep_style=SeparatorStyle.LLAMA_3,
sep="<|eot_id|>",
)
def get_chat_messages(history):
chat_history = []
user_role = conv_template.roles[0]
assistant_role = conv_template.roles[1]
for i, message in enumerate(history):
if isinstance(message[0], str):
chat_history.append({"role": user_role, "text": message[0]})
if i != len(history) - 1:
assert message[1], "The bot message is not provided, internal error"
chat_history.append({"role": assistant_role, "text": message[1]})
else:
assert not message[1], "the bot message internal error, get: {}".format(
message[1]
)
chat_history.append({"role": assistant_role, "text": ""})
return chat_history
def get_chat_images(history):
images = []
for message in history:
if isinstance(message[0], tuple):
images.extend(message[0])
return images
@spaces.GPU
def bot(message, history):
if not model:
print(message, history)
images = message["files"] if message["files"] else None
text = message["text"].strip()
if not text:
raise gr.Error("You must enter a message!")
num_image_tokens = text.count(IMAGE_TOKEN)
# modify text
if images and num_image_tokens < len(images):
if num_image_tokens != 0:
gr.Warning(
"The number of images uploaded is more than the number of <image> placeholders in the text. Will automatically prepend <image> to the text."
)
# prefix image tokens
text = IMAGE_TOKEN * (len(images) - num_image_tokens) + text
if images and num_image_tokens > len(images):
raise gr.Error(
"The number of images uploaded is less than the number of <image> placeholders in the text!"
)
current_messages = []
if images:
current_messages += [[(image,), None] for image in images]
if text:
current_messages += [[text, None]]
current_history = history + current_messages
chat_messages = get_chat_messages(current_history)
chat_images = get_chat_images(current_history)
# Generate!
inputs = prepare_inputs(None, chat_images, model, processor, history=chat_messages, **generation_kwargs)
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens=True)
inputs["streamer"] = streamer
thread = Thread(target=model.generate, kwargs=inputs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
time.sleep(0.01)
yield buffer
DESCRIPTION = """# 🐟 Llama-3-EvoVLM-JP-v2
🤗 [モデル一覧](https://huggingface.co/SakanaAI) | 📚 [技術レポート](https://arxiv.org/abs/2403.13187) | 📝 [ブログ](https://sakana.ai/evovlm-jp/) | 🐦 [Twitter](https://twitter.com/SakanaAILabs)
[Llama-3-EvoVLM-JP-v2](https://huggingface.co/SakanaAI/Llama-3-EvoVLM-JP-v2)は[Sakana AI](https://sakana.ai/)が進化的モデルマージを用いて開発した日本語視覚言語モデルです。入力した画像に対して、質疑応答することが可能です。より詳しくは、上記の技術レポートとブログをご参照ください。
"""
examples = [
{
"text": "1番目と2番目の画像に写っている動物の違いは何ですか?簡潔に説明してください。",
"files": ["./examples/image_0.jpg", "./examples/image_1.jpg"],
},
{
"text": "2枚の写真について、簡単にそれぞれ説明してください。",
"files": ["./examples/image_2.jpg", "./examples/image_3.jpg"],
},
]
chat = gr.ChatInterface(
fn=bot,
multimodal=True,
chatbot=gr.Chatbot(label="Chatbot", scale=1, height=500),
textbox=gr.MultimodalTextbox(
interactive=True,
file_types=["image"],
# file_count="multiple",
placeholder="Enter message or upload images. Please use <image> to indicate the position of uploaded images",
show_label=True,
),
examples=examples,
fill_height=False,
stop_btn=None,
)
with gr.Blocks(fill_height=True) as demo:
gr.Markdown(DESCRIPTION)
chat.render()
chat.examples_handler.load_input_event.then(
fn=async_lambda(lambda: [[], [], None]),
outputs=[chat.chatbot, chat.chatbot_state, chat.saved_input],
)
gr.Markdown(
"""
### チャットの方法
Llama-3-EvoVLM-JP-v2は、画像をテキストの好きな場所に入力として配置することができます。画像をアップロードする場所は、`<image>`というフレーズで指定できます。
モデルの推論時に、自動的に`<image>`が画像トークンに置き換えられます。また、画像のアップロード数が`<image>`の数よりも少ない場合、余分な`<image>`が削除されます。
逆に、画像のアップロード数が`<image>`の数よりも多い場合、自動的に`<image>`が追加されます。
### 注意事項
本モデルは実験段階のプロトタイプであり、研究開発の目的でのみ提供されています。商用利用や、障害が重大な影響を及ぼす可能性のある環境(ミッションクリティカルな環境)での使用には適していません。
本モデルの使用は、利用者の自己責任で行われ、その性能や結果については何ら保証されません。
Sakana AIは、本モデルの使用によって生じた直接的または間接的な損失に対して、結果に関わらず、一切の責任を負いません。
利用者は、本モデルの使用に伴うリスクを十分に理解し、自身の判断で使用することが必要です。
また、このデモでは、できる限り多くの皆様にお使いいただけるように、出力テキストのサイズを制限しております。"""
)
gr.Markdown(
"""
### Citation
```bibtex
@misc{Llama-3-EvoVLM-JP-v2,
url = {[https://huggingface.co/SakanaAI/Llama-3-EvoVLM-JP-v2](https://huggingface.co/SakanaAI/Llama-3-EvoVLM-JP-v2)},
title = {Llama-3-EvoVLM-JP-v2},
author = {Yuichi, Inoue and Takuya, Akiba and Shing, Makoto}
}
```
```bibtex
@misc{akiba2024evomodelmerge,
title = {Evolutionary Optimization of Model Merging Recipes},
author. = {Takuya Akiba and Makoto Shing and Yujin Tang and Qi Sun and David Ha},
year = {2024},
eprint = {2403.13187},
archivePrefix = {arXiv},
primaryClass = {cs.NE}
}
```"""
)
demo.queue().launch()