Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,867 Bytes
cc5b602 6f619d7 d381360 6386510 51a7d9e 3eed0af 6386510 970d940 51a7d9e e6367a7 423ddc8 51a7d9e 6386510 bd34f0b 423ddc8 bd34f0b 51a7d9e 970d940 423ddc8 970d940 423ddc8 970d940 423ddc8 970d940 423ddc8 3eed0af d381360 4ed884e 1d4c579 4ed884e e59867b 423ddc8 3eed0af 423ddc8 e59867b 423ddc8 970d940 3eed0af 970d940 8c5184e 970d940 51a7d9e 970d940 51a7d9e 8c5184e 51a7d9e 1d4c579 51a7d9e 4ed884e 51a7d9e b64165b 51a7d9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
import os
import time
import spaces
import torch
import gradio as gr
from threading import Thread
from huggingface_hub import snapshot_download
from pathlib import Path
from mistral_inference.transformer import Transformer
from mistral_inference.generate import generate
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
from mistral_common.protocol.instruct.messages import AssistantMessage, UserMessage
from mistral_common.protocol.instruct.request import ChatCompletionRequest
HF_TOKEN = os.environ.get("HF_TOKEN", None)
TITLE = "<h1><center>Mistral-lab</center></h1>"
PLACEHOLDER = """
<center>
<p>Chat with Mistral AI LLM.</p>
</center>
"""
CSS = """
.duplicate-button {
margin: auto !important;
color: white !important;
background: black !important;
border-radius: 100vh !important;
}
h3 {
text-align: center;
}
"""
# download model
mistral_models_path = Path.home().joinpath('mistral_models', '8B-Instruct')
mistral_models_path.mkdir(parents=True, exist_ok=True)
snapshot_download(repo_id="mistralai/Ministral-8B-Instruct-2410", allow_patterns=["params.json", "consolidated.safetensors", "tekken.json"], local_dir=mistral_models_path)
# tokenizer
device = "cuda" if torch.cuda.is_available() else "cpu" # for GPU usage or "cpu" for CPU usage
tokenizer = MistralTokenizer.from_file(f"{mistral_models_path}/tekken.json")
model = Transformer.from_folder(
mistral_models_path,
device=device,
dtype=torch.bfloat16)
@spaces.GPU()
def stream_chat(
message: str,
history: list,
temperature: float = 0.3,
max_new_tokens: int = 1024,
):
print(f'message: {message}')
print(f'history: {history}')
conversation = []
for prompt, answer in history:
conversation.append(UserMessage(content=prompt))
conversation.append(AssistantMessage(content=answer))
conversation.append(UserMessage(content=message))
completion_request = ChatCompletionRequest(messages=conversation)
tokens = tokenizer.encode_chat_completion(completion_request).tokens
out_tokens, _ = generate(
[tokens],
model,
max_tokens=max_new_tokens,
temperature=temperature,
eos_id=tokenizer.instruct_tokenizer.tokenizer.eos_id)
result = tokenizer.instruct_tokenizer.tokenizer.decode(out_tokens[0])
for i in range(len(result)):
time.sleep(0.05)
yield result[: i + 1]
chatbot = gr.Chatbot(
height=600,
placeholder=PLACEHOLDER
)
with gr.Blocks(theme="ocean", css=CSS) as demo:
gr.HTML(TITLE)
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
gr.ChatInterface(
fn=stream_chat,
chatbot=chatbot,
examples=[
{"text": "Help me study vocabulary: write a sentence for me to fill in the blank, and I'll try to pick the correct option."},
{"text": "What are 5 creative things I could do with my kids' art? I don't want to throw them away, but it's also so much clutter."},
{"text": "Tell me a random fun fact about the Roman Empire."},
{"text": "Show me a code snippet of a website's sticky header in CSS and JavaScript."},
],
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
additional_inputs=[
gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.3,
label="Temperature",
render=False,
),
gr.Slider(
minimum=128,
maximum=8192,
step=1,
value=1024,
label="Max new tokens",
render=False,
),
],
)
if __name__ == "__main__":
demo.launch()
|