Commit
•
452f87d
1
Parent(s):
559f623
Update app.py
Browse files
app.py
CHANGED
@@ -17,7 +17,7 @@ if not torch.cuda.is_available():
|
|
17 |
MAX_SEED = np.iinfo(np.int32).max
|
18 |
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
|
19 |
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1536"))
|
20 |
-
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "
|
21 |
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
|
22 |
|
23 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
@@ -32,11 +32,11 @@ if torch.cuda.is_available():
|
|
32 |
add_watermarker=False,
|
33 |
variant="fp16"
|
34 |
)
|
35 |
-
if ENABLE_CPU_OFFLOAD:
|
36 |
-
|
37 |
-
else:
|
38 |
-
|
39 |
-
|
40 |
|
41 |
if USE_TORCH_COMPILE:
|
42 |
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
@@ -67,6 +67,7 @@ def generate(
|
|
67 |
use_resolution_binning: bool = True,
|
68 |
progress=gr.Progress(track_tqdm=True),
|
69 |
):
|
|
|
70 |
seed = int(randomize_seed_fn(seed, randomize_seed))
|
71 |
generator = torch.Generator().manual_seed(seed)
|
72 |
|
|
|
17 |
MAX_SEED = np.iinfo(np.int32).max
|
18 |
CACHE_EXAMPLES = torch.cuda.is_available() and os.getenv("CACHE_EXAMPLES", "1") == "1"
|
19 |
MAX_IMAGE_SIZE = int(os.getenv("MAX_IMAGE_SIZE", "1536"))
|
20 |
+
USE_TORCH_COMPILE = os.getenv("USE_TORCH_COMPILE", "0") == "1"
|
21 |
ENABLE_CPU_OFFLOAD = os.getenv("ENABLE_CPU_OFFLOAD", "0") == "1"
|
22 |
|
23 |
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
|
|
32 |
add_watermarker=False,
|
33 |
variant="fp16"
|
34 |
)
|
35 |
+
#if ENABLE_CPU_OFFLOAD:
|
36 |
+
# pipe.enable_model_cpu_offload()
|
37 |
+
#else:
|
38 |
+
# pipe.to(device)
|
39 |
+
# print("Loaded on Device!")
|
40 |
|
41 |
if USE_TORCH_COMPILE:
|
42 |
pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
|
|
|
67 |
use_resolution_binning: bool = True,
|
68 |
progress=gr.Progress(track_tqdm=True),
|
69 |
):
|
70 |
+
pipe.to(device)
|
71 |
seed = int(randomize_seed_fn(seed, randomize_seed))
|
72 |
generator = torch.Generator().manual_seed(seed)
|
73 |
|