|
import streamlit as st |
|
import os |
|
import requests |
|
import io |
|
from PIL import Image |
|
from freeGPT import Client |
|
|
|
generative, chat = st.tabs(["Image Generation", "Chat"]) |
|
hf_token = os.environ.get("API_TOKEN") |
|
headers = {"Authorization": f"Bearer {hf_token}"} |
|
API_URL_FACE = "https://api-inference.huggingface.co/models/prompthero/linkedin-diffusion" |
|
API_URL_PIX = "https://api-inference.huggingface.co/models/nerijs/pixel-art-xl" |
|
API_URL_3D = "https://api-inference.huggingface.co/models/goofyai/3d_render_style_xl" |
|
API_URL_REAL = "https://api-inference.huggingface.co/models/stablediffusionapi/realistic-vision-v51" |
|
API_URL_DALLE = "https://api-inference.huggingface.co/models/openskyml/dalle-3-xl" |
|
API_URL_INKPUNK = "https://api-inference.huggingface.co/models/Envvi/Inkpunk-Diffusion" |
|
API_URL_DREAM = "https://api-inference.huggingface.co/models/Lykon/dreamshaper-xl-v2-turbo" |
|
API_URL_COVER = "https://api-inference.huggingface.co/models/Norod78/sxl-laisha-magazine-cover-lora" |
|
|
|
with generative: |
|
st.title("✨ Open Text2Image Models Leaderboard") |
|
st.write("Choose one model to generate image and enter prompt") |
|
model = st.selectbox( |
|
'Model', |
|
('Dall-e 3', 'Pixel', '3D Render', 'Realistic', 'Inkpunk', 'Dremscape', 'Magazine-cover', 'Faces') |
|
) |
|
prompt = st.text_area('Enter prompt') |
|
button = st.button('Generate') |
|
|
|
if model == 'Dall-e 3': |
|
API_URL = API_URL_DALLE |
|
elif model == 'Pixel': |
|
API_URL = API_URL_PIX |
|
elif model == '3D Render': |
|
API_URL = API_URL_3D |
|
elif model == 'Realistic': |
|
API_URL = API_URL_REAL |
|
elif model == 'Inkpunk': |
|
API_URL = API_URL_INKPUNK |
|
elif model == 'Dremscape': |
|
API_URL = API_URL_DREAM |
|
elif model == 'Magazine-cover': |
|
API_URL = API_URL_COVER |
|
elif model == 'Faces': |
|
API_URL = API_URL_FACE |
|
|
|
|
|
def query(payload): |
|
response = requests.post(API_URL, headers=headers, json=payload) |
|
return response.content |
|
|
|
def generate_image(input_text): |
|
image_bytes = query({ |
|
"inputs": input_text, |
|
}) |
|
image = Image.open(io.BytesIO(image_bytes)) |
|
return image |
|
|
|
if button: |
|
generated_image = generate_image(prompt) |
|
st.image(generated_image, caption='Generated Image') |
|
|
|
with chat: |
|
st.title("AI Chat") |
|
messages = st.container(height=600) |
|
prompt_gpt = st.chat_input("Ask something") |
|
if prompt_gpt: |
|
user_msg = messages.chat_message("user").write(prompt_gpt) |
|
try: |
|
resp = Client.create_completion("gpt3", prompt_gpt) |
|
message_out = messages.chat_message("assistant").write(resp) |
|
except Exception as e: |
|
message_out = messages.chat_message("assistant").write(e) |