Spaces:
Paused
Paused
playgound with openai dependency 2
Browse files- args_manager.py +1 -0
- dummy_inference.py +32 -0
- inference.py +65 -0
- launch.py +10 -7
- webui.py +10 -0
args_manager.py
CHANGED
@@ -37,6 +37,7 @@ args_parser.parser.add_argument("--always-download-new-model", action='store_tru
|
|
37 |
args_parser.parser.set_defaults(
|
38 |
disable_cuda_malloc=True,
|
39 |
in_browser=True,
|
|
|
40 |
port=None
|
41 |
)
|
42 |
|
|
|
37 |
args_parser.parser.set_defaults(
|
38 |
disable_cuda_malloc=True,
|
39 |
in_browser=True,
|
40 |
+
api_mode=False,
|
41 |
port=None
|
42 |
)
|
43 |
|
dummy_inference.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from openai import OpenAI
|
2 |
+
from dotenv import load_dotenv
|
3 |
+
import os
|
4 |
+
|
5 |
+
load_dotenv()
|
6 |
+
openai_key = os.getenv("OPENAI_KEY")
|
7 |
+
|
8 |
+
if openai_key == "<YOUR_OPENAI_KEY>":
|
9 |
+
openai_key = ""
|
10 |
+
|
11 |
+
if openai_key == "":
|
12 |
+
sys.exit("Please Provide Your OpenAI API Key")
|
13 |
+
|
14 |
+
def infer_stable_diffusion(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
|
15 |
+
return "dummy_image"
|
16 |
+
|
17 |
+
def infer_dall_e(text, model, quality, size):
|
18 |
+
try:
|
19 |
+
client = OpenAI(api_key=openai_key)
|
20 |
+
|
21 |
+
response = client.images.generate(
|
22 |
+
prompt=text,
|
23 |
+
model=model,
|
24 |
+
quality=quality,
|
25 |
+
size=size,
|
26 |
+
n=1,
|
27 |
+
)
|
28 |
+
except Exception as error:
|
29 |
+
print(str(error))
|
30 |
+
raise gr.Error("An error occurred while generating image.")
|
31 |
+
|
32 |
+
return response.data[0].url
|
inference.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import random
|
3 |
+
from diffusers import DiffusionPipeline
|
4 |
+
import torch
|
5 |
+
from openai import OpenAI
|
6 |
+
from dotenv import load_dotenv
|
7 |
+
import os
|
8 |
+
|
9 |
+
load_dotenv()
|
10 |
+
openai_key = os.getenv("OPENAI_KEY")
|
11 |
+
|
12 |
+
if openai_key == "<YOUR_OPENAI_KEY>":
|
13 |
+
openai_key = ""
|
14 |
+
|
15 |
+
if openai_key == "":
|
16 |
+
sys.exit("Please Provide Your OpenAI API Key")
|
17 |
+
|
18 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
19 |
+
|
20 |
+
if torch.cuda.is_available():
|
21 |
+
torch.cuda.max_memory_allocated(device=device)
|
22 |
+
pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
|
23 |
+
pipe.enable_xformers_memory_efficient_attention()
|
24 |
+
pipe = pipe.to(device)
|
25 |
+
else:
|
26 |
+
pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", use_safetensors=True)
|
27 |
+
pipe = pipe.to(device)
|
28 |
+
|
29 |
+
MAX_SEED = np.iinfo(np.int32).max
|
30 |
+
|
31 |
+
def infer_stable_diffusion(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
|
32 |
+
|
33 |
+
if randomize_seed:
|
34 |
+
seed = random.randint(0, MAX_SEED)
|
35 |
+
|
36 |
+
generator = torch.Generator().manual_seed(seed)
|
37 |
+
|
38 |
+
image = pipe(
|
39 |
+
prompt = prompt,
|
40 |
+
negative_prompt = negative_prompt,
|
41 |
+
guidance_scale = guidance_scale,
|
42 |
+
num_inference_steps = num_inference_steps,
|
43 |
+
width = width,
|
44 |
+
height = height,
|
45 |
+
generator = generator
|
46 |
+
).images[0]
|
47 |
+
|
48 |
+
return image
|
49 |
+
|
50 |
+
def infer_dall_e(text, model, quality, size):
|
51 |
+
try:
|
52 |
+
client = OpenAI(api_key=openai_key)
|
53 |
+
|
54 |
+
response = client.images.generate(
|
55 |
+
prompt=text,
|
56 |
+
model=model,
|
57 |
+
quality=quality,
|
58 |
+
size=size,
|
59 |
+
n=1,
|
60 |
+
)
|
61 |
+
except Exception as error:
|
62 |
+
print(str(error))
|
63 |
+
raise gr.Error("An error occurred while generating image.")
|
64 |
+
|
65 |
+
return response.data[0].url
|
launch.py
CHANGED
@@ -27,6 +27,7 @@ TRY_INSTALL_XFORMERS = False
|
|
27 |
|
28 |
|
29 |
def prepare_environment():
|
|
|
30 |
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu121")
|
31 |
torch_command = os.environ.get('TORCH_COMMAND',
|
32 |
f"pip install torch==2.1.0 torchvision==0.16.0 --extra-index-url {torch_index_url}")
|
@@ -71,10 +72,12 @@ def ini_args():
|
|
71 |
from args_manager import args
|
72 |
return args
|
73 |
|
74 |
-
|
75 |
-
prepare_environment()
|
76 |
-
build_launcher()
|
77 |
args = ini_args()
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
if args.gpu_device_id is not None:
|
80 |
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_device_id)
|
@@ -128,9 +131,9 @@ def download_models(default_model, previous_default_models, checkpoint_downloads
|
|
128 |
|
129 |
return default_model, checkpoint_downloads
|
130 |
|
131 |
-
|
132 |
-
config.default_base_model_name, config.checkpoint_downloads = download_models(
|
133 |
-
|
134 |
-
|
135 |
|
136 |
from webui import *
|
|
|
27 |
|
28 |
|
29 |
def prepare_environment():
|
30 |
+
|
31 |
torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu121")
|
32 |
torch_command = os.environ.get('TORCH_COMMAND',
|
33 |
f"pip install torch==2.1.0 torchvision==0.16.0 --extra-index-url {torch_index_url}")
|
|
|
72 |
from args_manager import args
|
73 |
return args
|
74 |
|
|
|
|
|
|
|
75 |
args = ini_args()
|
76 |
+
if not args.api_mode:
|
77 |
+
prepare_environment()
|
78 |
+
build_launcher()
|
79 |
+
args = ini_args()
|
80 |
+
|
81 |
|
82 |
if args.gpu_device_id is not None:
|
83 |
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_device_id)
|
|
|
131 |
|
132 |
return default_model, checkpoint_downloads
|
133 |
|
134 |
+
if not args.api_mode:
|
135 |
+
config.default_base_model_name, config.checkpoint_downloads = download_models(
|
136 |
+
config.default_base_model_name, config.previous_default_models, config.checkpoint_downloads,
|
137 |
+
config.embeddings_downloads, config.lora_downloads)
|
138 |
|
139 |
from webui import *
|
webui.py
CHANGED
@@ -16,6 +16,7 @@ import modules.meta_parser
|
|
16 |
import args_manager
|
17 |
import copy
|
18 |
import launch
|
|
|
19 |
|
20 |
from modules.sdxl_styles import legal_style_names
|
21 |
from modules.private_logger import get_current_html_path
|
@@ -23,6 +24,15 @@ from modules.ui_gradio_extensions import reload_javascript
|
|
23 |
from modules.auth import auth_enabled, check_auth
|
24 |
from modules.util import is_json
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
def get_task(*args):
|
27 |
args = list(args)
|
28 |
args.pop(0)
|
|
|
16 |
import args_manager
|
17 |
import copy
|
18 |
import launch
|
19 |
+
import torch
|
20 |
|
21 |
from modules.sdxl_styles import legal_style_names
|
22 |
from modules.private_logger import get_current_html_path
|
|
|
24 |
from modules.auth import auth_enabled, check_auth
|
25 |
from modules.util import is_json
|
26 |
|
27 |
+
def ini_args():
|
28 |
+
from args_manager import args
|
29 |
+
return args
|
30 |
+
|
31 |
+
if ini_args().api_mode:
|
32 |
+
import dummy_inference as inf
|
33 |
+
else:
|
34 |
+
import inference as inf
|
35 |
+
|
36 |
def get_task(*args):
|
37 |
args = list(args)
|
38 |
args.pop(0)
|