File size: 3,438 Bytes
f816b98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400caac
f816b98
 
c96c1a0
f816b98
400caac
f816b98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e9f724a
f816b98
 
 
 
 
 
 
 
 
 
 
e9f724a
f816b98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
import torch
import spaces
import subprocess
import gradio as gr

from threading import Thread
from huggingface_hub import login
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer

login(os.environ.get("HF_TOKEN"))
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)

model_id = "microsoft/Phi-3-mini-128k-instruct"

tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
    attn_implementation="flash_attention_2"
)

@spaces.GPU()
def generate(
    message: str,
    chat_history: list[tuple[str, str]],
    system_prompt: str,
    max_new_tokens: int,
    temperature: float,
    top_p: float,
    top_k: int,
    repetition_penalty: int
    ):

    conversation = []
    if system_prompt:
        conversation.append({"role": "system", "content": system_prompt})

    for user, assistant in chat_history:
        conversation.append({"role": "user", "content": user})
        conversation.append({"role": "assistant", "content": assistant})

    conversation.append({"role": "user", "content": message})

    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    input_ids, attention_mask = tokenizer.apply_chat_template(
        conversation,
        add_generation_prompt=True,
        return_tensors="pt",
        return_dict=True
    ).to(model.device).values()

    generate_kwargs = dict(
        {"input_ids": input_ids, "attention_mask": attention_mask},
        streamer=streamer,
        do_sample=True,
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_k=top_k,
        repetition_penalty=repetition_penalty,
        top_p=top_p
    )

    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()
    outputs = []
    for new_token in streamer:
        outputs.append(new_token)
        yield "".join(outputs)


gr.ChatInterface(
    fn=generate,
    title="🚀 Phi-3 mini 128k instruct",
    description="",
    additional_inputs=[
        gr.Textbox(
            label="System prompt",
            lines=5,
            value="You are a helpful digital assistant."
        ),
        gr.Slider(
            label="Max new tokens",
            minimum=1,
            maximum=2048,
            step=1,
            value=1024,
        ),
        gr.Slider(
            label="Temperature",
            minimum=0.1,
            maximum=1.0,
            step=0.1,
            value=0.6,
        ),
        gr.Slider(
            label="Top-p (nucleus sampling)",
            minimum=0.05,
            maximum=1.0,
            step=0.05,
            value=0.9,
        ),
        gr.Slider(
            label="Top-k",
            minimum=1,
            maximum=1000,
            step=1,
            value=50,
        ),
        gr.Slider(
            label="Repetition penalty",
            minimum=1.0,
            maximum=2.0,
            step=0.05,
            value=1.2,
        ),
    ],
    stop_btn=None,
    examples=[
        ["Can you provide ways to eat combinations of bananas and dragonfruits?"],
        ["Write a story about a dragon fruit that flies into outer space!"],
        ["I am going to Bali, what should I see"],
    ],
).queue().launch()