File size: 2,949 Bytes
478e2ea
dc29be0
478e2ea
dc29be0
224b704
dc29be0
4f0c03a
 
 
 
 
 
 
 
 
 
dc29be0
3ac0172
dc29be0
9fda940
deaf8e9
b92cea6
478e2ea
 
 
 
 
200efcf
0de4b61
5235e45
 
200efcf
 
620b014
200efcf
 
ef3e97a
200efcf
 
b54ef0a
e220f85
 
 
 
200efcf
478e2ea
e220f85
478e2ea
 
 
dc29be0
 
 
 
478e2ea
 
 
 
224b704
 
478e2ea
 
224b704
 
 
478e2ea
0f63834
478e2ea
 
224b704
478e2ea
 
 
16e1809
478e2ea
 
 
 
 
dc29be0
478e2ea
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from gtts import gTTS
import io
from PIL import Image

# Install PyTorch
try:
    import torch
except ImportError:
    st.warning("PyTorch is not installed. Installing PyTorch...")
    import subprocess
    subprocess.run(["pip", "install", "torch"])
    st.success("PyTorch has been successfully installed!")
    import torch

# Load the image captioning model
caption_model = pipeline("image-to-text", model="unography/blip-large-long-cap")

story_generator = pipeline("text-generation", model="distilbert/distilgpt2")


def generate_caption(image):
    # Generate the caption for the uploaded image
    caption = caption_model(image)[0]["generated_text"]
    return caption

def generate_story(caption):
    # Generate the story based on the caption using the GPT-2 model
    prompt = f"Write a short, simple children's story approximately 100 words based on the following image description:\n\n{caption}\n\nStory:"
    story = story_generator(prompt, max_length=500, num_return_sequences=1)[0]["generated_text"]
    
    # Extract the story text from the generated output
    story = story.split("Story:")[1].strip()
    
    # Post-process the story (example: remove inappropriate words)
    inappropriate_words = ["violence", "horror", "scary", "adult", "death", "gun", "shoot","criminal","rumors","die","died","kill","accident","drug","shot","ghost","sex"]
    for word in inappropriate_words:
        story = story.replace(word, "")
    
    # Limit the story to approximately 100 words
    words = story.split()
    if len(words) > 100:
        story = " ".join(words[:100]) + "..."
    
    return story
    
def convert_to_audio(story):
    # Convert the story to audio using gTTS
    tts = gTTS(text=story, lang="en")
    audio_bytes = io.BytesIO()
    tts.write_to_fp(audio_bytes)
    audio_bytes.seek(0)
    return audio_bytes

def main():
    st.title("Storytelling Application")
    
    # File uploader for the image (restricted to JPG)
    uploaded_image = st.file_uploader("Upload an image", type=["jpg"])
    
    if uploaded_image is not None:
        # Convert the uploaded image to PIL image
        image = Image.open(uploaded_image)
        
        # Display the uploaded image
        st.image(image, caption="Uploaded Image", use_container_width=True)
        
        # Generate the caption for the image
        caption = generate_caption(image)
        st.subheader("Generated Caption:")
        st.write(caption)
        
        # Generate the story based on the caption using the GPT-2 model
        story = generate_story(caption)
        st.subheader("Generated Story:")
        st.write(story)
        
        # Convert the story to audio
        audio_bytes = convert_to_audio(story)
        
        # Display the audio player
        st.audio(audio_bytes, format="audio/mp3")

if __name__ == "__main__":
    main()