xujinheng666 commited on
Commit
b736fc0
·
verified ·
1 Parent(s): d9299d7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -33
app.py CHANGED
@@ -10,53 +10,64 @@ def img2text(url):
10
 
11
  # text2story
12
  def text2story(text):
13
- story_generator = pipeline("text-generation", model="aspis/gpt2-genre-story-generation") # Corrected pipeline initialization
14
- story_text = story_generator(text, max_length=150, num_return_sequences=1) # Pass parameters here
15
- return story_text[0]["generated_text"] # Extract generated text
16
 
17
  # text2audio
18
  def text2audio(story_text):
19
- tts_model = pipeline("text-to-speech", model="facebook/mms-tts-eng") # Initialize pipeline
20
- audio_data = tts_model(story_text) # Generate audio
21
- return audio_data # Return generated audio
22
-
23
 
24
- #main part
 
25
  def main():
26
- st.set_page_config(page_title="Your Image to Audio Story",
27
- page_icon="🦜")
28
  st.header("Turn Your Image to Audio Story")
 
 
 
 
 
 
 
 
29
  uploaded_file = st.file_uploader("Select an Image...")
30
-
31
- if uploaded_file is not None:
32
  print(uploaded_file)
33
  bytes_data = uploaded_file.getvalue()
34
  with open(uploaded_file.name, "wb") as file:
35
  file.write(bytes_data)
36
-
37
- st.image(uploaded_file, caption="Uploaded Image",
38
- use_column_width=True)
39
-
40
- #Stage 1: Image to Text
41
  st.text('Processing img2text...')
42
- scenario = img2text(uploaded_file.name)
43
- st.write(scenario)
44
-
45
- #Stage 2: Text to Story
46
  st.text('Generating a story...')
47
- story = text2story(scenario)
48
- st.write(story)
49
-
50
- #Stage 3: Story to Audio data
51
  st.text('Generating audio data...')
52
- audio_data =text2audio(story)
 
 
 
 
 
53
 
54
- # Play button
55
- if st.button("Play Audio"):
56
- st.audio(audio_data['audio'],
57
- format="audio/wav",
58
- start_time=0,
59
- sample_rate = audio_data['sampling_rate'])
60
 
61
  if __name__ == "__main__":
62
- main()
 
10
 
11
  # text2story
12
  def text2story(text):
13
+ story_generator = pipeline("text-generation", model="aspis/gpt2-genre-story-generation")
14
+ story_text = story_generator(text, max_length=150, num_return_sequences=1)
15
+ return story_text[0]["generated_text"]
16
 
17
  # text2audio
18
  def text2audio(story_text):
19
+ tts_model = pipeline("text-to-speech", model="facebook/mms-tts-eng")
20
+ audio_data = tts_model(story_text)
21
+ return audio_data
 
22
 
23
+
24
+ # Main part
25
  def main():
26
+ st.set_page_config(page_title="Your Image to Audio Story", page_icon="🦜")
 
27
  st.header("Turn Your Image to Audio Story")
28
+
29
+ if "scenario" not in st.session_state:
30
+ st.session_state.scenario = None
31
+ if "story" not in st.session_state:
32
+ st.session_state.story = None
33
+ if "audio_data" not in st.session_state:
34
+ st.session_state.audio_data = None
35
+
36
  uploaded_file = st.file_uploader("Select an Image...")
37
+
38
+ if uploaded_file is not None and st.session_state.scenario is None:
39
  print(uploaded_file)
40
  bytes_data = uploaded_file.getvalue()
41
  with open(uploaded_file.name, "wb") as file:
42
  file.write(bytes_data)
43
+
44
+ st.image(uploaded_file, caption="Uploaded Image", use_column_width=True)
45
+
46
+ # Stage 1: Image to Text
 
47
  st.text('Processing img2text...')
48
+ st.session_state.scenario = img2text(uploaded_file.name)
49
+ st.write(st.session_state.scenario)
50
+
51
+ # Stage 2: Text to Story
52
  st.text('Generating a story...')
53
+ st.session_state.story = text2story(st.session_state.scenario)
54
+ st.write(st.session_state.story)
55
+
56
+ # Stage 3: Story to Audio Data
57
  st.text('Generating audio data...')
58
+ st.session_state.audio_data = text2audio(st.session_state.story)
59
+
60
+ elif st.session_state.scenario:
61
+ st.image(uploaded_file, caption="Uploaded Image", use_column_width=True)
62
+ st.write("Image Caption: ", st.session_state.scenario)
63
+ st.write("Generated Story: ", st.session_state.story)
64
 
65
+ # Play button (No reprocessing)
66
+ if st.session_state.audio_data and st.button("Play Audio"):
67
+ st.audio(st.session_state.audio_data['audio'],
68
+ format="audio/wav",
69
+ start_time=0,
70
+ sample_rate=st.session_state.audio_data['sampling_rate'])
71
 
72
  if __name__ == "__main__":
73
+ main()