File size: 4,023 Bytes
f752a71 697a1f0 f752a71 697a1f0 7dbba05 d10e0fb 7dbba05 697a1f0 85c421c 94487fd 7dbba05 94487fd 7dbba05 94487fd 7dbba05 94487fd 7dbba05 f752a71 94487fd f752a71 7dbba05 f752a71 7dbba05 697a1f0 7dbba05 697a1f0 94487fd 697a1f0 7dbba05 697a1f0 7dbba05 f752a71 7dbba05 f752a71 7dbba05 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, ChameleonProcessor, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
from threading import Thread
from PIL import Image
import requests
model_path = "facebook/chameleon-7b"
# model = ChameleonForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, device_map="auto")
# processor = ChameleonProcessor.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, device_map="auto")
model.eval()
processor = ChameleonProcessor.from_pretrained(model_path)
tokenizer = processor.tokenizer
# file_name, alt
multimodal_file = tuple[str, str]
# {'text': 'message here', 'files': []}
multimodal_message = list[str | multimodal_file] | multimodal_file
# todo: verify this type with gr.ChatInterface
message_t = dict[str, str | list[multimodal_file]]
history_t = list[tuple[str, str] | list[tuple[multimodal_message, multimodal_message]]]
def history_to_prompt(
message,
history: history_t,
eot_id = "<reserved08706>",
image_placeholder = "<image>"
):
prompt = message["text"]
images = [Image.open(f) for f in message["files"]]
for turn in history:
print("turn:", turn)
# turn should be a tuple of user message and assistant message
for message in turn:
if isinstance(message, str):
prompt += user_message
prompt += eot_id
if isinstance(message, list):
for item in message:
if isinstance(item, str):
prompt += item
elif isinstance(item, tuple):
image_path, alt = item
prompt += image_placeholder
image = Image.open(requests.get(image_path, stream=True).raw)
images.append(image)
else:
prompt += f"(unhandled message type: {message})"
prompt += eot_id
return prompt, images
@spaces.GPU(duration=30)
def respond(
message,
history: history_t,
system_message,
max_tokens,
temperature,
top_p,
):
response = ""
print(f"message: {message}\nhistory:\n\n{history}\n")
prompt, images = history_to_prompt(message, history)
print(f"prompt:\n\n{prompt}\n")
# prompt = "I'm very intrigued by this work of art:<image>Please tell me about the artist."
# image = Image.open(requests.get("https://uploads4.wikiart.org/images/paul-klee/death-for-the-idea-1915.jpg!Large.jpg", stream=True).raw)
# images = [image]
inputs = processor(prompt, images=images, return_tensors="pt").to(model.device, dtype=torch.bfloat16)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20)
try:
# launch generation in the background
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
partial_message = ""
for new_token in streamer:
partial_message += new_token
yield partial_message
except e:
return f"Error: {e}"
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
respond,
multimodal=True,
additional_inputs=[
gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
)
if __name__ == "__main__":
demo.launch(debug=True) |