File size: 1,982 Bytes
c61d2bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
265eae6
c61d2bb
19d38a8
d6c38d2
c61d2bb
 
 
fa250ff
 
 
 
 
c61d2bb
 
 
 
 
265eae6
 
c61d2bb
 
 
 
5387232
c61d2bb
5387232
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from transformers import GPT2Tokenizer, TFGPT2LMHeadModel, pipeline
import gradio as gr


model = TFGPT2LMHeadModel.from_pretrained("egosumkira/gpt2-fantasy")
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

story = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    device=0
)


def generate(tags_text, temp, n_beams, max_len):
    tags = tags_text.split(", ")
    prefix = f"~^{'^'.join(tags)}~@"
    g_text = story(prefix, temperature=float(temp), repetition_penalty=7.0, num_beams=int(n_beams), max_length=int(max_len))[0]['generated_text']
    return g_text[g_text.find("@") + 1:]


title = "GPT-2 fantasy story generator"
description = 'This is fine-tuned GPT-2 model for "conditional" generation. The model was trained on a custom-made dataset of IMDB plots & keywords.\n' \
'There are two main parameters to generate output:\n' \
'1. Temperature. If the temperature is low, the probabilities to sample other but the class with the highest log probability will be small, and the model will probably output the most correct text, but rather boring, with small variation. If the temperature is high, the model can output, with rather high probability, other words than those with the highest probability. The generated text will be more diverse, but there is a higher possibility of grammar mistakes and generation of nonsense.\n'\
'2. Number of beams in Beam Search. Beam search is a clever way to find the best sentences in a algorithm that writes words. It looks at a few possible sentences at a time, and keeps track of the most promising ones.'

iface = gr.Interface(generate,
	inputs = [
    gr.Textbox(label="Keywords (comma separated)"),
	gr.inputs.Slider(0, 2, default=1.0, step=0.05, label="Temperature"),
	gr.inputs.Slider(1, 10, default=3, label="Number of beams", step=1),
    gr.Number(label="Max lenght", value=128)
	],
	outputs = gr.Textbox(label="Output"),
	title=title,
    description=description
)

iface.queue()
iface.launch()