Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -19,10 +19,14 @@ if 'story' not in st.session_state:
|
|
19 |
st.session_state.story = None
|
20 |
if 'story_zh' not in st.session_state:
|
21 |
st.session_state.story_zh = None
|
22 |
-
if '
|
23 |
-
st.session_state.
|
24 |
-
if '
|
25 |
-
st.session_state.
|
|
|
|
|
|
|
|
|
26 |
|
27 |
|
28 |
# function part
|
@@ -40,26 +44,42 @@ def translate_to_chinese(text):
|
|
40 |
return translation
|
41 |
|
42 |
|
43 |
-
# text2story
|
44 |
def text2story(text):
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
|
61 |
|
62 |
-
# Text to audio using edge_tts for Cantonese
|
63 |
async def text2audio_cantonese(text):
|
64 |
try:
|
65 |
# Use Cantonese voice from edge-tts
|
@@ -80,7 +100,35 @@ async def text2audio_cantonese(text):
|
|
80 |
'success': True
|
81 |
}
|
82 |
except Exception as e:
|
83 |
-
st.error(f"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
return {
|
85 |
'path': None,
|
86 |
'success': False
|
@@ -275,6 +323,21 @@ st.markdown("""
|
|
275 |
font-size: 4rem;
|
276 |
margin-bottom: 1rem;
|
277 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
278 |
</style>
|
279 |
""", unsafe_allow_html=True)
|
280 |
|
@@ -283,6 +346,9 @@ st.title("✨ 故事魔法")
|
|
283 |
st.markdown("<p class='subtitle'>上載一張圖片,睇下佢點變成一個神奇嘅故事!</p>",
|
284 |
unsafe_allow_html=True)
|
285 |
|
|
|
|
|
|
|
286 |
# File uploader with Cantonese
|
287 |
with st.container():
|
288 |
st.subheader("揀一張靚相啦!")
|
@@ -304,14 +370,22 @@ if uploaded_file is not None:
|
|
304 |
with st.container():
|
305 |
st.markdown("<h3><span class='stage-icon'>🔍</span> 圖片解讀中</h3>", unsafe_allow_html=True)
|
306 |
|
|
|
|
|
|
|
|
|
307 |
# Generate caption if not already done
|
308 |
st.session_state.scenario = img2text(temp_file_path)
|
|
|
309 |
|
310 |
# Display English caption
|
311 |
st.text("英文描述: " + st.session_state.scenario)
|
312 |
|
313 |
# Translate the caption to Chinese
|
|
|
|
|
314 |
st.session_state.scenario_zh = translate_to_chinese(st.session_state.scenario)
|
|
|
315 |
|
316 |
# Display Chinese caption
|
317 |
st.text("中文描述: " + st.session_state.scenario_zh)
|
@@ -320,17 +394,28 @@ if uploaded_file is not None:
|
|
320 |
with st.container():
|
321 |
st.markdown("<h3><span class='stage-icon'>📝</span> 故事創作中</h3>", unsafe_allow_html=True)
|
322 |
|
|
|
|
|
|
|
323 |
# Generate story if not already done
|
324 |
st.session_state.story = text2story(st.session_state.scenario)
|
|
|
325 |
|
326 |
# Display English story
|
327 |
st.text("英文故事: " + st.session_state.story)
|
328 |
|
329 |
# Translate the story to Chinese
|
|
|
|
|
330 |
st.session_state.story_zh = translate_to_chinese(st.session_state.story)
|
|
|
331 |
|
332 |
# Display Chinese story
|
333 |
st.text("中文故事: " + st.session_state.story_zh)
|
|
|
|
|
|
|
|
|
334 |
else:
|
335 |
# Display saved results from session state
|
336 |
with st.container():
|
@@ -347,22 +432,50 @@ if uploaded_file is not None:
|
|
347 |
with st.container():
|
348 |
st.markdown("<h3><span class='stage-icon'>🔊</span> 故事準備朗讀中</h3>", unsafe_allow_html=True)
|
349 |
|
350 |
-
#
|
351 |
-
|
352 |
-
|
353 |
-
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
st.session_state.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
358 |
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
|
365 |
-
|
|
|
|
|
366 |
|
367 |
# Cleanup: Remove the temporary file when the user is done
|
368 |
if os.path.exists(temp_file_path):
|
@@ -370,9 +483,15 @@ if uploaded_file is not None:
|
|
370 |
else:
|
371 |
# Clear session state when no file is uploaded
|
372 |
# Also clean up any temporary audio files
|
373 |
-
if st.session_state.
|
|
|
|
|
|
|
|
|
|
|
|
|
374 |
try:
|
375 |
-
os.remove(st.session_state.
|
376 |
except:
|
377 |
pass
|
378 |
|
@@ -380,8 +499,10 @@ else:
|
|
380 |
st.session_state.scenario_zh = None
|
381 |
st.session_state.story = None
|
382 |
st.session_state.story_zh = None
|
383 |
-
st.session_state.
|
384 |
-
st.session_state.
|
|
|
|
|
385 |
|
386 |
# Welcome message in Cantonese
|
387 |
st.markdown("""
|
|
|
19 |
st.session_state.story = None
|
20 |
if 'story_zh' not in st.session_state:
|
21 |
st.session_state.story_zh = None
|
22 |
+
if 'audio_generated_zh' not in st.session_state:
|
23 |
+
st.session_state.audio_generated_zh = False
|
24 |
+
if 'audio_path_zh' not in st.session_state:
|
25 |
+
st.session_state.audio_path_zh = None
|
26 |
+
if 'audio_generated_en' not in st.session_state:
|
27 |
+
st.session_state.audio_generated_en = False
|
28 |
+
if 'audio_path_en' not in st.session_state:
|
29 |
+
st.session_state.audio_path_en = None
|
30 |
|
31 |
|
32 |
# function part
|
|
|
44 |
return translation
|
45 |
|
46 |
|
47 |
+
# text2story - using mosaicml/mpt-7b-storywriter model for better stories
|
48 |
def text2story(text):
|
49 |
+
try:
|
50 |
+
# Initialize the improved story generation pipeline
|
51 |
+
generator = pipeline("text-generation", model="mosaicml/mpt-7b-storywriter", trust_remote_code=True)
|
52 |
+
|
53 |
+
# Create a prompt for the story
|
54 |
+
prompt = f"Write a short children's story about this scene: {text}\n\nStory: "
|
55 |
+
|
56 |
+
# Generate the story - limit to a smaller max_length due to model size
|
57 |
+
story = generator(prompt,
|
58 |
+
max_length=150,
|
59 |
+
num_return_sequences=1,
|
60 |
+
temperature=0.7,
|
61 |
+
repetition_penalty=1.2)[0]['generated_text']
|
62 |
+
|
63 |
+
# Clean up the story by removing the prompt
|
64 |
+
story = story.replace(prompt, "").strip()
|
65 |
+
|
66 |
+
# Trim to a reasonable length if needed
|
67 |
+
if len(story) > 500:
|
68 |
+
sentences = story.split('.')
|
69 |
+
trimmed_story = '.'.join(sentences[:5]) + '.'
|
70 |
+
return trimmed_story
|
71 |
+
|
72 |
+
return story
|
73 |
+
except Exception as e:
|
74 |
+
st.error(f"故事生成出問題: {str(e)}")
|
75 |
+
# Fallback to simpler model if the advanced one fails
|
76 |
+
fallback_generator = pipeline('text-generation', model='gpt2')
|
77 |
+
fallback_prompt = f"Create a short story about this scene: {text}\n\nStory:"
|
78 |
+
fallback_story = fallback_generator(fallback_prompt, max_length=100, num_return_sequences=1)[0]['generated_text']
|
79 |
+
return fallback_story.replace(fallback_prompt, "").strip()
|
80 |
|
81 |
|
82 |
+
# Text to audio using edge_tts for Cantonese audio
|
83 |
async def text2audio_cantonese(text):
|
84 |
try:
|
85 |
# Use Cantonese voice from edge-tts
|
|
|
100 |
'success': True
|
101 |
}
|
102 |
except Exception as e:
|
103 |
+
st.error(f"中文音頻製作出左問題: {str(e)}")
|
104 |
+
return {
|
105 |
+
'path': None,
|
106 |
+
'success': False
|
107 |
+
}
|
108 |
+
|
109 |
+
|
110 |
+
# Text to audio using edge_tts for English audio
|
111 |
+
async def text2audio_english(text):
|
112 |
+
try:
|
113 |
+
# Use English voice from edge-tts
|
114 |
+
voice = "en-US-AriaNeural" # Female English voice
|
115 |
+
# Alternative: "en-US-GuyNeural" for male voice
|
116 |
+
|
117 |
+
# Create a temporary file
|
118 |
+
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
|
119 |
+
temp_file.close()
|
120 |
+
|
121 |
+
# Configure edge-tts to save to the file path
|
122 |
+
communicate = edge_tts.Communicate(text, voice)
|
123 |
+
await communicate.save(temp_file.name)
|
124 |
+
|
125 |
+
# Return the path to the audio file
|
126 |
+
return {
|
127 |
+
'path': temp_file.name,
|
128 |
+
'success': True
|
129 |
+
}
|
130 |
+
except Exception as e:
|
131 |
+
st.error(f"English audio generation error: {str(e)}")
|
132 |
return {
|
133 |
'path': None,
|
134 |
'success': False
|
|
|
323 |
font-size: 4rem;
|
324 |
margin-bottom: 1rem;
|
325 |
}
|
326 |
+
|
327 |
+
/* Audio player container */
|
328 |
+
.audio-container {
|
329 |
+
background: white;
|
330 |
+
padding: 1rem;
|
331 |
+
border-radius: 12px;
|
332 |
+
margin-bottom: 1rem;
|
333 |
+
box-shadow: var(--shadow);
|
334 |
+
}
|
335 |
+
|
336 |
+
.audio-title {
|
337 |
+
font-weight: 600;
|
338 |
+
margin-bottom: 0.5rem;
|
339 |
+
color: var(--primary-color);
|
340 |
+
}
|
341 |
</style>
|
342 |
""", unsafe_allow_html=True)
|
343 |
|
|
|
346 |
st.markdown("<p class='subtitle'>上載一張圖片,睇下佢點變成一個神奇嘅故事!</p>",
|
347 |
unsafe_allow_html=True)
|
348 |
|
349 |
+
# Add a progress indicator for model loading
|
350 |
+
progress_placeholder = st.empty()
|
351 |
+
|
352 |
# File uploader with Cantonese
|
353 |
with st.container():
|
354 |
st.subheader("揀一張靚相啦!")
|
|
|
370 |
with st.container():
|
371 |
st.markdown("<h3><span class='stage-icon'>🔍</span> 圖片解讀中</h3>", unsafe_allow_html=True)
|
372 |
|
373 |
+
with progress_placeholder.container():
|
374 |
+
st.write("正在分析圖片...")
|
375 |
+
progress_bar = st.progress(0)
|
376 |
+
|
377 |
# Generate caption if not already done
|
378 |
st.session_state.scenario = img2text(temp_file_path)
|
379 |
+
progress_bar.progress(33)
|
380 |
|
381 |
# Display English caption
|
382 |
st.text("英文描述: " + st.session_state.scenario)
|
383 |
|
384 |
# Translate the caption to Chinese
|
385 |
+
with progress_placeholder.container():
|
386 |
+
st.write("正在翻譯...")
|
387 |
st.session_state.scenario_zh = translate_to_chinese(st.session_state.scenario)
|
388 |
+
progress_bar.progress(66)
|
389 |
|
390 |
# Display Chinese caption
|
391 |
st.text("中文描述: " + st.session_state.scenario_zh)
|
|
|
394 |
with st.container():
|
395 |
st.markdown("<h3><span class='stage-icon'>📝</span> 故事創作中</h3>", unsafe_allow_html=True)
|
396 |
|
397 |
+
with progress_placeholder.container():
|
398 |
+
st.write("正在創作故事...")
|
399 |
+
|
400 |
# Generate story if not already done
|
401 |
st.session_state.story = text2story(st.session_state.scenario)
|
402 |
+
progress_bar.progress(85)
|
403 |
|
404 |
# Display English story
|
405 |
st.text("英文故事: " + st.session_state.story)
|
406 |
|
407 |
# Translate the story to Chinese
|
408 |
+
with progress_placeholder.container():
|
409 |
+
st.write("正在翻譯故事...")
|
410 |
st.session_state.story_zh = translate_to_chinese(st.session_state.story)
|
411 |
+
progress_bar.progress(100)
|
412 |
|
413 |
# Display Chinese story
|
414 |
st.text("中文故事: " + st.session_state.story_zh)
|
415 |
+
|
416 |
+
# Clear progress indicator
|
417 |
+
progress_placeholder.empty()
|
418 |
+
|
419 |
else:
|
420 |
# Display saved results from session state
|
421 |
with st.container():
|
|
|
432 |
with st.container():
|
433 |
st.markdown("<h3><span class='stage-icon'>🔊</span> 故事準備朗讀中</h3>", unsafe_allow_html=True)
|
434 |
|
435 |
+
# Create two columns for English and Cantonese buttons
|
436 |
+
col1, col2 = st.columns(2)
|
437 |
+
|
438 |
+
# English audio button
|
439 |
+
with col1:
|
440 |
+
if st.button("🔊 Play Story in English"):
|
441 |
+
# Only generate audio if not already done
|
442 |
+
if not st.session_state.audio_generated_en:
|
443 |
+
with st.spinner("Generating English audio..."):
|
444 |
+
# Need to run async function with asyncio
|
445 |
+
audio_result = asyncio.run(text2audio_english(st.session_state.story))
|
446 |
+
st.session_state.audio_path_en = audio_result['path']
|
447 |
+
st.session_state.audio_generated_en = audio_result['success']
|
448 |
+
|
449 |
+
# Play the audio
|
450 |
+
if st.session_state.audio_path_en and os.path.exists(st.session_state.audio_path_en):
|
451 |
+
with open(st.session_state.audio_path_en, "rb") as audio_file:
|
452 |
+
audio_bytes = audio_file.read()
|
453 |
+
st.markdown("<div class='audio-container'><div class='audio-title'>English Story</div>", unsafe_allow_html=True)
|
454 |
+
st.audio(audio_bytes, format="audio/mp3")
|
455 |
+
st.markdown("</div>", unsafe_allow_html=True)
|
456 |
+
else:
|
457 |
+
st.error("Sorry! Please try again.")
|
458 |
+
|
459 |
+
# Cantonese audio button
|
460 |
+
with col2:
|
461 |
+
if st.button("🔊 播放廣東話故事"):
|
462 |
+
# Only generate audio if not already done
|
463 |
+
if not st.session_state.audio_generated_zh:
|
464 |
+
with st.spinner("正在準備廣東話語音..."):
|
465 |
+
# Need to run async function with asyncio
|
466 |
+
audio_result = asyncio.run(text2audio_cantonese(st.session_state.story_zh))
|
467 |
+
st.session_state.audio_path_zh = audio_result['path']
|
468 |
+
st.session_state.audio_generated_zh = audio_result['success']
|
469 |
|
470 |
+
# Play the audio
|
471 |
+
if st.session_state.audio_path_zh and os.path.exists(st.session_state.audio_path_zh):
|
472 |
+
with open(st.session_state.audio_path_zh, "rb") as audio_file:
|
473 |
+
audio_bytes = audio_file.read()
|
474 |
+
st.markdown("<div class='audio-container'><div class='audio-title'>廣東話故事</div>", unsafe_allow_html=True)
|
475 |
+
st.audio(audio_bytes, format="audio/mp3")
|
476 |
+
st.markdown("</div>", unsafe_allow_html=True)
|
477 |
+
else:
|
478 |
+
st.error("哎呀!再試多次啦!")
|
479 |
|
480 |
# Cleanup: Remove the temporary file when the user is done
|
481 |
if os.path.exists(temp_file_path):
|
|
|
483 |
else:
|
484 |
# Clear session state when no file is uploaded
|
485 |
# Also clean up any temporary audio files
|
486 |
+
if st.session_state.audio_path_zh and os.path.exists(st.session_state.audio_path_zh):
|
487 |
+
try:
|
488 |
+
os.remove(st.session_state.audio_path_zh)
|
489 |
+
except:
|
490 |
+
pass
|
491 |
+
|
492 |
+
if st.session_state.audio_path_en and os.path.exists(st.session_state.audio_path_en):
|
493 |
try:
|
494 |
+
os.remove(st.session_state.audio_path_en)
|
495 |
except:
|
496 |
pass
|
497 |
|
|
|
499 |
st.session_state.scenario_zh = None
|
500 |
st.session_state.story = None
|
501 |
st.session_state.story_zh = None
|
502 |
+
st.session_state.audio_generated_zh = False
|
503 |
+
st.session_state.audio_path_zh = None
|
504 |
+
st.session_state.audio_generated_en = False
|
505 |
+
st.session_state.audio_path_en = None
|
506 |
|
507 |
# Welcome message in Cantonese
|
508 |
st.markdown("""
|