ZhifengKong commited on
Commit
0195d32
1 Parent(s): 64fc4c7
Files changed (2) hide show
  1. app.py +3 -1
  2. inference_utils.py +6 -6
app.py CHANGED
@@ -132,7 +132,8 @@ laionclap_model = load_laionclap()
132
  model = prepare_model(
133
  model_config=model_config,
134
  clap_config=clap_config,
135
- checkpoint_path='chat.pt'
 
136
  )
137
 
138
 
@@ -147,6 +148,7 @@ def inference_item(name, prompt):
147
  outputs = inference(
148
  model, text_tokenizer, item, processed_item,
149
  inference_kwargs,
 
150
  )
151
 
152
  laionclap_scores = compute_laionclap_text_audio_sim(
 
132
  model = prepare_model(
133
  model_config=model_config,
134
  clap_config=clap_config,
135
+ checkpoint_path='chat.pt',
136
+ device=device
137
  )
138
 
139
 
 
148
  outputs = inference(
149
  model, text_tokenizer, item, processed_item,
150
  inference_kwargs,
151
+ device=device
152
  )
153
 
154
  laionclap_scores = compute_laionclap_text_audio_sim(
inference_utils.py CHANGED
@@ -33,7 +33,7 @@ def prepare_tokenizer(model_config):
33
  return text_tokenizer
34
 
35
 
36
- def prepare_model(model_config, clap_config, checkpoint_path, device_id=0):
37
  os.environ["TOKENIZERS_PARALLELISM"] = "false" # disable the tokenizer parallelism warning
38
  model, tokenizer = create_model_and_transforms(
39
  **model_config,
@@ -43,7 +43,7 @@ def prepare_model(model_config, clap_config, checkpoint_path, device_id=0):
43
  freeze_lm_embeddings=False,
44
  )
45
  model.eval()
46
- model = model.to(device_id)
47
 
48
  checkpoint = torch.load(checkpoint_path, map_location="cpu")
49
  model_state_dict = checkpoint["model_state_dict"]
@@ -53,11 +53,11 @@ def prepare_model(model_config, clap_config, checkpoint_path, device_id=0):
53
  return model
54
 
55
 
56
- def inference(model, tokenizer, item, processed_item, inference_kwargs, device_id=0):
57
  filename, audio_clips, audio_embed_mask, input_ids, attention_mask = processed_item
58
- audio_clips = audio_clips.to(device_id, dtype=None, non_blocking=True)
59
- audio_embed_mask = audio_embed_mask.to(device_id, dtype=None, non_blocking=True)
60
- input_ids = input_ids.to(device_id, dtype=None, non_blocking=True).squeeze()
61
 
62
  media_token_id = tokenizer.encode("<audio>")[-1]
63
  eoc_token_id = tokenizer.encode("<|endofchunk|>")[-1]
 
33
  return text_tokenizer
34
 
35
 
36
+ def prepare_model(model_config, clap_config, checkpoint_path, device=0):
37
  os.environ["TOKENIZERS_PARALLELISM"] = "false" # disable the tokenizer parallelism warning
38
  model, tokenizer = create_model_and_transforms(
39
  **model_config,
 
43
  freeze_lm_embeddings=False,
44
  )
45
  model.eval()
46
+ model = model.to(device)
47
 
48
  checkpoint = torch.load(checkpoint_path, map_location="cpu")
49
  model_state_dict = checkpoint["model_state_dict"]
 
53
  return model
54
 
55
 
56
+ def inference(model, tokenizer, item, processed_item, inference_kwargs, device=0):
57
  filename, audio_clips, audio_embed_mask, input_ids, attention_mask = processed_item
58
+ audio_clips = audio_clips.to(device, dtype=None, non_blocking=True)
59
+ audio_embed_mask = audio_embed_mask.to(device, dtype=None, non_blocking=True)
60
+ input_ids = input_ids.to(device, dtype=None, non_blocking=True).squeeze()
61
 
62
  media_token_id = tokenizer.encode("<audio>")[-1]
63
  eoc_token_id = tokenizer.encode("<|endofchunk|>")[-1]