Spaces:
Paused
Paused
import streamlit as st | |
from PIL import Image | |
from inference import inference | |
import torch | |
import io | |
def main(): | |
genres_dict = { | |
'Action': 1, | |
'Adventure': 2, | |
'Animation': 3, | |
'Comedy': 4, | |
'Drama': 5, | |
'Family': 6, | |
'Horror': 7, | |
'Music': 8, | |
'Romance': 9, | |
'Science Fiction': 10, | |
'Western': 11, | |
'Fantasy': 12, | |
'Thriller': 13 | |
} | |
st.title("Image Display App") | |
cond = torch.tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]) | |
# Add a sidebar for genre selection | |
#genre = st.sidebar.selectbox("Select Genre", list(genres_dict.keys())) | |
selected_genres = st.sidebar.multiselect('Select Genres', list(genres_dict.keys())) | |
# Button to trigger image generation | |
if st.button('Generate Image'): | |
for genre in selected_genres: | |
code = genres_dict[genre] | |
cond[code-1] = code | |
# Display loading sign while generating image | |
with st.spinner('Generating Image...'): | |
# Call the function from inference.py with selected genre | |
image = inference(cond) | |
#image = inference(genre) | |
# Convert Pillow image to bytes for display in Streamlit | |
img_buffer = io.BytesIO() | |
#"""0,0,0,0,0,0,0,1, 2, 7, 4, 0, 0, 0""" | |
image.save(img_buffer, format="PNG") | |
img_buffer.seek(0) | |
# Display the generated image | |
st.image(img_buffer, caption='Generated Image', use_column_width=True) | |
if __name__ == "__main__": | |
main() | |