File size: 6,144 Bytes
5ac3442
 
 
 
 
 
 
fc12774
 
04bc708
 
 
5ac3442
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d06509
5ac3442
 
 
 
 
c1c914a
5ac3442
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364167e
5ac3442
 
 
 
 
 
8c6742d
5ac3442
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05a4e8a
5ac3442
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05a4e8a
5ac3442
 
 
 
2d06509
5ac3442
 
 
2d06509
5ac3442
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import html
import os
import time

import torch
import transformers

import gradio as gr

class FormComponent:
    def get_expected_parent(self):
        return gr.components.Form
class FormRow(FormComponent, gr.Row):
    """Same as gr.Row but fits inside gradio forms"""

    def get_block_name(self):
        return "row"

def wrap_gradio_gpu_call(func, extra_outputs=None):
    def f(*args, **kwargs):
        res = func(*args, **kwargs)
        return res
    return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True)


class Model:
    name = None
    model = None
    tokenizer = None

available_models = ["0Tick/e621TagAutocomplete","0Tick/danbooruTagAutocomplete"]
current = Model()
job_count = 1


def device():
    return torch.device("cpu")


def generate_batch(input_ids, min_length, max_length, num_beams, temperature, repetition_penalty, length_penalty, sampling_mode, top_k, top_p):
    top_p = float(top_p) if sampling_mode == 'Top P' else None
    top_k = int(top_k) if sampling_mode == 'Top K' else None

    outputs = current.model.generate(
        input_ids,
        do_sample=True,
        temperature=max(float(temperature), 1e-6),
        repetition_penalty=repetition_penalty,
        length_penalty=length_penalty,
        top_p=top_p,
        top_k=top_k,
        num_beams=int(num_beams),
        min_length=min_length,
        max_length=max_length,
        pad_token_id=current.tokenizer.pad_token_id or current.tokenizer.eos_token_id
    )
    texts = current.tokenizer.batch_decode(outputs, skip_special_tokens=True)
    return texts


def model_selection_changed(model_name):
    if model_name == "None":
        current.tokenizer = None
        current.model = None
        current.name = None

        devices.torch_gc()

def generate(id_task, model_name, batch_count, batch_size, text, *args):
    job_count = batch_count
    print(f"Model:{model_name},Count:{batch_count*batch_size},StartingText:{text}")
    if current.name != model_name:
        current.tokenizer = None
        current.model = None
        current.name = None

        if model_name != 'None':
            path = model_name
            current.tokenizer = transformers.AutoTokenizer.from_pretrained(path)
            current.model = transformers.AutoModelForCausalLM.from_pretrained(path)
            current.name = model_name

    assert current.model, 'No model available'
    assert current.tokenizer, 'No tokenizer available'

    current.model.to(device())

    input_ids = current.tokenizer(text, return_tensors="pt").input_ids
    if input_ids.shape[1] == 0:
        input_ids = torch.asarray([[current.tokenizer.bos_token_id]], dtype=torch.long)
    input_ids = input_ids.to(device())
    input_ids = input_ids.repeat((batch_size, 1))

    markup = '<table><tbody>'

    index = 0
    for i in range(batch_count):
        texts = generate_batch(input_ids, *args)
        for generated_text in texts:
            index += 1
            markup += f"""
<tr>
<td>
<div class="prompt gr-box gr-text-input">
    <p id='promptgen_res_{index}'>{html.escape(generated_text)}</p>
</div>
</td>
</tr>
"""

    markup += '</tbody></table>'

    return markup, ''


with gr.Blocks(analytics_enabled=False) as space:
    with gr.Row():
        with gr.Column(scale=80):
            prompt = gr.Textbox(label="Prompt", elem_id="promptgen_prompt", show_label=False, lines=2, placeholder="Beginning of the prompt").style(container=False)
        with gr.Column(scale=10):
            submit = gr.Button('Generate', elem_id="promptgen_generate", variant='primary')

    with gr.Row(elem_id="promptgen_main"):
        with gr.Column(variant="compact"):
            selected_text = gr.TextArea(elem_id='promptgen_selected_text', visible=False)

            with FormRow():
                model_selection = gr.Dropdown(label="Model", elem_id="promptgen_model", value=available_models[0], choices=["None"] + available_models)

            with FormRow():
                sampling_mode = gr.Radio(label="Sampling mode", elem_id="promptgen_sampling_mode", value="Top K", choices=["Top K", "Top P"])
                top_k = gr.Slider(label="Top K", elem_id="promptgen_top_k", value=12, minimum=1, maximum=50, step=1)
                top_p = gr.Slider(label="Top P", elem_id="promptgen_top_p", value=0.15, minimum=0, maximum=1, step=0.001)

            with gr.Row():
                num_beams = gr.Slider(label="Number of beams", elem_id="promptgen_num_beams", value=1, minimum=1, maximum=8, step=1)
                temperature = gr.Slider(label="Temperature", elem_id="promptgen_temperature", value=1, minimum=0, maximum=4, step=0.01)
                repetition_penalty = gr.Slider(label="Repetition penalty", elem_id="promptgen_repetition_penalty", value=1, minimum=1, maximum=4, step=0.01)

            with FormRow():
                length_penalty = gr.Slider(label="Length preference", elem_id="promptgen_length_preference", value=1, minimum=-10, maximum=10, step=0.1)
                min_length = gr.Slider(label="Min length", elem_id="promptgen_min_length", value=20, minimum=1, maximum=400, step=1)
                max_length = gr.Slider(label="Max length", elem_id="promptgen_max_length", value=150, minimum=1, maximum=400, step=1)

            with FormRow():
                batch_count = gr.Slider(label="Batch count", elem_id="promptgen_batch_count", value=1, minimum=1, maximum=100, step=1)
                batch_size = gr.Slider(label="Batch size", elem_id="promptgen_batch_size", value=10, minimum=1, maximum=100, step=1)

        with gr.Column():
            with gr.Group(elem_id="promptgen_results_column"):
                res = gr.HTML()
                res_info = gr.HTML()

    submit.click(
            fn=generate,
            inputs=[model_selection, model_selection, batch_count, batch_size, prompt, min_length, max_length, num_beams, temperature, repetition_penalty, length_penalty, sampling_mode, top_k, top_p, ],
            outputs=[res, res_info]
    )

    model_selection.change(
            fn=model_selection_changed,
            inputs=[model_selection],
            outputs=[],
    )


space.launch()