Spaces:
Runtime error
Runtime error
File size: 2,410 Bytes
cd20c72 2eedd09 cd20c72 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
class PoetryGenerator:
GENRES = ('bốn chữ', 'năm chữ', 'sáu chữ', 'bảy chữ', 'tám chữ', 'lục bát', 'song thất lục bát')
def __init__(
self,
model_name_or_path: str = './checkpoint',
max_length: int = 70
):
self.max_length = max_length
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
def generate(self, start_words: str, genre: str, n_poems: int = 1, collate: bool = False):
assert genre in self.GENRES, f"Expect genre in {self.GENRES}. Got {genre}."
tokenized = self.tokenizer(
self.tokenizer.bos_token +
genre +
self.tokenizer.sep_token +
start_words,
return_tensors='pt'
)
generated = [
self.model.generate(
**tokenized,
do_sample=True,
max_length=self.max_length,
top_k=4,
num_beams=5,
no_repeat_ngram_size=2,
num_return_sequences=1
)[0]
for _ in range(n_poems)
]
poems = []
for token_ids in generated:
decoded = self.tokenizer.decode(token_ids)
poem_content = decoded.split(self.tokenizer.sep_token)[1]
poem_verses = poem_content.split(self.tokenizer.eos_token)[:4]
poem_content = '\n'.join(poem_verses)
poems.append(poem_content)
# Ugly way to show multiple poems with gradio
if collate:
for i in range(n_poems):
poems[i] = f'BÀI {i + 1}\n' + poems[i]
return '\n\n'.join(poems)
return poems
if __name__ == '__main__':
generator = PoetryGenerator()
MAX_POEMS = 5
gr.Interface(
lambda *args: generator.generate(*args, collate=True),
inputs=[
gr.Textbox(label="Start words"),
gr.Dropdown(choices=PoetryGenerator.GENRES, label="Genre"),
gr.Slider(1, MAX_POEMS, step=1, label="Number of poems")
],
outputs='text',
examples=[
['thân em', 'lục bát', 2],
['chiều chiều', 'bảy chữ', 1]
]
).launch()
|