jitesh commited on
Commit
644c96b
1 Parent(s): 5d27e4d

adds simple story generator

Browse files
Files changed (2) hide show
  1. app.py +40 -0
  2. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+
3
+ from transformers import pipeline, set_seed
4
+
5
+ import time, sys
6
+
7
+ from transformers import pipeline, set_seed
8
+ import printj
9
+ start = time.time()
10
+ generator = pipeline('text-generation', model='gpt2')
11
+ classifier = pipeline("text-classification", model="j-hartmann/emotion-english-distilroberta-base", return_all_scores=True)
12
+ process_time = time.time()-start
13
+ print(f'Process Time: {process_time}')
14
+ # set_seed(42)
15
+ # sys.exit()
16
+
17
+ def story(story_till_now, num_generation, length):
18
+ last_length = 0
19
+
20
+ # story_till_now = "Hello, I'm a language model,"
21
+ for i in range(num_generation):
22
+ # start = time.time()
23
+ results = generator(story_till_now, max_length=30+length*i, num_return_sequences=1)
24
+ # process_time = time.time()-start
25
+ # print(f'Process Time: {process_time}, avg. time: {process_time/num_return_sequences}')
26
+ story_till_now = results[0]['generated_text']
27
+ new_sentence = story_till_now[last_length:]
28
+ emotion = classifier(new_sentence)
29
+ printj.yellow(f'Sentence {i}:')
30
+ story_to_print = f'{printj.ColorText.cyan(story_till_now[:last_length])}{printj.ColorText.green(story_till_now[last_length:])}\n'
31
+ print(story_to_print)
32
+ printj.purple(f'Emotion: {emotion}')
33
+ last_length = len(story_till_now)
34
+ return story_till_now
35
+ story_till_now=st.text_input(label='First Sentence', value='Hello, I\'m a language model,', placeholder="Start writing your story...")
36
+
37
+ num_generation= st.sidebar.slider(label='Number of generation', min_value=1, max_value=100, value=10, step=1)
38
+ length= st.sidebar.slider(label='Length of the generated sentence', min_value=1, max_value=100, value=20, step=1)
39
+ story_till_now=story(story_till_now, num_generation, length)
40
+ st.text(story_till_now)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ transformers==4.16.2
2
+ printj==0.1.0
3
+ torch==1.10.2
4
+ torchvision==0.11.3