Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -14,20 +14,13 @@ except ImportError:
|
|
14 |
st.success("PyTorch has been successfully installed!")
|
15 |
import torch
|
16 |
|
17 |
-
# Install Einops library
|
18 |
-
try:
|
19 |
-
import einops
|
20 |
-
except ImportError:
|
21 |
-
st.warning("Einops library is not installed. Installing einops...")
|
22 |
-
import subprocess
|
23 |
-
subprocess.run(["pip", "install", "einops"])
|
24 |
-
st.success("Einops library has been successfully installed!")
|
25 |
-
import einops
|
26 |
-
|
27 |
# Load the image captioning model
|
28 |
caption_model = pipeline("image-to-text", model="unography/blip-large-long-cap")
|
29 |
|
30 |
-
story_generator = pipeline("text-generation", model="
|
|
|
|
|
|
|
31 |
|
32 |
def generate_caption(image):
|
33 |
# Generate the caption for the uploaded image
|
@@ -36,7 +29,7 @@ def generate_caption(image):
|
|
36 |
|
37 |
def generate_story(caption):
|
38 |
# Generate the story based on the caption using the GPT-2 model
|
39 |
-
prompt = f"Write a short, simple children's
|
40 |
story = story_generator(prompt, max_length=500, num_return_sequences=1)[0]["generated_text"]
|
41 |
|
42 |
# Extract the story text from the generated output
|
|
|
14 |
st.success("PyTorch has been successfully installed!")
|
15 |
import torch
|
16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
# Load the image captioning model
|
18 |
caption_model = pipeline("image-to-text", model="unography/blip-large-long-cap")
|
19 |
|
20 |
+
#story_generator = pipeline("text-generation", model="distilbert/distilgpt2")
|
21 |
+
|
22 |
+
story_generator = pipeline("text-generation", model="isarth/distill_gpt2_story_generator")
|
23 |
+
|
24 |
|
25 |
def generate_caption(image):
|
26 |
# Generate the caption for the uploaded image
|
|
|
29 |
|
30 |
def generate_story(caption):
|
31 |
# Generate the story based on the caption using the GPT-2 model
|
32 |
+
prompt = f"Write a short, simple children's story inspired by the image of {caption}. Here's the story:\n\n"
|
33 |
story = story_generator(prompt, max_length=500, num_return_sequences=1)[0]["generated_text"]
|
34 |
|
35 |
# Extract the story text from the generated output
|