File size: 5,413 Bytes
eacf0bd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
# A100 Zero GPU
import spaces
# TroL Package
import torch
from PIL import Image
from utils.utils import *
import torch.nn.functional as F
from trol.load_trol import load_trol
from torchvision.transforms.functional import pil_to_tensor
# Gradio Package
import time
import gradio as gr
from threading import Thread
from accelerate import Accelerator
from transformers import TextIteratorStreamer
from torchvision.transforms.functional import pil_to_tensor
# flash attention
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
# accel
accel = Accelerator()
# model selection
link = "TroL-7B" # [Select One] 'TroL-1.8B' | 'TroL-3.8B' | 'TroL-7B'
# User prompt
prompt_type="with_image" # Select one option "text_only", "with_image"
img_path='figures/demo.png'
question="What is the troll doing? Provide the detail in the image and imagine what the event happens."
# loading model
model, tokenizer = load_trol(link=link)
# cpu -> gpu
for param in model.parameters():
if not param.is_cuda:
param.data = param.to('cuda:0')
def threading_function(inputs, image_token_number, streamer, device, temperature, new_max_token, top_p):
# propagation
_inputs = model.eval_process(inputs=inputs,
data='demo',
tokenizer=tokenizer,
device=device,
img_token_number=image_token_number)
generation_kwargs = _inputs
generation_kwargs.update({'streamer': streamer})
generation_kwargs.update({'do_sample': True})
generation_kwargs.update({'max_new_tokens': new_max_token})
generation_kwargs.update({'top_p': top_p})
generation_kwargs.update({'temperature': temperature})
generation_kwargs.update({'use_cache': True})
return model.generate(**generation_kwargs)
@spaces.GPU
def bot_streaming(message, history, link, temperature, new_max_token, top_p):
try:
# prompt type -> input prompt
image_token_number = None
if len(message['files']) != 0:
# Image Load
image = pil_to_tensor(Image.open(Image.open(message['files'][0]).convert("RGB")).convert("RGB"))
if not "3.8B" in link:
image_token_number = 1225
image = F.interpolate(image.unsqueeze(0), size=(490, 490), mode='bicubic').squeeze(0)
inputs = [{'image': image, 'question': message['text']}]
else:
inputs = [{'question': message['text']}]
# Text Generation
with torch.inference_mode():
# kwargs
streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
# Threading generation
thread = Thread(target=threading_function, kwargs=dict(inputs=inputs,
image_token_number=image_token_number,
streamer=streamer,
device=accel.device,
temperature=temperature,
new_max_token=new_max_token,
top_p=top_p))
thread.start()
# generated text
generated_text = ""
for new_text in streamer:
generated_text += new_text
generated_text
# Text decoding
response = output_filtering(generated_text, model)
except:
response = "There may be unsupported format: ex) pdf, video, sound. Only supported is single image in this version."
# private log print
text = message['text']
files = message['files']
print(f'Text: {text}')
print(f'MM Files: {files}')
buffer = ""
for character in response:
buffer += character
time.sleep(0.015)
yield buffer
demo = gr.ChatInterface(fn=bot_streaming,
additional_inputs = [gr.Slider(0, 1, 0.9, label="temperature"), gr.Slider(1, 1024, 128, label="new_max_token"), gr.Slider(0, 1, 0.95, label="top_p")],
additional_inputs_accordion="Generation Hyperparameters",
theme=gr.themes.Soft(),
title="☄️Meteor",
description="Meteor is efficient 7B size Large Language and Vision Model built on the help of traversal of rationale.\n"
"Its inference speed highly depends on assinging non-scheduled GPU. (Therefore, once all GPUs are busy, then inference may be taken in infinity)",
stop_btn="Stop Generation", multimodal=True)
demo.launch()
# Generate
with torch.inference_mode():
_inputs = model.eval_process(inputs=inputs,
data='demo',
tokenizer=tokenizer,
device='cuda:0',
img_token_number=image_token_number)
generate_ids = model.generate(**_inputs, max_new_tokens=256, use_cache=True)
response = output_filtering(tokenizer.batch_decode(generate_ids, skip_special_tokens=False)[0], model)
print(response) |