File size: 4,883 Bytes
94e7301
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a701996
 
 
94e7301
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import spaces

import time
from threading import Thread

import gradio as gr
import torch
from PIL import Image
from transformers import AutoProcessor
from llava.constants import (
    IMAGE_TOKEN_INDEX,
    DEFAULT_IMAGE_TOKEN,
    DEFAULT_IM_START_TOKEN,
    DEFAULT_IM_END_TOKEN,
    IMAGE_PLACEHOLDER,
)
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import (
    process_images,
    tokenizer_image_token,
    get_model_name_from_path,
)
from io import BytesIO
import requests
import os
from conversation import Conversation, SeparatorStyle

model_id = "ytu-ce-cosmos/Turkish-LLaVA-v0.1"

disable_torch_init()
model_name = get_model_name_from_path(model_id)
tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_id, None, model_name
)

def load_image(image_file):
    if image_file.startswith("http") or image_file.startswith("https"):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert("RGB")
    elif os.path.exists(image_file):
        image = Image.open(image_file).convert("RGB")
    else:
        raise FileNotFoundError(f"Görüntü dosyası {image_file} bulunamadı.")
    return image

def infer_single_image(model_id, image_file, prompt):
    image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
    if IMAGE_PLACEHOLDER in prompt:
        if model.config.mm_use_im_start_end:
            prompt = re.sub(IMAGE_PLACEHOLDER, image_token_se, prompt)
        else:
            prompt = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, prompt)
    else:
        if model.config.mm_use_im_start_end:
            prompt = image_token_se + "\n" + prompt
        else:
            prompt = DEFAULT_IMAGE_TOKEN + "\n" + prompt

    conv = Conversation(
        system="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nSen bir yapay zeka asistanısın. Kullanıcı sana bir görev verecek. Amacın görevi olabildiğince sadık bir şekilde tamamlamak. Görevi yerine getirirken adım adım düşün ve adımlarını gerekçelendir.""",
        roles=("<|start_header_id|>user<|end_header_id|>\n\n", "<|start_header_id|>assistant<|end_header_id|>\n\n"),
        version="llama3",
        messages=[],
        offset=0,
        sep_style=SeparatorStyle.MPT,
        sep="<|eot_id|>",
    )
    conv.append_message(conv.roles[0], prompt)
    conv.append_message(conv.roles[1], None)
    full_prompt = conv.get_prompt()

    print("full prompt: ", full_prompt)

    image = load_image(image_file)
    image_tensor = process_images(
        [image],
        image_processor,
        model.config
    ).to(model.device, dtype=torch.float16)

    input_ids = (
        tokenizer_image_token(full_prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
        .unsqueeze(0)
        .cuda()
    )

    with torch.inference_mode():
        output_ids = model.generate(
            input_ids,
            images=image_tensor,
            image_sizes=[image.size],
            do_sample=False,
            max_new_tokens=512,
            use_cache=True,
        )

    output = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
    return output

@spaces.GPU
def bot_streaming(message, history):
    print(message)
    if message["files"]:
        if type(message["files"][-1]) == dict:
            image = message["files"][-1]["path"]
        else:
            image = message["files"][-1]
    else:
        for hist in history:
            if type(hist[0]) == tuple:
                image = hist[0][0]
    try:
        if image is None:
            gr.Error("LLaVA'nın çalışması için bir resim yüklemeniz gerekir.")
    except NameError:
        gr.Error("LLaVA'nın çalışması için bir resim yüklemeniz gerekir.")

    prompt = message['text']

    result = infer_single_image(model_id, image, prompt)
    
    print(result)
    
    yield result

chatbot = gr.Chatbot(scale=1)
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Mesaj girin veya dosya yükleyin...", show_label=False)

with gr.Blocks(fill_height=True) as demo:
    gr.ChatInterface(
        fn=bot_streaming,
        title="Cosmos LLaVA",
        examples=[{"text": "Bu kitabın adı ne?", "files": ["./book.jpg"]},
                  {"text": "Çiçeğin üzerinde ne var?", "files": ["./bee.jpg"]},
                  {"text": "Bu tatlı nasıl yapılır?", "files": ["./baklava.png"]}],
        description="",
        stop_btn="Stop Generation",
        multimodal=True,
        textbox=chat_input,
        chatbot=chatbot,
    )

demo.queue(api_open=False)
demo.launch(show_api=False, share=False)