Spaces:
Runtime error
Runtime error
fixes weight slider and storytelling mode
Browse files- app.py +8 -8
- story_gen.py +1 -1
- story_gen_test.py +3 -0
app.py
CHANGED
@@ -21,14 +21,14 @@ num_generation = container_param.slider(
|
|
21 |
length = container_param.slider(label='Length of the generated sentence',
|
22 |
min_value=1, max_value=100, value=10, step=1)
|
23 |
if mode == 'Create Statistics':
|
24 |
-
|
25 |
num_tests = container_param.slider(
|
26 |
label='Number of tests', min_value=1, max_value=1000, value=3, step=1)
|
27 |
reaction_weight_mode = container_param.select_slider(
|
28 |
"Reaction Weight w:", ["Random", "Fixed"])
|
29 |
if reaction_weight_mode == "Fixed":
|
30 |
reaction_weight = container_param.slider(
|
31 |
-
label='
|
32 |
elif reaction_weight_mode == "Random":
|
33 |
reaction_weight = -1
|
34 |
if container_button.button('Analyse'):
|
@@ -41,9 +41,9 @@ if mode == 'Create Statistics':
|
|
41 |
data=gen.stats_df[gen.stats_df.sentence_no==3]
|
42 |
fig = px.violin(data_frame=data, x="reaction_weight", y="num_reactions", hover_data=data.columns)
|
43 |
st.plotly_chart(fig, use_container_width=True)
|
44 |
-
|
45 |
-
|
46 |
-
elif mode == '
|
47 |
container_mode.write('Let\'s play storytelling.')
|
48 |
|
49 |
# # , placeholder="Start writing your story...")
|
@@ -57,8 +57,8 @@ elif mode == 'Create Statistics':
|
|
57 |
if container_button.button('Run'):
|
58 |
story_till_now, emotion = gen.story(
|
59 |
story_till_now, num_generation, length)
|
60 |
-
st.
|
61 |
st.text(story_till_now)
|
62 |
-
st.
|
63 |
else:
|
64 |
-
st.
|
|
|
21 |
length = container_param.slider(label='Length of the generated sentence',
|
22 |
min_value=1, max_value=100, value=10, step=1)
|
23 |
if mode == 'Create Statistics':
|
24 |
+
|
25 |
num_tests = container_param.slider(
|
26 |
label='Number of tests', min_value=1, max_value=1000, value=3, step=1)
|
27 |
reaction_weight_mode = container_param.select_slider(
|
28 |
"Reaction Weight w:", ["Random", "Fixed"])
|
29 |
if reaction_weight_mode == "Fixed":
|
30 |
reaction_weight = container_param.slider(
|
31 |
+
label='Reaction Weight w', min_value=0.0, max_value=1.0, value=0.5, step=0.01)
|
32 |
elif reaction_weight_mode == "Random":
|
33 |
reaction_weight = -1
|
34 |
if container_button.button('Analyse'):
|
|
|
41 |
data=gen.stats_df[gen.stats_df.sentence_no==3]
|
42 |
fig = px.violin(data_frame=data, x="reaction_weight", y="num_reactions", hover_data=data.columns)
|
43 |
st.plotly_chart(fig, use_container_width=True)
|
44 |
+
else:
|
45 |
+
st.markdown('### You selected statistics. Now set your parameters and click the `Analyse` button.')
|
46 |
+
elif mode == 'Play Storytelling':
|
47 |
container_mode.write('Let\'s play storytelling.')
|
48 |
|
49 |
# # , placeholder="Start writing your story...")
|
|
|
57 |
if container_button.button('Run'):
|
58 |
story_till_now, emotion = gen.story(
|
59 |
story_till_now, num_generation, length)
|
60 |
+
st.markdown(f'### Story')
|
61 |
st.text(story_till_now)
|
62 |
+
st.markdown(f'The last sentence has an **Emotion** of {emotion["label"]} with a confidence score of {emotion["score"]}.')
|
63 |
else:
|
64 |
+
st.markdown('### Write the first sentence and then hit the `Run` button')
|
story_gen.py
CHANGED
@@ -57,7 +57,7 @@ class StoryGenerator:
|
|
57 |
for i in range(num_generation):
|
58 |
last_length = len(story_till_now)
|
59 |
genreate_robot_sentence = self.generator(story_till_now, max_length=self.get_num_token(story_till_now) +
|
60 |
-
length
|
61 |
story_till_now = genreate_robot_sentence[0]['generated_text']
|
62 |
new_sentence = story_till_now[last_length:]
|
63 |
emotions = self.classifier(new_sentence)
|
|
|
57 |
for i in range(num_generation):
|
58 |
last_length = len(story_till_now)
|
59 |
genreate_robot_sentence = self.generator(story_till_now, max_length=self.get_num_token(story_till_now) +
|
60 |
+
length, num_return_sequences=1)
|
61 |
story_till_now = genreate_robot_sentence[0]['generated_text']
|
62 |
new_sentence = story_till_now[last_length:]
|
63 |
emotions = self.classifier(new_sentence)
|
story_gen_test.py
CHANGED
@@ -29,3 +29,6 @@ ax = sns.violinplot(x="reaction_weight", y="num_reactions", data=data).set_title
|
|
29 |
|
30 |
gen.stats_df[gen.stats_df.sentence_no==3]
|
31 |
# %%
|
|
|
|
|
|
|
|
29 |
|
30 |
gen.stats_df[gen.stats_df.sentence_no==3]
|
31 |
# %%
|
32 |
+
import re
|
33 |
+
len(re.findall(r'\w+', 'line ive '))
|
34 |
+
# %%
|