File size: 3,005 Bytes
10fb5a7
 
 
 
 
 
 
adafa1e
 
10fb5a7
 
 
 
 
 
 
dca7fa0
10fb5a7
 
 
59fcbd6
 
d0d15d6
10fb5a7
 
67eeb75
10fb5a7
 
cfc7981
10fb5a7
 
67eeb75
59fcbd6
 
 
67eeb75
59fcbd6
 
 
 
d31aa4a
59fcbd6
 
d0d15d6
59fcbd6
10fb5a7
 
 
 
 
 
 
 
 
 
 
 
59fcbd6
d31aa4a
59fcbd6
 
10fb5a7
 
 
 
 
 
 
 
 
 
 
d31aa4a
59fcbd6
10fb5a7
 
 
 
59fcbd6
10fb5a7
 
59fcbd6
 
 
 
10fb5a7
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import gradio as gr

import nltk
import string
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GenerationConfig, set_seed
import random

nltk.download('punkt')

response_length = 200

sentence_detector = nltk.data.load('tokenizers/punkt/english.pickle')

tokenizer = GPT2Tokenizer.from_pretrained("gpt2-medium")
tokenizer.truncation_side = 'right'

model = GPT2LMHeadModel.from_pretrained('coffeeee/nsfw-story-generator')
generation_config = GenerationConfig.from_pretrained('gpt2-medium')
generation_config.max_new_tokens = response_length
generation_config.pad_token_id = generation_config.eos_token_id
def generate_response(outputs, new_prompt):

    story_so_far = "\n".join(outputs[:int(1024 / response_length + 1)]) if outputs else ""

    set_seed(random.randint(0, 4000000000))
    inputs = tokenizer.encode(story_so_far + "\n" + new_prompt if story_so_far else new_prompt,
                              return_tensors='pt', truncation=True,
                              max_length=1024 - response_length)

    output = model.generate(inputs, do_sample=True, generation_config=generation_config)

    response = clean_paragraph(tokenizer.batch_decode(output)[0][(len(story_so_far) + 1 if story_so_far else 0):])
    outputs.append(response)
    return {
        user_outputs: outputs,
        story: (story_so_far + "\n" if story_so_far else "") + response
    }

def undo(outputs):

    outputs = outputs[:-1] if outputs else []
    return {
        user_outputs: outputs,
        story: "\n".join(outputs) if outputs else None
    }

def clean_paragraph(entry):
    paragraphs = entry.split('\n')

    for i in range(len(paragraphs)):
        split_sentences = nltk.tokenize.sent_tokenize(paragraphs[i], language='english')
        if i == len(paragraphs) - 1 and split_sentences[:1][-1] not in string.punctuation:
            paragraphs[i] = " ".join(split_sentences[:-1])

    return capitalize_first_char("\n".join(paragraphs))

def reset():
    return {
        user_outputs: [],
        story: None
    }

def capitalize_first_char(entry):
    for i in range(len(entry)):
        if entry[i].isalpha():
            return entry[:i] + entry[i].upper() + entry[i + 1:]
    return entry

with gr.Blocks() as demo:
    story = gr.Textbox(interactive=False, lines=20)
    story.style(show_copy_button=True)

    user_outputs = gr.State([])

    prompt = gr.Textbox(placeholder="Continue the story here!", lines=3, max_lines=3)

    with gr.Row():
        gen_button = gr.Button('Generate')
        undo_button = gr.Button("Undo")
        res_button = gr.Button("Reset")

    prompt.submit(generate_response, [user_outputs, prompt], [user_outputs, story], scroll_to_output=True)
    gen_button.click(generate_response, [user_outputs, prompt], [user_outputs, story], scroll_to_output=True)
    undo_button.click(undo, user_outputs, [user_outputs, story], scroll_to_output=True)
    res_button.click(reset, [], [user_outputs, story], scroll_to_output=True)

demo.launch(inbrowser=True)