File size: 4,347 Bytes
854044a
 
 
 
 
 
 
552b26e
854044a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
552b26e
854044a
 
 
 
 
 
 
 
552b26e
854044a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f69474
 
 
fe17515
 
 
 
 
6f69474
 
fe17515
 
 
 
 
6f69474
fe17515
 
 
854044a
 
 
552b26e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
854044a
 
552b26e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from huggingface_hub import InferenceClient
import gradio as gr

client = InferenceClient(
    "mistralai/Mistral-7B-Instruct-v0.1"
)


def format_prompt(message, history):
  prompt = "<s>"
  for user_prompt, bot_response in history:
    prompt += f"[INST] {user_prompt} [/INST]"
    prompt += f" {bot_response}</s> "
  prompt += f"[INST] {message} [/INST]"
  return prompt

def generate(
    prompt, history, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
):
    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)

    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=42,
    )

    formatted_prompt = format_prompt(prompt, history)

    stream = client.text_generation(formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
    output = ""

    for response in stream:
        output += response.token.text
        yield output
    return output


additional_inputs=[
    gr.Slider(
        label="Temperature",
        value=0.9,
        minimum=0.0,
        maximum=1.0,
        step=0.05,
        interactive=True,
        info="Higher values produce more diverse outputs",
    ),
    gr.Slider(
        label="Max new tokens",
        value=256,
        minimum=0,
        maximum=1048,
        step=64,
        interactive=True,
        info="The maximum numbers of new tokens",
    ),
    gr.Slider(
        label="Top-p (nucleus sampling)",
        value=0.90,
        minimum=0.0,
        maximum=1,
        step=0.05,
        interactive=True,
        info="Higher values sample more low-probability tokens",
    ),
    gr.Slider(
        label="Repetition penalty",
        value=1.2,
        minimum=1.0,
        maximum=2.0,
        step=0.05,
        interactive=True,
        info="Penalize repeated tokens",
    )
]

css = """
  #mkd {
    height: 500px; 
    overflow: auto; 
    border: 1px solid #ccc; 
  }
"""

with gr.Blocks(css=css) as demo:
    gr.HTML("""πŸ€– Mistral 7B Instruct πŸ€–
        In this demo, you can chat with <a href='https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.1'>Mistral-7B-Instruct</a> model. πŸ’¬
        πŸ›  Model Features πŸ› 
        <ul>
          <li>πŸͺŸ Sliding Window Attention with 128K tokens span</li>
          <li>πŸš€ GQA for faster inference</li>
          <li>πŸ“ Byte-fallback BPE tokenizer</li>
        </ul>
        πŸ“œ License πŸ“œ  Released under Apache 2.0 License
        πŸ“¦ Usage πŸ“¦
        <ul>
          <li>πŸ“š Available on Huggingface Hub</li>
          <li>🐍 Python code snippets for easy setup</li>
          <li>πŸ“ˆ Expected speedups with Flash Attention 2</li>
        </ul>
        Learn more about the model <a href='https://huggingface.co/docs/transformers/main/model_doc/mistral'>here</a>. πŸ“š
    """)

    
    gr.ChatInterface(
        generate,
        additional_inputs=additional_inputs,
        examples=[
        ["Create a ten-point markdown outline with emojis about: Decreased Ξ±-ketoglutarate dehydrogenase activity in astrocytes"],
        ["Create a ten-point markdown outline with emojis about: Lewy body dementia"],
        ["Create a ten-point markdown outline with emojis about: Delusional disorder"],
        ["Create a ten-point markdown outline with emojis about: Galantamine"],
        ["Create a ten-point markdown outline with emojis about: Neural crest"],
        ["Create a ten-point markdown outline with emojis about: Progressive multifocal encephalopathy (PML)"],
        ["Create a ten-point markdown outline with emojis about: CT head"],
        ["Create a ten-point markdown outline with emojis about: Ξ²-Galactocerebrosidase"],
        ["Create a ten-point markdown outline with emojis about: Dopamine"],
        ["Create a ten-point markdown outline with emojis about: G protein-coupled receptors"],
        ["Create a ten-point markdown outline with emojis about: CT scan of the head without contrast"],
        ["Create a ten-point markdown outline with emojis about: Pyogenic brain abscess"],
        ["Create a ten-point markdown outline with emojis about: Pneumocystitis jiroveci"]
        ]
    )

demo.queue().launch(debug=True)