slliac's picture
Update app.py
89d0588 verified
import streamlit as st
from transformers import pipeline
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline , AutoModelForCausalLM , AutoProcessor, AutoModel , AutoModelForSequenceClassification , TextGenerationPipeline, GPT2LMHeadModel
import edge_tts
import asyncio
import os
import io
import tempfile
# Initialize session state for storing data
if 'scenario' not in st.session_state:
st.session_state.scenario = None
if 'scenario_zh' not in st.session_state:
st.session_state.scenario_zh = None
if 'story' not in st.session_state:
st.session_state.story = None
if 'story_zh' not in st.session_state:
st.session_state.story_zh = None
if 'audio_generated_zh' not in st.session_state:
st.session_state.audio_generated_zh = False
if 'audio_path_zh' not in st.session_state:
st.session_state.audio_path_zh = None
if 'audio_generated_en' not in st.session_state:
st.session_state.audio_generated_en = False
if 'audio_path_en' not in st.session_state:
st.session_state.audio_path_en = None
# function part
# img2text
def img2text(url):
image_to_text_model = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
text = image_to_text_model(url)[0]["generated_text"]
return text
# Translation function EN to ZH
def translate_to_chinese(text):
translator = pipeline("translation", model="steve-tong/opus-mt-en-zh-hk")
translation = translator(text)[0]["translation_text"]
return translation
# Generate story acccording to text
def text2story(text):
try:
#text2story - using aspis/gpt2-genre-story-generation for generated better stories
model_name = "aspis/gpt2-genre-story-generation"
model = GPT2LMHeadModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
generator = TextGenerationPipeline(model=model, tokenizer=tokenizer)
# Input should be of format "<BOS> <Genre token> Optional starter text"
input_prompt = f"<BOS> <adventure> {text}"
story = generator(input_prompt, max_length=100, do_sample=True,
repetition_penalty=1.5, temperature=1.2,
top_p=0.95, top_k=50)
return story[0]['generated_text'].strip('<BOS> <adventure>')
except Exception as e:
# just Fallback only used openai-community/gpt2 if the advanced one fails
fallback_generator = AutoModelForCausalLM.from_pretrained("openai-community/gpt2")
fallback_prompt = f"{text}"
tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2")
inputs = tokenizer(fallback_prompt, return_tensors="pt")
fallback_story = fallback_generator.generate(
inputs.input_ids,
min_length=50,
max_new_tokens=100,
temperature=0.7,
top_p=0.9,
top_k=40,
repetition_penalty=1.2,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
fallback_story = tokenizer.decode(fallback_story[0], skip_special_tokens=True)
return fallback_story
def load_css(css_file):
with open(css_file) as f:
st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True)
# Text to audio using edge_tts for Cantonese audio
async def text2audio_cantonese(text):
try:
# Use Cantonese voice from edge-tts
voice = "zh-HK-HiuMaanNeural" # Female Cantonese voice
# Alternative: "zh-HK-WanLungNeural" for male voice
# Create a temporary file
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
temp_file.close()
# Configure edge-tts to save to the file path
communicate = edge_tts.Communicate(text, voice)
await communicate.save(temp_file.name)
# Return the path to the audio file
return {
'path': temp_file.name,
'success': True
}
except Exception as e:
st.error(f"中文音頻製作出左問題: {str(e)}")
return {
'path': None,
'success': False
}
# Text to audio using edge_tts for English audio
async def text2audio_english(text):
try:
# Use English voice from edge-tts
voice = "en-US-AriaNeural" # Female English voice
# Alternative: "en-US-GuyNeural" for male voice
# Create a temporary file
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp3")
temp_file.close()
# Configure edge-tts to save to the file path
communicate = edge_tts.Communicate(text, voice)
await communicate.save(temp_file.name)
# Return the path to the audio file
return {
'path': temp_file.name,
'success': True
}
except Exception as e:
st.error(f"English audio generation error: {str(e)}")
return {
'path': None,
'success': False
}
# Apply custom CSS for modern, stylish kid-friendly UI
st.set_page_config(page_title="歡迎嚟到 ISOM 5240 - 故事魔法師!", page_icon="✨", layout="wide")
load_css('styles.css')
# App header with Cantonese
st.title("")
st.markdown("<div class='welcome-message'> <div class='banner-container'><div class='magician-banner'><div class='magic-hat'>🎩</div><div class='magic-elements wand-left'>🪄</div><div class='magic-elements wand-right'>🪄</div><span class='sparkle spark1'>✨</span><span class='sparkle spark2'>✨</span><span class='sparkle spark3'>✨</span><span class='sparkle spark4'>✨</span><span class='sparkle spark5'>✨</span><h1 class='title' style='font-color: white !important;'>歡迎嚟到 ISOM 5240 - 故事魔法師!</h1></div></div><p style='font-size: 1.2rem; color: #6B7897; max-width: 500px; margin: 0 auto 30px;'>上載一張你鍾意嘅相片,我哋嘅魔法師會幫你變出一個好好玩嘅故事!</p><div><span class='emoji'>🚀</span><span class='emoji'>🦄</span> <span class='emoji'>🔮</span><span class='emoji'>🌈</span></div></div>", unsafe_allow_html=True)
# Add a progress indicator for model loading
progress_placeholder = st.empty()
# File uploader with Cantonese
with st.container():
uploaded_file = st.file_uploader(
"",
type=["jpg", "jpeg", "png"], # Limit file types
key="upload"
)
if uploaded_file is not None:
st.success("上載成功!")
else:
st.info("請選擇一張畫作上載,格式必須係 JPG、JPEG 或 PNG! 最大 200 MB !")
if uploaded_file is not None:
# Save uploaded file
bytes_data = uploaded_file.getvalue()
temp_file_path = uploaded_file.name
with open(temp_file_path, "wb") as file:
file.write(bytes_data)
# Display image
st.image(uploaded_file, use_container_width=True)
# Reset session state if a new file is uploaded (detect by checking if there's no scenario yet)
if st.session_state.scenario is None:
# Stage 1: Image to Text
with st.container():
st.markdown("<h3><span class='stage-icon'>🔍</span> 以下係我哋嘅魔術師覺得你嘅畫作係關於 : </h3>", unsafe_allow_html=True)
with progress_placeholder.container():
st.write("正在分析圖片...")
progress_bar = st.progress(0)
# Generate caption if not already done
st.session_state.scenario = img2text(temp_file_path)
progress_bar.progress(33)
# Display English caption
st.text("英文描述: " + st.session_state.scenario)
# Translate the caption to Chinese
with progress_placeholder.container():
st.write("正在翻譯...")
st.session_state.scenario_zh = translate_to_chinese(st.session_state.scenario)
progress_bar.progress(66)
# Display Chinese caption
st.text("中文描述: " + st.session_state.scenario_zh)
# Stage 2: Text to Story
with st.container():
st.markdown("<h3><span class='stage-icon'>📝</span> 以下係我哋嘅魔術師正根據你嘅畫作嘅故事: </h3>", unsafe_allow_html=True)
with progress_placeholder.container():
st.write("正在創作故事...")
# Generate story if not already done
st.session_state.story = text2story(st.session_state.scenario)
progress_bar.progress(85)
# Display English story
st.text("英文故事: " + st.session_state.story)
# Translate the story to Chinese
with progress_placeholder.container():
st.write("正在翻譯故事...")
st.session_state.story_zh = translate_to_chinese(st.session_state.story)
progress_bar.progress(100)
# Display Chinese story
st.text("中文故事: " + st.session_state.story_zh)
# Clear progress indicator
progress_placeholder.empty()
else:
# Display saved results from session state
with st.container():
st.markdown("<h3><span class='stage-icon'>🔍</span> 以下係我哋嘅魔術師覺得你嘅畫作係關於 : </h3>", unsafe_allow_html=True)
st.text("英文描述: " + st.session_state.scenario)
st.text("中文描述: " + st.session_state.scenario_zh)
with st.container():
st.markdown("<h3><span class='stage-icon'>📝</span> 以下係我哋嘅魔術師正根據你嘅畫作嘅故事: </h3>", unsafe_allow_html=True)
st.text("英文故事: " + st.session_state.story)
st.text("中文故事: " + st.session_state.story_zh)
# Stage 3: Story to Audio data
with st.container():
st.markdown("<h3><span class='stage-icon'>🔊</span> 你嘅故事準備好俾我哋嘅魔術師讀出嚟喇!</h3>", unsafe_allow_html=True)
# Create two columns for English and Cantonese buttons
col1, col2 = st.columns(2)
# English audio button
with col1:
if st.button("🔊 Play Story in English that read aloud by our magician !"):
# Only generate audio if not already done
if not st.session_state.audio_generated_en:
# Need to run async function with asyncio
audio_result = asyncio.run(text2audio_english(st.session_state.story))
st.session_state.audio_path_en = audio_result['path']
st.session_state.audio_generated_en = audio_result['success']
# Play the audio
if st.session_state.audio_path_en and os.path.exists(st.session_state.audio_path_en):
with open(st.session_state.audio_path_en, "rb") as audio_file:
audio_bytes = audio_file.read()
st.audio(audio_bytes, format="audio/mp3")
else:
st.error("Sorry! Please try again.")
# Cantonese audio button
with col2:
if st.button("🔊 你嘅故事已經準備好俾我哋嘅魔術師用廣東話讀出嚟喇!"):
# Only generate audio if not already done
if not st.session_state.audio_generated_zh:
# Need to run async function with asyncio
audio_result = asyncio.run(text2audio_cantonese(st.session_state.story_zh))
st.session_state.audio_path_zh = audio_result['path']
st.session_state.audio_generated_zh = audio_result['success']
# Play the audio
if st.session_state.audio_path_zh and os.path.exists(st.session_state.audio_path_zh):
with open(st.session_state.audio_path_zh, "rb") as audio_file:
audio_bytes = audio_file.read()
st.audio(audio_bytes, format="audio/mp3")
else:
st.error("哎呀!再試多次啦!")
# Cleanup: Remove the temporary file when the user is done
if os.path.exists(temp_file_path):
os.remove(temp_file_path)
else:
# Clear session state when no file is uploaded
# Also clean up any temporary audio files
if st.session_state.audio_path_zh and os.path.exists(st.session_state.audio_path_zh):
try:
os.remove(st.session_state.audio_path_zh)
except:
pass
if st.session_state.audio_path_en and os.path.exists(st.session_state.audio_path_en):
try:
os.remove(st.session_state.audio_path_en)
except:
pass
st.session_state.scenario = None
st.session_state.scenario_zh = None
st.session_state.story = None
st.session_state.story_zh = None
st.session_state.audio_generated_zh = False
st.session_state.audio_path_zh = None
st.session_state.audio_generated_en = False
st.session_state.audio_path_en = None