Tulu / app.py
Tonic's picture
Update app.py
665957e
raw history blame
No virus
4.78 kB
import os
import math
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import gradio as gr
import sentencepiece
from tokenization_xgen import XgenTokenizer
title = "Welcome to 🙋🏻‍♂️Tonic's🌷Tulu Chat!"
description = "[allenai/tulu-2-dpo-7b](https://huggingface.co/allenai/tulu-2-dpo-7b) and larger Tulu-2 models are Instruct Llama Finetunes using the [mistralai/Mistral-7B](https://huggingface.co/mistralai/Mistral-7B-v0.1) recipe. You can use [allenai/tulu-2-13b](https://huggingface.co/allenai/tulu-2-13b) here via API using Gradio by scrolling down and clicking Use 'Via API' or privately by [cloning this space on huggingface](https://huggingface.co/spaces/Tonic1/TuluDemo?duplicate=true) See also the large model here : [allenai/tulu-2-dpo-70b](https://huggingface.co/allenai/tulu-2-dpo-70b) . [Join my active builders' server on discord](https://discord.gg/VqTxc76K3u). Let's build together!."
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:50'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model_name = "allenai/tulu-2-dpo-13b"
tokenizer = AutoTokenizer.from_pretrained("allenai/tulu-2-dpo-13b")
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
class TuluChatBot:
def __init__(self, model, tokenizer, system_message="You are 🌷Tulu, an AI language model created by Tonic-AI. You are a cautious assistant. You carefully follow instructions. You are helpful and harmless and you follow ethical guidelines and promote positive behavior."):
self.model = model
self.tokenizer = tokenizer
self.system_message = system_message
def set_system_message(self, new_system_message):
self.system_message = new_system_message
def format_prompt(self, user_message):
prompt = f"<|assistant|>\n {self.system_message}\n\n <|user|>{user_message}\n\n<|assistant|>\n"
return prompt
def predict(self, user_message, temperature, max_new_tokens, top_p, repetition_penalty, do_sample):
prompt = self.format_prompt(user_message)
inputs = self.tokenizer(prompt, return_tensors='pt', add_special_tokens=False)
input_ids = inputs["input_ids"].to(self.model.device)
attention_mask = inputs["attention_mask"].to(self.model.device)
output_ids = self.model.generate(
input_ids,
attention_mask=attention_mask,
max_length=input_ids.shape[1] + max_new_tokens,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
do_sample=do_sample
)
response = self.tokenizer.decode(output_ids[0], skip_special_tokens=True)
return response
def gradio_predict(user_message, system_message, max_new_tokens, temperature, top_p, repetition_penalty, do_sample):
Tulu_bot.set_system_message(system_message)
if not do_sample:
max_length = 1269
temperature = 1.2 # Default value
top_p = 0.9 # Default value
repetition_penalty = 0.9 # Default value
response = Tulu_bot.predict(user_message, temperature, max_new_tokens, top_p, repetition_penalty, do_sample)
return response
Tulu_bot = TuluChatBot(model, tokenizer)
with gr.Blocks() as demo:
theme="ParityError/Anime"
with gr.Row():
user_message = gr.Textbox(label="Your Message", lines=3)
system_message = gr.Textbox(label="Introduce a Character Here or Set a Scene (system prompt)", lines=2)
with gr.Row():
do_sample = gr.Checkbox(label="Advanced", value=False)
# Advanced settings in an Accordion
with gr.Accordion("Advanced Settings", open=lambda do_sample: do_sample):
with gr.Row():
max_new_tokens = gr.Slider(label="Max new tokens", value=1269, minimum=550, maximum=3200, step=1)
temperature = gr.Slider(label="Temperature", value=1.2, minimum=0.05, maximum=4.0, step=0.05)
top_p = gr.Slider(label="Top-p (nucleus sampling)", value=0.90, minimum=0.01, maximum=0.99, step=0.05)
repetition_penalty = gr.Slider(label="Repetition penalty", value=1.9, minimum=1.0, maximum=2.0, step=0.05)
submit_button = gr.Button("Submit")
output_text = gr.Textbox()
def process(user_message, system_message, max_new_tokens, temperature, top_p, repetition_penalty, do_sample):
return gradio_predict(user_message, system_message, max_new_tokens, temperature, top_p, repetition_penalty, do_sample)
submit_button.click(
process,
inputs=[user_message, system_message, max_new_tokens, temperature, top_p, repetition_penalty, do_sample],
outputs=output_text
)
demo.launch()