Spaces:
Build error
Build error
ZhifengKong
commited on
Commit
·
64fc4c7
1
Parent(s):
715f77e
auto select device
Browse files- app.py +6 -2
- src/factory.py +7 -3
app.py
CHANGED
@@ -17,9 +17,13 @@ import laion_clap
|
|
17 |
from inference_utils import prepare_tokenizer, prepare_model, inference
|
18 |
from data import AudioTextDataProcessor
|
19 |
|
|
|
|
|
|
|
|
|
20 |
|
21 |
def load_laionclap():
|
22 |
-
model = laion_clap.CLAP_Module(enable_fusion=True, amodel='HTSAT-tiny').
|
23 |
model.load_ckpt(ckpt='630k-audioset-fusion-best.pt')
|
24 |
model.eval()
|
25 |
return model
|
@@ -94,7 +98,7 @@ def compute_laionclap_text_audio_sim(audio_file, laionclap_model, outputs):
|
|
94 |
return [0.0] * len(outputs)
|
95 |
|
96 |
audio_data = data.reshape(1, -1)
|
97 |
-
audio_data_tensor = torch.from_numpy(int16_to_float32(float32_to_int16(audio_data))).float().
|
98 |
audio_embed = laionclap_model.get_audio_embedding_from_data(x=audio_data_tensor, use_tensor=True)
|
99 |
|
100 |
text_embed = laionclap_model.get_text_embedding(outputs, use_tensor=True)
|
|
|
17 |
from inference_utils import prepare_tokenizer, prepare_model, inference
|
18 |
from data import AudioTextDataProcessor
|
19 |
|
20 |
+
if torch.cuda.is_available():
|
21 |
+
device = 'cuda:0'
|
22 |
+
else:
|
23 |
+
device = 'cpu'
|
24 |
|
25 |
def load_laionclap():
|
26 |
+
model = laion_clap.CLAP_Module(enable_fusion=True, amodel='HTSAT-tiny').to(device)
|
27 |
model.load_ckpt(ckpt='630k-audioset-fusion-best.pt')
|
28 |
model.eval()
|
29 |
return model
|
|
|
98 |
return [0.0] * len(outputs)
|
99 |
|
100 |
audio_data = data.reshape(1, -1)
|
101 |
+
audio_data_tensor = torch.from_numpy(int16_to_float32(float32_to_int16(audio_data))).float().to(device)
|
102 |
audio_embed = laionclap_model.get_audio_embedding_from_data(x=audio_data_tensor, use_tensor=True)
|
103 |
|
104 |
text_embed = laionclap_model.get_text_embedding(outputs, use_tensor=True)
|
src/factory.py
CHANGED
@@ -30,7 +30,11 @@ class CLAP(nn.Module):
|
|
30 |
def __init__(self, clap_config):
|
31 |
super(CLAP, self).__init__()
|
32 |
self.method = clap_config["method"]
|
33 |
-
|
|
|
|
|
|
|
|
|
34 |
|
35 |
if self.method == 'laion-clap':
|
36 |
# https://github.com/LAION-AI/CLAP
|
@@ -42,7 +46,7 @@ class CLAP(nn.Module):
|
|
42 |
raise NotImplementedError
|
43 |
|
44 |
enable_fusion = 'fusion' in clap_config["model_name"].lower()
|
45 |
-
self.laion_clap = CLAP_Module(enable_fusion=enable_fusion, amodel=amodel, device=
|
46 |
self.laion_clap.load_ckpt(ckpt=clap_config["checkpoint"])
|
47 |
|
48 |
for param in self.laion_clap.parameters():
|
@@ -57,7 +61,7 @@ class CLAP(nn.Module):
|
|
57 |
clap_config["checkpoint"],
|
58 |
config_root=clap_config["config_root"],
|
59 |
version=clap_config['model_name'],
|
60 |
-
use_cuda=
|
61 |
)
|
62 |
|
63 |
if clap_config['model_name'] in ['2022', '2023']:
|
|
|
30 |
def __init__(self, clap_config):
|
31 |
super(CLAP, self).__init__()
|
32 |
self.method = clap_config["method"]
|
33 |
+
|
34 |
+
if torch.cuda.is_available():
|
35 |
+
device = 'cuda:0'
|
36 |
+
else:
|
37 |
+
device = 'cpu'
|
38 |
|
39 |
if self.method == 'laion-clap':
|
40 |
# https://github.com/LAION-AI/CLAP
|
|
|
46 |
raise NotImplementedError
|
47 |
|
48 |
enable_fusion = 'fusion' in clap_config["model_name"].lower()
|
49 |
+
self.laion_clap = CLAP_Module(enable_fusion=enable_fusion, amodel=amodel, device=device)
|
50 |
self.laion_clap.load_ckpt(ckpt=clap_config["checkpoint"])
|
51 |
|
52 |
for param in self.laion_clap.parameters():
|
|
|
61 |
clap_config["checkpoint"],
|
62 |
config_root=clap_config["config_root"],
|
63 |
version=clap_config['model_name'],
|
64 |
+
use_cuda=torch.cuda.is_available()
|
65 |
)
|
66 |
|
67 |
if clap_config['model_name'] in ['2022', '2023']:
|