harpreetsahota's picture
Update app.py
0dd9c69
raw
history blame
5.21 kB
import os
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
token = os.environ["HUGGINGFACEHUB_API_TOKEN"]
model_id = 'Deci/DeciLM-6b-instruct'
SYSTEM_PROMPT_TEMPLATE = """Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
{instruction}
### Response:
"""
DESCRIPTION = """
# <p style="text-align: center; color: #292b47;"> 🤖 <span style='color: #3264ff;'>DeciLM-6B-Instruct:</span> A Fast Instruction-Tuned Model💨 </p>
<span style='color: #292b47;'>Welcome to <a href="https://huggingface.co/Deci/DeciLM-6b-instruct" style="color: #3264ff;">DeciLM-6B-Instruct</a>! DeciLM-6B-Instruct is a 6B parameter instruction-tuned language model and released under the Llama license. It's an instruction-tuned model, not a chat-tuned model; you should prompt the model with an instruction that describes a task, and the model will respond appropriately to complete the task.</span>
<p><span style='color: #292b47;'>Learn more about the base model <a href="https://deci.ai/blog/decilm-15-times-faster-than-llama2-nas-generated-llm-with-variable-gqa/" style="color: #3264ff;">DeciLM-6B.</a></span></p>
"""
if not torch.cuda.is_available():
DESCRIPTION += 'You need a GPU for this example. Try using colab: https://bit.ly/decilm-instruct-nb'
if torch.cuda.is_available():
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map='auto',
trust_remote_code=True,
use_auth_token=token
)
else:
model = None
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
tokenizer.pad_token = tokenizer.eos_token
# Function to construct the prompt using the new system prompt template
def get_prompt_with_template(message: str) -> str:
return SYSTEM_PROMPT_TEMPLATE.format(instruction=message)
# Function to generate the model's response
def generate_model_response(message: str) -> str:
prompt = get_prompt_with_template(message)
inputs = tokenizer(prompt, return_tensors='pt')
if torch.cuda.is_available():
inputs = inputs.to('cuda')
# Include **generate_kwargs to include the user-defined options
output = model.generate(**inputs,
max_new_tokens=3000,
num_beams=5,
no_repeat_ngram_size=4,
early_stopping=True,
do_sample=True
)
return tokenizer.decode(output[0], skip_special_tokens=True)
# Function to extract the content after "### Response:"
def extract_response_content(full_response: str, ) -> str:
response_start_index = full_response.find("### Response:")
if response_start_index != -1:
return full_response[response_start_index + len("### Response:"):].strip()
else:
return full_response
# The main function that uses the dynamic generate_kwargs
def get_response_with_template(message: str) -> str:
full_response = generate_model_response(message)
return extract_response_content(full_response)
with gr.Blocks(css="/content/style.css") as demo:
gr.Markdown(DESCRIPTION)
gr.DuplicateButton(value='Duplicate Space for private use',
elem_id='duplicate-button')
with gr.Group():
chatbot = gr.Textbox(label='DeciLM-6B-Instruct Output:')
with gr.Row():
textbox = gr.Textbox(
container=False,
show_label=False,
placeholder='Type an instruction...',
scale=10,
elem_id="textbox"
)
submit_button = gr.Button(
'💬 Submit',
variant='primary',
scale=1,
min_width=0,
elem_id="submit_button"
)
# Clear button to clear the chat history
clear_button = gr.Button(
'🗑️ Clear',
variant='secondary',
)
clear_button.click(
fn=lambda: ('',''),
outputs=[textbox, chatbot],
queue=False,
api_name=False,
)
submit_button.click(
fn=get_response_with_template,
inputs=textbox,
outputs= chatbot,
queue=False,
api_name=False,
)
gr.Examples(
examples=[
'Write detailed instructions for making chocolate chip pancakes.',
'Write a 250-word article about your love of pancakes.',
'Explain the plot of Back to the Future in three sentences.',
'How do I make a trap beat?',
'A step-by-step guide to learning Python in one month.',
],
inputs=textbox,
outputs=chatbot,
fn=get_response_with_template,
cache_examples=True,
elem_id="examples"
)
gr.HTML(label="Keep in touch", value="<img src='https://huggingface.co/spaces/Deci/DeciLM-6b-instruct/resolve/main/deci-coder-banner.png' alt='Keep in touch' style='display: block; color: #292b47; margin: auto; max-width: 800px;'>")
demo.launch()