Guess-the-prompt / main.py
jeremyLE-Ekimetrics's picture
fix
452c139
raw
history blame
5.85 kB
import torch
import numpy as np
from openai import OpenAI
import os
import streamlit as st
from PIL import Image
from diffusers import AutoPipelineForText2Image
import random
client = OpenAI()
@st.cache_data(ttl=3600)
def get_prompt_to_guess(index):
# prompts = [
# "Une cascade lumineuse dans une forêt enchantée",
# "Un coucher de soleil sur une plage déserte",
# "Un champ de tulipes multicolores à perte de vue",
# "Un château perché sur une montagne majestueuse",
# "Une ville futuriste illuminée par des néons",
# "Une forêt brumeuse où les arbres semblent danser",
# "Une soirée magique dans le ciel étoilé",
# "Une bibliothèque remplie de livres flottants",
# "Un paysage hivernal avec des arbres enneigés",
# "Une ville suspendue dans les nuages",
# "Un pont de cristal au-dessus d'une cascade étincelante",
# "Un champ de coquelicots sous un ciel bleu azur",
# "Un bateau en papier naviguant sur une rivière magique",
# "Un jardin secret rempli de fleurs exotiques",
# "Une île déserte entourée d'une mer turquoise",
# "Une montgolfière survolant un paysage onirique",
# "Un champ de lavande embaumant l'air",
# "Un petit village entouré de montagnes enneigées",
# "Une forêt tropicale avec des plantes géantes",
# "Un phare solitaire sur une falaise abrupte",
# "Un arc-en-ciel se reflétant dans un lac calme",
# "Une cabane en bois cachée au milieu des arbres",
# "Un champ de tournesols sous un soleil éclatant",
# "Une ville médiévale entourée de murailles imposantes",
# "Un château de glace scintillant dans la nuit",
# "Un chemin de pierres menant à un endroit mystérieux",
# "Une rue animée remplie de cafés et de terrasses",
# "Une cascade gelée dans un paysage d'hiver",
# "Un jardin japonais paisible avec un petit étang",
# "Une aurore boréale éblouissante embrassant le ciel étoilé",
# ]
# return random.choice(prompts)
random_prompt = ["arbre",
"écologie",
"chat",
"chien",
"consultant",
"artificial intelligence",
"beauté",
"immeuble",
"plage",
"cyborg",
"futuriste"]
response = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": "You are a helpful assistant to generate one simple prompt in order to generate an image. Your given prompt won't go over 10 words. You only return the prompt. You will also answer in french."},
{"role": "user", "content": f"Donne moi un prompt pour generer une image de {random.choice(random_prompt)}"},
]
)
return response.choices[0].message.content
@st.cache_resource
def get_model():
pipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float32, variant="fp16")
return pipe
@st.cache_data
def generate_image(_pipe, prompt):
return _pipe(prompt=prompt, num_inference_steps=1, guidance_scale=0.0, seed=1).images[0]
def check_prompt(prompt, prompt_to_guess):
return prompt.strip() == prompt_to_guess.strip()
pipe = get_model()
with open("style.css") as f:
st.markdown('<style>{}</style>'.format(f.read()), unsafe_allow_html=True)
from text import compare_text, format_text_html
if "guess_number" not in st.session_state:
st.session_state["guess_number"] = 0
st.title("Guess the prompt by Ekimetrics")
st.markdown("""
Game developed by Jeremy LE from Ekimetrics to test and play with the new SDXL Turbo model from stability.ai\n
Rules : \n
- guess the prompt (in French, with no fault) to generate the left image with the sdxl turbo model\n
- use testing prompt side to help you guess the prompt by testing some\n
- If a word is **correct** and **at the right place in sentence**, the word is in :green[green]\n
- If a word is **correct** and **not** at the right place in sentence, the word is in :gray[gray]\n
- If a word is **incorrect**, the word is in :red[red]\n
**Disclosure** : this runs on CPU so generation are quite slow (even with sdxl turbo). Generation time took approx 40s.
""")
next_guess = st.button("click here for next guess", use_container_width=True)
if next_guess:
st.session_state["guess_number"] += 1
prompt = get_prompt_to_guess(st.session_state["guess_number"])
col_1, col_2 = st.columns([1,1])
with col_1:
st.header("GUESS THE PROMPT")
guessed_prompt = st.text_input("Input your guess prompt")
submit_guess = st.button("guess the prompt", use_container_width=True, type="primary")
if submit_guess:
if check_prompt(guessed_prompt, prompt):
st.markdown("Good prompt ! test again in 1h or click on next guess!")
else:
st.markdown("wrong prompt !")
compare_dict = compare_text(guessed_prompt, prompt)
st.markdown(format_text_html(compare_dict))
get_answer = st.button("get the answer", use_container_width=True)
if get_answer:
st.markdown(f"Cheater ! but here is the prompt : \n**{prompt}**")
with col_2:
st.header("TEST THE PROMPT")
testing_prompt = st.text_input("Input your testing prompt")
test_prompt = st.button("test prompt",use_container_width=True, type="primary")
with col_1:
im_to_guess = generate_image(pipe, prompt)
h, w = im_to_guess.size
st.image(im_to_guess)
with col_2:
if test_prompt:
im = generate_image(pipe, testing_prompt)
st.session_state["testing"] = False
else:
im = np.zeros([h,w,3])
st.image(im)