Spaces:
Paused
Paused
File size: 5,516 Bytes
260fe59 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import os
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
token = os.environ["HUGGINGFACEHUB_API_TOKEN"]
model_id = 'Deci/DeciLM-6b-instruct'
SYSTEM_PROMPT_TEMPLATE = """Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
{instruction}
### Response:
"""
DESCRIPTION = """
# <p style="text-align: center; color: #292b47;"> 🤖 <span style='color: #3264ff;'>DeciLM-6B-Instruct:</span> A Fast Instruction-Tuned Model💨 </p>
<span style='color: #292b47;'>Welcome to <a href="https://huggingface.co/Deci/DeciLM-6b-instruct" style="color: #3264ff;">DeciLM-6B-Instruct</a>! DeciLM-6B-Instruct is a 6B parameter instruction-tuned language model and released under the Llama license. It's an instruction-tuned model, not a chat-tuned model; you should prompt the model with an instruction that describes a task, and the model will respond appropriately to complete the task.</span>
<p><span style='color: #292b47;'>Learn more about the base model <a href="https://huggingface.co/Deci/DeciLM-6b" style="color: #3264ff;">DeciLM-6B.</a></span></p>
"""
# LICENSE = """
# <p/>
# ---
# As a derivate work of [Llama-2-7b-chat](https://huggingface.co/meta-llama/Llama-2-7b-chat) by Meta,
# this demo is governed by the original [license](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/USE_POLICY.md).
# """
if not torch.cuda.is_available():
DESCRIPTION += 'You need a GPU for this example. Try using colab: https://bit.ly/decilm-instruct-nb'
if torch.cuda.is_available():
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map='auto',
trust_remote_code=True,
use_auth_token=token
)
else:
model = None
tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
tokenizer.pad_token = tokenizer.eos_token
# Function to construct the prompt using the new system prompt template
def get_prompt_with_template(message: str) -> str:
return SYSTEM_PROMPT_TEMPLATE.format(instruction=message)
# Function to generate the model's response
def generate_model_response(message: str) -> str:
prompt = get_prompt_with_template(message)
inputs = tokenizer(prompt, return_tensors='pt')
if torch.cuda.is_available():
inputs = inputs.to('cuda')
# Include **generate_kwargs to include the user-defined options
output = model.generate(**inputs,
max_new_tokens=3000,
num_beams=5,
no_repeat_ngram_size=4,
early_stopping=True,
do_sample=True
)
return tokenizer.decode(output[0], skip_special_tokens=True)
# Function to extract the content after "### Response:"
def extract_response_content(full_response: str, ) -> str:
response_start_index = full_response.find("### Response:")
if response_start_index != -1:
return full_response[response_start_index + len("### Response:"):].strip()
else:
return full_response
# The main function that uses the dynamic generate_kwargs
def get_response_with_template(message: str) -> str:
full_response = generate_model_response(message)
return extract_response_content(full_response)
with gr.Blocks(css="/content/style.css") as demo:
gr.Markdown(DESCRIPTION)
gr.DuplicateButton(value='Duplicate Space for private use',
elem_id='duplicate-button')
with gr.Group():
chatbot = gr.Textbox(label='DeciLM-6B-Instruct Output:')
with gr.Row():
textbox = gr.Textbox(
container=False,
show_label=False,
placeholder='Type an instruction...',
scale=10,
elem_id="textbox"
)
submit_button = gr.Button(
'💬 Submit',
variant='primary',
scale=1,
min_width=0,
elem_id="submit_button"
)
# Clear button to clear the chat history
clear_button = gr.Button(
'🗑️ Clear',
variant='secondary',
)
clear_button.click(
fn=lambda: ('',''),
outputs=[textbox, chatbot],
queue=False,
api_name=False,
)
submit_button.click(
fn=get_response_with_template,
inputs=textbox,
outputs= chatbot,
queue=False,
api_name=False,
)
gr.Examples(
examples=[
'Write detailed instructions for making chocolate chip pancakes.',
'Write a 250-word article about your love of pancakes.',
'Explain the plot of Back to the Future in three sentences.',
'How do I make a trap beat?',
'A step-by-step guide to learning Python in one month.',
],
inputs=textbox,
outputs=chatbot,
fn=get_response_with_template,
cache_examples=True,
elem_id="examples"
)
gr.HTML(label="Keep in touch", value="<img src='./content/deci-coder-banner.png' alt='Keep in touch' style='display: block; color: #292b47; margin: auto; max-width: 800px;'>")
demo.launch(share=True, debug=True) |