File size: 2,890 Bytes
b8c24aa
3a82207
63b82b4
 
 
 
 
 
c8fdb3b
3a82207
4e81072
7dc3087
08c1bd3
19af97e
7dc3087
 
8ea3940
7dc3087
 
63b82b4
b9b37c9
955e4ad
0844d7e
 
 
 
7dc3087
64d8a64
63b82b4
64d8a64
 
63b82b4
64d8a64
63b82b4
fccbbf3
63b82b4
 
08c1bd3
19af97e
ea9c0d3
3a82207
 
 
 
 
 
 
 
 
63b82b4
 
3a82207
 
 
63b82b4
3a82207
63b82b4
0844d7e
3a82207
ea9c0d3
 
 
 
3a82207
 
 
 
 
 
 
 
7dc3087
0f1f78e
3a82207
63b82b4
 
 
 
e2534da
63b82b4
 
 
 
 
 
 
2cdab2a
63b82b4
 
 
 
ea9c0d3
63b82b4
 
 
 
9a34670
63b82b4
19af97e
63b82b4
3a82207
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import gradio as gr
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TextIteratorStreamer,
    BitsAndBytesConfig,
)
import os
from threading import Thread
import spaces
import time

#token = os.environ["HF_TOKEN"]

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16
)

model = AutoModelForCausalLM.from_pretrained(
    "shisa-ai/shisa-v1-qwen2-7b", quantization_config=quantization_config)
tok = AutoTokenizer.from_pretrained("shisa-ai/shisa-v1-qwen2-7b")
#terminators = [
#    tok.eos_token_id,
#    tok.convert_tokens_to_ids("<|eot_id|>")
#]

if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"Using GPU: {torch.cuda.get_device_name(device)}")
else:
    device = torch.device("cpu")
    print("Using CPU")

# model = model.to(device)
# Dispatch Errors


@spaces.GPU(duration=120)
def chat(message, history, temperature,do_sample, max_tokens):
    chat = []
    for item in history:
        chat.append({"role": "user", "content": item[0]})
        if item[1] is not None:
            chat.append({"role": "assistant", "content": item[1]})
    chat.append({"role": "user", "content": message})
    messages = tok.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
    model_inputs = tok([messages], return_tensors="pt").to(device)
    streamer = TextIteratorStreamer(
        tok, timeout=10.0, skip_prompt=True, skip_special_tokens=True
    )
    generate_kwargs = dict(
        model_inputs,
        streamer=streamer,
        max_new_tokens=max_tokens,
        do_sample=True,
        temperature=temperature,
        eos_token_id=tok.eos_token_id,  # terminatorsをeos_token_idに変更
    )
    
    if temperature == 0:
        generate_kwargs['do_sample'] = False
    
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    partial_text = ""
    for new_text in streamer:
        partial_text += new_text
        yield partial_text

    tokens = len(tok.tokenize(partial_text))
    yield partial_text 


demo = gr.ChatInterface(
    fn=chat,
    examples=[["Write me a poem about Machine Learning."]],
    # multimodal=False,
    additional_inputs_accordion=gr.Accordion(
        label="⚙️ Parameters", open=False, render=False
    ),
    additional_inputs=[
        gr.Slider(
            minimum=0, maximum=1, step=0.1, value=0.9, label="Temperature", render=False
        ),
        gr.Checkbox(label="Sampling",value=True),
        gr.Slider(
            minimum=128,
            maximum=4096,
            step=1,
            value=512,
            label="Max new tokens",
            render=False,
        ),
    ],
    stop_btn="Stop Generation",
    title="Chat With LLMs",
    description="Now Running [shisa-ai/shisa-v1-qwen2-7b](https://huggingface.co/shisa-ai/shisa-v1-qwen2-7b) in 4bit"
)
demo.launch()