lengyue233 commited on
Commit
662d788
·
verified ·
1 Parent(s): 2f06fba

Enable compile on A10G

Browse files
Files changed (2) hide show
  1. app.py +2 -2
  2. tools/llama/generate.py +37 -23
app.py CHANGED
@@ -251,7 +251,7 @@ def build_app():
251
  # speaker,
252
  ],
253
  [audio, error],
254
- # concurrency_limit=1,
255
  )
256
 
257
  return app
@@ -287,7 +287,7 @@ if __name__ == "__main__":
287
  args = parse_args()
288
 
289
  args.precision = torch.half if args.half else torch.bfloat16
290
- # args.compile = True
291
 
292
  logger.info("Loading Llama model...")
293
  llama_model, decode_one_token = load_llama_model(
 
251
  # speaker,
252
  ],
253
  [audio, error],
254
+ concurrency_limit=1,
255
  )
256
 
257
  return app
 
287
  args = parse_args()
288
 
289
  args.precision = torch.half if args.half else torch.bfloat16
290
+ args.compile = True
291
 
292
  logger.info("Loading Llama model...")
293
  llama_model, decode_one_token = load_llama_model(
tools/llama/generate.py CHANGED
@@ -14,7 +14,7 @@ from loguru import logger
14
  from tqdm import tqdm
15
  from transformers import AutoTokenizer
16
 
17
- from fish_speech.datasets.text import CODEBOOK_EOS_TOKEN_ID
18
  from fish_speech.text.clean import clean_text
19
 
20
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -291,11 +291,11 @@ def encode_tokens(
291
  ):
292
  string = clean_text(string)
293
 
294
- if speaker is not None:
295
- string = f"[SPK: {speaker}] {string}"
296
 
297
  string = (
298
- f"<|im_start|>user<|im_sep|>{string}<|im_end|><|im_start|>assistant<|im_sep|>"
299
  )
300
  if bos:
301
  string = f"<|begin_of_sequence|>{string}"
@@ -309,7 +309,10 @@ def encode_tokens(
309
  tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
310
 
311
  # Codebooks
312
- zeros = torch.zeros((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
 
 
 
313
  prompt = torch.cat((tokens, zeros), dim=0)
314
 
315
  if prompt_tokens is None:
@@ -331,13 +334,23 @@ def encode_tokens(
331
  )
332
  data = data[:num_codebooks]
333
 
 
 
 
 
 
 
 
 
 
 
334
  # Since 1.0, we use <|semantic|>
335
  s0_token_id = tokenizer.convert_tokens_to_ids("<|semantic|>")
336
- main_token_ids = torch.tensor(
337
- [[s0_token_id] * data.size(1)],
338
- dtype=torch.int,
339
- device=device,
340
  )
 
341
 
342
  data = torch.cat((main_token_ids, data), dim=0)
343
  prompt = torch.cat((prompt, data), dim=1)
@@ -450,6 +463,20 @@ def generate_long(
450
  use_prompt = prompt_text is not None and prompt_tokens is not None
451
  encoded = []
452
  texts = split_text(text, chunk_length) if iterative_prompt else [text]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
453
  for idx, text in enumerate(texts):
454
  encoded.append(
455
  encode_tokens(
@@ -457,25 +484,12 @@ def generate_long(
457
  string=text,
458
  bos=idx == 0 and not use_prompt,
459
  device=device,
460
- speaker=None,
461
  num_codebooks=model.config.num_codebooks,
462
  )
463
  )
464
  logger.info(f"Encoded text: {text}")
465
 
466
- if use_prompt:
467
- encoded_prompt = encode_tokens(
468
- tokenizer,
469
- prompt_text,
470
- prompt_tokens=prompt_tokens,
471
- bos=True,
472
- device=device,
473
- speaker=speaker,
474
- num_codebooks=model.config.num_codebooks,
475
- )
476
-
477
- encoded[0] = torch.cat((encoded_prompt, encoded[0]), dim=1)
478
-
479
  for sample_idx in range(num_samples):
480
  torch.cuda.synchronize()
481
  global_encoded = []
 
14
  from tqdm import tqdm
15
  from transformers import AutoTokenizer
16
 
17
+ from fish_speech.datasets.text import CODEBOOK_EOS_TOKEN_ID, CODEBOOK_PAD_TOKEN_ID
18
  from fish_speech.text.clean import clean_text
19
 
20
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
 
291
  ):
292
  string = clean_text(string)
293
 
294
+ if speaker is None:
295
+ speaker = "assistant"
296
 
297
  string = (
298
+ f"<|im_start|>user<|im_sep|>{string}<|im_end|><|im_start|>{speaker}<|im_sep|>"
299
  )
300
  if bos:
301
  string = f"<|begin_of_sequence|>{string}"
 
309
  tokens = torch.tensor([new_tokens], dtype=torch.int, device=device)
310
 
311
  # Codebooks
312
+ zeros = (
313
+ torch.ones((num_codebooks, tokens.size(1)), dtype=torch.int, device=device)
314
+ * CODEBOOK_PAD_TOKEN_ID
315
+ )
316
  prompt = torch.cat((tokens, zeros), dim=0)
317
 
318
  if prompt_tokens is None:
 
334
  )
335
  data = data[:num_codebooks]
336
 
337
+ # Add eos token for each codebook
338
+ data = torch.cat(
339
+ (
340
+ data,
341
+ torch.ones((data.size(0), 1), dtype=torch.int, device=device)
342
+ * CODEBOOK_EOS_TOKEN_ID,
343
+ ),
344
+ dim=1,
345
+ )
346
+
347
  # Since 1.0, we use <|semantic|>
348
  s0_token_id = tokenizer.convert_tokens_to_ids("<|semantic|>")
349
+ end_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
350
+ main_token_ids = (
351
+ torch.ones((1, data.size(1)), dtype=torch.int, device=device) * s0_token_id
 
352
  )
353
+ main_token_ids[0, -1] = end_token_id
354
 
355
  data = torch.cat((main_token_ids, data), dim=0)
356
  prompt = torch.cat((prompt, data), dim=1)
 
463
  use_prompt = prompt_text is not None and prompt_tokens is not None
464
  encoded = []
465
  texts = split_text(text, chunk_length) if iterative_prompt else [text]
466
+
467
+ if use_prompt:
468
+ encoded.append(
469
+ encode_tokens(
470
+ tokenizer,
471
+ prompt_text,
472
+ prompt_tokens=prompt_tokens,
473
+ bos=True,
474
+ device=device,
475
+ speaker=speaker,
476
+ num_codebooks=model.config.num_codebooks,
477
+ )
478
+ )
479
+
480
  for idx, text in enumerate(texts):
481
  encoded.append(
482
  encode_tokens(
 
484
  string=text,
485
  bos=idx == 0 and not use_prompt,
486
  device=device,
487
+ speaker=speaker,
488
  num_codebooks=model.config.num_codebooks,
489
  )
490
  )
491
  logger.info(f"Encoded text: {text}")
492
 
 
 
 
 
 
 
 
 
 
 
 
 
 
493
  for sample_idx in range(num_samples):
494
  torch.cuda.synchronize()
495
  global_encoded = []