File size: 3,663 Bytes
fefc78b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
import logging

from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr

log_level = os.environ.get("LOG_LEVEL", "WARNING")
logging.basicConfig(encoding='utf-8', level=log_level)

logging.info("Loading Model")
tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("microsoft/phi-2", trust_remote_code=True)

def format_prompt(message, history):
    """Formats the prompt for the AI"""
    logging.info("Formatting Prompt")
    logging.debug("Input Message: %s", message)
    logging.debug("Input History: %s", history)

    prompt = f"Instruct: {message}\n"
    prompt += "Output: "
    return prompt


def generate(
    prompt, history, system_prompt, temperature=0.9, max_new_tokens=256, top_p=0.95, repetition_penalty=1.0,
):
    logging.info("Generating Response")
    logging.debug("Input Prompt: %s", prompt)
    logging.debug("Input History: %s", history)
    logging.debug("Input System Prompt: %s", system_prompt)
    logging.debug("Input Temperature: %s", temperature)
    logging.debug("Input Max New Tokens: %s", max_new_tokens)
    logging.debug("Input Top P: %s", top_p)
    logging.debug("Input Repetition Penalty: %s", repetition_penalty)

    logging.info("Converting Parameters to Correct Type")
    temperature = float(temperature)
    if temperature < 1e-2:
        temperature = 1e-2
    top_p = float(top_p)
    logging.debug("Temperature: %s", temperature)
    logging.debug("Top P: %s", top_p)

    logging.info("Creating Generate kwargs")
    generate_kwargs = dict(
        temperature=temperature,
        max_new_tokens=max_new_tokens,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        do_sample=True,
        seed=42,
    )
    logging.debug("Generate Args: %s", generate_kwargs)

    formatted_prompt = format_prompt(f"{system_prompt}, {prompt}", history)
    logging.debug("Prompt: %s", formatted_prompt)

    logging.info("Generating Text")
    stream = model.generate(tokenizer(prompt, return_tensors="pt"), **generate_kwargs)

    logging.info("Creating Output")
    output = ""
    for response in stream:
        output += response.token.text
        yield output

    logging.debug("Output: %s", output)
    return output


additional_inputs = [
    gr.Textbox(
        label="System Prompt",
        max_lines=1,
        interactive=True,
    ),
    gr.Slider(
        label="Temperature",
        value=0.9,
        minimum=0.0,
        maximum=1.0,
        step=0.05,
        interactive=True,
        info="Higher values produce more diverse outputs",
    ),
    gr.Slider(
        label="Max new tokens",
        value=256,
        minimum=0,
        maximum=1048,
        step=64,
        interactive=True,
        info="The maximum numbers of new tokens",
    ),
    gr.Slider(
        label="Top-p (nucleus sampling)",
        value=0.90,
        minimum=0.0,
        maximum=1,
        step=0.05,
        interactive=True,
        info="Higher values sample more low-probability tokens",
    ),
    gr.Slider(
        label="Repetition penalty",
        value=1.2,
        minimum=1.0,
        maximum=2.0,
        step=0.05,
        interactive=True,
        info="Penalize repeated tokens",
    )
]

examples = []

logging.info("Creating Chat Interface")
gr.ChatInterface(
    fn=generate,
    chatbot=gr.Chatbot(show_label=False, show_share_button=False,
                       show_copy_button=True, likeable=True, layout="panel"),
    additional_inputs=additional_inputs,
    title="Mixtral Instruct",
    examples=examples,
    concurrency_limit=20,
).launch(show_api=False)