File size: 5,516 Bytes
260fe59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import os
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

token = os.environ["HUGGINGFACEHUB_API_TOKEN"]

model_id = 'Deci/DeciLM-6b-instruct'

SYSTEM_PROMPT_TEMPLATE = """Below is an instruction that describes a task. Write a response that appropriately completes the request.

### Instruction:

{instruction}

### Response:
"""

DESCRIPTION = """
# <p style="text-align: center; color: #292b47;"> 🤖 <span style='color: #3264ff;'>DeciLM-6B-Instruct:</span> A Fast Instruction-Tuned Model💨 </p>
<span style='color: #292b47;'>Welcome to <a href="https://huggingface.co/Deci/DeciLM-6b-instruct" style="color: #3264ff;">DeciLM-6B-Instruct</a>! DeciLM-6B-Instruct is a 6B parameter instruction-tuned language model and released under the Llama license. It's an instruction-tuned model, not a chat-tuned model;  you should prompt the model with an instruction that describes a task, and the model will respond appropriately to complete the task.</span>
<p><span style='color: #292b47;'>Learn more about the base model <a href="https://huggingface.co/Deci/DeciLM-6b" style="color: #3264ff;">DeciLM-6B.</a></span></p>
"""

# LICENSE = """
# <p/>

# ---
# As a derivate work of [Llama-2-7b-chat](https://huggingface.co/meta-llama/Llama-2-7b-chat) by Meta,
# this demo is governed by the original [license](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/LICENSE.txt) and [acceptable use policy](https://huggingface.co/spaces/huggingface-projects/llama-2-7b-chat/blob/main/USE_POLICY.md).
# """

if not torch.cuda.is_available():
    DESCRIPTION += 'You need a GPU for this example. Try using colab: https://bit.ly/decilm-instruct-nb'

if torch.cuda.is_available():
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        torch_dtype=torch.float16,
        device_map='auto',
        trust_remote_code=True, 
        use_auth_token=token
    )
else:
    model = None

tokenizer = AutoTokenizer.from_pretrained(model_id, use_auth_token=token)
tokenizer.pad_token = tokenizer.eos_token

# Function to construct the prompt using the new system prompt template
def get_prompt_with_template(message: str) -> str:
    return SYSTEM_PROMPT_TEMPLATE.format(instruction=message)

# Function to generate the model's response
def generate_model_response(message: str) -> str:
    prompt = get_prompt_with_template(message)
    inputs = tokenizer(prompt, return_tensors='pt')
    if torch.cuda.is_available():
        inputs = inputs.to('cuda')
    # Include **generate_kwargs to include the user-defined options
    output = model.generate(**inputs, 
                            max_new_tokens=3000, 
                            num_beams=5,
                            no_repeat_ngram_size=4,
                            early_stopping=True,
                            do_sample=True
                            ) 
    return tokenizer.decode(output[0], skip_special_tokens=True)

# Function to extract the content after "### Response:"
def extract_response_content(full_response: str, ) -> str:
    response_start_index = full_response.find("### Response:")
    if response_start_index != -1:
        return full_response[response_start_index + len("### Response:"):].strip()
    else:
        return full_response

# The main function that uses the dynamic generate_kwargs
def get_response_with_template(message: str) -> str:
    full_response = generate_model_response(message)
    return extract_response_content(full_response)

with gr.Blocks(css="/content/style.css") as demo:
    gr.Markdown(DESCRIPTION)
    gr.DuplicateButton(value='Duplicate Space for private use',
                       elem_id='duplicate-button')
    with gr.Group():
        chatbot = gr.Textbox(label='DeciLM-6B-Instruct Output:')
        with gr.Row():
            textbox = gr.Textbox(
                container=False,
                show_label=False,
                placeholder='Type an instruction...',
                scale=10,
                elem_id="textbox"
            )
            submit_button = gr.Button(
                '💬 Submit',
                variant='primary',
                scale=1,
                min_width=0,
                elem_id="submit_button"
            )

            # Clear button to clear the chat history
            clear_button = gr.Button(
                '🗑️ Clear',
                variant='secondary',
            )

    clear_button.click(
        fn=lambda: ('',''),
        outputs=[textbox, chatbot],
        queue=False,
        api_name=False,
    )

    submit_button.click(
        fn=get_response_with_template,
        inputs=textbox,
        outputs= chatbot,
        queue=False,
        api_name=False,
    )

    gr.Examples(
        examples=[
            'Write detailed instructions for making chocolate chip pancakes.',
            'Write a 250-word article about your love of pancakes.',
            'Explain the plot of Back to the Future in three sentences.',
            'How do I make a trap beat?',
            'A step-by-step guide to learning Python in one month.',
        ],
        inputs=textbox,
        outputs=chatbot,
        fn=get_response_with_template,
        cache_examples=True,
        elem_id="examples"
    )


    gr.HTML(label="Keep in touch", value="<img src='./content/deci-coder-banner.png' alt='Keep in touch' style='display: block; color: #292b47; margin: auto; max-width: 800px;'>")


demo.launch(share=True, debug=True)