|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import os |
|
import sys |
|
from threading import Thread |
|
|
|
import torch |
|
from PIL import Image |
|
from transformers import TextIteratorStreamer |
|
|
|
from deepseek_vl.utils.io import load_pretrained_model |
|
|
|
|
|
def load_image(image_file): |
|
image = Image.open(image_file).convert("RGB") |
|
return image |
|
|
|
|
|
def get_help_message(image_token): |
|
help_msg = ( |
|
f"\t\t DeepSeek-VL-Chat is a chatbot that can answer questions based on the given image. Enjoy it! \n" |
|
f"Usage: \n" |
|
f" 1. type `exit` to quit. \n" |
|
f" 2. type `{image_token}` to indicate there is an image. You can enter multiple images, " |
|
f"e.g '{image_token} is a dot, {image_token} is a cat, and what is it in {image_token}?'. " |
|
f"When you type `{image_token}`, the chatbot will ask you to input image file path. \n" |
|
f" 4. type `help` to get the help messages. \n" |
|
f" 5. type `new` to start a new conversation. \n" |
|
f" Here is an example, you can type: '<image_placeholder>Describe the image.'\n" |
|
) |
|
|
|
return help_msg |
|
|
|
|
|
@torch.inference_mode() |
|
def response( |
|
args, conv, pil_images, tokenizer, vl_chat_processor, vl_gpt, generation_config |
|
): |
|
prompt = conv.get_prompt() |
|
prepare_inputs = vl_chat_processor.__call__( |
|
prompt=prompt, images=pil_images, force_batchify=True |
|
).to(vl_gpt.device) |
|
|
|
|
|
inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs) |
|
|
|
streamer = TextIteratorStreamer( |
|
tokenizer=tokenizer, skip_prompt=True, skip_special_tokens=True |
|
) |
|
generation_config["inputs_embeds"] = inputs_embeds |
|
generation_config["attention_mask"] = prepare_inputs.attention_mask |
|
generation_config["streamer"] = streamer |
|
|
|
thread = Thread(target=vl_gpt.language_model.generate, kwargs=generation_config) |
|
thread.start() |
|
|
|
yield from streamer |
|
|
|
|
|
def get_user_input(hint: str): |
|
user_input = "" |
|
while user_input == "": |
|
try: |
|
user_input = input(f"{hint}") |
|
except KeyboardInterrupt: |
|
print() |
|
continue |
|
except EOFError: |
|
user_input = "exit" |
|
|
|
return user_input |
|
|
|
|
|
def chat(args, tokenizer, vl_chat_processor, vl_gpt, generation_config): |
|
image_token = vl_chat_processor.image_token |
|
help_msg = get_help_message(image_token) |
|
|
|
while True: |
|
print(help_msg) |
|
|
|
pil_images = [] |
|
conv = vl_chat_processor.new_chat_template() |
|
roles = conv.roles |
|
|
|
while True: |
|
|
|
user_input = get_user_input( |
|
f"{roles[0]} [{image_token} indicates an image]: " |
|
) |
|
|
|
if user_input == "exit": |
|
print("Chat program exited.") |
|
sys.exit(0) |
|
|
|
elif user_input == "help": |
|
print(help_msg) |
|
|
|
elif user_input == "new": |
|
os.system("clear") |
|
pil_images = [] |
|
conv = vl_chat_processor.new_chat_template() |
|
torch.cuda.empty_cache() |
|
print("New conversation started.") |
|
|
|
else: |
|
conv.append_message(conv.roles[0], user_input) |
|
conv.append_message(conv.roles[1], None) |
|
|
|
|
|
num_images = user_input.count(image_token) |
|
cur_img_idx = 0 |
|
|
|
while cur_img_idx < num_images: |
|
try: |
|
image_file = input( |
|
f"({cur_img_idx + 1}/{num_images}) Input the image file path: " |
|
) |
|
image_file = ( |
|
image_file.strip() |
|
) |
|
|
|
except KeyboardInterrupt: |
|
print() |
|
continue |
|
|
|
except EOFError: |
|
image_file = None |
|
|
|
if image_file and os.path.exists(image_file): |
|
pil_image = load_image(image_file) |
|
pil_images.append(pil_image) |
|
cur_img_idx += 1 |
|
|
|
elif image_file == "exit": |
|
print("Chat program exited.") |
|
sys.exit(0) |
|
|
|
else: |
|
print( |
|
f"File error, `{image_file}` does not exist. Please input the correct file path." |
|
) |
|
|
|
|
|
answer = "" |
|
answer_iter = response( |
|
args, |
|
conv, |
|
pil_images, |
|
tokenizer, |
|
vl_chat_processor, |
|
vl_gpt, |
|
generation_config, |
|
) |
|
sys.stdout.write(f"{conv.roles[1]}: ") |
|
for char in answer_iter: |
|
answer += char |
|
sys.stdout.write(char) |
|
sys.stdout.flush() |
|
|
|
sys.stdout.write("\n") |
|
sys.stdout.flush() |
|
conv.update_last_message(answer) |
|
|
|
|
|
|
|
def main(args): |
|
|
|
tokenizer, vl_chat_processor, vl_gpt = load_pretrained_model(args.model_path) |
|
generation_config = dict( |
|
pad_token_id=vl_chat_processor.tokenizer.eos_token_id, |
|
bos_token_id=vl_chat_processor.tokenizer.bos_token_id, |
|
eos_token_id=vl_chat_processor.tokenizer.eos_token_id, |
|
max_new_tokens=args.max_gen_len, |
|
use_cache=True, |
|
) |
|
if args.temperature > 0: |
|
generation_config.update( |
|
{ |
|
"do_sample": True, |
|
"top_p": args.top_p, |
|
"temperature": args.temperature, |
|
"repetition_penalty": args.repetition_penalty, |
|
} |
|
) |
|
else: |
|
generation_config.update({"do_sample": False}) |
|
|
|
chat(args, tokenizer, vl_chat_processor, vl_gpt, generation_config) |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--model_path", |
|
type=str, |
|
default="deepseek-ai/deepseek-vl-7b-chat", |
|
help="the huggingface model name or the local path of the downloaded huggingface model.", |
|
) |
|
parser.add_argument("--temperature", type=float, default=0.2) |
|
parser.add_argument("--top_p", type=float, default=0.95) |
|
parser.add_argument("--repetition_penalty", type=float, default=1.1) |
|
parser.add_argument("--max_gen_len", type=int, default=512) |
|
args = parser.parse_args() |
|
main(args) |
|
|