slliac commited on
Commit
ea507cd
·
verified ·
1 Parent(s): 8da5beb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +162 -41
app.py CHANGED
@@ -19,10 +19,14 @@ if 'story' not in st.session_state:
19
  st.session_state.story = None
20
  if 'story_zh' not in st.session_state:
21
  st.session_state.story_zh = None
22
- if 'audio_generated' not in st.session_state:
23
- st.session_state.audio_generated = False
24
- if 'audio_path' not in st.session_state:
25
- st.session_state.audio_path = None
 
 
 
 
26
 
27
 
28
  # function part
@@ -40,26 +44,42 @@ def translate_to_chinese(text):
40
  return translation
41
 
42
 
43
- # text2story
44
  def text2story(text):
45
- # Initialize the text generation pipeline
46
- generator = pipeline('text-generation', model='gpt2')
47
-
48
- # Create a prompt for the story
49
- prompt = f"Create a short story about this scene: {text}\n\nStory:"
50
-
51
- # Generate the story
52
- story = generator(prompt,
53
- max_length=100,
54
- num_return_sequences=1,
55
- temperature=0.7)[0]['generated_text']
56
-
57
- # Clean up the story by removing the prompt
58
- story = story.replace(prompt, "").strip()
59
- return story
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
 
62
- # Text to audio using edge_tts for Cantonese support
63
  async def text2audio_cantonese(text):
64
  try:
65
  # Use Cantonese voice from edge-tts
@@ -80,7 +100,35 @@ async def text2audio_cantonese(text):
80
  'success': True
81
  }
82
  except Exception as e:
