slliac's picture
Update app.py
ea507cd verified
raw
history blame
17.2 kB
import streamlit as st
from transformers import pipeline
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from transformers import AutoProcessor, AutoModel
import edge_tts
import asyncio
import os
import io
import tempfile
# Initialize session state for storing data
if 'scenario' not in st.session_state:
st.session_state.scenario = None
if 'scenario_zh' not in st.session_state:
st.session_state.scenario_zh = None
if 'story' not in st.session_state:
st.session_state.story = None
if 'story_zh' not in st.session_state:
st.session_state.story_zh = None
if 'audio_generated_zh' not in st.session_state:
st.session_state.audio_generated_zh = False
if 'audio_path_zh' not in st.session_state:
st.session_state.audio_path_zh = None
if 'audio_generated_en' not in st.session_state:
st.session_state.audio_generated_en = False
if 'audio_path_en' not in st.session_state:
st.session_state.audio_path_en = None
# function part
# img2text
def img2text(url):
image_to_text_model = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
text = image_to_text_model(url)[0]["generated_text"]
return text
# Translation function EN to ZH
def translate_to_chinese(text):
translator = pipeline("translation", model="steve-tong/opus-mt-en-zh-hk")
translation = translator(text)[0]["translation_text"]
return translation
# text2story - using mosaicml/mpt-7b-storywriter model for better stories
def text2story(text):
try:
# Initialize the improved story generation pipeline
generator = pipeline("text-generation", model="mosaicml/mpt-7b-storywriter", trust_remote_code=True)
# Create a prompt for the story
prompt = f"Write a short children's story about this scene: {text}\n\nStory: "
# Generate the story - limit to a smaller max_length due to model size
story = generator(prompt,
max_length=150,
num_return_sequences=1,
temperature=0.7,
repetition_penalty=1.2)[0]['generated_text']
# Clean up the story by removing the prompt
story = story.replace(prompt, "").strip()
# Trim to a reasonable length if needed
if len(story) > 500:
sentences = story.split('.')
trimmed_story = '.'.join(sentences[:5]) + '.'
return trimmed_story
return story
except Exception as e:
st.error(f"故事生成出問題: {str(e)}")
# Fallback to simpler model if the advanced one fails
fallback_generator = pipeline('text-generation', model='gpt2')
fallback_prompt = f"Create a short story about this scene: {text}\n\nStory:"
fallback_story = fallback_generator(fallback_prompt, max_length=100, num_return_sequences=1)[0]['generated_text']
return fallback_story.replace(fallback_prompt, "").strip()
# Text to audio using edge_tts for Cantonese audio
async def text2audio_cantonese(text):
try:
# Use Cantonese voice from edge-tts
voice = "zh-HK-HiuMaanNeural" # Female Cantonese voice
# Alternative: "zh-HK-WanLungNeural" for male voice
# Create a temporary file
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
temp_file.close()
# Configure edge-tts to save to the file path
communicate = edge_tts.Communicate(text, voice)
await communicate.save(temp_file.name)
# Return the path to the audio file
return {
'path': temp_file.name,
'success': True
}
except Exception as e:
st.error(f"中文音頻製作出左問題: {str(e)}")
return {
'path': None,
'success': False
}
# Text to audio using edge_tts for English audio
async def text2audio_english(text):
try:
# Use English voice from edge-tts
voice = "en-US-AriaNeural" # Female English voice
# Alternative: "en-US-GuyNeural" for male voice
# Create a temporary file
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
temp_file.close()
# Configure edge-tts to save to the file path
communicate = edge_tts.Communicate(text, voice)
await communicate.save(temp_file.name)
# Return the path to the audio file
return {
'path': temp_file.name,
'success': True
}
except Exception as e:
st.error(f"English audio generation error: {str(e)}")
return {
'path': None,
'success': False
}
# Apply custom CSS for modern, stylish kid-friendly UI
st.set_page_config(page_title="故事魔法", page_icon="✨", layout="wide")
st.markdown("""
<style>
/* Modern, stylish kid-friendly design */
@import url('https://fonts.googleapis.com/css2?family=Quicksand:wght@400;600;700&display=swap');
@import url('https://fonts.googleapis.com/css2?family=Noto+Sans+HK:wght@400;500;700&display=swap');
:root {
--primary-color: #6C63FF;
--secondary-color: #41B883;
--accent-color: #FF6B6B;
--background-light: #F7F9FC;
--text-dark: #2E3A59;
--shadow: 0 10px 20px rgba(0,0,0,0.08);
--border-radius: 16px;
}
.stApp {
background: linear-gradient(135deg, #F4F9FF, #EEFAFF);
font-family: 'Noto Sans HK', sans-serif;
color: var(--text-dark);
}
.main .block-container {
max-width: 1000px;
padding-top: 2rem;
padding-bottom: 2rem;
}
/* Modern headers */
h1, h2, h3 {
font-family: 'Noto Sans HK', sans-serif;
font-weight: 700;
color: var(--primary-color);
}
h1 {
font-size: 2.5rem;
text-align: center;
margin-bottom: 0;
}
h2 {
font-size: 1.8rem;
margin-bottom: 1rem;
}
h3 {
font-size: 1.4rem;
margin-bottom: 0.8rem;
}
/* Subtitle */
.subtitle {
text-align: center;
color: #6B7897;
font-size: 1.2rem;
margin-bottom: 2rem;
}
/* Card containers */
.stCard {
background: white;
border-radius: var(--border-radius);
padding: 1.5rem;
box-shadow: var(--shadow);
margin-bottom: 1.5rem;
}
/* Accent borders for stages */
.css-nahz7x, .css-ocqkz7, .css-4z1n4l {
border-left: it 5px solid var(--primary-color) !important;
}
.css-1r6slb0, .css-1ubpcwi {
border-left: 5px solid var(--secondary-color) !important;
}
.css-pkbazv, .css-5rimss {
border-left: 5px solid var(--accent-color) !important;
}
/* Custom file uploader */
.stFileUploader > div > div {
background: var(--background-light);
border: 2px dashed #D0D8E6;
border-radius: 12px;
padding: 20px;
transition: all 0.3s ease;
}
.stFileUploader > div > div:hover {
border-color: var(--primary-color);
}
/* Uploaded image styling */
.stImage img {
border-radius: 12px;
box-shadow: var(--shadow);
}
/* Stage icons */
.stage-icon {
font-size: 1.6rem;
margin-right: 10px;
vertical-align: middle;
}
/* Response styling */
.stText {
font-size: 1.1rem;
line-height: 1.7;
background: var(--background-light);
padding: 1rem;
border-radius: 12px;
border-left: 4px solid var(--secondary-color);
margin: 1rem 0;
box-shadow: 0 5px 15px rgba(0,0,0,0.05);
}
/* Button styling */
.stButton > button {
background: var(--secondary-color) !important;
color: white !important;
border: none !important;
border-radius: 50px !important;
padding: 0.6rem 1.5rem !important;
font-size: 1.1rem !important;
font-weight: 600 !important;
font-family: 'Noto Sans HK', sans-serif !important;
transition: all 0.3s ease !important;
box-shadow: 0 5px 15px rgba(65, 184, 131, 0.3) !important;
}
.stButton > button:hover {
background: #37A574 !important;
transform: translateY(-3px) !important;
box-shadow: 0 8px 20px rgba(65, 184, 131, 0.4) !important;
}
.stButton > button:active {
transform: translateY(0) !important;
}
/* Audio player styling */
audio {
width: 100%;
border-radius: 50px;
height: 40px;
}
/* Emoji animation */
@keyframes bounce {
0%, 100% { transform: translateY(0); }
50% { transform: translateY(-15px); }
}
.emoji {
font-size: 1.8rem;
display: inline-block;
animation: bounce 2s infinite;
margin: 0 8px;
}
.emoji:nth-child(2) {
animation-delay: 0.2s;
}
.emoji:nth-child(3) {
animation-delay: 0.4s;
}
.emoji:nth-child(4) {
animation-delay: 0.6s;
}
/* Welcome message */
.welcome-message {
text-align: center;
padding: 3rem 1.5rem;
}
.welcome-icon {
font-size: 4rem;
margin-bottom: 1rem;
}
/* Audio player container */
.audio-container {
background: white;
padding: 1rem;
border-radius: 12px;
margin-bottom: 1rem;
box-shadow: var(--shadow);
}
.audio-title {
font-weight: 600;
margin-bottom: 0.5rem;
color: var(--primary-color);
}
</style>
""", unsafe_allow_html=True)
# App header with Cantonese
st.title("✨ 故事魔法")
st.markdown("<p class='subtitle'>上載一張圖片,睇下佢點變成一個神奇嘅故事!</p>",
unsafe_allow_html=True)
# Add a progress indicator for model loading
progress_placeholder = st.empty()
# File uploader with Cantonese
with st.container():
st.subheader("揀一張靚相啦!")
uploaded_file = st.file_uploader("", key="upload")
if uploaded_file is not None:
# Save uploaded file
bytes_data = uploaded_file.getvalue()
temp_file_path = uploaded_file.name
with open(temp_file_path, "wb") as file:
file.write(bytes_data)
# Display image
st.image(uploaded_file, use_column_width=True)
# Reset session state if a new file is uploaded (detect by checking if there's no scenario yet)
if st.session_state.scenario is None:
# Stage 1: Image to Text
with st.container():
st.markdown("<h3><span class='stage-icon'>🔍</span> 圖片解讀中</h3>", unsafe_allow_html=True)
with progress_placeholder.container():
st.write("正在分析圖片...")
progress_bar = st.progress(0)
# Generate caption if not already done
st.session_state.scenario = img2text(temp_file_path)
progress_bar.progress(33)
# Display English caption
st.text("英文描述: " + st.session_state.scenario)
# Translate the caption to Chinese
with progress_placeholder.container():
st.write("正在翻譯...")
st.session_state.scenario_zh = translate_to_chinese(st.session_state.scenario)
progress_bar.progress(66)
# Display Chinese caption
st.text("中文描述: " + st.session_state.scenario_zh)
# Stage 2: Text to Story
with st.container():
st.markdown("<h3><span class='stage-icon'>📝</span> 故事創作中</h3>", unsafe_allow_html=True)
with progress_placeholder.container():
st.write("正在創作故事...")
# Generate story if not already done
st.session_state.story = text2story(st.session_state.scenario)
progress_bar.progress(85)
# Display English story
st.text("英文故事: " + st.session_state.story)
# Translate the story to Chinese
with progress_placeholder.container():
st.write("正在翻譯故事...")
st.session_state.story_zh = translate_to_chinese(st.session_state.story)
progress_bar.progress(100)
# Display Chinese story
st.text("中文故事: " + st.session_state.story_zh)
# Clear progress indicator
progress_placeholder.empty()
else:
# Display saved results from session state
with st.container():
st.markdown("<h3><span class='stage-icon'>🔍</span> 圖片解讀中</h3>", unsafe_allow_html=True)
st.text("英文描述: " + st.session_state.scenario)
st.text("中文描述: " + st.session_state.scenario_zh)
with st.container():
st.markdown("<h3><span class='stage-icon'>📝</span> 故事創作中</h3>", unsafe_allow_html=True)
st.text("英文故事: " + st.session_state.story)
st.text("中文故事: " + st.session_state.story_zh)
# Stage 3: Story to Audio data
with st.container():
st.markdown("<h3><span class='stage-icon'>🔊</span> 故事準備朗讀中</h3>", unsafe_allow_html=True)
# Create two columns for English and Cantonese buttons
col1, col2 = st.columns(2)
# English audio button
with col1:
if st.button("🔊 Play Story in English"):
# Only generate audio if not already done
if not st.session_state.audio_generated_en:
with st.spinner("Generating English audio..."):
# Need to run async function with asyncio
audio_result = asyncio.run(text2audio_english(st.session_state.story))
st.session_state.audio_path_en = audio_result['path']
st.session_state.audio_generated_en = audio_result['success']
# Play the audio
if st.session_state.audio_path_en and os.path.exists(st.session_state.audio_path_en):
with open(st.session_state.audio_path_en, "rb") as audio_file:
audio_bytes = audio_file.read()
st.markdown("<div class='audio-container'><div class='audio-title'>English Story</div>", unsafe_allow_html=True)
st.audio(audio_bytes, format="audio/mp3")
st.markdown("</div>", unsafe_allow_html=True)
else:
st.error("Sorry! Please try again.")
# Cantonese audio button
with col2:
if st.button("🔊 播放廣東話故事"):
# Only generate audio if not already done
if not st.session_state.audio_generated_zh:
with st.spinner("正在準備廣東話語音..."):
# Need to run async function with asyncio
audio_result = asyncio.run(text2audio_cantonese(st.session_state.story_zh))
st.session_state.audio_path_zh = audio_result['path']
st.session_state.audio_generated_zh = audio_result['success']
# Play the audio
if st.session_state.audio_path_zh and os.path.exists(st.session_state.audio_path_zh):
with open(st.session_state.audio_path_zh, "rb") as audio_file:
audio_bytes = audio_file.read()
st.markdown("<div class='audio-container'><div class='audio-title'>廣東話故事</div>", unsafe_allow_html=True)
st.audio(audio_bytes, format="audio/mp3")
st.markdown("</div>", unsafe_allow_html=True)
else:
st.error("哎呀!再試多次啦!")
# Cleanup: Remove the temporary file when the user is done
if os.path.exists(temp_file_path):
os.remove(temp_file_path)
else:
# Clear session state when no file is uploaded
# Also clean up any temporary audio files
if st.session_state.audio_path_zh and os.path.exists(st.session_state.audio_path_zh):
try:
os.remove(st.session_state.audio_path_zh)
except:
pass
if st.session_state.audio_path_en and os.path.exists(st.session_state.audio_path_en):
try:
os.remove(st.session_state.audio_path_en)
except:
pass
st.session_state.scenario = None
st.session_state.scenario_zh = None
st.session_state.story = None
st.session_state.story_zh = None
st.session_state.audio_generated_zh = False
st.session_state.audio_path_zh = None
st.session_state.audio_generated_en = False
st.session_state.audio_path_en = None
# Welcome message in Cantonese
st.markdown("""
<div class="welcome-message">
<div class="welcome-icon">✨</div>
<h2>歡迎嚟到故事魔法!</h2>
<p style="font-size: 1.2rem; color: #6B7897; max-width: 500px; margin: 0 auto 30px;">
上載一張你鍾意嘅相片,我哋嘅魔法師會幫你變出一個好好玩嘅故事!
</p>
<div>
<span class="emoji">🚀</span>
<span class="emoji">🦄</span>
<span class="emoji">🔮</span>
<span class="emoji">🌈</span>
</div>
</div>
""", unsafe_allow_html=True)