AidenYan commited on
Commit
635f88c
·
verified ·
1 Parent(s): 620aad6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -52
app.py CHANGED
@@ -1,59 +1,56 @@
1
- from transformers import pipeline
2
  import streamlit as st
 
 
3
  from PIL import Image
4
  import requests
5
- from io import BytesIO
6
- from diffusers import DiffusionPipeline
7
 
 
 
8
 
 
9
  text_to_image = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
10
- # Initialize the pipeline
11
- image_to_text = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning")
12
 
13
- st.title('Image Captioning Application')
14
-
15
- # Function to load images from URL
16
- def load_image_from_url(url):
17
- try:
18
- response = requests.get(url)
19
- img = Image.open(BytesIO(response.content))
20
- return img
21
- except Exception as e:
22
- st.error(f"Error loading image from URL: {e}")
23
- return None
24
-
25
- # User option to select input type: Upload or URL
26
- input_type = st.radio("Select input type:", ("Upload Image", "Image URL"))
27
-
28
- if input_type == "Upload Image":
29
- uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
30
- if uploaded_file is not None:
31
- image = Image.open(uploaded_file)
32
- st.image(image, caption='Uploaded Image', use_column_width=True)
33
- elif input_type == "Image URL":
34
- image_url = st.text_input("Enter the image URL here:", "")
35
- if image_url:
36
- image = load_image_from_url(image_url)
37
- if image:
38
- st.image(image, caption='Image from URL', use_column_width=True)
39
-
40
- # Generate caption button
41
- if st.button('Generate Caption'):
42
- if not image:
43
- st.warning("Please upload an image or enter an image URL.")
44
- else:
45
- with st.spinner("Generating caption..."):
46
- # Process the image and generate caption
47
- if input_type == "Upload Image":
48
- # Save the uploaded image to a temporary file to pass its path to the model
49
- with open("temp_image.jpg", "wb") as f:
50
- f.write(uploaded_file.getbuffer())
51
- result = image_to_text("temp_image.jpg")
52
- elif input_type == "Image URL" and image_url:
53
- result = image_to_text(image_url)
54
-
55
- if result:
56
- generated_text = result[0]['generated_text']
57
- st.success(f'Generated Caption: {generated_text}')
58
- else:
59
- st.error("Failed to generate caption.")
 
 
1
  import streamlit as st
2
+ from transformers import pipeline
3
+ from diffusers import DiffusionPipeline
4
  from PIL import Image
5
  import requests
6
+ import io
 
7
 
8
+ # Load the image-to-text pipeline
9
+ image_to_text = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning")
10
 
11
+ # Load the text-to-image model
12
  text_to_image = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0")
 
 
13
 
14
+ def main():
15
+ st.title("Image to Story to Image Converter")
16
+
17
+ # User input for text or URL
18
+ input_option = st.radio("Select input option:", ("Text", "URL"))
19
+
20
+ # Input text
21
+ if input_option == "Text":
22
+ text_input = st.text_input("Enter the text:")
23
+ if st.button("Generate Story and Image") and text_input:
24
+ generate_image(text_input)
25
+
26
+ # Input URL
27
+ elif input_option == "URL":
28
+ uploaded_file = st.file_uploader("Upload an image file:", type=["jpg", "jpeg", "png"])
29
+ if uploaded_file is not None:
30
+ image = Image.open(uploaded_file)
31
+ st.image(image, caption="Uploaded Image", use_column_width=True)
32
+ if st.button("Generate Story and Image"):
33
+ image_text = image_to_text_from_file(uploaded_file)
34
+ generate_image(image_text)
35
+ else:
36
+ image_url = st.text_input("Enter the image URL:")
37
+ if st.button("Generate Story and Image") and image_url:
38
+ image_text = image_to_text_from_url(image_url)
39
+ generate_image(image_text)
40
+
41
+ def image_to_text_from_file(uploaded_file):
42
+ image_bytes = io.BytesIO(uploaded_file.read())
43
+ return image_to_text(image_bytes)[0]['generated_text']
44
+
45
+ def image_to_text_from_url(image_url):
46
+ response = requests.get(image_url)
47
+ image_bytes = io.BytesIO(response.content)
48
+ return image_to_text(image_bytes)[0]['generated_text']
49
+
50
+ def generate_image(text):
51
+ rephrased_text = "I want to buy " + text + " and [MASK] for my children"
52
+ generated_image = text_to_image(rephrased_text)
53
+ st.image(generated_image, caption="Generated Image", use_column_width=True)
54
+
55
+ if __name__ == "__main__":
56
+ main()