rapGPT / app.py
Hugo Jarkoff
Reuse files from HF model repo
cb54c90
raw
history blame
1.74 kB
import argparse
import random
import torch
from rapgpt.config import Config
from rapgpt.encoder import Encoder
from rapgpt.model import HFHubTransformerModel
import gradio as gr
from huggingface_hub import hf_hub_download
if __name__ == "__main__":
artists_tokens = hf_hub_download(repo_id="hugojarkoff/rapgpt", filename="artists_tokens.txt", repo_type="model")
config_file = hf_hub_download(repo_id="hugojarkoff/rapgpt", filename="config.toml", repo_type="model")
with open(artists_tokens, "r") as f:
artists_tokens = {
line.split(":")[0]: int(line.split(":")[1].rstrip("\n")) for line in f
}
config = Config.load_from_toml(config_file)
encoder = Encoder(config=config)
model = HFHubTransformerModel.from_pretrained("hugojarkoff/rapgpt")
def predict(
lyrics_prompt: str,
new_tokens: int,
artist_token: int,
seed: int = 42,
):
# Set Seed
random.seed(seed)
torch.manual_seed(seed)
# Predict
sample_input = encoder.encode_data(lyrics_prompt)
sample_input = torch.tensor(sample_input).unsqueeze(0)
output = model.generate(
x=sample_input,
new_tokens=new_tokens,
artist_token=artist_token,
)
return encoder.decode_data(output[0].tolist())
gradio_app = gr.Interface(
predict,
inputs=[
gr.Textbox(value="ekip"),
gr.Number(value=100),
gr.Dropdown(
value="freeze corleone", choices=artists_tokens.keys(), type="index"
),
gr.Number(value=42),
],
outputs=[gr.TextArea()],
title="rapGPT",
)
gradio_app.launch()