Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,549 Bytes
bd0332f 1d72a65 933ec2b a475ce0 09c998a bd0332f 9ce660a bd0332f 09c998a d9b1afb 933ec2b 4becd74 933ec2b a475ce0 1d72a65 933ec2b bb98ae2 1d72a65 933ec2b 1d72a65 bd0332f d9b1afb 933ec2b 9b150da 933ec2b d9b1afb 933ec2b d9b1afb b94cdc8 d9b1afb b94cdc8 bd0332f d9b1afb 933ec2b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import torch
import spaces
from threading import Thread
from typing import Iterator
model_id = "mistralai/Mistral-Nemo-Instruct-2407"
MAX_INPUT_TOKEN_LENGTH = 4096
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
load_in_8bit=True
)
@spaces.GPU
def generate(
message: str,
chat_history: list[tuple[str, str]],
max_new_tokens: int = 1024,
temperature: float = 0.6,
top_p: float = 0.9
) -> Iterator[str]:
conversation = []
for user, assistant in chat_history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
temperature=temperature,
num_beams=1
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
# Set up Gradio interface
iface = gr.ChatInterface(
generate,
chatbot=gr.Chatbot(height=600),
textbox=gr.Textbox(placeholder="Enter your message here...", container=False, scale=7),
title="Chat with Mistral Next v1.1",
description="This is a chat interface for the Mistral Next v1.1 Chat 4B model. Ask questions and get answers!",
theme="soft",
retry_btn="Retry",
undo_btn="Undo Last",
clear_btn="Clear",
additional_inputs=[
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Maximum number of new tokens"),
gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"),
],
)
# Launch the interface
iface.launch() |