Spaces:
Running
Running
import torch | |
from transformers import T5Tokenizer, GPT2LMHeadModel | |
from flask import Flask, request, jsonify | |
import cutlet | |
convertors = {} | |
for romaji_sys in ["hepburn", "kunrei", "nippon"]: | |
convertors[romaji_sys] = cutlet.Cutlet(romaji_sys) | |
device = torch.device("cpu") | |
if torch.cuda.is_available(): | |
device = torch.device("cuda") | |
tokenizer = T5Tokenizer.from_pretrained("skytnt/gpt2-japanese-lyric-medium") | |
model = GPT2LMHeadModel.from_pretrained("skytnt/gpt2-japanese-lyric-medium") | |
model = model.to(device) | |
def gen_lyric(title: str, prompt_text: str): | |
if len(title) != 0 or len(prompt_text) != 0: | |
prompt_text = "<s>" + title + "[CLS]" + prompt_text | |
prompt_text = prompt_text.replace("\n", "\\n ") | |
prompt_tokens = tokenizer.tokenize(prompt_text) | |
prompt_token_ids = tokenizer.convert_tokens_to_ids(prompt_tokens) | |
prompt_tensor = torch.LongTensor(prompt_token_ids) | |
prompt_tensor = prompt_tensor.view(1, -1).to(device) | |
else: | |
prompt_tensor = None | |
# model forward | |
output_sequences = model.generate( | |
input_ids=prompt_tensor, | |
max_length=512, | |
top_p=0.95, | |
top_k=40, | |
temperature=1.0, | |
do_sample=True, | |
early_stopping=True, | |
bos_token_id=tokenizer.bos_token_id, | |
eos_token_id=tokenizer.eos_token_id, | |
pad_token_id=tokenizer.pad_token_id, | |
num_return_sequences=1 | |
) | |
# convert model outputs to readable sentence | |
generated_sequence = output_sequences.tolist()[0] | |
generated_tokens = tokenizer.convert_ids_to_tokens(generated_sequence) | |
generated_text = tokenizer.convert_tokens_to_string(generated_tokens) | |
generated_text = "\n".join([s.strip() for s in generated_text.split('\\n')]).replace(' ', '\u3000').replace('<s>', | |
'').replace( | |
'</s>', '\n\n---end---') | |
title_and_lyric = generated_text.split("[CLS]", 1) | |
if len(title_and_lyric) == 1: | |
title, lyric = "", title_and_lyric[0].strip() | |
else: | |
title, lyric = title_and_lyric[0].strip(), title_and_lyric[1].strip() | |
return title, lyric | |
app = Flask(__name__, static_url_path="", static_folder="frontend/dist") | |
def index_page(): | |
return app.send_static_file("index.html") | |
def generate(): | |
if request.method == "POST": | |
try: | |
data = request.get_json() | |
title = data['title'] | |
text = data['text'] | |
title, lyric = gen_lyric(title, text) | |
result = { | |
"state": 200, | |
"title": title, | |
"lyric": lyric | |
} | |
except Exception as e: | |
result = { | |
"state": 400, | |
"msg": f"{e}" | |
} | |
return jsonify(result), result["state"] | |
def romaji(): | |
if request.method == "POST": | |
try: | |
data = request.get_json() | |
text = data['text'] | |
system = data['system'] | |
lines = [] | |
# 不支持带换行符的直接转换 | |
for line in text.split("\n"): | |
lines.append(convertors[system].romaji(line)) | |
result = { | |
"state": 200, | |
"romaji": "\n".join(lines), | |
} | |
except Exception as e: | |
result = { | |
"state": 400, | |
"msg": f"{e}" | |
} | |
return jsonify(result), result["state"] | |
if __name__ == '__main__': | |
app.run(host="0.0.0.0", port=7860, debug=False, use_reloader=False) | |