Sergidev commited on
Commit
2360e47
1 Parent(s): bf6dd2e

updated v3p3

Browse files
Files changed (1) hide show
  1. app.py +52 -0
app.py CHANGED
@@ -15,6 +15,58 @@ from diffusers import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipelin
15
 
16
  # ... (keep the existing imports and configurations)
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  # Add a new function to parse and validate JSON input
19
  def parse_json_parameters(json_str):
20
  try:
 
15
 
16
  # ... (keep the existing imports and configurations)
17
 
18
+ logging.basicConfig(level=logging.INFO)
19
+ logger = logging.getLogger(__name__)
20
+
21
+ DESCRIPTION = "PonyDiffusion V6 XL"
22
+ if not torch.cuda.is_available():
23
+ DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU. </p>"
24
+ IS_COLAB = utils.is_google_colab() or os.getenv("IS_COLAB") == "1"
25
+ HF_TOKEN = os.getenv("HF_TOKEN")
26
+ CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES") == "1"
27
+ MIN_IMAGE_SIZE = int(os.getenv("MIN_IMAGE_SIZE", "512"))
28
+ MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "2048"))
29
+ USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE") == "1"
30
+ ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD") == "1"
31
+ OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./outputs")
32
+
33
+ MODEL = os.getenv(
34
+ "MODEL",
35
+ "https://huggingface.co/AstraliteHeart/pony-diffusion-v6/blob/main/v6.safetensors",
36
+ )
37
+
38
+ torch.backends.cudnn.deterministic = True
39
+ torch.backends.cudnn.benchmark = False
40
+
41
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
42
+
43
+
44
+ def load_pipeline(model_name):
45
+ vae = AutoencoderKL.from_pretrained(
46
+ "madebyollin/sdxl-vae-fp16-fix",
47
+ torch_dtype=torch.float16,
48
+ )
49
+ pipeline = (
50
+ StableDiffusionXLPipeline.from_single_file
51
+ if MODEL.endswith(".safetensors")
52
+ else StableDiffusionXLPipeline.from_pretrained
53
+ )
54
+
55
+ pipe = pipeline(
56
+ model_name,
57
+ vae=vae,
58
+ torch_dtype=torch.float16,
59
+ custom_pipeline="lpw_stable_diffusion_xl",
60
+ use_safetensors=True,
61
+ add_watermarker=False,
62
+ use_auth_token=HF_TOKEN,
63
+ variant="fp16",
64
+ )
65
+
66
+ pipe.to(device)
67
+ return pipe
68
+
69
+
70
  # Add a new function to parse and validate JSON input
71
  def parse_json_parameters(json_str):
72
  try: