|
''' |
|
import os |
|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import torch |
|
|
|
model_name_2_7B_instruct = "Zyphra/Zamba2-2.7B-instruct" |
|
model_name_7B_instruct = "Zyphra/Zamba2-7B-instruct" |
|
max_context_length = 4096 |
|
|
|
tokenizer_2_7B_instruct = AutoTokenizer.from_pretrained(model_name_2_7B_instruct) |
|
model_2_7B_instruct = AutoModelForCausalLM.from_pretrained( |
|
model_name_2_7B_instruct, device_map="cuda", torch_dtype=torch.bfloat16 |
|
) |
|
|
|
tokenizer_7B_instruct = AutoTokenizer.from_pretrained(model_name_7B_instruct) |
|
model_7B_instruct = AutoModelForCausalLM.from_pretrained( |
|
model_name_7B_instruct, device_map="cuda", torch_dtype=torch.bfloat16 |
|
) |
|
|
|
def extract_assistant_response(generated_text): |
|
assistant_token = '<|im_start|> assistant' |
|
end_token = '<|im_end|>' |
|
start_idx = generated_text.rfind(assistant_token) |
|
if start_idx == -1: |
|
# Assistant token not found |
|
return generated_text.strip() |
|
start_idx += len(assistant_token) |
|
end_idx = generated_text.find(end_token, start_idx) |
|
if end_idx == -1: |
|
# End token not found, return from start_idx to end |
|
return generated_text[start_idx:].strip() |
|
else: |
|
return generated_text[start_idx:end_idx].strip() |
|
|
|
def generate_response(chat_history, max_new_tokens, model, tokenizer): |
|
sample = [] |
|
for turn in chat_history: |
|
if turn[0]: |
|
sample.append({'role': 'user', 'content': turn[0]}) |
|
if turn[1]: |
|
sample.append({'role': 'assistant', 'content': turn[1]}) |
|
chat_sample = tokenizer.apply_chat_template(sample, tokenize=False) |
|
input_ids = tokenizer(chat_sample, return_tensors='pt', add_special_tokens=False).to(model.device) |
|
|
|
max_new_tokens = int(max_new_tokens) |
|
max_input_length = max_context_length - max_new_tokens |
|
if input_ids['input_ids'].size(1) > max_input_length: |
|
input_ids['input_ids'] = input_ids['input_ids'][:, -max_input_length:] |
|
if 'attention_mask' in input_ids: |
|
input_ids['attention_mask'] = input_ids['attention_mask'][:, -max_input_length:] |
|
|
|
with torch.no_grad(): |
|
outputs = model.generate(**input_ids, max_new_tokens=int(max_new_tokens), return_dict_in_generate=False, output_scores=False, use_cache=True, num_beams=1, do_sample=False) |
|
""" |
|
outputs = model.generate( |
|
input_ids=input_ids, |
|
max_new_tokens=int(max_new_tokens), |
|
do_sample=True, |
|
use_cache=True, |
|
temperature=temperature, |
|
top_k=int(top_k), |
|
top_p=top_p, |
|
repetition_penalty=repetition_penalty, |
|
num_beams=int(num_beams), |
|
length_penalty=length_penalty, |
|
num_return_sequences=1 |
|
) |
|
""" |
|
generated_text = tokenizer.decode(outputs[0]) |
|
assistant_response = extract_assistant_response(generated_text) |
|
|
|
del input_ids |
|
del outputs |
|
torch.cuda.empty_cache() |
|
|
|
return assistant_response |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Zamba2 Model Selector") |
|
with gr.Tabs(): |
|
with gr.TabItem("7B Instruct Model"): |
|
gr.Markdown("### Zamba2-7B Instruct Model") |
|
with gr.Column(): |
|
chat_history_7B_instruct = gr.State([]) |
|
chatbot_7B_instruct = gr.Chatbot() |
|
message_7B_instruct = gr.Textbox(lines=2, placeholder="Enter your message...", label="Your Message") |
|
with gr.Accordion("Generation Parameters", open=False): |
|
max_new_tokens_7B_instruct = gr.Slider(50, 1000, step=50, value=500, label="Max New Tokens") |
|
# temperature_7B_instruct = gr.Slider(0.1, 1.5, step=0.1, value=0.2, label="Temperature") |
|
# top_k_7B_instruct = gr.Slider(1, 100, step=1, value=50, label="Top K") |
|
# top_p_7B_instruct = gr.Slider(0.1, 1.0, step=0.1, value=1.0, label="Top P") |
|
# repetition_penalty_7B_instruct = gr.Slider(1.0, 2.0, step=0.1, value=1.2, label="Repetition Penalty") |
|
# num_beams_7B_instruct = gr.Slider(1, 10, step=1, value=1, label="Number of Beams") |
|
# length_penalty_7B_instruct = gr.Slider(0.0, 2.0, step=0.1, value=1.0, label="Length Penalty") |
|
|
|
def user_message_7B_instruct(message, chat_history): |
|
chat_history = chat_history + [[message, None]] |
|
return gr.update(value=""), chat_history, chat_history |
|
|
|
def bot_response_7B_instruct(chat_history, max_new_tokens): |
|
response = generate_response(chat_history, max_new_tokens, model_7B_instruct, tokenizer_7B_instruct) |
|
chat_history[-1][1] = response |
|
return chat_history, chat_history |
|
|
|
send_button_7B_instruct = gr.Button("Send") |
|
send_button_7B_instruct.click( |
|
fn=user_message_7B_instruct, |
|
inputs=[message_7B_instruct, chat_history_7B_instruct], |
|
outputs=[message_7B_instruct, chat_history_7B_instruct, chatbot_7B_instruct] |
|
).then( |
|
fn=bot_response_7B_instruct, |
|
inputs=[ |
|
chat_history_7B_instruct, |
|
max_new_tokens_7B_instruct |
|
], |
|
outputs=[chat_history_7B_instruct, chatbot_7B_instruct] |
|
) |
|
with gr.TabItem("2.7B Instruct Model"): |
|
gr.Markdown("### Zamba2-2.7B Instruct Model") |
|
with gr.Column(): |
|
chat_history_2_7B_instruct = gr.State([]) |
|
chatbot_2_7B_instruct = gr.Chatbot() |
|
message_2_7B_instruct = gr.Textbox(lines=2, placeholder="Enter your message...", label="Your Message") |
|
with gr.Accordion("Generation Parameters", open=False): |
|
max_new_tokens_2_7B_instruct = gr.Slider(50, 1000, step=50, value=500, label="Max New Tokens") |
|
# temperature_2_7B_instruct = gr.Slider(0.1, 1.5, step=0.1, value=0.2, label="Temperature") |
|
# top_k_2_7B_instruct = gr.Slider(1, 100, step=1, value=50, label="Top K") |
|
# top_p_2_7B_instruct = gr.Slider(0.1, 1.0, step=0.1, value=1.0, label="Top P") |
|
# repetition_penalty_2_7B_instruct = gr.Slider(1.0, 2.0, step=0.1, value=1.2, label="Repetition Penalty") |
|
# num_beams_2_7B_instruct = gr.Slider(1, 10, step=1, value=1, label="Number of Beams") |
|
# length_penalty_2_7B_instruct = gr.Slider(0.0, 2.0, step=0.1, value=1.0, label="Length Penalty") |
|
|
|
def user_message_2_7B_instruct(message, chat_history): |
|
chat_history = chat_history + [[message, None]] |
|
return gr.update(value=""), chat_history, chat_history |
|
|
|
def bot_response_2_7B_instruct(chat_history, max_new_tokens): |
|
response = generate_response(chat_history, max_new_tokens, model_2_7B_instruct, tokenizer_2_7B_instruct) |
|
chat_history[-1][1] = response |
|
return chat_history, chat_history |
|
|
|
send_button_2_7B_instruct = gr.Button("Send") |
|
send_button_2_7B_instruct.click( |
|
fn=user_message_2_7B_instruct, |
|
inputs=[message_2_7B_instruct, chat_history_2_7B_instruct], |
|
outputs=[message_2_7B_instruct, chat_history_2_7B_instruct, chatbot_2_7B_instruct] |
|
).then( |
|
fn=bot_response_2_7B_instruct, |
|
inputs=[ |
|
chat_history_2_7B_instruct, |
|
max_new_tokens_2_7B_instruct |
|
], |
|
outputs=[chat_history_2_7B_instruct, chatbot_2_7B_instruct] |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.queue().launch(max_threads=1) |
|
''' |
|
|
|
import os |
|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer |
|
import torch |
|
import threading |
|
import re |
|
|
|
model_name_2_7B_instruct = "Zyphra/Zamba2-2.7B-instruct" |
|
model_name_7B_instruct = "Zyphra/Zamba2-7B-instruct" |
|
max_context_length = 4096 |
|
|
|
tokenizer_2_7B_instruct = AutoTokenizer.from_pretrained(model_name_2_7B_instruct) |
|
model_2_7B_instruct = AutoModelForCausalLM.from_pretrained( |
|
model_name_2_7B_instruct, device_map="cuda", torch_dtype=torch.bfloat16 |
|
) |
|
|
|
tokenizer_7B_instruct = AutoTokenizer.from_pretrained(model_name_7B_instruct) |
|
model_7B_instruct = AutoModelForCausalLM.from_pretrained( |
|
model_name_7B_instruct, device_map="cuda", torch_dtype=torch.bfloat16 |
|
) |
|
|
|
def generate_response(chat_history, max_new_tokens, model, tokenizer): |
|
sample = [] |
|
for turn in chat_history: |
|
if turn[0]: |
|
sample.append({'role': 'user', 'content': turn[0]}) |
|
if turn[1]: |
|
sample.append({'role': 'assistant', 'content': turn[1]}) |
|
chat_sample = tokenizer.apply_chat_template(sample, tokenize=False) |
|
input_ids = tokenizer(chat_sample, return_tensors='pt', add_special_tokens=False).to(model.device) |
|
|
|
max_new_tokens = int(max_new_tokens) |
|
max_input_length = max_context_length - max_new_tokens |
|
if input_ids['input_ids'].size(1) > max_input_length: |
|
input_ids['input_ids'] = input_ids['input_ids'][:, -max_input_length:] |
|
if 'attention_mask' in input_ids: |
|
input_ids['attention_mask'] = input_ids['attention_mask'][:, -max_input_length:] |
|
|
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
generation_kwargs = dict(**input_ids, max_new_tokens=int(max_new_tokens), streamer=streamer) |
|
|
|
thread = threading.Thread(target=model.generate, kwargs=generation_kwargs) |
|
thread.start() |
|
|
|
assistant_response = "" |
|
|
|
for new_text in streamer: |
|
new_text = re.sub(r'^\s*(?i:assistant)[:\s]*', '', new_text) |
|
assistant_response += new_text |
|
yield assistant_response |
|
|
|
thread.join() |
|
del input_ids |
|
torch.cuda.empty_cache() |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Zamba2 Model Selector") |
|
with gr.Tabs(): |
|
with gr.TabItem("7B Instruct Model"): |
|
gr.Markdown("### Zamba2-7B Instruct Model") |
|
with gr.Column(): |
|
chat_history_7B_instruct = gr.State([]) |
|
chatbot_7B_instruct = gr.Chatbot() |
|
message_7B_instruct = gr.Textbox(lines=2, placeholder="Enter your message...", label="Your Message") |
|
with gr.Accordion("Generation Parameters", open=False): |
|
max_new_tokens_7B_instruct = gr.Slider(50, 1000, step=50, value=500, label="Max New Tokens") |
|
|
|
def user_message_7B_instruct(message, chat_history): |
|
chat_history = chat_history + [[message, None]] |
|
return gr.update(value=""), chat_history, chat_history |
|
|
|
def bot_response_7B_instruct(chat_history, max_new_tokens): |
|
assistant_response_generator = generate_response(chat_history, max_new_tokens, model_7B_instruct, tokenizer_7B_instruct) |
|
for assistant_response in assistant_response_generator: |
|
chat_history[-1][1] = assistant_response |
|
yield chat_history |
|
|
|
send_button_7B_instruct = gr.Button("Send") |
|
send_button_7B_instruct.click( |
|
fn=user_message_7B_instruct, |
|
inputs=[message_7B_instruct, chat_history_7B_instruct], |
|
outputs=[message_7B_instruct, chat_history_7B_instruct, chatbot_7B_instruct] |
|
).then( |
|
fn=bot_response_7B_instruct, |
|
inputs=[chat_history_7B_instruct, max_new_tokens_7B_instruct], |
|
outputs=chatbot_7B_instruct, |
|
) |
|
|
|
with gr.TabItem("2.7B Instruct Model"): |
|
gr.Markdown("### Zamba2-2.7B Instruct Model") |
|
with gr.Column(): |
|
chat_history_2_7B_instruct = gr.State([]) |
|
chatbot_2_7B_instruct = gr.Chatbot() |
|
message_2_7B_instruct = gr.Textbox(lines=2, placeholder="Enter your message...", label="Your Message") |
|
with gr.Accordion("Generation Parameters", open=False): |
|
max_new_tokens_2_7B_instruct = gr.Slider(50, 1000, step=50, value=500, label="Max New Tokens") |
|
|
|
def user_message_2_7B_instruct(message, chat_history): |
|
chat_history = chat_history + [[message, None]] |
|
return gr.update(value=""), chat_history, chat_history |
|
|
|
def bot_response_2_7B_instruct(chat_history, max_new_tokens): |
|
assistant_response_generator = generate_response(chat_history, max_new_tokens, model_2_7B_instruct, tokenizer_2_7B_instruct) |
|
for assistant_response in assistant_response_generator: |
|
chat_history[-1][1] = assistant_response |
|
yield chat_history |
|
|
|
send_button_2_7B_instruct = gr.Button("Send") |
|
send_button_2_7B_instruct.click( |
|
fn=user_message_2_7B_instruct, |
|
inputs=[message_2_7B_instruct, chat_history_2_7B_instruct], |
|
outputs=[message_2_7B_instruct, chat_history_2_7B_instruct, chatbot_2_7B_instruct] |
|
).then( |
|
fn=bot_response_2_7B_instruct, |
|
inputs=[chat_history_2_7B_instruct, max_new_tokens_2_7B_instruct], |
|
outputs=chatbot_2_7B_instruct, |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.queue().launch(max_threads=1) |
|
|