coffeeee commited on
Commit
36d748c
β€’
1 Parent(s): 781dabe

added proj files

Browse files
Files changed (2) hide show
  1. README.md +5 -5
  2. app.py +100 -0
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: Nsfw C0ffees Erotic Story Generator2
3
- emoji: 🐒
4
- colorFrom: blue
5
- colorTo: gray
6
  sdk: gradio
7
- sdk_version: 3.29.0
8
  app_file: app.py
9
  pinned: false
10
  ---
 
1
  ---
2
+ title: "[NSFW] C0ffee's Erotic Story Generator 2"
3
+ emoji: πŸ‘
4
+ colorFrom: gray
5
+ colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 3.27.0
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ import nltk
4
+ import string
5
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer, GenerationConfig, set_seed
6
+ import random
7
+
8
+ nltk.download('punkt')
9
+
10
+ response_length = 200
11
+
12
+ sentence_detector = nltk.data.load('tokenizers/punkt/english.pickle')
13
+
14
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2-medium")
15
+ tokenizer.truncation_side = 'right'
16
+
17
+ # model = GPT2LMHeadModel.from_pretrained('checkpoint-50000')
18
+ model = GPT2LMHeadModel.from_pretrained('coffeeee/nsfw-story-generator2')
19
+ generation_config = GenerationConfig.from_pretrained('gpt2-medium')
20
+ generation_config.max_new_tokens = response_length
21
+ generation_config.pad_token_id = generation_config.eos_token_id
22
+ def generate_response(outputs, new_prompt):
23
+
24
+ story_so_far = "\n".join(outputs[:int(1024 / response_length + 1)]) if outputs else ""
25
+
26
+ set_seed(random.randint(0, 4000000000))
27
+ inputs = tokenizer.encode(story_so_far + "\n" + new_prompt if story_so_far else new_prompt,
28
+ return_tensors='pt', truncation=True,
29
+ max_length=1024 - response_length)
30
+
31
+ output = model.generate(inputs, do_sample=True, generation_config=generation_config)
32
+
33
+ response = clean_paragraph(tokenizer.batch_decode(output)[0][(len(story_so_far) + 1 if story_so_far else 0):])
34
+ outputs.append(response)
35
+ return {
36
+ user_outputs: outputs,
37
+ story: (story_so_far + "\n" if story_so_far else "") + response,
38
+ prompt: None
39
+ }
40
+
41
+ def undo(outputs):
42
+
43
+ outputs = outputs[:-1] if outputs else []
44
+ return {
45
+ user_outputs: outputs,
46
+ story: "\n".join(outputs) if outputs else None
47
+ }
48
+
49
+ def clean_paragraph(entry):
50
+ paragraphs = entry.split('\n')
51
+
52
+ for i in range(len(paragraphs)):
53
+ split_sentences = nltk.tokenize.sent_tokenize(paragraphs[i], language='english')
54
+ if i == len(paragraphs) - 1 and split_sentences[:1][-1] not in string.punctuation:
55
+ paragraphs[i] = " ".join(split_sentences[:-1])
56
+
57
+ return capitalize_first_char("\n".join(paragraphs))
58
+
59
+ def reset():
60
+ return {
61
+ user_outputs: [],
62
+ story: None
63
+ }
64
+
65
+ def capitalize_first_char(entry):
66
+ for i in range(len(entry)):
67
+ if entry[i].isalpha():
68
+ return entry[:i] + entry[i].upper() + entry[i + 1:]
69
+ return entry
70
+
71
+ with gr.Blocks(theme=gr.themes.Default(text_size='lg', font=[gr.themes.GoogleFont("Bitter"), "Arial", "sans-serif"])) as demo:
72
+
73
+ placeholder_text = '''
74
+ Disclaimer: everything this model generates is a work of fiction.
75
+ Content from this model WILL generate inappropriate and potentially offensive content.
76
+
77
+ Use at your own discretion. Please respect the Huggingface code of conduct.'''
78
+
79
+ story = gr.Textbox(label="Story", interactive=False, lines=20, placeholder=placeholder_text)
80
+ story.style(show_copy_button=True)
81
+
82
+ user_outputs = gr.State([])
83
+
84
+ prompt = gr.Textbox(label="Prompt", placeholder="Start a new story, or continue your current one!", lines=3, max_lines=3)
85
+
86
+ with gr.Row():
87
+ gen_button = gr.Button('Generate')
88
+ undo_button = gr.Button("Undo")
89
+ res_button = gr.Button("Reset")
90
+
91
+ prompt.submit(generate_response, [user_outputs, prompt], [user_outputs, story, prompt], scroll_to_output=True)
92
+ gen_button.click(generate_response, [user_outputs, prompt], [user_outputs, story, prompt], scroll_to_output=True)
93
+ undo_button.click(undo, user_outputs, [user_outputs, story], scroll_to_output=True)
94
+ res_button.click(reset, [], [user_outputs, story], scroll_to_output=True)
95
+
96
+ # for local server; comment out for deploy
97
+
98
+ demo.launch(inbrowser=True, server_name='0.0.0.0')
99
+
100
+