tharms commited on
Commit
bcac76b
·
1 Parent(s): e5446e8

playgound with openai dependency 2

Browse files
Files changed (5) hide show
  1. args_manager.py +1 -0
  2. dummy_inference.py +32 -0
  3. inference.py +65 -0
  4. launch.py +10 -7
  5. webui.py +10 -0
args_manager.py CHANGED
@@ -37,6 +37,7 @@ args_parser.parser.add_argument("--always-download-new-model", action='store_tru
37
  args_parser.parser.set_defaults(
38
  disable_cuda_malloc=True,
39
  in_browser=True,
 
40
  port=None
41
  )
42
 
 
37
  args_parser.parser.set_defaults(
38
  disable_cuda_malloc=True,
39
  in_browser=True,
40
+ api_mode=False,
41
  port=None
42
  )
43
 
dummy_inference.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from openai import OpenAI
2
+ from dotenv import load_dotenv
3
+ import os
4
+
5
+ load_dotenv()
6
+ openai_key = os.getenv("OPENAI_KEY")
7
+
8
+ if openai_key == "<YOUR_OPENAI_KEY>":
9
+ openai_key = ""
10
+
11
+ if openai_key == "":
12
+ sys.exit("Please Provide Your OpenAI API Key")
13
+
14
+ def infer_stable_diffusion(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
15
+ return "dummy_image"
16
+
17
+ def infer_dall_e(text, model, quality, size):
18
+ try:
19
+ client = OpenAI(api_key=openai_key)
20
+
21
+ response = client.images.generate(
22
+ prompt=text,
23
+ model=model,
24
+ quality=quality,
25
+ size=size,
26
+ n=1,
27
+ )
28
+ except Exception as error:
29
+ print(str(error))
30
+ raise gr.Error("An error occurred while generating image.")
31
+
32
+ return response.data[0].url
inference.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+ from diffusers import DiffusionPipeline
4
+ import torch
5
+ from openai import OpenAI
6
+ from dotenv import load_dotenv
7
+ import os
8
+
9
+ load_dotenv()
10
+ openai_key = os.getenv("OPENAI_KEY")
11
+
12
+ if openai_key == "<YOUR_OPENAI_KEY>":
13
+ openai_key = ""
14
+
15
+ if openai_key == "":
16
+ sys.exit("Please Provide Your OpenAI API Key")
17
+
18
+ device = "cuda" if torch.cuda.is_available() else "cpu"
19
+
20
+ if torch.cuda.is_available():
21
+ torch.cuda.max_memory_allocated(device=device)
22
+ pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
23
+ pipe.enable_xformers_memory_efficient_attention()
24
+ pipe = pipe.to(device)
25
+ else:
26
+ pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", use_safetensors=True)
27
+ pipe = pipe.to(device)
28
+
29
+ MAX_SEED = np.iinfo(np.int32).max
30
+
31
+ def infer_stable_diffusion(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
32
+
33
+ if randomize_seed:
34
+ seed = random.randint(0, MAX_SEED)
35
+
36
+ generator = torch.Generator().manual_seed(seed)
37
+
38
+ image = pipe(
39
+ prompt = prompt,
40
+ negative_prompt = negative_prompt,
41
+ guidance_scale = guidance_scale,
42
+ num_inference_steps = num_inference_steps,
43
+ width = width,
44
+ height = height,
45
+ generator = generator
46
+ ).images[0]
47
+
48
+ return image
49
+
50
+ def infer_dall_e(text, model, quality, size):
51
+ try:
52
+ client = OpenAI(api_key=openai_key)
53
+
54
+ response = client.images.generate(
55
+ prompt=text,
56
+ model=model,
57
+ quality=quality,
58
+ size=size,
59
+ n=1,
60
+ )
61
+ except Exception as error:
62
+ print(str(error))
63
+ raise gr.Error("An error occurred while generating image.")
64
+
65
+ return response.data[0].url
launch.py CHANGED
@@ -27,6 +27,7 @@ TRY_INSTALL_XFORMERS = False
27
 
28
 
29
  def prepare_environment():
 
30
  torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu121")
31
  torch_command = os.environ.get('TORCH_COMMAND',
32
  f"pip install torch==2.1.0 torchvision==0.16.0 --extra-index-url {torch_index_url}")
@@ -71,10 +72,12 @@ def ini_args():
71
  from args_manager import args
72
  return args
73
 
74
-
75
- prepare_environment()
76
- build_launcher()
77
  args = ini_args()
 
 
 
 
 
78
 
79
  if args.gpu_device_id is not None:
80
  os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_device_id)
@@ -128,9 +131,9 @@ def download_models(default_model, previous_default_models, checkpoint_downloads
128
 
129
  return default_model, checkpoint_downloads
130
 
131
-
132
- config.default_base_model_name, config.checkpoint_downloads = download_models(
133
- config.default_base_model_name, config.previous_default_models, config.checkpoint_downloads,
134
- config.embeddings_downloads, config.lora_downloads)
135
 
136
  from webui import *
 
27
 
28
 
29
  def prepare_environment():
30
+
31
  torch_index_url = os.environ.get('TORCH_INDEX_URL', "https://download.pytorch.org/whl/cu121")
32
  torch_command = os.environ.get('TORCH_COMMAND',
33
  f"pip install torch==2.1.0 torchvision==0.16.0 --extra-index-url {torch_index_url}")
 
72
  from args_manager import args
73
  return args
74
 
 
 
 
75
  args = ini_args()
76
+ if not args.api_mode:
77
+ prepare_environment()
78
+ build_launcher()
79
+ args = ini_args()
80
+
81
 
82
  if args.gpu_device_id is not None:
83
  os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_device_id)
 
131
 
132
  return default_model, checkpoint_downloads
133
 
134
+ if not args.api_mode:
135
+ config.default_base_model_name, config.checkpoint_downloads = download_models(
136
+ config.default_base_model_name, config.previous_default_models, config.checkpoint_downloads,
137
+ config.embeddings_downloads, config.lora_downloads)
138
 
139
  from webui import *
webui.py CHANGED
@@ -16,6 +16,7 @@ import modules.meta_parser
16
  import args_manager
17
  import copy
18
  import launch
 
19
 
20
  from modules.sdxl_styles import legal_style_names
21
  from modules.private_logger import get_current_html_path
@@ -23,6 +24,15 @@ from modules.ui_gradio_extensions import reload_javascript
23
  from modules.auth import auth_enabled, check_auth
24
  from modules.util import is_json
25
 
 
 
 
 
 
 
 
 
 
26
  def get_task(*args):
27
  args = list(args)
28
  args.pop(0)
 
16
  import args_manager
17
  import copy
18
  import launch
19
+ import torch
20
 
21
  from modules.sdxl_styles import legal_style_names
22
  from modules.private_logger import get_current_html_path
 
24
  from modules.auth import auth_enabled, check_auth
25
  from modules.util import is_json
26
 
27
+ def ini_args():
28
+ from args_manager import args
29
+ return args
30
+
31
+ if ini_args().api_mode:
32
+ import dummy_inference as inf
33
+ else:
34
+ import inference as inf
35
+
36
  def get_task(*args):
37
  args = list(args)
38
  args.pop(0)