Spaces:
Sleeping
Sleeping
File size: 6,144 Bytes
5ac3442 fc12774 04bc708 5ac3442 2d06509 5ac3442 c1c914a 5ac3442 364167e 5ac3442 8c6742d 5ac3442 05a4e8a 5ac3442 05a4e8a 5ac3442 2d06509 5ac3442 2d06509 5ac3442 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 |
import html
import os
import time
import torch
import transformers
import gradio as gr
class FormComponent:
def get_expected_parent(self):
return gr.components.Form
class FormRow(FormComponent, gr.Row):
"""Same as gr.Row but fits inside gradio forms"""
def get_block_name(self):
return "row"
def wrap_gradio_gpu_call(func, extra_outputs=None):
def f(*args, **kwargs):
res = func(*args, **kwargs)
return res
return wrap_gradio_call(f, extra_outputs=extra_outputs, add_stats=True)
class Model:
name = None
model = None
tokenizer = None
available_models = ["0Tick/e621TagAutocomplete","0Tick/danbooruTagAutocomplete"]
current = Model()
job_count = 1
def device():
return torch.device("cpu")
def generate_batch(input_ids, min_length, max_length, num_beams, temperature, repetition_penalty, length_penalty, sampling_mode, top_k, top_p):
top_p = float(top_p) if sampling_mode == 'Top P' else None
top_k = int(top_k) if sampling_mode == 'Top K' else None
outputs = current.model.generate(
input_ids,
do_sample=True,
temperature=max(float(temperature), 1e-6),
repetition_penalty=repetition_penalty,
length_penalty=length_penalty,
top_p=top_p,
top_k=top_k,
num_beams=int(num_beams),
min_length=min_length,
max_length=max_length,
pad_token_id=current.tokenizer.pad_token_id or current.tokenizer.eos_token_id
)
texts = current.tokenizer.batch_decode(outputs, skip_special_tokens=True)
return texts
def model_selection_changed(model_name):
if model_name == "None":
current.tokenizer = None
current.model = None
current.name = None
devices.torch_gc()
def generate(id_task, model_name, batch_count, batch_size, text, *args):
job_count = batch_count
print(f"Model:{model_name},Count:{batch_count*batch_size},StartingText:{text}")
if current.name != model_name:
current.tokenizer = None
current.model = None
current.name = None
if model_name != 'None':
path = model_name
current.tokenizer = transformers.AutoTokenizer.from_pretrained(path)
current.model = transformers.AutoModelForCausalLM.from_pretrained(path)
current.name = model_name
assert current.model, 'No model available'
assert current.tokenizer, 'No tokenizer available'
current.model.to(device())
input_ids = current.tokenizer(text, return_tensors="pt").input_ids
if input_ids.shape[1] == 0:
input_ids = torch.asarray([[current.tokenizer.bos_token_id]], dtype=torch.long)
input_ids = input_ids.to(device())
input_ids = input_ids.repeat((batch_size, 1))
markup = '<table><tbody>'
index = 0
for i in range(batch_count):
texts = generate_batch(input_ids, *args)
for generated_text in texts:
index += 1
markup += f"""
<tr>
<td>
<div class="prompt gr-box gr-text-input">
<p id='promptgen_res_{index}'>{html.escape(generated_text)}</p>
</div>
</td>
</tr>
"""
markup += '</tbody></table>'
return markup, ''
with gr.Blocks(analytics_enabled=False) as space:
with gr.Row():
with gr.Column(scale=80):
prompt = gr.Textbox(label="Prompt", elem_id="promptgen_prompt", show_label=False, lines=2, placeholder="Beginning of the prompt").style(container=False)
with gr.Column(scale=10):
submit = gr.Button('Generate', elem_id="promptgen_generate", variant='primary')
with gr.Row(elem_id="promptgen_main"):
with gr.Column(variant="compact"):
selected_text = gr.TextArea(elem_id='promptgen_selected_text', visible=False)
with FormRow():
model_selection = gr.Dropdown(label="Model", elem_id="promptgen_model", value=available_models[0], choices=["None"] + available_models)
with FormRow():
sampling_mode = gr.Radio(label="Sampling mode", elem_id="promptgen_sampling_mode", value="Top K", choices=["Top K", "Top P"])
top_k = gr.Slider(label="Top K", elem_id="promptgen_top_k", value=12, minimum=1, maximum=50, step=1)
top_p = gr.Slider(label="Top P", elem_id="promptgen_top_p", value=0.15, minimum=0, maximum=1, step=0.001)
with gr.Row():
num_beams = gr.Slider(label="Number of beams", elem_id="promptgen_num_beams", value=1, minimum=1, maximum=8, step=1)
temperature = gr.Slider(label="Temperature", elem_id="promptgen_temperature", value=1, minimum=0, maximum=4, step=0.01)
repetition_penalty = gr.Slider(label="Repetition penalty", elem_id="promptgen_repetition_penalty", value=1, minimum=1, maximum=4, step=0.01)
with FormRow():
length_penalty = gr.Slider(label="Length preference", elem_id="promptgen_length_preference", value=1, minimum=-10, maximum=10, step=0.1)
min_length = gr.Slider(label="Min length", elem_id="promptgen_min_length", value=20, minimum=1, maximum=400, step=1)
max_length = gr.Slider(label="Max length", elem_id="promptgen_max_length", value=150, minimum=1, maximum=400, step=1)
with FormRow():
batch_count = gr.Slider(label="Batch count", elem_id="promptgen_batch_count", value=1, minimum=1, maximum=100, step=1)
batch_size = gr.Slider(label="Batch size", elem_id="promptgen_batch_size", value=10, minimum=1, maximum=100, step=1)
with gr.Column():
with gr.Group(elem_id="promptgen_results_column"):
res = gr.HTML()
res_info = gr.HTML()
submit.click(
fn=generate,
inputs=[model_selection, model_selection, batch_count, batch_size, prompt, min_length, max_length, num_beams, temperature, repetition_penalty, length_penalty, sampling_mode, top_k, top_p, ],
outputs=[res, res_info]
)
model_selection.change(
fn=model_selection_changed,
inputs=[model_selection],
outputs=[],
)
space.launch() |