import gradio as gr
from transformers import GPT2LMHeadModel
from indobenchmark import IndoNLGTokenizer
gpt_tokenizer = IndoNLGTokenizer.from_pretrained("indobenchmark/indogpt")
gpt_tokenizer.pad_token = gpt_tokenizer.eos_token
kancilgpt = GPT2LMHeadModel.from_pretrained("abdiharyadi/kancilgpt")
def generate_story(judul: str):
yield "...", "..."
stop = False
while not stop:
gpt_input = gpt_tokenizer(' awal cerita | judul:', return_tensors='pt')
gpt_out = kancilgpt.generate(**gpt_input, do_sample=True, max_length=512, pad_token_id=gpt_tokenizer.eos_token_id)
result = gpt_tokenizer.decode(gpt_out[0])
_, judul_prompt, isi, *end_part = result.split(" | ")
end_part = "".join(end_part)
_, *judul_words = judul_prompt.split()
judul = " ".join(judul_words)
if "" in judul or "" in isi or "|" in isi or (not any(end_part.startswith(x) for x in ["bersambung", "tamat"])):
print("Invalid output! Regenerating ....")
continue
quote_count = 0
for c in isi:
if c == "\"":
quote_count += 1
if quote_count % 2 != 0:
print("Invalid output! Regenerating ....")
continue
stop = True
yield judul, isi + " ..."
total_isi = isi
while not end_part.startswith("tamat"):
yield judul, total_isi + " ..."
i = 0
in_quote = False
end_sentence = False
limit = 1750
while i < len(isi) and not (end_sentence and (not in_quote) and isi[i] == " " and (len(isi) - i) < limit):
if isi[i] == "\"":
in_quote = not in_quote
if end_sentence:
end_sentence = isi[i] not in "abcdefghijklmnopqrstuvwxyz"
else:
end_sentence = isi[i] in ".?!"
i += 1
# i == len(isi) or end_sentence or (not in_quote) or isi[i] == " "
while i < len(isi) and not (isi[i] in "abcdefghijklmnopqrstuvwxyz\""):
i += 1
# i == len(isi) or isi[i] in "abcdefghijklmnopqrstuvwxyz\""
if i == len(isi):
raise ValueError("What???")
next_isi = isi[i:]
stop = False
while not stop:
gpt_input = gpt_tokenizer(f' pertengahan cerita | judul: {judul} | {next_isi}', return_tensors='pt')
gpt_out = kancilgpt.generate(**gpt_input, do_sample=True, max_length=512, pad_token_id=gpt_tokenizer.eos_token_id)
result = gpt_tokenizer.decode(gpt_out[0])
_, judul_prompt, isi, *end_part = result.split(" | ")
end_part = "".join(end_part)
_, *judul_words = judul_prompt.split()
judul = " ".join(judul_words)
if isi[len(next_isi) + 1:].strip() != "":
print(isi[len(next_isi) + 1:])
if "" in isi or "|" in isi or (not any(end_part.startswith(x) for x in ["bersambung", "tamat"])):
print("Invalid output! Regenerating ....")
continue
quote_count = 0
for c in isi:
if c == "\"":
quote_count += 1
if quote_count % 2 != 0:
print("Invalid output! Regenerating ....")
continue
stop = True
total_isi += " " + isi[len(next_isi) + 1:]
yield judul, total_isi + "\n\ntamat."
demo = gr.Interface(
fn=generate_story,
inputs=None,
outputs=[
gr.Textbox(label="judul", lines=1),
gr.Textbox(label="cerita", lines=7)
]
)
demo.launch()