Spaces:
Running
on
A10G
Running
on
A10G
lengyue233
commited on
Commit
•
a4dfb48
1
Parent(s):
f7a538e
wait for init
Browse files- app.py +3 -2
- tools/llama/generate.py +4 -0
app.py
CHANGED
@@ -306,7 +306,7 @@ if __name__ == "__main__":
|
|
306 |
args.vqgan_config_name = "vqgan_pretrain"
|
307 |
|
308 |
logger.info("Loading Llama model...")
|
309 |
-
|
310 |
llama_queue = launch_thread_safe_queue(
|
311 |
config_name=args.llama_config_name,
|
312 |
checkpoint_path=args.llama_checkpoint_path,
|
@@ -314,11 +314,12 @@ if __name__ == "__main__":
|
|
314 |
precision=args.precision,
|
315 |
max_length=args.max_length,
|
316 |
compile=args.compile,
|
|
|
317 |
)
|
318 |
llama_tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
|
|
|
319 |
logger.info("Llama model loaded, loading VQ-GAN model...")
|
320 |
|
321 |
-
hydra.core.global_hydra.GlobalHydra.instance().clear()
|
322 |
vqgan_model = load_vqgan_model(
|
323 |
config_name=args.vqgan_config_name,
|
324 |
checkpoint_path=args.vqgan_checkpoint_path,
|
|
|
306 |
args.vqgan_config_name = "vqgan_pretrain"
|
307 |
|
308 |
logger.info("Loading Llama model...")
|
309 |
+
init_event = threading.Event()
|
310 |
llama_queue = launch_thread_safe_queue(
|
311 |
config_name=args.llama_config_name,
|
312 |
checkpoint_path=args.llama_checkpoint_path,
|
|
|
314 |
precision=args.precision,
|
315 |
max_length=args.max_length,
|
316 |
compile=args.compile,
|
317 |
+
init_event=init_event,
|
318 |
)
|
319 |
llama_tokenizer = AutoTokenizer.from_pretrained(args.tokenizer)
|
320 |
+
init_event.wait()
|
321 |
logger.info("Llama model loaded, loading VQ-GAN model...")
|
322 |
|
|
|
323 |
vqgan_model = load_vqgan_model(
|
324 |
config_name=args.vqgan_config_name,
|
325 |
checkpoint_path=args.vqgan_checkpoint_path,
|
tools/llama/generate.py
CHANGED
@@ -607,6 +607,7 @@ def launch_thread_safe_queue(
|
|
607 |
precision,
|
608 |
max_length,
|
609 |
compile=False,
|
|
|
610 |
):
|
611 |
input_queue = queue.Queue()
|
612 |
|
@@ -615,6 +616,9 @@ def launch_thread_safe_queue(
|
|
615 |
config_name, checkpoint_path, device, precision, max_length, compile=compile
|
616 |
)
|
617 |
|
|
|
|
|
|
|
618 |
while True:
|
619 |
item = input_queue.get()
|
620 |
if item is None:
|
|
|
607 |
precision,
|
608 |
max_length,
|
609 |
compile=False,
|
610 |
+
init_event=None,
|
611 |
):
|
612 |
input_queue = queue.Queue()
|
613 |
|
|
|
616 |
config_name, checkpoint_path, device, precision, max_length, compile=compile
|
617 |
)
|
618 |
|
619 |
+
if init_event is not None:
|
620 |
+
init_event.set()
|
621 |
+
|
622 |
while True:
|
623 |
item = input_queue.get()
|
624 |
if item is None:
|