Spaces:
Runtime error
Runtime error
File size: 3,005 Bytes
10fb5a7 adafa1e 10fb5a7 dca7fa0 10fb5a7 59fcbd6 d0d15d6 10fb5a7 67eeb75 10fb5a7 cfc7981 10fb5a7 67eeb75 59fcbd6 67eeb75 59fcbd6 d31aa4a 59fcbd6 d0d15d6 59fcbd6 10fb5a7 59fcbd6 d31aa4a 59fcbd6 10fb5a7 d31aa4a 59fcbd6 10fb5a7 59fcbd6 10fb5a7 59fcbd6 10fb5a7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import gradio as gr
import nltk
import string
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GenerationConfig, set_seed
import random
nltk.download('punkt')
response_length = 200
sentence_detector = nltk.data.load('tokenizers/punkt/english.pickle')
tokenizer = GPT2Tokenizer.from_pretrained("gpt2-medium")
tokenizer.truncation_side = 'right'
model = GPT2LMHeadModel.from_pretrained('coffeeee/nsfw-story-generator')
generation_config = GenerationConfig.from_pretrained('gpt2-medium')
generation_config.max_new_tokens = response_length
generation_config.pad_token_id = generation_config.eos_token_id
def generate_response(outputs, new_prompt):
story_so_far = "\n".join(outputs[:int(1024 / response_length + 1)]) if outputs else ""
set_seed(random.randint(0, 4000000000))
inputs = tokenizer.encode(story_so_far + "\n" + new_prompt if story_so_far else new_prompt,
return_tensors='pt', truncation=True,
max_length=1024 - response_length)
output = model.generate(inputs, do_sample=True, generation_config=generation_config)
response = clean_paragraph(tokenizer.batch_decode(output)[0][(len(story_so_far) + 1 if story_so_far else 0):])
outputs.append(response)
return {
user_outputs: outputs,
story: (story_so_far + "\n" if story_so_far else "") + response
}
def undo(outputs):
outputs = outputs[:-1] if outputs else []
return {
user_outputs: outputs,
story: "\n".join(outputs) if outputs else None
}
def clean_paragraph(entry):
paragraphs = entry.split('\n')
for i in range(len(paragraphs)):
split_sentences = nltk.tokenize.sent_tokenize(paragraphs[i], language='english')
if i == len(paragraphs) - 1 and split_sentences[:1][-1] not in string.punctuation:
paragraphs[i] = " ".join(split_sentences[:-1])
return capitalize_first_char("\n".join(paragraphs))
def reset():
return {
user_outputs: [],
story: None
}
def capitalize_first_char(entry):
for i in range(len(entry)):
if entry[i].isalpha():
return entry[:i] + entry[i].upper() + entry[i + 1:]
return entry
with gr.Blocks() as demo:
story = gr.Textbox(interactive=False, lines=20)
story.style(show_copy_button=True)
user_outputs = gr.State([])
prompt = gr.Textbox(placeholder="Continue the story here!", lines=3, max_lines=3)
with gr.Row():
gen_button = gr.Button('Generate')
undo_button = gr.Button("Undo")
res_button = gr.Button("Reset")
prompt.submit(generate_response, [user_outputs, prompt], [user_outputs, story], scroll_to_output=True)
gen_button.click(generate_response, [user_outputs, prompt], [user_outputs, story], scroll_to_output=True)
undo_button.click(undo, user_outputs, [user_outputs, story], scroll_to_output=True)
res_button.click(reset, [], [user_outputs, story], scroll_to_output=True)
demo.launch(inbrowser=True)
|