Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import pipeline | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
from transformers import AutoProcessor, AutoModel | |
import edge_tts | |
import asyncio | |
import os | |
import io | |
import tempfile | |
# Initialize session state for storing data | |
if 'scenario' not in st.session_state: | |
st.session_state.scenario = None | |
if 'scenario_zh' not in st.session_state: | |
st.session_state.scenario_zh = None | |
if 'story' not in st.session_state: | |
st.session_state.story = None | |
if 'story_zh' not in st.session_state: | |
st.session_state.story_zh = None | |
if 'audio_generated_zh' not in st.session_state: | |
st.session_state.audio_generated_zh = False | |
if 'audio_path_zh' not in st.session_state: | |
st.session_state.audio_path_zh = None | |
if 'audio_generated_en' not in st.session_state: | |
st.session_state.audio_generated_en = False | |
if 'audio_path_en' not in st.session_state: | |
st.session_state.audio_path_en = None | |
# function part | |
# img2text | |
def img2text(url): | |
image_to_text_model = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base") | |
text = image_to_text_model(url)[0]["generated_text"] | |
return text | |
# Translation function EN to ZH | |
def translate_to_chinese(text): | |
translator = pipeline("translation", model="steve-tong/opus-mt-en-zh-hk") | |
translation = translator(text)[0]["translation_text"] | |
return translation | |
# text2story - using mosaicml/mpt-7b-storywriter model for better stories | |
def text2story(text): | |
try: | |
# Initialize the improved story generation pipeline | |
generator = pipeline("text-generation", model="mosaicml/mpt-7b-storywriter", trust_remote_code=True) | |
# Create a prompt for the story | |
prompt = f"Write a short children's story about this scene: {text}\n\nStory: " | |
# Generate the story - limit to a smaller max_length due to model size | |
story = generator(prompt, | |
max_length=150, | |
num_return_sequences=1, | |
temperature=0.7, | |
repetition_penalty=1.2)[0]['generated_text'] | |
# Clean up the story by removing the prompt | |
story = story.replace(prompt, "").strip() | |
# Trim to a reasonable length if needed | |
if len(story) > 500: | |
sentences = story.split('.') | |
trimmed_story = '.'.join(sentences[:5]) + '.' | |
return trimmed_story | |
return story | |
except Exception as e: | |
st.error(f"故事生成出問題: {str(e)}") | |
# Fallback to simpler model if the advanced one fails | |
fallback_generator = pipeline('text-generation', model='gpt2') | |
fallback_prompt = f"Create a short story about this scene: {text}\n\nStory:" | |
fallback_story = fallback_generator(fallback_prompt, max_length=100, num_return_sequences=1)[0]['generated_text'] | |
return fallback_story.replace(fallback_prompt, "").strip() | |
# Text to audio using edge_tts for Cantonese audio | |
async def text2audio_cantonese(text): | |
try: | |
# Use Cantonese voice from edge-tts | |
voice = "zh-HK-HiuMaanNeural" # Female Cantonese voice | |
# Alternative: "zh-HK-WanLungNeural" for male voice | |
# Create a temporary file | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") | |
temp_file.close() | |
# Configure edge-tts to save to the file path | |
communicate = edge_tts.Communicate(text, voice) | |
await communicate.save(temp_file.name) | |
# Return the path to the audio file | |
return { | |
'path': temp_file.name, | |
'success': True | |
} | |
except Exception as e: | |
st.error(f"中文音頻製作出左問題: {str(e)}") | |
return { | |
'path': None, | |
'success': False | |
} | |
# Text to audio using edge_tts for English audio | |
async def text2audio_english(text): | |
try: | |
# Use English voice from edge-tts | |
voice = "en-US-AriaNeural" # Female English voice | |
# Alternative: "en-US-GuyNeural" for male voice | |
# Create a temporary file | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") | |
temp_file.close() | |
# Configure edge-tts to save to the file path | |
communicate = edge_tts.Communicate(text, voice) | |
await communicate.save(temp_file.name) | |
# Return the path to the audio file | |
return { | |
'path': temp_file.name, | |
'success': True | |
} | |
except Exception as e: | |
st.error(f"English audio generation error: {str(e)}") | |
return { | |
'path': None, | |
'success': False | |
} | |
# Apply custom CSS for modern, stylish kid-friendly UI | |
st.set_page_config(page_title="故事魔法", page_icon="✨", layout="wide") | |
st.markdown(""" | |
<style> | |
/* Modern, stylish kid-friendly design */ | |
@import url('https://fonts.googleapis.com/css2?family=Quicksand:wght@400;600;700&display=swap'); | |
@import url('https://fonts.googleapis.com/css2?family=Noto+Sans+HK:wght@400;500;700&display=swap'); | |
:root { | |
--primary-color: #6C63FF; | |
--secondary-color: #41B883; | |
--accent-color: #FF6B6B; | |
--background-light: #F7F9FC; | |
--text-dark: #2E3A59; | |
--shadow: 0 10px 20px rgba(0,0,0,0.08); | |
--border-radius: 16px; | |
} | |
.stApp { | |
background: linear-gradient(135deg, #F4F9FF, #EEFAFF); | |
font-family: 'Noto Sans HK', sans-serif; | |
color: var(--text-dark); | |
} | |
.main .block-container { | |
max-width: 1000px; | |
padding-top: 2rem; | |
padding-bottom: 2rem; | |
} | |
/* Modern headers */ | |
h1, h2, h3 { | |
font-family: 'Noto Sans HK', sans-serif; | |
font-weight: 700; | |
color: var(--primary-color); | |
} | |
h1 { | |
font-size: 2.5rem; | |
text-align: center; | |
margin-bottom: 0; | |
} | |
h2 { | |
font-size: 1.8rem; | |
margin-bottom: 1rem; | |
} | |
h3 { | |
font-size: 1.4rem; | |
margin-bottom: 0.8rem; | |
} | |
/* Subtitle */ | |
.subtitle { | |
text-align: center; | |
color: #6B7897; | |
font-size: 1.2rem; | |
margin-bottom: 2rem; | |
} | |
/* Card containers */ | |
.stCard { | |
background: white; | |
border-radius: var(--border-radius); | |
padding: 1.5rem; | |
box-shadow: var(--shadow); | |
margin-bottom: 1.5rem; | |
} | |
/* Accent borders for stages */ | |
.css-nahz7x, .css-ocqkz7, .css-4z1n4l { | |
border-left: it 5px solid var(--primary-color) !important; | |
} | |
.css-1r6slb0, .css-1ubpcwi { | |
border-left: 5px solid var(--secondary-color) !important; | |
} | |
.css-pkbazv, .css-5rimss { | |
border-left: 5px solid var(--accent-color) !important; | |
} | |
/* Custom file uploader */ | |
.stFileUploader > div > div { | |
background: var(--background-light); | |
border: 2px dashed #D0D8E6; | |
border-radius: 12px; | |
padding: 20px; | |
transition: all 0.3s ease; | |
} | |
.stFileUploader > div > div:hover { | |
border-color: var(--primary-color); | |
} | |
/* Uploaded image styling */ | |
.stImage img { | |
border-radius: 12px; | |
box-shadow: var(--shadow); | |
} | |
/* Stage icons */ | |
.stage-icon { | |
font-size: 1.6rem; | |
margin-right: 10px; | |
vertical-align: middle; | |
} | |
/* Response styling */ | |
.stText { | |
font-size: 1.1rem; | |
line-height: 1.7; | |
background: var(--background-light); | |
padding: 1rem; | |
border-radius: 12px; | |
border-left: 4px solid var(--secondary-color); | |
margin: 1rem 0; | |
box-shadow: 0 5px 15px rgba(0,0,0,0.05); | |
} | |
/* Button styling */ | |
.stButton > button { | |
background: var(--secondary-color) !important; | |
color: white !important; | |
border: none !important; | |
border-radius: 50px !important; | |
padding: 0.6rem 1.5rem !important; | |
font-size: 1.1rem !important; | |
font-weight: 600 !important; | |
font-family: 'Noto Sans HK', sans-serif !important; | |
transition: all 0.3s ease !important; | |
box-shadow: 0 5px 15px rgba(65, 184, 131, 0.3) !important; | |
} | |
.stButton > button:hover { | |
background: #37A574 !important; | |
transform: translateY(-3px) !important; | |
box-shadow: 0 8px 20px rgba(65, 184, 131, 0.4) !important; | |
} | |
.stButton > button:active { | |
transform: translateY(0) !important; | |
} | |
/* Audio player styling */ | |
audio { | |
width: 100%; | |
border-radius: 50px; | |
height: 40px; | |
} | |
/* Emoji animation */ | |
@keyframes bounce { | |
0%, 100% { transform: translateY(0); } | |
50% { transform: translateY(-15px); } | |
} | |
.emoji { | |
font-size: 1.8rem; | |
display: inline-block; | |
animation: bounce 2s infinite; | |
margin: 0 8px; | |
} | |
.emoji:nth-child(2) { | |
animation-delay: 0.2s; | |
} | |
.emoji:nth-child(3) { | |
animation-delay: 0.4s; | |
} | |
.emoji:nth-child(4) { | |
animation-delay: 0.6s; | |
} | |
/* Welcome message */ | |
.welcome-message { | |
text-align: center; | |
padding: 3rem 1.5rem; | |
} | |
.welcome-icon { | |
font-size: 4rem; | |
margin-bottom: 1rem; | |
} | |
/* Audio player container */ | |
.audio-container { | |
background: white; | |
padding: 1rem; | |
border-radius: 12px; | |
margin-bottom: 1rem; | |
box-shadow: var(--shadow); | |
} | |
.audio-title { | |
font-weight: 600; | |
margin-bottom: 0.5rem; | |
color: var(--primary-color); | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# App header with Cantonese | |
st.title("✨ 故事魔法") | |
st.markdown("<p class='subtitle'>上載一張圖片,睇下佢點變成一個神奇嘅故事!</p>", | |
unsafe_allow_html=True) | |
# Add a progress indicator for model loading | |
progress_placeholder = st.empty() | |
# File uploader with Cantonese | |
with st.container(): | |
st.subheader("揀一張靚相啦!") | |
uploaded_file = st.file_uploader("", key="upload") | |
if uploaded_file is not None: | |
# Save uploaded file | |
bytes_data = uploaded_file.getvalue() | |
temp_file_path = uploaded_file.name | |
with open(temp_file_path, "wb") as file: | |
file.write(bytes_data) | |
# Display image | |
st.image(uploaded_file, use_column_width=True) | |
# Reset session state if a new file is uploaded (detect by checking if there's no scenario yet) | |
if st.session_state.scenario is None: | |
# Stage 1: Image to Text | |
with st.container(): | |
st.markdown("<h3><span class='stage-icon'>🔍</span> 圖片解讀中</h3>", unsafe_allow_html=True) | |
with progress_placeholder.container(): | |
st.write("正在分析圖片...") | |
progress_bar = st.progress(0) | |
# Generate caption if not already done | |
st.session_state.scenario = img2text(temp_file_path) | |
progress_bar.progress(33) | |
# Display English caption | |
st.text("英文描述: " + st.session_state.scenario) | |
# Translate the caption to Chinese | |
with progress_placeholder.container(): | |
st.write("正在翻譯...") | |
st.session_state.scenario_zh = translate_to_chinese(st.session_state.scenario) | |
progress_bar.progress(66) | |
# Display Chinese caption | |
st.text("中文描述: " + st.session_state.scenario_zh) | |
# Stage 2: Text to Story | |
with st.container(): | |
st.markdown("<h3><span class='stage-icon'>📝</span> 故事創作中</h3>", unsafe_allow_html=True) | |
with progress_placeholder.container(): | |
st.write("正在創作故事...") | |
# Generate story if not already done | |
st.session_state.story = text2story(st.session_state.scenario) | |
progress_bar.progress(85) | |
# Display English story | |
st.text("英文故事: " + st.session_state.story) | |
# Translate the story to Chinese | |
with progress_placeholder.container(): | |
st.write("正在翻譯故事...") | |
st.session_state.story_zh = translate_to_chinese(st.session_state.story) | |
progress_bar.progress(100) | |
# Display Chinese story | |
st.text("中文故事: " + st.session_state.story_zh) | |
# Clear progress indicator | |
progress_placeholder.empty() | |
else: | |
# Display saved results from session state | |
with st.container(): | |
st.markdown("<h3><span class='stage-icon'>🔍</span> 圖片解讀中</h3>", unsafe_allow_html=True) | |
st.text("英文描述: " + st.session_state.scenario) | |
st.text("中文描述: " + st.session_state.scenario_zh) | |
with st.container(): | |
st.markdown("<h3><span class='stage-icon'>📝</span> 故事創作中</h3>", unsafe_allow_html=True) | |
st.text("英文故事: " + st.session_state.story) | |
st.text("中文故事: " + st.session_state.story_zh) | |
# Stage 3: Story to Audio data | |
with st.container(): | |
st.markdown("<h3><span class='stage-icon'>🔊</span> 故事準備朗讀中</h3>", unsafe_allow_html=True) | |
# Create two columns for English and Cantonese buttons | |
col1, col2 = st.columns(2) | |
# English audio button | |
with col1: | |
if st.button("🔊 Play Story in English"): | |
# Only generate audio if not already done | |
if not st.session_state.audio_generated_en: | |
with st.spinner("Generating English audio..."): | |
# Need to run async function with asyncio | |
audio_result = asyncio.run(text2audio_english(st.session_state.story)) | |
st.session_state.audio_path_en = audio_result['path'] | |
st.session_state.audio_generated_en = audio_result['success'] | |
# Play the audio | |
if st.session_state.audio_path_en and os.path.exists(st.session_state.audio_path_en): | |
with open(st.session_state.audio_path_en, "rb") as audio_file: | |
audio_bytes = audio_file.read() | |
st.markdown("<div class='audio-container'><div class='audio-title'>English Story</div>", unsafe_allow_html=True) | |
st.audio(audio_bytes, format="audio/mp3") | |
st.markdown("</div>", unsafe_allow_html=True) | |
else: | |
st.error("Sorry! Please try again.") | |
# Cantonese audio button | |
with col2: | |
if st.button("🔊 播放廣東話故事"): | |
# Only generate audio if not already done | |
if not st.session_state.audio_generated_zh: | |
with st.spinner("正在準備廣東話語音..."): | |
# Need to run async function with asyncio | |
audio_result = asyncio.run(text2audio_cantonese(st.session_state.story_zh)) | |
st.session_state.audio_path_zh = audio_result['path'] | |
st.session_state.audio_generated_zh = audio_result['success'] | |
# Play the audio | |
if st.session_state.audio_path_zh and os.path.exists(st.session_state.audio_path_zh): | |
with open(st.session_state.audio_path_zh, "rb") as audio_file: | |
audio_bytes = audio_file.read() | |
st.markdown("<div class='audio-container'><div class='audio-title'>廣東話故事</div>", unsafe_allow_html=True) | |
st.audio(audio_bytes, format="audio/mp3") | |
st.markdown("</div>", unsafe_allow_html=True) | |
else: | |
st.error("哎呀!再試多次啦!") | |
# Cleanup: Remove the temporary file when the user is done | |
if os.path.exists(temp_file_path): | |
os.remove(temp_file_path) | |
else: | |
# Clear session state when no file is uploaded | |
# Also clean up any temporary audio files | |
if st.session_state.audio_path_zh and os.path.exists(st.session_state.audio_path_zh): | |
try: | |
os.remove(st.session_state.audio_path_zh) | |
except: | |
pass | |
if st.session_state.audio_path_en and os.path.exists(st.session_state.audio_path_en): | |
try: | |
os.remove(st.session_state.audio_path_en) | |
except: | |
pass | |
st.session_state.scenario = None | |
st.session_state.scenario_zh = None | |
st.session_state.story = None | |
st.session_state.story_zh = None | |
st.session_state.audio_generated_zh = False | |
st.session_state.audio_path_zh = None | |
st.session_state.audio_generated_en = False | |
st.session_state.audio_path_en = None | |
# Welcome message in Cantonese | |
st.markdown(""" | |
<div class="welcome-message"> | |
<div class="welcome-icon">✨</div> | |
<h2>歡迎嚟到故事魔法!</h2> | |
<p style="font-size: 1.2rem; color: #6B7897; max-width: 500px; margin: 0 auto 30px;"> | |
上載一張你鍾意嘅相片,我哋嘅魔法師會幫你變出一個好好玩嘅故事! | |
</p> | |
<div> | |
<span class="emoji">🚀</span> | |
<span class="emoji">🦄</span> | |
<span class="emoji">🔮</span> | |
<span class="emoji">🌈</span> | |
</div> | |
</div> | |
""", unsafe_allow_html=True) |