File size: 5,656 Bytes
1cea0e1
 
 
 
 
 
 
 
11eab42
1cea0e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11eab42
 
 
 
 
1cea0e1
 
 
11eab42
9d33c66
70f2494
 
11eab42
 
70f2494
11eab42
 
 
70f2494
11eab42
70f2494
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1cea0e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11eab42
1cea0e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9d33c66
 
 
 
70f2494
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import spaces
import os
import re
import gradio as gr
import torch
from transformers import AutoModelForCausalLM
from transformers import TextIteratorStreamer
from threading import Thread
from PIL import Image

model_name = 'AIDC-AI/Ovis1.6-Gemma2-9B'

# load model
model = AutoModelForCausalLM.from_pretrained(model_name,
                                             torch_dtype=torch.bfloat16,
                                             multimodal_max_length=8192,
                                             trust_remote_code=True).to(device='cuda')
text_tokenizer = model.get_text_tokenizer()
visual_tokenizer = model.get_visual_tokenizer()
streamer = TextIteratorStreamer(text_tokenizer, skip_prompt=True, skip_special_tokens=True)
image_placeholder = '<image>'
cur_dir = os.path.dirname(os.path.abspath(__file__))


@spaces.GPU
def ovis_chat(message, history):
    try :
        image_input = Image.open(message["files"][0]).convert("RGB")
    except : 
        image_input = None
    # preprocess inputs
    conversations = []
    response = ""
    text_input = message["text"]
    for msg in history:
        # case history entry pair only has text
        if isinstance(msg[0],str):
            conversations.append({
                "from": "human",
                "value": msg[0]
            })
            conversations.append({
                "from": "gpt",
                "value": msg[1]
            })
        # case history pair has an image
        elif isinstance(msg[0],tuple): 
            # case user did not pass a new image
            # we override the none with the new image
            if image_input is None : 
                # if someone uploads a file and not an image this will break
                image_input = Image.open(msg[0][0]).convert("RGB")
            # we always process the text
            conversations.append({
                "from": "human",
                "value": msg[1][0]
            })
            conversations.append({
                "from": "gpt",
                "value": msg[1][1]
            })

    text_input = text_input.replace(image_placeholder, '')
    conversations.append({
        "from": "human",
        "value": text_input
    })
    if image_input is not None:
        conversations[0]["value"] = image_placeholder + '\n' + conversations[0]["value"]
    prompt, input_ids, pixel_values = model.preprocess_inputs(conversations, [image_input])
    attention_mask = torch.ne(input_ids, text_tokenizer.pad_token_id)
    input_ids = input_ids.unsqueeze(0).to(device=model.device)
    attention_mask = attention_mask.unsqueeze(0).to(device=model.device)
    if image_input is None:
        pixel_values = [None]
    else:
        pixel_values = [pixel_values.to(dtype=visual_tokenizer.dtype, device=visual_tokenizer.device)]

    with torch.inference_mode():
        gen_kwargs = dict(
            max_new_tokens=512,
            do_sample=False,
            top_p=None,
            top_k=None,
            temperature=None,
            repetition_penalty=None,
            eos_token_id=model.generation_config.eos_token_id,
            pad_token_id=text_tokenizer.pad_token_id,
            use_cache=True
        )
    response = ""
    thread = Thread(target=model.generate, 
                kwargs={"inputs": input_ids,
                        "pixel_values": pixel_values,
                        "attention_mask": attention_mask,
                        "streamer": streamer,
                        **gen_kwargs})
    thread.start()
    for new_text in streamer:
        response += new_text
        yield response
    thread.join()

def clear_chat():
    return [], None, ""

with open(f"{cur_dir}/resource/logo.svg", "r", encoding="utf-8") as svg_file:
    svg_content = svg_file.read()
font_size = "2.5em"
svg_content = re.sub(r'(<svg[^>]*)(>)', rf'\1 height="{font_size}" style="vertical-align: middle; display: inline-block;"\2', svg_content)
html = f"""
<p align="center" style="font-size: {font_size}; line-height: 1;">
    <span style="display: inline-block; vertical-align: middle;">{svg_content}</span>
    <span style="display: inline-block; vertical-align: middle;">{model_name.split('/')[-1]}</span>
</p>
<center><font size=3><b>Ovis</b> has been open-sourced on <a href='https://huggingface.co/{model_name}'>😊 Huggingface</a> and <a href='https://github.com/AIDC-AI/Ovis'>🌟 GitHub</a>. If you find Ovis useful, a like❤️ or a star🌟 would be appreciated.</font></center>
"""

latex_delimiters_set = [{
        "left": "\\(",
        "right": "\\)",
        "display": False 
    }, {
        "left": "\\begin{equation}",
        "right": "\\end{equation}",
        "display": True 
    }, {
        "left": "\\begin{align}",
        "right": "\\end{align}",
        "display": True
    }, {
        "left": "\\begin{alignat}",
        "right": "\\end{alignat}",
        "display": True
    }, {
        "left": "\\begin{gather}",
        "right": "\\end{gather}",
        "display": True
    }, {
        "left": "\\begin{CD}",
        "right": "\\end{CD}",
        "display": True
    }, {
        "left": "\\[",
        "right": "\\]",
        "display": True
    }]

#     send_click_event = send_btn.click(submit_chat, [chatbot, text_input], [chatbot, text_input]).then(ovis_chat,[chatbot, image_input],chatbot)
#     submit_event = text_input.submit(submit_chat, [chatbot, text_input], [chatbot, text_input]).then(ovis_chat,[chatbot, image_input],chatbot)
#     clear_btn.click(clear_chat, outputs=[chatbot, image_input, text_input])

demo = gr.ChatInterface(fn=ovis_chat, textbox=gr.MultimodalTextbox(),multimodal=True)
demo.launch(debug=True)