File size: 3,610 Bytes
829da7c
 
31a1ff8
e2b5fc2
829da7c
 
54fe16b
829da7c
 
 
 
 
 
 
54fe16b
829da7c
 
 
 
 
54fe16b
31a1ff8
54fe16b
829da7c
 
54fe16b
829da7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53b40bf
54fe16b
 
829da7c
4856892
 
829da7c
 
54fe16b
 
4856892
 
54fe16b
829da7c
54fe16b
 
 
 
 
829da7c
 
54fe16b
 
994685c
2f52a8c
eb95198
658eb41
eb95198
 
994685c
2f52a8c
994685c
54fe16b
53b40bf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import argparse
import os
import spaces

import gradio as gr

import json
from threading import Thread
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

MAX_LENGTH = 4096
DEFAULT_MAX_NEW_TOKENS = 1024


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--base_model", type=str)  # model path
    parser.add_argument("--n_gpus", type=int, default=1)  # n_gpu
    return parser.parse_args()

@spaces.GPU()
def predict(message, history, system_prompt, temperature, max_tokens):
    global model, tokenizer, device
    messages = [{'role': 'system', 'content': system_prompt}]
    for human, assistant in history:
        messages.append({'role': 'user', 'content': human})
        messages.append({'role': 'assistant', 'content': assistant})
    messages.append({'role': 'user', 'content': message})
    problem = [tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)]
    stop_tokens = ["<|endoftext|>", "<|im_end|>"]
    streamer = TextIteratorStreamer(tokenizer, timeout=100.0, skip_prompt=True, skip_special_tokens=True)
    enc = tokenizer(problem, return_tensors="pt", padding=True, truncation=True)
    input_ids = enc.input_ids
    attention_mask = enc.attention_mask

    if input_ids.shape[1] > MAX_LENGTH:
        input_ids = input_ids[:, -MAX_LENGTH:]

    input_ids = input_ids.to(device)
    attention_mask = attention_mask.to(device)
    generate_kwargs = dict(
        {"input_ids": input_ids, "attention_mask": attention_mask},
        streamer=streamer,
        do_sample=True,
        top_p=0.95,
        temperature=temperature,
        max_new_tokens=DEFAULT_MAX_NEW_TOKENS,
        use_cache=True,
        eos_token_id=100278 # <|im_end|>
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()
    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)



if __name__ == "__main__":
    args = parse_args()
    tokenizer = AutoTokenizer.from_pretrained("stabilityai/stablelm-2-12b-chat", trust_remote_code=True)
    model = AutoModelForCausalLM.from_pretrained("stabilityai/stablelm-2-12b-chat", trust_remote_code=True, torch_dtype=torch.bfloat16)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    gr.ChatInterface(
        predict,
        title="StableLM 2 12B Chat - Demo",
        description="StableLM 2 12B Chat - StabilityAI",
        theme="soft",
        chatbot=gr.Chatbot(label="Chat History",),
        textbox=gr.Textbox(placeholder="input", container=False, scale=7),
        retry_btn=None,
        undo_btn="Delete Previous",
        clear_btn="Clear",
        additional_inputs=[
            gr.Textbox("You are a helpful assistant.", label="System Prompt"),
            gr.Slider(0, 1, 0.5, label="Temperature"),
            gr.Slider(100, 2048, 1024, label="Max Tokens"),
        ],
        examples=[
            ["What's been the role of music in human societies?"],
            ["Escribe un poema corto sobre la historia del Mediterráneo."],
            ["Scrivi un Haiku che celebri il gelato."],
            ["Schreibe ein Haiku über die Alpen."],
            ["Ecris une prose a propos de la mer du Nord."],
            ["Escreva um poema sobre a saudade."],
            ["Jane has 8 apples, out of which 2 are red and 3 are green. Assuming there are only red, green and white apples, how many of them are white? Solve this in Python."],
        ],
        additional_inputs_accordion_name="Parameters",
    ).queue().launch()