File size: 2,410 Bytes
cd20c72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2eedd09
 
 
 
 
cd20c72
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr


class PoetryGenerator:
    GENRES = ('bốn chữ', 'năm chữ', 'sáu chữ', 'bảy chữ', 'tám chữ', 'lục bát', 'song thất lục bát')

    def __init__(
        self,
        model_name_or_path: str = './checkpoint',
        max_length: int = 70
    ):
        self.max_length = max_length

        self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
        self.model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
    
    def generate(self, start_words: str, genre: str, n_poems: int = 1, collate: bool = False):
        assert genre in self.GENRES, f"Expect genre in {self.GENRES}. Got {genre}."
        tokenized = self.tokenizer(
            self.tokenizer.bos_token +
            genre + 
            self.tokenizer.sep_token +
            start_words,
            return_tensors='pt'
        )

        generated = [
            self.model.generate(
                **tokenized,
                do_sample=True,
                max_length=self.max_length,
                top_k=4,
                num_beams=5,
                no_repeat_ngram_size=2,
                num_return_sequences=1
            )[0]
            for _ in range(n_poems)
        ]

        poems = []
        for token_ids in generated:
            decoded = self.tokenizer.decode(token_ids)
            poem_content = decoded.split(self.tokenizer.sep_token)[1]
            poem_verses = poem_content.split(self.tokenizer.eos_token)[:4]
            poem_content = '\n'.join(poem_verses)

            poems.append(poem_content)

        # Ugly way to show multiple poems with gradio
        if collate:
            for i in range(n_poems):
                poems[i] = f'BÀI {i + 1}\n' + poems[i]
            
            return '\n\n'.join(poems)

        return poems


if __name__ == '__main__':
    generator = PoetryGenerator()
    MAX_POEMS = 5

    gr.Interface(
        lambda *args: generator.generate(*args, collate=True),
        inputs=[
            gr.Textbox(label="Start words"),
            gr.Dropdown(choices=PoetryGenerator.GENRES, label="Genre"),
            gr.Slider(1, MAX_POEMS, step=1, label="Number of poems")
        ],
        outputs='text',
        examples=[
            ['thân em', 'lục bát', 2],
            ['chiều chiều', 'bảy chữ', 1]
        ]
    ).launch()