prompt-extend-2 / app.py
RamAnanth1's picture
Update app.py
84cc295 verified
import gradio as gr
from transformers import pipeline
import os
pipe = pipeline('text-generation', model='RamAnanth1/distilgpt2-sd-prompts')
stable_diffusion = gr.load(name="spaces/runwayml/stable-diffusion-v1-5")
clip_interrogator_2 = gr.load(name="spaces/fffiloni/CLIP-Interrogator-2")
def get_images(prompt):
gallery_dir = stable_diffusion(prompt, fn_index=2)
img_results = [os.path.join(gallery_dir, img) for img in os.listdir(gallery_dir)]
return img_results[0]
def get_new_prompt(img, mode):
interrogate = clip_interrogator_2(img, mode, 12, api_name="clipi2")
return interrogate
def infer(input):
prompt = pipe(input+',', num_return_sequences=1)[0]["generated_text"]
img = get_images(prompt)
result = get_new_prompt(img, 'fast')
return prompt,result[0]
input_prompt = gr.Text(label="Enter the initial prompt")
sd1_output = gr.Text(label="Extended prompt suitable for Stable Diffusion 1.x")
sd2_output = gr.Text(label="Extended prompt suitable for Stable Diffusion 2.x")
description="""
<p style="text-align:center;">
Since Stable Diffusion 2 uses OpenCLIP ViT-H model trained on LAION dataset compared to the OpenAI ViT-L of Stable Diffusion 1, the prompting style varies and the exact prompt is often hard to write.
<br />This demo extends an initial idea and generates suitable prompts compatible with v1.x stable diffusion and v2.x stable diffusion. The version 1.x prompts are first obtained and the corresponding version 2.x prompt are obtained
<br />by generating an image through <a href="https://huggingface.co/runwayml/stable-diffusion-v1-5" target="_blank">RunwayML Stable Diffusion 1.5</a>, then Interrogate the resulting image through <a href="https://huggingface.co/spaces/fffiloni/CLIP-Interrogator-2" target="_blank">CLIP Interrogator 2</a> to give you a Stable Diffusion 2 equivalent prompt.
</p>
"""
title="Prompt Extender 2"
examples = [
["giant dragon flying in the sky"],
["peaceful village landscape"],
]
demo = gr.Interface(fn=infer, inputs=input_prompt, outputs=[sd1_output,sd2_output], description = description, title = title, examples = examples)
demo.queue(max_size=10,concurrency_count=20)
demo.launch(enable_queue=True)