anhnv125 commited on
Commit
6bd0ee9
·
1 Parent(s): bdb2571

update code

Browse files
.idea/.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
.idea/FRN.iml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="inheritedJdk" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ <component name="PyDocumentationSettings">
9
+ <option name="format" value="PLAIN" />
10
+ <option name="myDocStringFormat" value="Plain" />
11
+ </component>
12
+ </module>
.idea/inspectionProfiles/Project_Default.xml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <profile version="1.0">
3
+ <option name="myName" value="Project Default" />
4
+ <inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
5
+ <option name="ignoredPackages">
6
+ <value>
7
+ <list size="10">
8
+ <item index="0" class="java.lang.String" itemvalue="tqdm" />
9
+ <item index="1" class="java.lang.String" itemvalue="scipy" />
10
+ <item index="2" class="java.lang.String" itemvalue="torchmetrics" />
11
+ <item index="3" class="java.lang.String" itemvalue="tensorboard" />
12
+ <item index="4" class="java.lang.String" itemvalue="scikit_learn" />
13
+ <item index="5" class="java.lang.String" itemvalue="matplotlib" />
14
+ <item index="6" class="java.lang.String" itemvalue="torch" />
15
+ <item index="7" class="java.lang.String" itemvalue="numpy" />
16
+ <item index="8" class="java.lang.String" itemvalue="einops" />
17
+ <item index="9" class="java.lang.String" itemvalue="pandas" />
18
+ </list>
19
+ </value>
20
+ </option>
21
+ </inspection_tool>
22
+ </profile>
23
+ </component>
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/FRN.iml" filepath="$PROJECT_DIR$/.idea/FRN.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/vcs.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="VcsDirectoryMappings">
4
+ <mapping directory="" vcs="Git" />
5
+ </component>
6
+ </project>
main.py CHANGED
@@ -104,21 +104,21 @@ if __name__ == '__main__':
104
  model.freeze()
105
  if args.mode == 'eval':
106
  model.cuda(device=0)
107
- trainer = pl.Trainer(gpus=1)
108
  testset = TestLoader()
109
  test_loader = DataLoader(testset, batch_size=1, num_workers=4)
110
  trainer.test(model, test_loader)
111
  print('Version', args.version)
112
  masking = CONFIG.DATA.EVAL.masking
113
  prob = CONFIG.DATA.EVAL.transition_probs[0]
114
- loss_percent = (1 - prob[0]) / (2 - prob[0][prob[1]]) * 100
115
  print('Evaluate with real trace' if masking == 'real' else
116
- 'Evaluate with generated trace {}% packet loss'.format(str(prob)))
117
  elif args.mode == 'test':
118
  model.cuda(device=0)
119
  testset = BlindTestLoader(test_dir=CONFIG.TEST.in_dir)
120
  test_loader = DataLoader(testset, batch_size=1, num_workers=4)
121
- trainer = pl.Trainer(gpus=1)
122
  preds = trainer.predict(model, test_loader, return_predictions=True)
123
  mkdir_p(CONFIG.TEST.out_dir)
124
  for idx, path in enumerate(test_loader.dataset.data_list):
 
104
  model.freeze()
105
  if args.mode == 'eval':
106
  model.cuda(device=0)
107
+ trainer = pl.Trainer(accelerator='gpu', devices=1, enable_checkpointing=False, logger=False)
108
  testset = TestLoader()
109
  test_loader = DataLoader(testset, batch_size=1, num_workers=4)
110
  trainer.test(model, test_loader)
111
  print('Version', args.version)
112
  masking = CONFIG.DATA.EVAL.masking
113
  prob = CONFIG.DATA.EVAL.transition_probs[0]
114
+ loss_percent = (1 - prob[0]) / (2 - prob[0] - prob[1]) * 100
115
  print('Evaluate with real trace' if masking == 'real' else
116
+ 'Evaluate with generated trace with {:.2f}% packet loss'.format(prob))
117
  elif args.mode == 'test':
118
  model.cuda(device=0)
119
  testset = BlindTestLoader(test_dir=CONFIG.TEST.in_dir)
120
  test_loader = DataLoader(testset, batch_size=1, num_workers=4)
121
+ trainer = pl.Trainer(accelerator='gpu', devices=1, enable_checkpointing=False, logger=False)
122
  preds = trainer.predict(model, test_loader, return_predictions=True)
123
  mkdir_p(CONFIG.TEST.out_dir)
124
  for idx, path in enumerate(test_loader.dataset.data_list):
models/frn.py CHANGED
@@ -160,8 +160,8 @@ class PLCModel(pl.LightningModule):
160
  sf.write(os.path.join(path, 'lossy_input.wav'), inp_wav, samplerate=CONFIG.DATA.sr, subtype='PCM_16')
161
  sf.write(os.path.join(path, 'target.wav'), tar_wav, samplerate=CONFIG.DATA.sr, subtype='PCM_16')
162
  if CONFIG.DATA.sr != 16000:
163
- pred = librosa.resample(pred, 48000, 16000)
164
- tar_wav = librosa.resample(tar_wav, 48000, 16000, res_type='kaiser_fast')
165
  ret = plcmos.run(pred, tar_wav)
166
  pesq = self.pesq(torch.tensor(pred), torch.tensor(tar_wav))
167
  metrics = {
 
160
  sf.write(os.path.join(path, 'lossy_input.wav'), inp_wav, samplerate=CONFIG.DATA.sr, subtype='PCM_16')
161
  sf.write(os.path.join(path, 'target.wav'), tar_wav, samplerate=CONFIG.DATA.sr, subtype='PCM_16')
162
  if CONFIG.DATA.sr != 16000:
163
+ pred = librosa.resample(pred, orig_sr=48000, target_sr=16000)
164
+ tar_wav = librosa.resample(tar_wav, orig_sr=48000, target_sr=16000, res_type='kaiser_fast')
165
  ret = plcmos.run(pred, tar_wav)
166
  pesq = self.pesq(torch.tensor(pred), torch.tensor(tar_wav))
167
  metrics = {
utils/utils.py CHANGED
@@ -52,7 +52,7 @@ def visualize(target, input, recon, path):
52
 
53
 
54
  def get_power(x, nfft):
55
- S = librosa.stft(x, nfft)
56
  S = np.log(np.abs(S) ** 2 + 1e-8)
57
  return S
58
 
 
52
 
53
 
54
  def get_power(x, nfft):
55
+ S = librosa.stft(x, n_fft=nfft)
56
  S = np.log(np.abs(S) ** 2 + 1e-8)
57
  return S
58