generative-media-ai / inference.py
tharms's picture
Reverting to the state of the project at 590cec9cc252b3de2e304bc845d7b711c47919cc
27cf343
raw
history blame
1.84 kB
import modules.constants as constants
import random
from diffusers import DiffusionPipeline
import torch
from openai import OpenAI
from dotenv import load_dotenv
import os
load_dotenv()
openai_key = os.getenv("OPENAI_KEY")
if openai_key == "<YOUR_OPENAI_KEY>":
openai_key = ""
if openai_key == "":
sys.exit("Please Provide Your OpenAI API Key")
device = "cuda" if torch.cuda.is_available() else "cpu"
if torch.cuda.is_available():
torch.cuda.max_memory_allocated(device=device)
pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
pipe.enable_xformers_memory_efficient_attention()
pipe = pipe.to(device)
else:
pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", use_safetensors=True)
pipe = pipe.to(device)
def infer_stable_diffusion(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
if randomize_seed:
seed = random.randint(0, constants.MAX_SEED)
generator = torch.Generator().manual_seed(seed)
image = pipe(
prompt = prompt,
negative_prompt = negative_prompt,
guidance_scale = guidance_scale,
num_inference_steps = num_inference_steps,
width = width,
height = height,
generator = generator
).images[0]
return image
def infer_dall_e(text, model, quality, size):
try:
client = OpenAI(api_key=openai_key)
response = client.images.generate(
prompt=text,
model=model,
quality=quality,
size=size,
n=1,
)
except Exception as error:
print(str(error))
raise gr.Error("An error occurred while generating image.")
return response.data[0].url