Spaces:
Running
on
L4
Running
on
L4
lengyue233
commited on
Enable compile on A10G
Browse files- app.py +2 -2
- tools/llama/generate.py +37 -23
app.py
CHANGED
@@ -251,7 +251,7 @@ def build_app():
|
|
251 |
# speaker,
|
252 |
],
|
253 |
[audio, error],
|
254 |
-
|
255 |
)
|
256 |
|
257 |
return app
|
@@ -287,7 +287,7 @@ if __name__ == "__main__":
|
|
287 |
args = parse_args()
|
288 |
|
289 |
args.precision = torch.half if args.half else torch.bfloat16
|
290 |
-
|
291 |
|
292 |
logger.info("Loading Llama model...")
|
293 |
llama_model, decode_one_token = load_llama_model(
|
|
|
251 |
# speaker,
|
252 |
],
|
253 |
[audio, error],
|
254 |
+
concurrency_limit=1,
|
255 |
)
|
256 |
|
257 |
return app
|
|
|
287 |
args = parse_args()
|
288 |
|
289 |
args.precision = torch.half if args.half else torch.bfloat16
|
290 |
+
args.compile = True
|
291 |
|
292 |
logger.info("Loading Llama model...")
|
293 |
llama_model, decode_one_token = load_llama_model(
|
tools/llama/generate.py
CHANGED
@@ -14,7 +14,7 @@ from loguru import logger
|
|
14 |
from tqdm import tqdm
|
15 |
from transformers import AutoTokenizer
|
16 |
|
17 |
-
from fish_speech.datasets.text import CODEBOOK_EOS_TOKEN_ID
|
18 |
from fish_speech.text.clean import clean_text
|
19 |
|
20 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
@@ -291,11 +291,11 @@ def encode_tokens(
|
|
291 |
):
|
292 |
string = clean_text(string)
|
293 |
|
294 |
-
if speaker is
|
295 |
-
|
296 |
|
297 |
string = (
|
298 |
-
f"<|im_start|>user<|im_sep|>{string}<|im_end|><|im_start|>
|
299 |
)
|
300 |
if bos:
|
301 |
string = f"<|begin_of_sequence|>{string}"
|
@@ -309,7 +309,10 @@ def encode_tokens(
|
|
309 |
tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
|
310 |
|
311 |
# Codebooks
|
312 |
-
zeros =
|
|
|
|
|
|
|
313 |
prompt = torch.cat((tokens, zeros), dim=0)
|
314 |
|
315 |
if prompt_tokens is None:
|
@@ -331,13 +334,23 @@ def encode_tokens(
|
|
331 |
)
|
332 |
data = data[:num_codebooks]
|
333 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
334 |
# Since 1.0, we use <|semantic|>
|
335 |
s0_token_id = tokenizer.convert_tokens_to_ids("<|semantic|>")
|
336 |
-
|
337 |
-
|
338 |
-
dtype=torch.int,
|
339 |
-
device=device,
|
340 |
)
|
|
|
341 |
|
342 |
data = torch.cat((main_token_ids, data), dim=0)
|
343 |
prompt = torch.cat((prompt, data), dim=1)
|
@@ -450,6 +463,20 @@ def generate_long(
|
|
450 |
use_prompt = prompt_text is not None and prompt_tokens is not None
|
451 |
encoded = []
|
452 |
texts = split_text(text, chunk_length) if iterative_prompt else [text]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
453 |
for idx, text in enumerate(texts):
|
454 |
encoded.append(
|
455 |
encode_tokens(
|
@@ -457,25 +484,12 @@ def generate_long(
|
|
457 |
string=text,
|
458 |
bos=idx == 0 and not use_prompt,
|
459 |
device=device,
|
460 |
-
speaker=
|
461 |
num_codebooks=model.config.num_codebooks,
|
462 |
)
|
463 |
)
|
464 |
logger.info(f"Encoded text: {text}")
|
465 |
|
466 |
-
if use_prompt:
|
467 |
-
encoded_prompt = encode_tokens(
|
468 |
-
tokenizer,
|
469 |
-
prompt_text,
|
470 |
-
prompt_tokens=prompt_tokens,
|
471 |
-
bos=True,
|
472 |
-
device=device,
|
473 |
-
speaker=speaker,
|
474 |
-
num_codebooks=model.config.num_codebooks,
|
475 |
-
)
|
476 |
-
|
477 |
-
encoded[0] = torch.cat((encoded_prompt, encoded[0]), dim=1)
|
478 |
-
|
479 |
for sample_idx in range(num_samples):
|
480 |
torch.cuda.synchronize()
|
481 |
global_encoded = []
|
|
|
14 |
from tqdm import tqdm
|
15 |
from transformers import AutoTokenizer
|
16 |
|
17 |
+
from fish_speech.datasets.text import CODEBOOK_EOS_TOKEN_ID, CODEBOOK_PAD_TOKEN_ID
|
18 |
from fish_speech.text.clean import clean_text
|
19 |
|
20 |
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
|
291 |
):
|
292 |
string = clean_text(string)
|
293 |
|
294 |
+
if speaker is None:
|
295 |
+
speaker = "assistant"
|
296 |
|
297 |
string = (
|
298 |
+
f"<|im_start|>user<|im_sep|>{string}<|im_end|><|im_start|>{speaker}<|im_sep|>"
|
299 |
)
|
300 |
if bos:
|
301 |
string = f"<|begin_of_sequence|>{string}"
|
|
|
309 |
tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
|
310 |
|
311 |
# Codebooks
|
312 |
+
zeros = (
|
313 |
+
torch.ones((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
|
314 |
+
* CODEBOOK_PAD_TOKEN_ID
|
315 |
+
)
|
316 |
prompt = torch.cat((tokens, zeros), dim=0)
|
317 |
|
318 |
if prompt_tokens is None:
|
|
|
334 |
)
|
335 |
data = data[:num_codebooks]
|
336 |
|
337 |
+
# Add eos token for each codebook
|
338 |
+
data = torch.cat(
|
339 |
+
(
|
340 |
+
data,
|
341 |
+
torch.ones((data.size(0), 1), dtype=torch.int, device=device)
|
342 |
+
* CODEBOOK_EOS_TOKEN_ID,
|
343 |
+
),
|
344 |
+
dim=1,
|
345 |
+
)
|
346 |
+
|
347 |
# Since 1.0, we use <|semantic|>
|
348 |
s0_token_id = tokenizer.convert_tokens_to_ids("<|semantic|>")
|
349 |
+
end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
350 |
+
main_token_ids = (
|
351 |
+
torch.ones((1, data.size(1)), dtype=torch.int, device=device) * s0_token_id
|
|
|
352 |
)
|
353 |
+
main_token_ids[0, -1] = end_token_id
|
354 |
|
355 |
data = torch.cat((main_token_ids, data), dim=0)
|
356 |
prompt = torch.cat((prompt, data), dim=1)
|
|
|
463 |
use_prompt = prompt_text is not None and prompt_tokens is not None
|
464 |
encoded = []
|
465 |
texts = split_text(text, chunk_length) if iterative_prompt else [text]
|
466 |
+
|
467 |
+
if use_prompt:
|
468 |
+
encoded.append(
|
469 |
+
encode_tokens(
|
470 |
+
tokenizer,
|
471 |
+
prompt_text,
|
472 |
+
prompt_tokens=prompt_tokens,
|
473 |
+
bos=True,
|
474 |
+
device=device,
|
475 |
+
speaker=speaker,
|
476 |
+
num_codebooks=model.config.num_codebooks,
|
477 |
+
)
|
478 |
+
)
|
479 |
+
|
480 |
for idx, text in enumerate(texts):
|
481 |
encoded.append(
|
482 |
encode_tokens(
|
|
|
484 |
string=text,
|
485 |
bos=idx == 0 and not use_prompt,
|
486 |
device=device,
|
487 |
+
speaker=speaker,
|
488 |
num_codebooks=model.config.num_codebooks,
|
489 |
)
|
490 |
)
|
491 |
logger.info(f"Encoded text: {text}")
|
492 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
493 |
for sample_idx in range(num_samples):
|
494 |
torch.cuda.synchronize()
|
495 |
global_encoded = []
|