Spaces:
Runtime error
Runtime error
import gradio as gr | |
from diffusers import StableDiffusionPipeline | |
import requests | |
import base64 | |
import torch | |
import os | |
auth_token = os.environ.get("auth_token") | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
generator = torch.Generator(device=device) | |
seed = 496012807434005 #generator.seed() | |
generator = generator.manual_seed(seed) | |
#HF_TOKEN = os.getenv('HF_TOKEN') | |
hf_writer =gr.HuggingFaceDatasetSaver(auth_token, "dst-movie-poster-demo") | |
def improve_image(img): | |
# ANSWER HERE | |
img_in_base64 = gr.processing_utils.encode_pil_to_base64(img) | |
scale=3 | |
resp_obj = requests.post('https://hf.space/embed/abidlabs/GFPGAN/+/api/predict',json={'data':[img_in_base64,scale]}) | |
resp_img = gr.processing_utils.decode_base64_to_image((resp_obj.json())['data'][0]) | |
return resp_img | |
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4",use_auth_token= auth_token) | |
pipe = pipe.to(device) | |
def generate(celebrity, setting): | |
# ANSWER HERE | |
prompt = 'A movie poster of {} in the movie{}'.format(celebrity,setting) | |
latent_sample = torch.randn((1,4,64,64),generator = generator,device=device) | |
gen_img = pipe(prompt,latents=latent_sample,num_inference_steps=70,guidance_scale=14).images[0] | |
image = improve_image(gen_img) | |
return image | |
gr.Interface( | |
# ANSWER HERE | |
fn=generate, | |
inputs=[gr.Textbox(label='Celebrity',value='Tom Cruise'), gr.Dropdown(['The Godfather', 'Titanic', 'Fast and Furious'], label='Movie')], | |
outputs = gr.Image(type='pill'), | |
title="Movie Poster Generation Using Stable Diffusion", | |
description="This is a movie poster generation app created as part of End to End Vision application course on CoRise by Abubakar Abid!<br/> Set the celebrity name and choose a movie name from the dropdown to generate the image.", | |
allow_flagging="manual", | |
flagging_options = ['Incorrect movie poster','Incorrect Actor','Other Problem'], | |
flagging_callback=hf_writer, | |
flagging_dir='flagged_data' | |
).launch() |