File size: 3,908 Bytes
f6b7ba0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4dd73d1
f6b7ba0
 
 
 
 
 
3139a51
f6b7ba0
 
 
4dd73d1
f6b7ba0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a5f9027
a9a69bd
 
f6b7ba0
 
 
d79744b
 
 
 
 
 
 
 
f6b7ba0
 
 
 
 
 
 
 
d79744b
 
f6b7ba0
 
4dd73d1
a9a69bd
 
 
 
 
 
4dd73d1
f6b7ba0
 
 
 
 
 
a337221
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import gradio as gr
from transformers import AutoModelForCausalLM, AutoProcessor, TextIteratorStreamer
from threading import Thread
import re
import time 
from PIL import Image
import torch
import spaces

processor = AutoProcessor.from_pretrained("ucsahin/TraVisionLM-base", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("ucsahin/TraVisionLM-base", trust_remote_code=True)

model.to("cuda:0")

@spaces.GPU
def bot_streaming(message, history, max_tokens, temperature, top_p, top_k, repetition_penalty):
    print(max_tokens, temperature, top_p, top_k, repetition_penalty)
    if message.files:
        image = message.files[-1].path
    else:
        # if there's no image uploaded for this turn, look for images in the past turns
        for hist in history:
            if type(hist[0])==tuple:
                image = hist[0][0]

    if image is None:
        gr.Error("Lütfen önce bir resim yükleyin.")

    prompt = f"{message.text}"
    image = Image.open(image).convert("RGB")
    inputs = processor(text=prompt, images=image, return_tensors="pt").to("cuda:0")

    streamer = TextIteratorStreamer(processor, **{"skip_special_tokens": True})
    generation_kwargs = dict(
        inputs, streamer=streamer, max_new_tokens=max_tokens, 
        do_sample=True, temperature=temperature, top_p=top_p, 
        top_k=top_k, repetition_penalty=repetition_penalty
    )
    generated_text = ""

    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    text_prompt = f"{message.text}\n"

    buffer = ""
    for new_text in streamer:
        buffer += new_text
        generated_text_without_prompt = buffer[len(text_prompt):]
    
        time.sleep(0.04)
        yield generated_text_without_prompt


gr.set_static_paths(paths=["static/images/"])
logo_path = "static/images/logo-color-v2.png"

PLACEHOLDER = f"""
<div style="display: flex; flex-direction: column; align-items: center; text-align: center; margin: 30px">
    <img src="/file={logo_path}" style="width: 40%; height: auto; opacity: 80%">
    <h3>Resim yükleyin ve bir soru sorun!</h3>
    <p>Örnek resim ve soruları kullanabilirsiniz.</p>
</div>
"""

DESCRIPTION = f"""
### 875M parametreli küçük ama çok hızlı bir Türkçe Görsel Dil Modeli 🇹🇷🌟⚡️⚡️🇹🇷

Yüklediğiniz resimleri açıklatabilir ve onlarla ilgili ucu açık sorular sorabilirsiniz 🖼️🤖

Detaylar için [ucsahin/TraVisionLM-base](https://huggingface.co/ucsahin/TraVisionLM-base) kontrol etmeyi unutmayın!
"""

with gr.Accordion("Generation parameters", open=False) as parameter_accordion:
    max_tokens_item = gr.Slider(64, 1024, value=512, step=64, label="Max tokens")
    temperature_item = gr.Slider(0.1, 2, value=0.6, step=0.1, label="Temperature")
    top_p_item = gr.Slider(0, 1.0, value=0.9, step=0.05, label="Top_p")
    top_k_item = gr.Slider(0, 100, value=50, label="Top_k")
    repeat_penalty_item = gr.Slider(0, 2, value=1.2, label="Repeat penalty")

demo = gr.ChatInterface(
    title="TraVisionLM - Demo",
    description=DESCRIPTION,
    fn=bot_streaming,
    chatbot=gr.Chatbot(placeholder=PLACEHOLDER, scale=1),   
    examples=[
        [{"text": "Detaylı açıkla", "files":["./family.jpg"]}],
        [{"text": "Görüntüde uçaklar ne yapıyor?", "files":["./plane.jpg"]}],
        [{"text": "Kısaca açıkla", "files":["./dog.jpg"]}],
        [{"text": "Tren istasyonu kalabalık mı yoksa boş mu?", "files":["./train.jpg"]}],
        [{"text": "Resimdeki araba hangi renk?", "files":["./car.jpg"]}],
        [{"text": "Görüntünün odak noktası nedir?", "files":["./mandog.jpg"]}]
    ], 
    additional_inputs=[max_tokens_item, temperature_item, top_p_item, top_k_item, repeat_penalty_item],
    additional_inputs_accordion=parameter_accordion, 
    stop_btn="Stop Generation", 
    multimodal=True
)

demo.launch(debug=True, max_file_size="5mb")