|
import time |
|
import os |
|
import PIL |
|
import gradio as gr |
|
|
|
import torch |
|
import transformers |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
from transformers import pipeline |
|
from diffusers import StableDiffusionPipeline |
|
|
|
READ_TOKEN = os.environ.get('HF_ACCESS_TOKEN', None) |
|
|
|
model_id = "runwayml/stable-diffusion-v1-5" |
|
|
|
|
|
has_cuda = torch.cuda.is_available() |
|
|
|
if has_cuda: |
|
pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, revision="fp16", use_auth_token=READ_TOKEN) |
|
device = "cuda" |
|
else: |
|
pipe = StableDiffusionPipeline.from_pretrained(model_id, use_auth_token=READ_TOKEN) |
|
device = "cpu" |
|
|
|
pipe.to(device) |
|
def safety_checker(images, clip_input): |
|
return images, False |
|
pipe.safety_checker = safety_checker |
|
|
|
SAVED_CHECKPOINT = 'mikegarts/distilgpt2-lotr' |
|
model = AutoModelForCausalLM.from_pretrained(SAVED_CHECKPOINT) |
|
tokenizer = AutoTokenizer.from_pretrained(SAVED_CHECKPOINT) |
|
|
|
summarizer = pipeline("summarization") |
|
|
|
|
|
|
|
def break_until_dot(txt): |
|
return txt.rsplit('.', 1)[0] + '.' |
|
|
|
def generate(prompt): |
|
input_context = prompt |
|
input_ids = tokenizer.encode(input_context, return_tensors="pt").to(model.device) |
|
|
|
outputs = model.generate( |
|
input_ids=input_ids, |
|
max_length=120, |
|
min_length=50, |
|
temperature=0.7, |
|
num_return_sequences=3, |
|
do_sample=True |
|
) |
|
decoded = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
return break_until_dot(decoded) |
|
|
|
|
|
def generate_story(prompt): |
|
story = generate(prompt=prompt) |
|
summary = summarizer(story, min_length=5, max_length=15)[0]['summary_text'] |
|
summary = break_until_dot(summary) |
|
return story, summary, gr.update(visible=True) |
|
|
|
def on_change_event(app_state): |
|
print(f'on_change_event {app_state}') |
|
if app_state and app_state['running'] and app_state['img']: |
|
img = app_state['img'] |
|
step = app_state['step'] |
|
print(f'Updating the image:! {app_state}') |
|
app_state['dots'] += 1 |
|
app_state['dots'] = app_state['dots'] % 10 |
|
message = app_state['status_msg'] + ' *' * app_state['dots'] |
|
print (f'message={message}') |
|
return gr.update(value=app_state['img_list'], label='intermediate steps'), gr.update(value=message) |
|
else: |
|
return gr.update(label='images list'), gr.update(value='') |
|
|
|
with gr.Blocks() as demo: |
|
|
|
def generate_image(prompt, inference_steps, app_state): |
|
app_state['running'] = True |
|
app_state['img_list'] = [] |
|
app_state['status_msg'] = 'Starting' |
|
def callback(step, ts, latents): |
|
app_state['status_msg'] = f'Reconstructing an image from the latent state on step {step}' |
|
latents = 1 / 0.18215 * latents |
|
res = pipe.vae.decode(latents).sample |
|
res = (res / 2 + 0.5).clamp(0, 1) |
|
res = res.cpu().permute(0, 2, 3, 1).detach().numpy() |
|
res = pipe.numpy_to_pil(res)[0] |
|
app_state['img'] = res |
|
app_state['step'] = step |
|
app_state['img_list'].append(res) |
|
app_state['status_msg'] = f'Generating step ({step + 1})' |
|
|
|
prompt = prompt + ' masterpiece charcoal pencil art lord of the rings illustration' |
|
img = pipe(prompt, height=512, width=512, num_inference_steps=inference_steps, callback=callback, callback_steps=1) |
|
app_state['running'] = False |
|
app_state['img'] = None |
|
app_state['status_msg'] = '' |
|
app_state['dots'] = 0 |
|
return gr.update(value=img.images[0], label='Generated image') |
|
|
|
app_state = gr.State({'img': None, |
|
'step':0, |
|
'running':False, |
|
'status_msg': '', |
|
'img_list': [], |
|
'dots': 0 |
|
}) |
|
title = gr.Markdown('## Lord of the rings app') |
|
description = gr.Markdown(f'#### A Lord of the rings inspired app that combines text and image generation.' |
|
f' The language modeling is done by fine tuning distilgpt2 on the LOTR trilogy.' |
|
f' The text2img model is {model_id}. The summarization is done using distilbart.') |
|
prompt = gr.Textbox(label="Your prompt", value="Frodo took the sword and") |
|
story = gr.Textbox(label="Your story") |
|
summary = gr.Textbox(label="Summary") |
|
|
|
bt_make_text = gr.Button("Generate text") |
|
bt_make_image = gr.Button(f"Generate an image (takes about 10-15 minutes on CPU).", visible=False) |
|
|
|
img_description = gr.Markdown('Image generation takes some time' |
|
' but here you can see what is generated from the latent state of the diffuser every few steps.' |
|
' Usually there is a significant improvement around step 12 that yields a much better image') |
|
status_msg = gr.Markdown() |
|
|
|
gallery = gr.Gallery() |
|
image = gr.Image(label='Illustration for your story', show_label=True) |
|
|
|
gallery.style(grid=[4]) |
|
|
|
inference_steps = gr.Slider(5, 30, |
|
value=20, |
|
step=1, |
|
visible=True, |
|
label=f"Num inference steps (more steps yields a better image but takes more time)") |
|
|
|
|
|
bt_make_text.click(fn=generate_story, inputs=prompt, outputs=[story, summary, bt_make_image]) |
|
bt_make_image.click(fn=generate_image, inputs=[summary, inference_steps, app_state], outputs=image) |
|
|
|
eventslider = gr.Slider(visible=False) |
|
dep = demo.load(on_change_event, app_state, [gallery, status_msg], every=5) |
|
eventslider.change(fn=on_change_event, inputs=[app_state], outputs=[gallery, status_msg], every=5, cancels=[dep]) |
|
|
|
|
|
if READ_TOKEN: |
|
demo.queue().launch() |
|
else: |
|
demo.queue().launch(share=True, debug=True) |
|
|