storytelling / app.py
jitesh's picture
adds button
d93b279
raw
history blame
2 kB
import streamlit as st
from transformers import pipeline, set_seed
import time, sys
from transformers import pipeline, set_seed
import printj
start = time.time()
generator = pipeline('text-generation', model='gpt2')
classifier = pipeline("text-classification", model="j-hartmann/emotion-english-distilroberta-base", return_all_scores=True)
process_time = time.time()-start
print(f'Process Time: {process_time}')
# set_seed(42)
# sys.exit()
st.set_page_config(layout="wide")
def story(story_till_now, num_generation, length):
last_length = 0
# story_till_now = "Hello, I'm a language model,"
for i in range(num_generation):
# start = time.time()
results = generator(story_till_now, max_length=30+length*i, num_return_sequences=1)
# process_time = time.time()-start
# print(f'Process Time: {process_time}, avg. time: {process_time/num_return_sequences}')
story_till_now = results[0]['generated_text']
new_sentence = story_till_now[last_length:]
emotion = classifier(new_sentence)
printj.yellow(f'Sentence {i}:')
story_to_print = f'{printj.ColorText.cyan(story_till_now[:last_length])}{printj.ColorText.green(story_till_now[last_length:])}\n'
print(story_to_print)
printj.purple(f'Emotion: {emotion}')
last_length = len(story_till_now)
return story_till_now, emotion
story_till_now=st.text_input(label='First Sentence', value='Hello, I\'m a language model,') # , placeholder="Start writing your story...")
num_generation= st.sidebar.slider(label='Number of generation', min_value=1, max_value=100, value=10, step=1)
length= st.sidebar.slider(label='Length of the generated sentence', min_value=1, max_value=100, value=20, step=1)
if st.button('Run'):
story_till_now, emotion =story(story_till_now, num_generation, length)
st.write('Story:')
st.text(story_till_now)
st.text(f'Emotion: {emotion}')
else:
st.write('Write the first sentence and then hit the Run button')