File size: 4,023 Bytes
f752a71
697a1f0
 
 
 
 
 
f752a71
 
697a1f0
 
 
7dbba05
d10e0fb
7dbba05
697a1f0
85c421c
94487fd
7dbba05
94487fd
7dbba05
 
94487fd
7dbba05
 
 
 
 
 
 
 
 
94487fd
 
 
 
7dbba05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f752a71
94487fd
f752a71
 
7dbba05
f752a71
 
 
 
 
 
 
7dbba05
 
 
 
 
 
 
697a1f0
7dbba05
697a1f0
94487fd
697a1f0
 
7dbba05
 
 
 
697a1f0
7dbba05
 
 
 
 
 
f752a71
 
 
 
 
 
 
7dbba05
f752a71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7dbba05
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, ChameleonProcessor, AutoTokenizer, StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
from threading import Thread
from PIL import Image
import requests


model_path = "facebook/chameleon-7b"
# model = ChameleonForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, device_map="auto")
# processor = ChameleonProcessor.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, device_map="auto")
model.eval()
processor = ChameleonProcessor.from_pretrained(model_path)
tokenizer = processor.tokenizer

# file_name, alt
multimodal_file = tuple[str, str]
# {'text': 'message here', 'files': []}
multimodal_message = list[str | multimodal_file] | multimodal_file
# todo: verify this type with gr.ChatInterface
message_t = dict[str, str | list[multimodal_file]]
history_t = list[tuple[str, str] | list[tuple[multimodal_message, multimodal_message]]]

def history_to_prompt(
        message,
        history: history_t,
        eot_id = "<reserved08706>",
        image_placeholder = "<image>"
    ):

    prompt = message["text"]
    images = [Image.open(f) for f in message["files"]]
    
    for turn in history:
        print("turn:", turn)
        # turn should be a tuple of user message and assistant message
        for message in turn:
            if isinstance(message, str):
                prompt += user_message
                prompt += eot_id
            if isinstance(message, list):
                for item in message:
                    if isinstance(item, str):
                        prompt += item
                    elif isinstance(item, tuple):
                        image_path, alt = item
                        prompt += image_placeholder
                        image = Image.open(requests.get(image_path, stream=True).raw)
                        images.append(image)
            else:
                prompt += f"(unhandled message type: {message})"
            prompt += eot_id
    return prompt, images

@spaces.GPU(duration=30)
def respond(
    message,
    history: history_t,
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    response = ""

    print(f"message: {message}\nhistory:\n\n{history}\n")
    prompt, images = history_to_prompt(message, history)
    print(f"prompt:\n\n{prompt}\n")

    # prompt = "I'm very intrigued by this work of art:<image>Please tell me about the artist."
    # image = Image.open(requests.get("https://uploads4.wikiart.org/images/paul-klee/death-for-the-idea-1915.jpg!Large.jpg", stream=True).raw)
    # images = [image]

    inputs = processor(prompt, images=images, return_tensors="pt").to(model.device, dtype=torch.bfloat16)

    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
    generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=20)

    try:
        # launch generation in the background
        thread = Thread(target=model.generate, kwargs=generation_kwargs)
        thread.start()

        partial_message = ""
        for new_token in streamer:
            partial_message += new_token
            yield partial_message
    except e:
        return f"Error: {e}"    


"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
    respond,
    multimodal=True,
    additional_inputs=[
        gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
        gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        ),
    ],
)


if __name__ == "__main__":
    demo.launch(debug=True)