File size: 3,836 Bytes
37448c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a6e6449
 
 
37448c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d1d5efe
37448c1
 
 
 
 
 
 
 
a6e6449
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37448c1
 
a6e6449
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import time

import gradio as gr
import torch
from transformers import AutoModelForCausalLM
from transformers import AutoProcessor

model_id = "microsoft/Phi-3.5-vision-instruct"

# Note: set _attn_implementation='eager' if you don't have flash_attn installed
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map = "auto",
    trust_remote_code = True,
    torch_dtype = torch.bfloat16,
    _attn_implementation = 'eager'
)
device = torch.device("cpu")
model.to(device)

# for best performance, use num_crops=4 for multi-frame, num_crops=16 for single-frame.
processor = AutoProcessor.from_pretrained(model_id,
                                          trust_remote_code = True,
                                          num_crops = 4
                                          )

user_prompt = '<|user|>\n'
assistant_prompt = '<|assistant|>\n'
prompt_suffix = "<|end|>\n"

title_html = """
<h2> This space uses the model/microsoft/Phi-3.5-vision-instruct </h2>
"""

def call_model(raw_image = None, text_input = None):
    prompt = f"{user_prompt}<|image_1|>\n{text_input}{prompt_suffix}{assistant_prompt}"
    image = raw_image.convert("RGB")

    inputs = processor(prompt, image, return_tensors = "pt").to("cpu:0")
    generate_ids = model.generate(**inputs,
                                  max_new_tokens = 1000,
                                  eos_token_id = processor.tokenizer.eos_token_id,
                                  )
    generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:]
    response = processor.batch_decode(generate_ids,
                                      skip_special_tokens = True,
                                      clean_up_tokenization_spaces = False)[0]
    return response


def get_model_memory_footprint(model_):
    footprint = model_.get_memory_footprint()
    return f"Footprint of the model in MBs:  {footprint / 1e+6}Mb"


def process(raw_image, prompt):
    print("start...")
    start_time = time.time()
    memory_usage = get_model_memory_footprint(model)
    model_response = call_model(raw_image = raw_image, text_input = prompt)
    end_time = time.time()
    execution_time = end_time - start_time
    execution_time_min = round((execution_time / 60), 2)
    print(f"Execution time: {execution_time:.4f} seconds")
    print(f"Execution time: {execution_time_min:.2f} min")
    return memory_usage, model_response, execution_time_min


with gr.Blocks() as demo:
    gr.HTML(title_html)
    gr.Markdown("""
        NOTES :
        - The performance of this model is low since it runs on a CPU and a free space, it takes 1min minimum !.
        - If the input text in not specified the model will describe the image, that will take more time
    """)
    with gr.Row():
        with gr.Column():
            _raw_image = gr.Image(type = 'pil')
            user_input = gr.Textbox(label = "What do you want to ask?")
            submit_btn = gr.Button(value = "Submit")
        with gr.Column():
            memory = gr.Textbox(label = "Memory usage")
            results = gr.Textbox(label = "Model response")
            exec_time = gr.Textbox(label = "Execution time (min)")

    submit_btn.click(
        process, inputs = [_raw_image, user_input], outputs = [memory, results, exec_time]
    )

    gr.Examples(
        examples=[
            ["assets/img.jpg", 'after you can split horizontally  the image into 6 rows,  extract all text into JSON format. ignore "Au-dessous de Normal" and "Au-dessus de Normal"'],
            ["assets/cats.jpg", 'how many cats are here? and what are they doing ?'],
            ["assets/demo.jpg", 'is it night time ?'],
        ],
        inputs=[_raw_image, user_input],
        outputs=[memory, results, exec_time],
        fn=process,
        label="Examples",
    )


if __name__ == '__main__':
    demo.launch()