Spaces:
Build error
Build error
ZhifengKong
commited on
Commit
•
0195d32
1
Parent(s):
64fc4c7
update
Browse files- app.py +3 -1
- inference_utils.py +6 -6
app.py
CHANGED
@@ -132,7 +132,8 @@ laionclap_model = load_laionclap()
|
|
132 |
model = prepare_model(
|
133 |
model_config=model_config,
|
134 |
clap_config=clap_config,
|
135 |
-
checkpoint_path='chat.pt'
|
|
|
136 |
)
|
137 |
|
138 |
|
@@ -147,6 +148,7 @@ def inference_item(name, prompt):
|
|
147 |
outputs = inference(
|
148 |
model, text_tokenizer, item, processed_item,
|
149 |
inference_kwargs,
|
|
|
150 |
)
|
151 |
|
152 |
laionclap_scores = compute_laionclap_text_audio_sim(
|
|
|
132 |
model = prepare_model(
|
133 |
model_config=model_config,
|
134 |
clap_config=clap_config,
|
135 |
+
checkpoint_path='chat.pt',
|
136 |
+
device=device
|
137 |
)
|
138 |
|
139 |
|
|
|
148 |
outputs = inference(
|
149 |
model, text_tokenizer, item, processed_item,
|
150 |
inference_kwargs,
|
151 |
+
device=device
|
152 |
)
|
153 |
|
154 |
laionclap_scores = compute_laionclap_text_audio_sim(
|
inference_utils.py
CHANGED
@@ -33,7 +33,7 @@ def prepare_tokenizer(model_config):
|
|
33 |
return text_tokenizer
|
34 |
|
35 |
|
36 |
-
def prepare_model(model_config, clap_config, checkpoint_path,
|
37 |
os.environ["TOKENIZERS_PARALLELISM"] = "false" # disable the tokenizer parallelism warning
|
38 |
model, tokenizer = create_model_and_transforms(
|
39 |
**model_config,
|
@@ -43,7 +43,7 @@ def prepare_model(model_config, clap_config, checkpoint_path, device_id=0):
|
|
43 |
freeze_lm_embeddings=False,
|
44 |
)
|
45 |
model.eval()
|
46 |
-
model = model.to(
|
47 |
|
48 |
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
49 |
model_state_dict = checkpoint["model_state_dict"]
|
@@ -53,11 +53,11 @@ def prepare_model(model_config, clap_config, checkpoint_path, device_id=0):
|
|
53 |
return model
|
54 |
|
55 |
|
56 |
-
def inference(model, tokenizer, item, processed_item, inference_kwargs,
|
57 |
filename, audio_clips, audio_embed_mask, input_ids, attention_mask = processed_item
|
58 |
-
audio_clips = audio_clips.to(
|
59 |
-
audio_embed_mask = audio_embed_mask.to(
|
60 |
-
input_ids = input_ids.to(
|
61 |
|
62 |
media_token_id = tokenizer.encode("<audio>")[-1]
|
63 |
eoc_token_id = tokenizer.encode("<|endofchunk|>")[-1]
|
|
|
33 |
return text_tokenizer
|
34 |
|
35 |
|
36 |
+
def prepare_model(model_config, clap_config, checkpoint_path, device=0):
|
37 |
os.environ["TOKENIZERS_PARALLELISM"] = "false" # disable the tokenizer parallelism warning
|
38 |
model, tokenizer = create_model_and_transforms(
|
39 |
**model_config,
|
|
|
43 |
freeze_lm_embeddings=False,
|
44 |
)
|
45 |
model.eval()
|
46 |
+
model = model.to(device)
|
47 |
|
48 |
checkpoint = torch.load(checkpoint_path, map_location="cpu")
|
49 |
model_state_dict = checkpoint["model_state_dict"]
|
|
|
53 |
return model
|
54 |
|
55 |
|
56 |
+
def inference(model, tokenizer, item, processed_item, inference_kwargs, device=0):
|
57 |
filename, audio_clips, audio_embed_mask, input_ids, attention_mask = processed_item
|
58 |
+
audio_clips = audio_clips.to(device, dtype=None, non_blocking=True)
|
59 |
+
audio_embed_mask = audio_embed_mask.to(device, dtype=None, non_blocking=True)
|
60 |
+
input_ids = input_ids.to(device, dtype=None, non_blocking=True).squeeze()
|
61 |
|
62 |
media_token_id = tokenizer.encode("<audio>")[-1]
|
63 |
eoc_token_id = tokenizer.encode("<|endofchunk|>")[-1]
|