ZhifengKong commited on
Commit
64fc4c7
·
1 Parent(s): 715f77e

auto select device

Browse files
Files changed (2) hide show
  1. app.py +6 -2
  2. src/factory.py +7 -3
app.py CHANGED
@@ -17,9 +17,13 @@ import laion_clap
17
  from inference_utils import prepare_tokenizer, prepare_model, inference
18
  from data import AudioTextDataProcessor
19
 
 
 
 
 
20
 
21
  def load_laionclap():
22
- model = laion_clap.CLAP_Module(enable_fusion=True, amodel='HTSAT-tiny').cuda()
23
  model.load_ckpt(ckpt='630k-audioset-fusion-best.pt')
24
  model.eval()
25
  return model
@@ -94,7 +98,7 @@ def compute_laionclap_text_audio_sim(audio_file, laionclap_model, outputs):
94
  return [0.0] * len(outputs)
95
 
96
  audio_data = data.reshape(1, -1)
97
- audio_data_tensor = torch.from_numpy(int16_to_float32(float32_to_int16(audio_data))).float().cuda()
98
  audio_embed = laionclap_model.get_audio_embedding_from_data(x=audio_data_tensor, use_tensor=True)
99
 
100
  text_embed = laionclap_model.get_text_embedding(outputs, use_tensor=True)
 
17
  from inference_utils import prepare_tokenizer, prepare_model, inference
18
  from data import AudioTextDataProcessor
19
 
20
+ if torch.cuda.is_available():
21
+ device = 'cuda:0'
22
+ else:
23
+ device = 'cpu'
24
 
25
  def load_laionclap():
26
+ model = laion_clap.CLAP_Module(enable_fusion=True, amodel='HTSAT-tiny').to(device)
27
  model.load_ckpt(ckpt='630k-audioset-fusion-best.pt')
28
  model.eval()
29
  return model
 
98
  return [0.0] * len(outputs)
99
 
100
  audio_data = data.reshape(1, -1)
101
+ audio_data_tensor = torch.from_numpy(int16_to_float32(float32_to_int16(audio_data))).float().to(device)
102
  audio_embed = laionclap_model.get_audio_embedding_from_data(x=audio_data_tensor, use_tensor=True)
103
 
104
  text_embed = laionclap_model.get_text_embedding(outputs, use_tensor=True)
src/factory.py CHANGED
@@ -30,7 +30,11 @@ class CLAP(nn.Module):
30
  def __init__(self, clap_config):
31
  super(CLAP, self).__init__()
32
  self.method = clap_config["method"]
33
- device_id = f'cuda:{torch.cuda.current_device()}'
 
 
 
 
34
 
35
  if self.method == 'laion-clap':
36
  # https://github.com/LAION-AI/CLAP
@@ -42,7 +46,7 @@ class CLAP(nn.Module):
42
  raise NotImplementedError
43
 
44
  enable_fusion = 'fusion' in clap_config["model_name"].lower()
45
- self.laion_clap = CLAP_Module(enable_fusion=enable_fusion, amodel=amodel, device=device_id)
46
  self.laion_clap.load_ckpt(ckpt=clap_config["checkpoint"])
47
 
48
  for param in self.laion_clap.parameters():
@@ -57,7 +61,7 @@ class CLAP(nn.Module):
57
  clap_config["checkpoint"],
58
  config_root=clap_config["config_root"],
59
  version=clap_config['model_name'],
60
- use_cuda=True
61
  )
62
 
63
  if clap_config['model_name'] in ['2022', '2023']:
 
30
  def __init__(self, clap_config):
31
  super(CLAP, self).__init__()
32
  self.method = clap_config["method"]
33
+
34
+ if torch.cuda.is_available():
35
+ device = 'cuda:0'
36
+ else:
37
+ device = 'cpu'
38
 
39
  if self.method == 'laion-clap':
40
  # https://github.com/LAION-AI/CLAP
 
46
  raise NotImplementedError
47
 
48
  enable_fusion = 'fusion' in clap_config["model_name"].lower()
49
+ self.laion_clap = CLAP_Module(enable_fusion=enable_fusion, amodel=amodel, device=device)
50
  self.laion_clap.load_ckpt(ckpt=clap_config["checkpoint"])
51
 
52
  for param in self.laion_clap.parameters():
 
61
  clap_config["checkpoint"],
62
  config_root=clap_config["config_root"],
63
  version=clap_config['model_name'],
64
+ use_cuda=torch.cuda.is_available()
65
  )
66
 
67
  if clap_config['model_name'] in ['2022', '2023']: