|
|
|
|
|
import argparse |
|
from collections import defaultdict |
|
import os |
|
import platform |
|
import re |
|
|
|
from project_settings import project_path |
|
|
|
os.environ["HUGGINGFACE_HUB_CACHE"] = (project_path / "cache/huggingface/hub").as_posix() |
|
|
|
import gradio as gr |
|
from threading import Thread |
|
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel |
|
from transformers.models.bert.tokenization_bert import BertTokenizer |
|
from transformers.generation.streamers import TextIteratorStreamer |
|
import torch |
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument("--max_new_tokens", default=512, type=int) |
|
parser.add_argument("--top_p", default=0.9, type=float) |
|
parser.add_argument("--temperature", default=0.35, type=float) |
|
parser.add_argument("--repetition_penalty", default=1.0, type=float) |
|
parser.add_argument('--device', default="cuda" if torch.cuda.is_available() else "cpu", type=str) |
|
|
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
description = """ |
|
## GPT2 Chat |
|
""" |
|
|
|
|
|
examples = [ |
|
|
|
] |
|
|
|
|
|
def repl(match): |
|
result = "{}{}".format(match.group(1), match.group(2)) |
|
return result |
|
|
|
|
|
def main(): |
|
args = get_args() |
|
|
|
if args.device == 'auto': |
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
else: |
|
device = args.device |
|
|
|
input_text_box = gr.Text(label="text") |
|
output_text_box = gr.Text(lines=4, label="generated_content") |
|
|
|
def fn_stream(text: str, |
|
max_new_tokens: int = 200, |
|
top_p: float = 0.85, |
|
temperature: float = 0.35, |
|
repetition_penalty: float = 1.2, |
|
model_name: str = "qgyd2021/lib_service_4chan", |
|
is_chat: bool = True, |
|
): |
|
tokenizer = BertTokenizer.from_pretrained(model_name) |
|
model = GPT2LMHeadModel.from_pretrained(model_name) |
|
model = model.eval() |
|
|
|
text_encoded = tokenizer.__call__(text, add_special_tokens=False) |
|
input_ids_ = text_encoded["input_ids"] |
|
|
|
input_ids = [tokenizer.cls_token_id] |
|
input_ids.extend(input_ids_) |
|
if is_chat: |
|
input_ids.append(tokenizer.sep_token_id) |
|
|
|
input_ids = torch.tensor([input_ids], dtype=torch.long) |
|
input_ids = input_ids.to(device) |
|
|
|
streamer = TextIteratorStreamer(tokenizer=tokenizer) |
|
|
|
generation_kwargs = dict( |
|
inputs=input_ids, |
|
max_new_tokens=max_new_tokens, |
|
do_sample=True, |
|
top_p=top_p, |
|
temperature=temperature, |
|
repetition_penalty=repetition_penalty, |
|
eos_token_id=tokenizer.sep_token_id if is_chat else None, |
|
pad_token_id=tokenizer.pad_token_id, |
|
streamer=streamer, |
|
) |
|
thread = Thread(target=model.generate, kwargs=generation_kwargs) |
|
thread.start() |
|
|
|
output: str = "" |
|
first_answer = True |
|
for output_ in streamer: |
|
if first_answer: |
|
first_answer = False |
|
continue |
|
|
|
output_ = output_.replace("[UNK] ", "") |
|
output_ = output_.replace("[UNK]", "") |
|
|
|
output += output_ |
|
|
|
output = output.lstrip("[SEP] ,.!?") |
|
output = output.replace("[SEP]", "\n") |
|
output = re.sub(r"([\u4e00-\u9fa5]) ([\u4e00-\u9fa5])", repl, output) |
|
|
|
output_text_box.value += output |
|
yield output |
|
|
|
model_name_choices = ["trained_models/lib_service_4chan"] \ |
|
if platform.system() == "Windows" else ["qgyd2021/lib_service_4chan"] |
|
demo = gr.Interface( |
|
fn=fn_stream, |
|
inputs=[ |
|
input_text_box, |
|
gr.Slider(minimum=0, maximum=512, value=512, step=1, label="max_new_tokens"), |
|
gr.Slider(minimum=0, maximum=1, value=0.85, step=0.01, label="top_p"), |
|
gr.Slider(minimum=0, maximum=1, value=0.35, step=0.01, label="temperature"), |
|
gr.Slider(minimum=0, maximum=2, value=1.2, step=0.01, label="repetition_penalty"), |
|
gr.Dropdown(choices=model_name_choices, value=model_name_choices[0], label="model_name"), |
|
gr.Checkbox(value=True, label="is_chat") |
|
], |
|
outputs=[output_text_box], |
|
examples=[ |
|
["怎样擦屁股才能擦的干净", 512, 0.75, 0.35, 1.2, "qgyd2021/lib_service_4chan", True], |
|
], |
|
cache_examples=False, |
|
examples_per_page=50, |
|
title="GPT2 Chat", |
|
description=description, |
|
) |
|
demo.queue().launch() |
|
|
|
return |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|