File size: 2,446 Bytes
e827598
 
 
 
 
 
 
 
 
 
 
 
 
8ae5b18
e827598
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ae5b18
 
 
 
e827598
8ae5b18
e827598
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import random
import warnings
from io import BytesIO
import torch
import stability_sdk.interfaces.gooseai.generation.generation_pb2 as generation
from stability_sdk import client
from PIL import Image
from IPython.display import display

import argparse
import subprocess
import sys
import time
import os


def setup():
    install_cmds = [
        ['pip', 'install', 'ftfy', 'gradio', 'regex', 'tqdm', 'stability-sdk',
            'transformers==4.21.2', 'timm', 'fairscale', 'requests'],
        ['pip', 'install', '-e', 'git+https://github.com/openai/CLIP.git@main#egg=clip'],
        ['pip', 'install', '-e',
            'git+https://github.com/pharmapsychotic/BLIP.git@lib#egg=blip'],
        ['git', 'clone', 'https://github.com/pharmapsychotic/clip-interrogator.git']
    ]
    for cmd in install_cmds:
        print(subprocess.run(cmd, stdout=subprocess.PIPE).stdout.decode('utf-8'))

setup()

sys.path.append('src/blip')
sys.path.append('src/clip')
sys.path.append('clip-interrogator')


import clip
import gradio as gr
import torch
from clip_interrogator import Interrogator, Config

ci = Interrogator(Config())

stability_api = client.StabilityInference(
    key=os.environ['STABILITY_KEY'],
    verbose=True
)


def inferAndRebuild(image, mode):
    image = image.convert('RGB')
    output = ''
    if (mode == 'best'):
        output = ci.interrogate(image)
    elif (mode == 'classic'):
        output = ci.interrogate_classic(image)
    else:
        output = ci.interrogate_fast(image)

    answers = stability_api.generate(
        prompt=str(output),
        seed=34567,
        steps=30,
        samples=5
    )

    imglist = []
    for resp in answers:
        for artifact in resp.artifacts:
            if artifact.finish_reason == generate.FILTER:
                warnings.warn(
                    "Your request activated the API's safety filters and could not be processed. Please modify the prompt and try again.")
            if artifact.type == generation.ARTIFACT_IMAGE:
                img = Image.open(BytesIO(artifact.binary))
                imglist.append(img)
    return [imglist, output]


inputs = [
    gr.inputs.Image(type='pil'),
    gr.Radio(['best', 'classic', 'fast'], label='Models', value='fast')
]

outputs = [
    gr.Gallery(),
    gr.outputs.Textbox(label='Prompt')
]

io = gr.Interface(
    inferAndRebuild,
    inputs,
    outputs,
    allow_flagging=False,
)

io.launch(share=True, debug=True)