|
|
|
|
|
import argparse |
|
from collections import defaultdict |
|
import json |
|
import os |
|
import platform |
|
import re |
|
import string |
|
from typing import List |
|
|
|
from project_settings import project_path |
|
|
|
os.environ["HUGGINGFACE_HUB_CACHE"] = (project_path / "cache/huggingface/hub").as_posix() |
|
|
|
import gradio as gr |
|
from threading import Thread |
|
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel |
|
from transformers.models.bert.tokenization_bert import BertTokenizer |
|
from transformers.generation.streamers import TextIteratorStreamer |
|
import torch |
|
|
|
|
|
def get_args(): |
|
parser = argparse.ArgumentParser() |
|
|
|
parser.add_argument("--max_new_tokens", default=512, type=int) |
|
parser.add_argument("--top_p", default=0.9, type=float) |
|
parser.add_argument("--temperature", default=0.35, type=float) |
|
parser.add_argument("--repetition_penalty", default=1.0, type=float) |
|
parser.add_argument('--device', default="cuda" if torch.cuda.is_available() else "cpu", type=str) |
|
|
|
parser.add_argument( |
|
"--examples_json_file", |
|
default="examples.json", |
|
type=str |
|
) |
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def repl1(match): |
|
result = "{}{}".format(match.group(1), match.group(2)) |
|
return result |
|
|
|
|
|
def repl2(match): |
|
result = "{}".format(match.group(1)) |
|
return result |
|
|
|
|
|
def remove_space_between_cn_en(text): |
|
splits = re.split(" ", text) |
|
if len(splits) < 2: |
|
return text |
|
|
|
result = "" |
|
for t in splits: |
|
if t == "": |
|
continue |
|
if re.search(f"[a-zA-Z0-9{string.punctuation}]$", result) and re.search("^[a-zA-Z0-9]", t): |
|
result += " " |
|
result += t |
|
else: |
|
if not result == "": |
|
result += t |
|
else: |
|
result = t |
|
|
|
if text.endswith(" "): |
|
result += " " |
|
return result |
|
|
|
|
|
def main(): |
|
args = get_args() |
|
|
|
description = """ |
|
## GPT2 Chat |
|
""" |
|
|
|
|
|
with open(args.examples_json_file, "r", encoding="utf-8") as f: |
|
examples = json.load(f) |
|
|
|
if args.device == 'auto': |
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
else: |
|
device = args.device |
|
|
|
input_text_box = gr.Text(label="text") |
|
output_text_box = gr.Text(lines=4, label="generated_content") |
|
|
|
def fn_stream(text: str, |
|
max_new_tokens: int = 200, |
|
top_p: float = 0.85, |
|
temperature: float = 0.35, |
|
repetition_penalty: float = 1.2, |
|
model_name: str = "qgyd2021/lip_service_4chan", |
|
is_chat: bool = True, |
|
): |
|
tokenizer = BertTokenizer.from_pretrained(model_name) |
|
model = GPT2LMHeadModel.from_pretrained(model_name) |
|
model = model.eval() |
|
|
|
text_encoded = tokenizer.__call__(text, add_special_tokens=False) |
|
input_ids_ = text_encoded["input_ids"] |
|
|
|
input_ids = [tokenizer.cls_token_id] |
|
input_ids.extend(input_ids_) |
|
if is_chat: |
|
input_ids.append(tokenizer.sep_token_id) |
|
|
|
input_ids = torch.tensor([input_ids], dtype=torch.long) |
|
input_ids = input_ids.to(device) |
|
|
|
streamer = TextIteratorStreamer(tokenizer=tokenizer) |
|
|
|
generation_kwargs = dict( |
|
inputs=input_ids, |
|
max_new_tokens=max_new_tokens, |
|
do_sample=True, |
|
top_p=top_p, |
|
temperature=temperature, |
|
repetition_penalty=repetition_penalty, |
|
eos_token_id=tokenizer.sep_token_id if is_chat else None, |
|
pad_token_id=tokenizer.pad_token_id, |
|
streamer=streamer, |
|
) |
|
thread = Thread(target=model.generate, kwargs=generation_kwargs) |
|
thread.start() |
|
|
|
output: str = "" |
|
first_answer = True |
|
for output_ in streamer: |
|
if first_answer: |
|
first_answer = False |
|
continue |
|
|
|
output_ = output_.replace("[UNK] ", "") |
|
output_ = output_.replace("[UNK]", "") |
|
output_ = output_.replace("[CLS] ", "") |
|
output_ = output_.replace("[CLS]", "") |
|
|
|
output += output_ |
|
if output.startswith("[SEP]"): |
|
output = output[5:] |
|
|
|
output = output.lstrip(" ,.!?") |
|
output = remove_space_between_cn_en(output) |
|
|
|
|
|
|
|
output = output.replace("[SEP] ", "\n") |
|
output = output.replace("[SEP]", "\n") |
|
|
|
yield output |
|
|
|
model_name_choices = ["trained_models/lip_service_4chan", "trained_models/chinese_porn_novel"] \ |
|
if platform.system() == "Windows" else \ |
|
[ |
|
"qgyd2021/lip_service_4chan", "qgyd2021/chinese_chitchat", |
|
"qgyd2021/chinese_porn_novel", "qgyd2021/few_shot_intent_gpt2", |
|
"qgyd2021/similar_question_generation", |
|
] |
|
|
|
demo = gr.Interface( |
|
fn=fn_stream, |
|
inputs=[ |
|
input_text_box, |
|
gr.Slider(minimum=0, maximum=512, value=512, step=1, label="max_new_tokens"), |
|
gr.Slider(minimum=0, maximum=1, value=0.85, step=0.01, label="top_p"), |
|
gr.Slider(minimum=0, maximum=1, value=0.35, step=0.01, label="temperature"), |
|
gr.Slider(minimum=0, maximum=2, value=1.2, step=0.01, label="repetition_penalty"), |
|
gr.Dropdown(choices=model_name_choices, value=model_name_choices[0], label="model_name"), |
|
gr.Checkbox(value=True, label="is_chat") |
|
], |
|
outputs=[output_text_box], |
|
examples=examples, |
|
cache_examples=False, |
|
examples_per_page=50, |
|
title="GPT2 Chat", |
|
description=description, |
|
) |
|
demo.queue().launch( |
|
share=False if platform.system() == "Windows" else False, |
|
server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0", |
|
server_port=7860 |
|
) |
|
|
|
return |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|