Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import GPT2LMHeadModel | |
from indobenchmark import IndoNLGTokenizer | |
gpt_tokenizer = IndoNLGTokenizer.from_pretrained("indobenchmark/indogpt") | |
gpt_tokenizer.pad_token = gpt_tokenizer.eos_token | |
kancilgpt = GPT2LMHeadModel.from_pretrained("abdiharyadi/kancilgpt") | |
def generate_story(): | |
stop = False | |
prompt = "<s> awal cerita | judul:" | |
judul = "" | |
isi = "" | |
end_part = "" | |
isi_not_checked = True | |
yield "..." | |
while not stop: | |
prompt_stop = False | |
while not prompt_stop: | |
gpt_input = gpt_tokenizer(prompt, return_tensors='pt') | |
gpt_out = kancilgpt.generate( | |
**gpt_input, | |
do_sample=True, | |
max_new_tokens=2, | |
pad_token_id=gpt_tokenizer.eos_token_id, | |
eos_token_id=gpt_tokenizer.eos_token_id | |
) | |
gpt_out = gpt_out[0] | |
result = gpt_tokenizer.decode(gpt_out) | |
splitted_result = result.split(" | ") | |
if len(splitted_result) <= 2: | |
_, judul_prompt = splitted_result | |
_, *judul_words = judul_prompt.split() | |
judul = " ".join(judul_words) | |
yield judul + "..." | |
if "." in judul: | |
print("Invalid judul!") | |
prompt = "<s> awal cerita | judul:" | |
continue | |
isi = "" | |
end_part = "" | |
if gpt_out[-1] == gpt_tokenizer.eos_token_id: | |
continue | |
else: | |
_, judul_prompt, isi, *end_part = splitted_result | |
end_part = "".join(end_part) | |
_, *judul_words = judul_prompt.split() | |
judul = " ".join(judul_words) | |
yield judul + "\n" + ("-" * len(judul)) + "\n" + isi + f"..." | |
if len(splitted_result) == 3: | |
if gpt_out[-1] == gpt_tokenizer.eos_token_id: | |
continue | |
elif isi_not_checked: | |
quote_count = 0 | |
prev_i = 0 | |
for i, c in enumerate(isi): | |
if c == "\"": | |
quote_count += 1 | |
prev_i = i | |
if quote_count % 2 != 0: | |
print("Invalid isi!") | |
trimmed_isi = isi[:prev_i].rstrip() | |
prompt = f"<s> awal cerita | judul: {judul} | {trimmed_isi}" | |
continue | |
isi_not_checked = False | |
if gpt_out[-1] == gpt_tokenizer.eos_token_id: | |
prompt_stop = True | |
else: | |
prompt = result | |
# prompt_stop | |
if (not any(end_part.startswith(x) for x in ["bersambung", "tamat"])): | |
print("Invalid ending! Regenerating ....") | |
prompt = f"<s> awal cerita | judul: {judul} | {isi} |" | |
continue | |
stop = True | |
total_isi = isi | |
print("We skip the rest of the part for debug.") | |
# TODO: Solve this. | |
# ellipsis = "..." | |
# while not end_part.startswith("tamat"): | |
# yield judul + "\n" + ("-" * len(judul)) + "\n" + total_isi + f" {ellipsis}" | |
# ellipsis += "." | |
# i = 0 | |
# in_quote = False | |
# end_sentence = False | |
# limit = 1750 | |
# while i < len(isi) and not (end_sentence and (not in_quote) and isi[i] == " " and (len(isi) - i) < limit): | |
# if isi[i] == "\"": | |
# in_quote = not in_quote | |
# if end_sentence: | |
# end_sentence = isi[i] not in "abcdefghijklmnopqrstuvwxyz" | |
# else: | |
# end_sentence = isi[i] in ".?!" | |
# i += 1 | |
# # i == len(isi) or end_sentence or (not in_quote) or isi[i] == " " | |
# while i < len(isi) and not (isi[i] in "abcdefghijklmnopqrstuvwxyz\""): | |
# i += 1 | |
# # i == len(isi) or isi[i] in "abcdefghijklmnopqrstuvwxyz\"" | |
# if i == len(isi): | |
# raise ValueError("What???") | |
# next_isi = isi[i:] | |
# stop = False | |
# while not stop: | |
# gpt_input = gpt_tokenizer(f'<s> pertengahan cerita | judul: {judul} | {next_isi}', return_tensors='pt') | |
# gpt_out = kancilgpt.generate(**gpt_input, do_sample=True, max_length=512, pad_token_id=gpt_tokenizer.eos_token_id) | |
# result = gpt_tokenizer.decode(gpt_out[0]) | |
# _, judul_prompt, isi, *end_part = result.split(" | ") | |
# end_part = "".join(end_part) | |
# _, *judul_words = judul_prompt.split() | |
# judul = " ".join(judul_words) | |
# if isi[len(next_isi) + 1:].strip() != "": | |
# print(isi[len(next_isi) + 1:]) | |
# if "</s>" in isi or "|" in isi or (not any(end_part.startswith(x) for x in ["bersambung", "tamat"])): | |
# print("Invalid output! Regenerating ....") | |
# continue | |
# quote_count = 0 | |
# for c in isi: | |
# if c == "\"": | |
# quote_count += 1 | |
# if quote_count % 2 != 0: | |
# print("Invalid output! Regenerating ....") | |
# continue | |
# stop = True | |
# total_isi += " " + isi[len(next_isi) + 1:] | |
# ellipsis = "..." | |
yield judul + "\n" + ("-" * len(judul)) + "\n" + total_isi + "\n\ntamat." | |
demo = gr.Interface( | |
fn=generate_story, | |
inputs=None, | |
outputs=[ | |
gr.Textbox(label="cerita", lines=7) | |
] | |
) | |
demo.launch() | |