File size: 3,426 Bytes
168da77
 
 
 
 
06de88f
bfd4b05
5ae7f9c
d364219
 
 
 
 
 
 
 
 
168da77
 
 
 
0963b2f
168da77
d3f5533
 
168da77
d364219
0963b2f
 
5659ce7
69eca47
0963b2f
8c734d9
d1df8d3
 
5659ce7
d1df8d3
1b08433
d1df8d3
5659ce7
1b08433
5659ce7
0963b2f
d1df8d3
5ae7f9c
1b08433
0963b2f
 
1b08433
 
 
 
 
 
 
 
 
d1df8d3
 
 
 
1b08433
8c734d9
1b08433
 
 
 
 
 
daa8caf
317d81f
0963b2f
 
daa8caf
 
0963b2f
 
153de5a
2c4c1d7
 
daa8caf
0963b2f
daa8caf
d1df8d3
daa8caf
0963b2f
5ae7f9c
8c734d9
5ae7f9c
 
 
8c734d9
1c699c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
import gradio as gr
from transformers import TextIteratorStreamer, AutoProcessor, LlavaForConditionalGeneration
from PIL import Image
import threading
import spaces
import accelerate
import time

DESCRIPTION = '''
<div>
<h1 style="text-align: center;">Krypton πŸ•‹</h1>
<p>This uses an Open Source model from <a href="https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers"><b>xtuner/llava-llama-3-8b-v1_1-transformers</b></a></p>
</div>
'''

model_id = "xtuner/llava-llama-3-8b-v1_1-transformers"
model = LlavaForConditionalGeneration.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    low_cpu_mem_usage=True
)

model.to('cuda')

processor = AutoProcessor.from_pretrained(model_id)

# Confirming and setting the eos_token_id (if necessary)
model.generation_config.eos_token_id = processor.tokenizer.eos_token_id

@spaces.GPU(duration=120)
def krypton(input, history):
  print(f"Input: {input}")  # Debug input
    print(f"History: {history}")  # Debug history

    if input["files"]:
        print("Found the image\n")
        image_path = input["files"][-1]["path"] if isinstance(input["files"][-1], dict) else input["files"][-1]
        print(f"Image path: {image_path}")
    else:
        image_path = None
        for hist in history:
            if isinstance(hist[0], tuple):
                image_path = hist[0][0]
    
    if not image_path:
        gr.Error("You need to upload an image for Krypton to work.")
        return

    try:
        image = Image.open(image_path)
        print(f"Image open: {image}")
    except Exception as e:
        print(f"Error opening image: {e}")
        gr.Error("Failed to open the image.")
        return

    # Adding more context to the prompt with a placeholder for the image
    prompt = f"user: Here is an image and a question about it.\n<image>{input['text']}\nassistant: "
    print("Made the prompt")

    try:
        inputs = processor(text=prompt, images=image, return_tensors='pt').to('cuda', torch.float16)
        print(f"Processed inputs: {inputs}")
    except Exception as e:
        print(f"Error processing inputs: {e}")
        gr.Error("Failed to process the inputs.")
        return
    
    # Streamer
    print('About to init streamer')
    streamer = TextIteratorStreamer(processor.tokenizer, skip_special_tokens=False, skip_prompt=True)
    
    # Generation kwargs
    generation_kwargs = dict(
        inputs=inputs['input_ids'],
        attention_mask=inputs['attention_mask'],
        streamer=streamer,
        max_new_tokens=1024,
        do_sample=False
    )
    
    thread = threading.Thread(target=model.generate, kwargs=generation_kwargs)
    print('Thread about to start')
    thread.start()
    
    buffer = ""
    # time.sleep(0.5)
    for new_text in streamer:
        buffer += new_text
        generated_text_without_prompt = buffer
        # time.sleep(0.06)
        yield generated_text_without_prompt


chatbot = gr.Chatbot(height=600, label="Krypt AI")
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter your question or upload an image.", show_label=False)

with gr.Blocks(fill_height=True) as demo:
    gr.Markdown(DESCRIPTION)
    gr.ChatInterface(
        fn=krypton,
        chatbot=chatbot,
        fill_height=True,
        multimodal=True,
        textbox=chat_input,
    )

demo.queue(api_open=False)
demo.launch(show_api=False, share=False)