File size: 2,288 Bytes
d0f499d
 
 
 
 
 
cb54c90
d0f499d
 
402a061
 
 
 
 
 
d0f499d
cb54c90
d0f499d
 
 
 
cb54c90
d0f499d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b03c785
 
 
 
 
 
 
 
 
 
 
d0f499d
b03c785
 
 
 
 
 
 
 
d0f499d
 
b03c785
d0f499d
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import random
import torch
from rapgpt.config import Config
from rapgpt.encoder import Encoder
from rapgpt.model import HFHubTransformerModel
import gradio as gr
from huggingface_hub import hf_hub_download

if __name__ == "__main__":
    artists_tokens = hf_hub_download(
        repo_id="hugojarkoff/rapgpt", filename="artists_tokens.txt", repo_type="model"
    )
    config_file = hf_hub_download(
        repo_id="hugojarkoff/rapgpt", filename="config.toml", repo_type="model"
    )

    with open(artists_tokens, "r") as f:
        artists_tokens = {
            line.split(":")[0]: int(line.split(":")[1].rstrip("\n")) for line in f
        }

    config = Config.load_from_toml(config_file)
    encoder = Encoder(config=config)
    model = HFHubTransformerModel.from_pretrained("hugojarkoff/rapgpt")

    def predict(
        lyrics_prompt: str,
        new_tokens: int,
        artist_token: int,
        seed: int = 42,
    ):
        # Set Seed
        random.seed(seed)
        torch.manual_seed(seed)

        # Predict
        sample_input = encoder.encode_data(lyrics_prompt)
        sample_input = torch.tensor(sample_input).unsqueeze(0)
        output = model.generate(
            x=sample_input,
            new_tokens=new_tokens,
            artist_token=artist_token,
        )
        return encoder.decode_data(output[0].tolist())

    gradio_app = gr.Interface(
        predict,
        inputs=[
            gr.Textbox(
                value="ekip",
                label="Lyrics prompt",
                info="rapGPT will continue this prompt",
            ),
            gr.Number(
                value=100,
                maximum=100,
                label="New tokens to generate",
                info="Number of new tokens to generate (limited to 100)",
            ),
            gr.Dropdown(
                value="freeze corleone",
                choices=artists_tokens.keys(),
                type="index",
                label="Artist",
                info="Which artist style to generate",
            ),
            gr.Number(
                value=42, label="Random seed", info="Change for different results"
            ),
        ],
        outputs=[gr.TextArea(label="Generated Lyrics")],
        title="rapGPT",
    )

    gradio_app.launch()