File size: 1,710 Bytes
bb38282
da30eb6
 
 
 
7786a61
bb38282
ce9d54d
 
da30eb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90a5bb8
da30eb6
 
 
 
431e0cf
90a5bb8
da30eb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0401773
 
921eebb
da30eb6
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import os
import gradio as gr 
from pathlib import Path
from diffusers import StableDiffusionPipeline
from PIL import Image
from huggingface_hub import notebook_login
#if not (Path.home()/'.huggingface'/'token').exists():
#token = os.environ.get("HUGGING_FACE_HUB_TOKEN")
token = "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"

import torch, logging
logging.disable(logging.WARNING)
torch.cuda.empty_cache()
torch.manual_seed(3407)
from torch import autocast
from contextlib import nullcontext
from diffusers import StableDiffusionPipeline



model_id = "CompVis/stable-diffusion-v1-4"
device = "cuda" if torch.cuda.is_available() else "cpu"
context = autocast if device == "cuda" else nullcontext

pipe = StableDiffusionPipeline.from_pretrained(model_id,use_auth_token=token).to(device)


def infer(prompt,samples):
    with context(device):
        images = pipe(samples*[prompt], guidance_scale=7.5).images
    return images


demo = gr.Blocks()

with demo:
    text = gr.Textbox(lines=7,placeholder="Enter your prompt to generate a background image... something like - Photorealistic scenery of bookshelf in a room")
    samples = gr.Slider(label="Number of Images", minimum=1, maximum=5, value=2, step=1)
    btn = gr.Button("Generate images",variant="primary").style(
                                margin=False,
                                rounded=(False, True, True, False),
                            )
    gallery = gr.Gallery(label="Generated images", show_label=True).style(grid=(1, 3), height="auto")
    
    text.submit(infer, inputs=[text, samples], outputs=gallery)
    btn.click(infer, inputs=[text, samples], outputs=gallery, show_progress=True, status_tracker=None)



demo.launch()