83
- st.error(f"音頻製作出左問題: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  return {
85
  'path': None,
86
  'success': False
@@ -275,6 +323,21 @@ st.markdown("""
275
  font-size: 4rem;
276
  margin-bottom: 1rem;
277
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
  </style>
279
  """, unsafe_allow_html=True)
280
 
@@ -283,6 +346,9 @@ st.title("✨ 故事魔法")
283
  st.markdown("<p class='subtitle'>上載一張圖片,睇下佢點變成一個神奇嘅故事!</p>",
284
  unsafe_allow_html=True)
285
 
 
 
 
286
  # File uploader with Cantonese
287
  with st.container():
288
  st.subheader("揀一張靚相啦!")
@@ -304,14 +370,22 @@ if uploaded_file is not None:
304
  with st.container():
305
  st.markdown("<h3><span class='stage-icon'>🔍</span> 圖片解讀中</h3>", unsafe_allow_html=True)
306
 
 
 
 
 
307
  # Generate caption if not already done
308
  st.session_state.scenario = img2text(temp_file_path)
 
309
 
310
  # Display English caption
311
  st.text("英文描述: " + st.session_state.scenario)
312
 
313
  # Translate the caption to Chinese
 
 
314
  st.session_state.scenario_zh = translate_to_chinese(st.session_state.scenario)
 
315
 
316
  # Display Chinese caption
317
  st.text("中文描述: " + st.session_state.scenario_zh)
@@ -320,17 +394,28 @@ if uploaded_file is not None:
320
  with st.container():
321
  st.markdown("<h3><span class='stage-icon'>📝</span> 故事創作中</h3>", unsafe_allow_html=True)
322
 
 
 
 
323
  # Generate story if not already done
324
  st.session_state.story = text2story(st.session_state.scenario)
 
325
 
326
  # Display English story
327
  st.text("英文故事: " + st.session_state.story)
328
 
329
  # Translate the story to Chinese
 
 
330
  st.session_state.story_zh = translate_to_chinese(st.session_state.story)
 
331
 
332
  # Display Chinese story
333
  st.text("中文故事: " + st.session_state.story_zh)
 
 
 
 
334
  else:
335
  # Display saved results from session state
336
  with st.container():
@@ -347,22 +432,50 @@ if uploaded_file is not None:
347
  with st.container():
348
  st.markdown("<h3><span class='stage-icon'>🔊</span> 故事準備朗讀中</h3>", unsafe_allow_html=True)
349
 
350
- # Play button with Cantonese text
351
- if st.button("🔊 播放故事"):
352
- # Only generate audio if not already done
353
- if not st.session_state.audio_generated:
354
- # Need to run async function with asyncio
355
- audio_result = asyncio.run(text2audio_cantonese(st.session_state.story_zh))
356
- st.session_state.audio_path = audio_result['path']
357
- st.session_state.audio_generated = audio_result['success']
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
 
359
- # Play the audio
360
- if st.session_state.audio_path and os.path.exists(st.session_state.audio_path):
361
- with open(st.session_state.audio_path, "rb") as audio_file:
362
- audio_bytes = audio_file.read()
363
- st.audio(audio_bytes, format="audio/mp3")
364
- else:
365
- st.error("哎呀!再試多次啦!")
 
 
366
 
367
  # Cleanup: Remove the temporary file when the user is done
368
  if os.path.exists(temp_file_path):
@@ -370,9 +483,15 @@ if uploaded_file is not None:
370
  else:
371
  # Clear session state when no file is uploaded
372
  # Also clean up any temporary audio files
373
- if st.session_state.audio_path and os.path.exists(st.session_state.audio_path):
 
 
 
 
 
 
374
  try:
375
- os.remove(st.session_state.audio_path)
376
  except:
377
  pass
378
 
@@ -380,8 +499,10 @@ else:
380
  st.session_state.scenario_zh = None
381
  st.session_state.story = None
382
  st.session_state.story_zh = None
383
- st.session_state.audio_generated = False
384
- st.session_state.audio_path = None
 
 
385
 
386
  # Welcome message in Cantonese
387
  st.markdown("""
 
19
  st.session_state.story = None
20
  if 'story_zh' not in st.session_state:
21
  st.session_state.story_zh = None
22
+ if 'audio_generated_zh' not in st.session_state:
23
+ st.session_state.audio_generated_zh = False
24
+ if 'audio_path_zh' not in st.session_state:
25
+ st.session_state.audio_path_zh = None
26
+ if 'audio_generated_en' not in st.session_state:
27
+ st.session_state.audio_generated_en = False
28
+ if 'audio_path_en' not in st.session_state:
29
+ st.session_state.audio_path_en = None
30
 
31
 
32
  # function part
 
44
  return translation
45
 
46
 
47
+ # text2story - using mosaicml/mpt-7b-storywriter model for better stories
48
  def text2story(text):
49
+ try:
50
+ # Initialize the improved story generation pipeline
51
+ generator = pipeline("text-generation", model="mosaicml/mpt-7b-storywriter", trust_remote_code=True)
52
+
53
+ # Create a prompt for the story
54
+ prompt = f"Write a short children's story about this scene: {text}\n\nStory: "
55
+
56
+ # Generate the story - limit to a smaller max_length due to model size
57
+ story = generator(prompt,
58
+ max_length=150,
59
+ num_return_sequences=1,
60
+ temperature=0.7,
61
+ repetition_penalty=1.2)[0]['generated_text']
62
+
63
+ # Clean up the story by removing the prompt
64
+ story = story.replace(prompt, "").strip()
65
+
66
+ # Trim to a reasonable length if needed
67
+ if len(story) > 500:
68
+ sentences = story.split('.')
69
+ trimmed_story = '.'.join(sentences[:5]) + '.'
70
+ return trimmed_story
71
+
72
+ return story
73
+ except Exception as e:
74
+ st.error(f"故事生成出問題: {str(e)}")
75
+ # Fallback to simpler model if the advanced one fails
76
+ fallback_generator = pipeline('text-generation', model='gpt2')
77
+ fallback_prompt = f"Create a short story about this scene: {text}\n\nStory:"
78
+ fallback_story = fallback_generator(fallback_prompt, max_length=100, num_return_sequences=1)[0]['generated_text']
79
+ return fallback_story.replace(fallback_prompt, "").strip()
80
 
81
 
82
+ # Text to audio using edge_tts for Cantonese audio
83
  async def text2audio_cantonese(text):
84
  try:
85
  # Use Cantonese voice from edge-tts
 
100
  'success': True
101
  }
102
  except Exception as e:
103
+ st.error(f"中文音頻製作出左問題: {str(e)}")
104
+ return {
105
+ 'path': None,
106
+ 'success': False
107
+ }
108
+
109
+
110
+ # Text to audio using edge_tts for English audio
111
+ async def text2audio_english(text):
112
+ try:
113
+ # Use English voice from edge-tts
114
+ voice = "en-US-AriaNeural" # Female English voice
115
+ # Alternative: "en-US-GuyNeural" for male voice
116
+
117
+ # Create a temporary file
118
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
119
+ temp_file.close()
120
+
121
+ # Configure edge-tts to save to the file path
122
+ communicate = edge_tts.Communicate(text, voice)
123
+ await communicate.save(temp_file.name)
124
+
125
+ # Return the path to the audio file
126
+ return {
127
+ 'path': temp_file.name,
128
+ 'success': True
129
+ }
130
+ except Exception as e:
131
+ st.error(f"English audio generation error: {str(e)}")
132
  return {
133
  'path': None,
134
  'success': False
 
323
  font-size: 4rem;
324
  margin-bottom: 1rem;
325
  }
326
+
327
+ /* Audio player container */
328
+ .audio-container {
329
+ background: white;
330
+ padding: 1rem;
331
+ border-radius: 12px;
332
+ margin-bottom: 1rem;
333
+ box-shadow: var(--shadow);
334
+ }
335
+
336
+ .audio-title {
337
+ font-weight: 600;
338
+ margin-bottom: 0.5rem;
339
+ color: var(--primary-color);
340
+ }
341
  </style>
342
  """, unsafe_allow_html=True)
343
 
 
346
  st.markdown("<p class='subtitle'>上載一張圖片,睇下佢點變成一個神奇嘅故事!</p>",
347
  unsafe_allow_html=True)
348
 
349
+ # Add a progress indicator for model loading
350
+ progress_placeholder = st.empty()
351
+
352
  # File uploader with Cantonese
353
  with st.container():
354
  st.subheader("揀一張靚相啦!")
 
370
  with st.container():
371
  st.markdown("<h3><span class='stage-icon'>🔍</span> 圖片解讀中</h3>", unsafe_allow_html=True)
372
 
373
+ with progress_placeholder.container():
374
+ st.write("正在分析圖片...")
375
+ progress_bar = st.progress(0)
376
+
377
  # Generate caption if not already done
378
  st.session_state.scenario = img2text(temp_file_path)
379
+ progress_bar.progress(33)
380
 
381
  # Display English caption
382
  st.text("英文描述: " + st.session_state.scenario)
383
 
384
  # Translate the caption to Chinese
385
+ with progress_placeholder.container():
386
+ st.write("正在翻譯...")
387
  st.session_state.scenario_zh = translate_to_chinese(st.session_state.scenario)
388
+ progress_bar.progress(66)
389
 
390
  # Display Chinese caption
391
  st.text("中文描述: " + st.session_state.scenario_zh)
 
394
  with st.container():
395
  st.markdown("<h3><span class='stage-icon'>📝</span> 故事創作中</h3>", unsafe_allow_html=True)
396
 
397
+ with progress_placeholder.container():
398
+ st.write("正在創作故事...")
399
+
400
  # Generate story if not already done
401
  st.session_state.story = text2story(st.session_state.scenario)
402
+ progress_bar.progress(85)
403
 
404
  # Display English story
405
  st.text("英文故事: " + st.session_state.story)
406
 
407
  # Translate the story to Chinese
408
+ with progress_placeholder.container():
409
+ st.write("正在翻譯故事...")
410
  st.session_state.story_zh = translate_to_chinese(st.session_state.story)
411
+ progress_bar.progress(100)
412
 
413
  # Display Chinese story
414
  st.text("中文故事: " + st.session_state.story_zh)
415
+
416
+ # Clear progress indicator
417
+ progress_placeholder.empty()
418
+
419
  else:
420
  # Display saved results from session state
421
  with st.container():
 
432
  with st.container():
433
  st.markdown("<h3><span class='stage-icon'>🔊</span> 故事準備朗讀中</h3>", unsafe_allow_html=True)
434
 
435
+ # Create two columns for English and Cantonese buttons
436
+ col1, col2 = st.columns(2)
437
+
438
+ # English audio button
439
+ with col1:
440
+ if st.button("🔊 Play Story in English"):
441
+ # Only generate audio if not already done
442
+ if not st.session_state.audio_generated_en:
443
+ with st.spinner("Generating English audio..."):
444
+ # Need to run async function with asyncio
445
+ audio_result = asyncio.run(text2audio_english(st.session_state.story))
446
+ st.session_state.audio_path_en = audio_result['path']
447
+ st.session_state.audio_generated_en = audio_result['success']
448
+
449
+ # Play the audio
450
+ if st.session_state.audio_path_en and os.path.exists(st.session_state.audio_path_en):
451
+ with open(st.session_state.audio_path_en, "rb") as audio_file:
452
+ audio_bytes = audio_file.read()
453
+ st.markdown("<div class='audio-container'><div class='audio-title'>English Story</div>", unsafe_allow_html=True)
454
+ st.audio(audio_bytes, format="audio/mp3")
455
+ st.markdown("</div>", unsafe_allow_html=True)
456
+ else:
457
+ st.error("Sorry! Please try again.")
458
+
459
+ # Cantonese audio button
460
+ with col2:
461
+ if st.button("🔊 播放廣東話故事"):
462
+ # Only generate audio if not already done
463
+ if not st.session_state.audio_generated_zh:
464
+ with st.spinner("正在準備廣東話語音..."):
465
+ # Need to run async function with asyncio
466
+ audio_result = asyncio.run(text2audio_cantonese(st.session_state.story_zh))
467
+ st.session_state.audio_path_zh = audio_result['path']
468
+ st.session_state.audio_generated_zh = audio_result['success']
469
 
470
+ # Play the audio
471
+ if st.session_state.audio_path_zh and os.path.exists(st.session_state.audio_path_zh):
472
+ with open(st.session_state.audio_path_zh, "rb") as audio_file:
473
+ audio_bytes = audio_file.read()
474
+ st.markdown("<div class='audio-container'><div class='audio-title'>廣東話故事</div>", unsafe_allow_html=True)
475
+ st.audio(audio_bytes, format="audio/mp3")
476
+ st.markdown("</div>", unsafe_allow_html=True)
477
+ else:
478
+ st.error("哎呀!再試多次啦!")
479
 
480
  # Cleanup: Remove the temporary file when the user is done
481
  if os.path.exists(temp_file_path):
 
483
  else:
484
  # Clear session state when no file is uploaded
485
  # Also clean up any temporary audio files
486
+ if st.session_state.audio_path_zh and os.path.exists(st.session_state.audio_path_zh):
487
+ try:
488
+ os.remove(st.session_state.audio_path_zh)
489
+ except:
490
+ pass
491
+
492
+ if st.session_state.audio_path_en and os.path.exists(st.session_state.audio_path_en):
493
  try:
494
+ os.remove(st.session_state.audio_path_en)
495
  except:
496
  pass
497
 
 
499
  st.session_state.scenario_zh = None
500
  st.session_state.story = None
501
  st.session_state.story_zh = None
502
+ st.session_state.audio_generated_zh = False
503
+ st.session_state.audio_path_zh = None
504
+ st.session_state.audio_generated_en = False
505
+ st.session_state.audio_path_en = None
506
 
507
  # Welcome message in Cantonese
508
  st.markdown("""