Lunar-4B / Test.py
Sakalti's picture
Create Test.py
a6a997a verified
raw
history blame
2.32 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import spaces
model_name = "Sakalti/SakalFusion-7B-Alpha"
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto"
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
@spaces.gpu(duration=100)
def generate(prompt, history, top_p, top_k, max_new_tokens, repetition_penalty, temperature):
messages = [
{"role": "system", "content": "あγͺγŸγ―γƒ•γƒ¬γƒ³γƒ‰γƒͺγƒΌγͺγƒγƒ£γƒƒγƒˆγƒœγƒƒγƒˆγ§γ™γ€‚"},
{"role": "user", "content": prompt}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
generated_ids = model.generate(
**model_inputs,
max_new_tokens=max_new_tokens,
top_p=top_p,
top_k=top_k,
repetition_penalty=repetition_penalty,
temperature=temperature
)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
return response, history + [[prompt, response]]
with gr.Blocks() as demo:
chatbot = gr.Chatbot()
msg = gr.Textbox()
clear = gr.Button("Clear")
with gr.Row():
top_p = gr.Slider(0.0, 1.0, value=0.9, label="Top P")
top_k = gr.Slider(0, 100, value=50, label="Top K")
max_new_tokens = gr.Slider(1, 2048, value=864, label="Max New Tokens")
repetition_penalty = gr.Slider(1.0, 2.0, value=1.2, label="Repetition Penalty")
temperature = gr.Slider(0.1, 1.0, value=0.7, label="Temperature")
def respond(message, chat_history, top_p, top_k, max_new_tokens, repetition_penalty, temperature):
bot_message, chat_history = generate(message, chat_history, top_p, top_k, max_new_tokens, repetition_penalty, temperature)
return "", chat_history, chat_history
msg.submit(respond, [msg, chatbot, top_p, top_k, max_new_tokens, repetition_penalty, temperature], [msg, chatbot, chatbot])
clear.click(lambda: ([], []), None, [chatbot, msg])
demo.launch(share=True)