Spaces:
Sleeping
Sleeping
import streamlit as st | |
from transformers import pipeline | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline , AutoModelForCausalLM , AutoProcessor, AutoModel , AutoModelForSequenceClassification , TextGenerationPipeline, GPT2LMHeadModel | |
import edge_tts | |
import asyncio | |
import os | |
import io | |
import tempfile | |
# Initialize session state for storing data | |
if 'scenario' not in st.session_state: | |
st.session_state.scenario = None | |
if 'scenario_zh' not in st.session_state: | |
st.session_state.scenario_zh = None | |
if 'story' not in st.session_state: | |
st.session_state.story = None | |
if 'story_zh' not in st.session_state: | |
st.session_state.story_zh = None | |
if 'audio_generated_zh' not in st.session_state: | |
st.session_state.audio_generated_zh = False | |
if 'audio_path_zh' not in st.session_state: | |
st.session_state.audio_path_zh = None | |
if 'audio_generated_en' not in st.session_state: | |
st.session_state.audio_generated_en = False | |
if 'audio_path_en' not in st.session_state: | |
st.session_state.audio_path_en = None | |
# function part | |
# img2text | |
def img2text(url): | |
image_to_text_model = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base") | |
text = image_to_text_model(url)[0]["generated_text"] | |
return text | |
# Translation function EN to ZH | |
def translate_to_chinese(text): | |
translator = pipeline("translation", model="steve-tong/opus-mt-en-zh-hk") | |
translation = translator(text)[0]["translation_text"] | |
return translation | |
# Generate story acccording to text | |
def text2story(text): | |
try: | |
#text2story - using aspis/gpt2-genre-story-generation for generated better stories | |
model_name = "aspis/gpt2-genre-story-generation" | |
model = GPT2LMHeadModel.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
generator = TextGenerationPipeline(model=model, tokenizer=tokenizer) | |
# Input should be of format "<BOS> <Genre token> Optional starter text" | |
input_prompt = f"<BOS> <adventure> {text}" | |
story = generator(input_prompt, max_length=100, do_sample=True, | |
repetition_penalty=1.5, temperature=1.2, | |
top_p=0.95, top_k=50) | |
return story[0]['generated_text'].strip('<BOS> <adventure>') | |
except Exception as e: | |
# just Fallback only used openai-community/gpt2 if the advanced one fails | |
fallback_generator = AutoModelForCausalLM.from_pretrained("openai-community/gpt2") | |
fallback_prompt = f"{text}" | |
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2") | |
inputs = tokenizer(fallback_prompt, return_tensors="pt") | |
fallback_story = fallback_generator.generate( | |
inputs.input_ids, | |
min_length=50, | |
max_new_tokens=100, | |
temperature=0.7, | |
top_p=0.9, | |
top_k=40, | |
repetition_penalty=1.2, | |
do_sample=True, | |
pad_token_id=tokenizer.eos_token_id | |
) | |
fallback_story = tokenizer.decode(fallback_story[0], skip_special_tokens=True) | |
return fallback_story | |
def load_css(css_file): | |
with open(css_file) as f: | |
st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True) | |
# Text to audio using edge_tts for Cantonese audio | |
async def text2audio_cantonese(text): | |
try: | |
# Use Cantonese voice from edge-tts | |
voice = "zh-HK-HiuMaanNeural" # Female Cantonese voice | |
# Alternative: "zh-HK-WanLungNeural" for male voice | |
# Create a temporary file | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") | |
temp_file.close() | |
# Configure edge-tts to save to the file path | |
communicate = edge_tts.Communicate(text, voice) | |
await communicate.save(temp_file.name) | |
# Return the path to the audio file | |
return { | |
'path': temp_file.name, | |
'success': True | |
} | |
except Exception as e: | |
st.error(f"中文音頻製作出左問題: {str(e)}") | |
return { | |
'path': None, | |
'success': False | |
} | |
# Text to audio using edge_tts for English audio | |
async def text2audio_english(text): | |
try: | |
# Use English voice from edge-tts | |
voice = "en-US-AriaNeural" # Female English voice | |
# Alternative: "en-US-GuyNeural" for male voice | |
# Create a temporary file | |
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") | |
temp_file.close() | |
# Configure edge-tts to save to the file path | |
communicate = edge_tts.Communicate(text, voice) | |
await communicate.save(temp_file.name) | |
# Return the path to the audio file | |
return { | |
'path': temp_file.name, | |
'success': True | |
} | |
except Exception as e: | |
st.error(f"English audio generation error: {str(e)}") | |
return { | |
'path': None, | |
'success': False | |
} | |
# Apply custom CSS for modern, stylish kid-friendly UI | |
st.set_page_config(page_title="歡迎嚟到 ISOM 5240 - 故事魔法師!", page_icon="✨", layout="wide") | |
load_css('styles.css') | |
# App header with Cantonese | |
st.title("") | |
st.markdown("<div class='welcome-message'> <div class='banner-container'><div class='magician-banner'><div class='magic-hat'>🎩</div><div class='magic-elements wand-left'>🪄</div><div class='magic-elements wand-right'>🪄</div><span class='sparkle spark1'>✨</span><span class='sparkle spark2'>✨</span><span class='sparkle spark3'>✨</span><span class='sparkle spark4'>✨</span><span class='sparkle spark5'>✨</span><h1 class='title' style='font-color: white !important;'>歡迎嚟到 ISOM 5240 - 故事魔法師!</h1></div></div><p style='font-size: 1.2rem; color: #6B7897; max-width: 500px; margin: 0 auto 30px;'>上載一張你鍾意嘅相片,我哋嘅魔法師會幫你變出一個好好玩嘅故事!</p><div><span class='emoji'>🚀</span><span class='emoji'>🦄</span> <span class='emoji'>🔮</span><span class='emoji'>🌈</span></div></div>", unsafe_allow_html=True) | |
# Add a progress indicator for model loading | |
progress_placeholder = st.empty() | |
# File uploader with Cantonese | |
with st.container(): | |
uploaded_file = st.file_uploader( | |
"", | |
type=["jpg", "jpeg", "png"], # Limit file types | |
key="upload" | |
) | |
if uploaded_file is not None: | |
st.success("上載成功!") | |
else: | |
st.info("請選擇一張畫作上載,格式必須係 JPG、JPEG 或 PNG! 最大 200 MB !") | |
if uploaded_file is not None: | |
# Save uploaded file | |
bytes_data = uploaded_file.getvalue() | |
temp_file_path = uploaded_file.name | |
with open(temp_file_path, "wb") as file: | |
file.write(bytes_data) | |
# Display image | |
st.image(uploaded_file, use_container_width=True) | |
# Reset session state if a new file is uploaded (detect by checking if there's no scenario yet) | |
if st.session_state.scenario is None: | |
# Stage 1: Image to Text | |
with st.container(): | |
st.markdown("<h3><span class='stage-icon'>🔍</span> 以下係我哋嘅魔術師覺得你嘅畫作係關於 : </h3>", unsafe_allow_html=True) | |
with progress_placeholder.container(): | |
st.write("正在分析圖片...") | |
progress_bar = st.progress(0) | |
# Generate caption if not already done | |
st.session_state.scenario = img2text(temp_file_path) | |
progress_bar.progress(33) | |
# Display English caption | |
st.text("英文描述: " + st.session_state.scenario) | |
# Translate the caption to Chinese | |
with progress_placeholder.container(): | |
st.write("正在翻譯...") | |
st.session_state.scenario_zh = translate_to_chinese(st.session_state.scenario) | |
progress_bar.progress(66) | |
# Display Chinese caption | |
st.text("中文描述: " + st.session_state.scenario_zh) | |
# Stage 2: Text to Story | |
with st.container(): | |
st.markdown("<h3><span class='stage-icon'>📝</span> 以下係我哋嘅魔術師正根據你嘅畫作嘅故事: </h3>", unsafe_allow_html=True) | |
with progress_placeholder.container(): | |
st.write("正在創作故事...") | |
# Generate story if not already done | |
st.session_state.story = text2story(st.session_state.scenario) | |
progress_bar.progress(85) | |
# Display English story | |
st.text("英文故事: " + st.session_state.story) | |
# Translate the story to Chinese | |
with progress_placeholder.container(): | |
st.write("正在翻譯故事...") | |
st.session_state.story_zh = translate_to_chinese(st.session_state.story) | |
progress_bar.progress(100) | |
# Display Chinese story | |
st.text("中文故事: " + st.session_state.story_zh) | |
# Clear progress indicator | |
progress_placeholder.empty() | |
else: | |
# Display saved results from session state | |
with st.container(): | |
st.markdown("<h3><span class='stage-icon'>🔍</span> 以下係我哋嘅魔術師覺得你嘅畫作係關於 : </h3>", unsafe_allow_html=True) | |
st.text("英文描述: " + st.session_state.scenario) | |
st.text("中文描述: " + st.session_state.scenario_zh) | |
with st.container(): | |
st.markdown("<h3><span class='stage-icon'>📝</span> 以下係我哋嘅魔術師正根據你嘅畫作嘅故事: </h3>", unsafe_allow_html=True) | |
st.text("英文故事: " + st.session_state.story) | |
st.text("中文故事: " + st.session_state.story_zh) | |
# Stage 3: Story to Audio data | |
with st.container(): | |
st.markdown("<h3><span class='stage-icon'>🔊</span> 你嘅故事準備好俾我哋嘅魔術師讀出嚟喇!</h3>", unsafe_allow_html=True) | |
# Create two columns for English and Cantonese buttons | |
col1, col2 = st.columns(2) | |
# English audio button | |
with col1: | |
if st.button("🔊 Play Story in English that read aloud by our magician !"): | |
# Only generate audio if not already done | |
if not st.session_state.audio_generated_en: | |
# Need to run async function with asyncio | |
audio_result = asyncio.run(text2audio_english(st.session_state.story)) | |
st.session_state.audio_path_en = audio_result['path'] | |
st.session_state.audio_generated_en = audio_result['success'] | |
# Play the audio | |
if st.session_state.audio_path_en and os.path.exists(st.session_state.audio_path_en): | |
with open(st.session_state.audio_path_en, "rb") as audio_file: | |
audio_bytes = audio_file.read() | |
st.audio(audio_bytes, format="audio/mp3") | |
else: | |
st.error("Sorry! Please try again.") | |
# Cantonese audio button | |
with col2: | |
if st.button("🔊 你嘅故事已經準備好俾我哋嘅魔術師用廣東話讀出嚟喇!"): | |
# Only generate audio if not already done | |
if not st.session_state.audio_generated_zh: | |
# Need to run async function with asyncio | |
audio_result = asyncio.run(text2audio_cantonese(st.session_state.story_zh)) | |
st.session_state.audio_path_zh = audio_result['path'] | |
st.session_state.audio_generated_zh = audio_result['success'] | |
# Play the audio | |
if st.session_state.audio_path_zh and os.path.exists(st.session_state.audio_path_zh): | |
with open(st.session_state.audio_path_zh, "rb") as audio_file: | |
audio_bytes = audio_file.read() | |
st.audio(audio_bytes, format="audio/mp3") | |
else: | |
st.error("哎呀!再試多次啦!") | |
# Cleanup: Remove the temporary file when the user is done | |
if os.path.exists(temp_file_path): | |
os.remove(temp_file_path) | |
else: | |
# Clear session state when no file is uploaded | |
# Also clean up any temporary audio files | |
if st.session_state.audio_path_zh and os.path.exists(st.session_state.audio_path_zh): | |
try: | |
os.remove(st.session_state.audio_path_zh) | |
except: | |
pass | |
if st.session_state.audio_path_en and os.path.exists(st.session_state.audio_path_en): | |
try: | |
os.remove(st.session_state.audio_path_en) | |
except: | |
pass | |
st.session_state.scenario = None | |
st.session_state.scenario_zh = None | |
st.session_state.story = None | |
st.session_state.story_zh = None | |
st.session_state.audio_generated_zh = False | |
st.session_state.audio_path_zh = None | |
st.session_state.audio_generated_en = False | |
st.session_state.audio_path_en = None | |