hoang1007 commited on
Commit
5381499
1 Parent(s): 74c8f6d
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
app.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ sys.path.append("..")
4
+
5
+ import gradio
6
+ import torch, torchaudio
7
+ import numpy as np
8
+ from transformers import (
9
+ Wav2Vec2ForPreTraining,
10
+ Wav2Vec2CTCTokenizer,
11
+ Wav2Vec2FeatureExtractor,
12
+ )
13
+ from finetuning.wav2vec2 import SpeechRecognizer
14
+
15
+
16
+ def load_model(ckpt_path: str):
17
+ model_name = "nguyenvulebinh/wav2vec2-base-vietnamese-250h"
18
+
19
+ wav2vec2 = Wav2Vec2ForPreTraining.from_pretrained(model_name)
20
+ tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(model_name)
21
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
22
+
23
+ model = SpeechRecognizer.load_from_checkpoint(
24
+ ckpt_path,
25
+ wav2vec2=wav2vec2,
26
+ tokenizer=tokenizer,
27
+ feature_extractor=feature_extractor,
28
+ )
29
+
30
+ return model
31
+
32
+ model = load_model("checkpoints/last.ckpt")
33
+ model.eval()
34
+
35
+ def transcribe(audio):
36
+ sample_rate, waveform = audio
37
+ waveform = torch.from_numpy(waveform[:, 0]).float().unsqueeze_(0)
38
+ waveform = torchaudio.functional.resample(waveform, sample_rate, 16_000)
39
+
40
+ transcript = model.predict(waveform)[0]
41
+
42
+ return transcript
43
+
44
+ gradio.Interface(fn=transcribe, inputs=gradio.Audio(source="microphone", type="numpy"), outputs="textbox").launch()
checkpoints/.gitkeep ADDED
File without changes
finetuning/preprocess.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ sys.path.append("..")
4
+
5
+ import os
6
+ import argparse
7
+ from torch.utils.data import random_split
8
+ from src.datamodule import VLSP2020TarDataset, VLSP2020Dataset
9
+
10
+
11
+ def prepare_tar_dataset(data_dir: str, dest_dir: str):
12
+ dts = VLSP2020Dataset(data_dir)
13
+ train_set, val_set = random_split(dts, [42_000, 14_427])
14
+
15
+ VLSP2020TarDataset(os.path.join(dest_dir, "vlsp2020_train_set.tar")).convert(
16
+ train_set
17
+ )
18
+ VLSP2020TarDataset(os.path.join(dest_dir, "vlsp2020_val_set.tar")).convert(val_set)
19
+
20
+
21
+ def main():
22
+ parser = argparse.ArgumentParser()
23
+ parser.add_argument("--data_dir", type=str, required=True)
24
+ parser.add_argument("--dest_dir", type=str, required=True)
25
+ args = parser.parse_args()
26
+
27
+ prepare_tar_dataset(args.data_dir, args.dest_dir)
finetuning/run.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python3 main.py \
2
+ --batch_size 2 \
3
+ --num_workers 2 \
4
+ --classifier_lr 1e-4 \
5
+ --wav2vec2_lr 1e-5 \
6
+ --max_epochs 10 \
7
+ --accelerator cpu \
8
+ --weight_decay 0.001 \
9
+ --warmup_steps 0.1 \
10
+ --constant_steps 0.4 \
11
+ --scheduler_factor 0.001 \
12
+ --data_dir data \
13
+ --ckpt_dir ckpt
finetuning/train.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ sys.path.append("..")
4
+
5
+ from argparse import ArgumentParser
6
+ import os, string
7
+ from transformers import (
8
+ Wav2Vec2ForPreTraining,
9
+ Wav2Vec2CTCTokenizer,
10
+ Wav2Vec2FeatureExtractor,
11
+ )
12
+ from pytorch_lightning import seed_everything
13
+ from pytorch_lightning import Trainer
14
+ from pytorch_lightning.callbacks import ModelCheckpoint, LearningRateMonitor
15
+ from pytorch_lightning.loggers import WandbLogger
16
+
17
+ from src.datamodule import VLSP2020TarDataset
18
+ from src.datamodule.vlsp2020 import get_dataloader
19
+ from finetuning.wav2vec2 import SpeechRecognizer
20
+
21
+
22
+ def remove_punctuation(text: str):
23
+ return text.translate(str.maketrans("", "", string.punctuation)).lower()
24
+
25
+
26
+ def prepare_dataloader(data_dir, batch_size, num_workers):
27
+ train_dataset = VLSP2020TarDataset(
28
+ os.path.join(data_dir, "vlsp2020_train_set.tar")
29
+ ).load()
30
+ val_dataset = VLSP2020TarDataset(
31
+ os.path.join(data_dir, "vlsp2020_val_set.tar")
32
+ ).load()
33
+
34
+ train_dataloader = get_dataloader(
35
+ train_dataset,
36
+ return_transcript=True,
37
+ target_transform=remove_punctuation,
38
+ batch_size=batch_size,
39
+ num_workers=num_workers,
40
+ )
41
+
42
+ val_dataloader = get_dataloader(
43
+ val_dataset,
44
+ return_transcript=True,
45
+ target_transform=remove_punctuation,
46
+ batch_size=batch_size,
47
+ num_workers=num_workers,
48
+ )
49
+
50
+ return train_dataloader, val_dataloader
51
+
52
+
53
+ def prepare_model(adam_config: dict, tristate_scheduler_config: dict):
54
+ model_name = "nguyenvulebinh/wav2vec2-base-vietnamese-250h"
55
+
56
+ wav2vec2 = Wav2Vec2ForPreTraining.from_pretrained(model_name)
57
+ tokenizer = Wav2Vec2CTCTokenizer.from_pretrained(model_name)
58
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name)
59
+
60
+ model = SpeechRecognizer(
61
+ wav2vec2, tokenizer, feature_extractor, adam_config, tristate_scheduler_config
62
+ )
63
+
64
+ return model
65
+
66
+
67
+ def main():
68
+ parser = ArgumentParser()
69
+
70
+ parser.add_argument("--batch_size", type=int, default=2)
71
+ parser.add_argument("--num_workers", type=int, default=0)
72
+ parser.add_argument("--classifier_lr", type=float, default=1e-4)
73
+ parser.add_argument("--wav2vec2_lr", type=float, default=1e-5)
74
+ parser.add_argument("--max_epochs", type=int, default=10)
75
+ parser.add_argument("--accelerator", type=str, default="gpu")
76
+ parser.add_argument("--weight_decay", type=float, default=0.0)
77
+ parser.add_argument("--warmup_steps", type=float, default=0.1)
78
+ parser.add_argument("--constant_steps", type=float, default=0.4)
79
+ parser.add_argument("--scheduler_factor", type=float, default=1e-3)
80
+ parser.add_argument("--data_dir", type=str, default="data")
81
+ parser.add_argument("--ckpt_dir", type=str, default="ckpt")
82
+ parser.add_argument("--ckpt_path", type=str, default=None)
83
+ parser.add_argument("--detect_anomaly", type=bool, default=False)
84
+ parser.add_argument("--grad_clip", type=float, default=None)
85
+ parser.add_argument("--wandb_id", type=str, default=None)
86
+
87
+ args = parser.parse_args()
88
+ print(args)
89
+
90
+ train_loader, val_loader = prepare_dataloader(
91
+ args.data_dir, args.batch_size, args.num_workers
92
+ )
93
+
94
+ total_steps = args.max_epochs * 42_000 // args.batch_size
95
+ warmup_steps = int(total_steps * args.warmup_steps)
96
+ constant_steps = int(total_steps * args.constant_steps)
97
+
98
+ model = prepare_model(
99
+ {
100
+ "wav2vec2_lr": args.wav2vec2_lr,
101
+ "classifier_lr": args.classifier_lr,
102
+ "weight_decay": args.weight_decay,
103
+ },
104
+ {
105
+ "warmup_steps": warmup_steps,
106
+ "constant_steps": constant_steps,
107
+ "total_steps": total_steps,
108
+ "factor": args.scheduler_factor,
109
+ },
110
+ )
111
+
112
+ trainer = Trainer(
113
+ accelerator=args.accelerator,
114
+ callbacks=[
115
+ ModelCheckpoint(
116
+ args.ckpt_dir,
117
+ monitor="val/wer",
118
+ mode="min",
119
+ save_top_k=1,
120
+ save_last=True,
121
+ ),
122
+ LearningRateMonitor(logging_interval="step"),
123
+ ],
124
+ logger=WandbLogger(project="Wav2Vec2", id=args.wandb_id),
125
+ max_epochs=args.max_epochs,
126
+ detect_anomaly=args.detect_anomaly,
127
+ gradient_clip_val=args.grad_clip,
128
+ )
129
+
130
+ trainer.fit(model, train_loader, val_loader)
131
+
132
+
133
+ if __name__ == "__main__":
134
+ seed_everything(188)
135
+ main()
finetuning/wav2vec2.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+ import torch
3
+ from pytorch_lightning import LightningModule
4
+ from torchmetrics import MeanMetric
5
+ from transformers import (
6
+ Wav2Vec2ForPreTraining,
7
+ Wav2Vec2CTCTokenizer,
8
+ Wav2Vec2FeatureExtractor,
9
+ )
10
+
11
+ from src.utils.metrics import character_error_rate, word_error_rate
12
+ from src.utils.scheduler import TriStateScheduler
13
+
14
+
15
+ class SpeechRecognizer(LightningModule):
16
+ def __init__(
17
+ self,
18
+ wav2vec2: Wav2Vec2ForPreTraining,
19
+ tokenizer: Wav2Vec2CTCTokenizer,
20
+ feature_extractor: Wav2Vec2FeatureExtractor,
21
+ adam_config: dict,
22
+ tristate_scheduler_config: dict,
23
+ ):
24
+ super().__init__()
25
+
26
+ self.hidden_size = wav2vec2.config.proj_codevector_dim
27
+ self.vocab_size = tokenizer.vocab_size
28
+
29
+ self.wav2vec2 = wav2vec2
30
+ self.wav2vec2.freeze_feature_encoder()
31
+ self.tokenizer = tokenizer
32
+ self.feature_extractor = feature_extractor
33
+
34
+ self.adam_config = adam_config
35
+ self.tristate_scheduler_config = tristate_scheduler_config
36
+
37
+ self.dropout = torch.nn.Dropout(0.1)
38
+ self.fc = torch.nn.Sequential(
39
+ torch.nn.Linear(self.hidden_size, self.hidden_size // 2),
40
+ torch.nn.ReLU(inplace=True),
41
+ torch.nn.Linear(self.hidden_size // 2, self.vocab_size),
42
+ )
43
+
44
+ self.criterion = torch.nn.CTCLoss(blank=tokenizer.pad_token_id, zero_infinity=True)
45
+
46
+ self.train_loss = MeanMetric()
47
+
48
+ self.save_hyperparameters(ignore=["wav2vec2", "tokenizer", "feature_extractor"])
49
+
50
+ def forward(self, waveforms: Tuple[torch.Tensor], transcripts: Tuple[str] = None):
51
+ # convert torch.Tensor to numpy.ndarray
52
+ waveforms = tuple(waveform.cpu().numpy() for waveform in waveforms)
53
+
54
+ input_values, attention_mask = self.feature_extractor(
55
+ waveforms,
56
+ sampling_rate=16000,
57
+ padding=True,
58
+ return_tensors="pt",
59
+ return_attention_mask=True,
60
+ ).values()
61
+
62
+ input_values = input_values.to(self.device)
63
+ attention_mask = attention_mask.to(self.device)
64
+
65
+ # hidden_states.shape == (batch_size, sequence_length, hidden_size)
66
+ hidden_states = self.wav2vec2(
67
+ input_values,
68
+ attention_mask=attention_mask,
69
+ )[0]
70
+
71
+ hidden_states = self.dropout(hidden_states)
72
+
73
+ # logits.shape == (batch_size, sequence_length, vocab_size)
74
+ logits = self.fc(hidden_states)
75
+
76
+ # get the length of valids sequence
77
+ input_lengths = self.wav2vec2._get_feat_extract_output_lengths(
78
+ attention_mask.sum(-1)
79
+ ).long()
80
+
81
+ if transcripts is not None:
82
+ # tokenize transcripts
83
+ target_ids, target_lengths = self.tokenizer(
84
+ transcripts,
85
+ padding=True,
86
+ return_length=True,
87
+ return_attention_mask=False,
88
+ return_tensors="pt",
89
+ ).values()
90
+
91
+ target_ids = target_ids.to(self.device)
92
+ assert (
93
+ target_ids < self.tokenizer.vocab_size
94
+ ).all(), "target_ids is out of range"
95
+
96
+ target_lengths = target_lengths.to(self.device)
97
+ assert (
98
+ target_lengths <= logits.size(1)
99
+ ).all(), "target_lengths is out of range"
100
+
101
+ # (batch_size, sequence_length, vocab_size) -> (sequence_length, batch_size, vocab_size)
102
+ log_probs = torch.nn.functional.log_softmax(logits, dim=-1).transpose(0, 1)
103
+
104
+ # compute loss
105
+ loss = self.criterion(log_probs, target_ids, input_lengths, target_lengths)
106
+
107
+ return loss, logits, input_lengths
108
+ else:
109
+ return logits, input_lengths
110
+
111
+ @staticmethod
112
+ def _get_predicted_ids(logits: torch.Tensor, lengths: torch.Tensor):
113
+ # logits.shape == (batch_size, sequence_length, vocab_size)
114
+ # lengths.shape == (batch_size, )
115
+
116
+ # get the max value of logits
117
+ predicted_ids = torch.argmax(logits, dim=-1)
118
+
119
+ # remove the padding
120
+ predicted_ids = [
121
+ predicted_id[:length]
122
+ for predicted_id, length in zip(predicted_ids, lengths)
123
+ ]
124
+
125
+ return predicted_ids
126
+
127
+ def training_step(self, batch, batch_idx):
128
+ transcripts, waveforms = batch
129
+
130
+ loss = self(waveforms, transcripts)[0]
131
+
132
+ self.train_loss(loss)
133
+
134
+ if self.global_step % 500 == 0:
135
+ self.log("train/loss", self.train_loss, on_step=True, on_epoch=True)
136
+
137
+ return loss
138
+
139
+ def on_train_epoch_end(self) -> None:
140
+ self.train_loss.reset()
141
+
142
+ def validation_step(self, batch, batch_idx):
143
+ transcripts, waveforms = batch
144
+
145
+ logits, seq_lengths = self(waveforms)
146
+
147
+ predicted_ids = self._get_predicted_ids(logits, seq_lengths)
148
+ predicted_texts = self.tokenizer.batch_decode(
149
+ predicted_ids, skip_special_tokens=True
150
+ )
151
+
152
+ wer = word_error_rate(predicted_texts, transcripts)
153
+ cer = character_error_rate(predicted_texts, transcripts)
154
+
155
+ return wer, cer
156
+
157
+ def validation_epoch_end(self, outputs):
158
+ wer, cer = zip(*outputs)
159
+
160
+ wer = sum(wer) / len(wer)
161
+ cer = sum(cer) / len(cer)
162
+
163
+ self.log("val/wer", wer, on_epoch=True)
164
+ self.log("val/cer", cer, on_epoch=True)
165
+
166
+ @torch.no_grad()
167
+ def predict(self, waveforms: Tuple[torch.Tensor]):
168
+ logits, seq_lengths = self(waveforms)
169
+
170
+ predicted_ids = self._get_predicted_ids(logits, seq_lengths)
171
+ predicted_texts = self.tokenizer.batch_decode(
172
+ predicted_ids, skip_special_tokens=True
173
+ )
174
+
175
+ return predicted_texts
176
+
177
+ def configure_optimizers(self):
178
+ optimizer = torch.optim.AdamW(
179
+ params=[
180
+ {
181
+ "params": self.wav2vec2.parameters(),
182
+ "lr": self.adam_config["wav2vec2_lr"],
183
+ },
184
+ {
185
+ "params": self.fc.parameters(),
186
+ "lr": self.adam_config["classifier_lr"],
187
+ },
188
+ ],
189
+ weight_decay=self.adam_config["weight_decay"],
190
+ )
191
+
192
+ scheduler = TriStateScheduler(optimizer, **self.tristate_scheduler_config)
193
+ return {
194
+ "optimizer": optimizer,
195
+ "lr_scheduler": {
196
+ "scheduler": scheduler,
197
+ "interval": "step",
198
+ "frequency": 1,
199
+ },
200
+ }
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ ffmpeg
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchaudio
3
+ pytorch-lightning
4
+ einops
5
+ easydict
6
+ webdataset
7
+ transformers
8
+ gradio
9
+ altair
src/__init__.py ADDED
File without changes
src/config/__init__.py ADDED
File without changes
src/config/model.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from easydict import EasyDict as dict
2
+
3
+ D_MODEL = 768
4
+ HIDDEN_SIZE = 512
5
+
6
+
7
+
8
+ context_encoder = dict(
9
+ feature_projection=dict(
10
+ in_features=HIDDEN_SIZE,
11
+ out_features=D_MODEL,
12
+ dropout=0.1,
13
+ ),
14
+ encoder=dict(
15
+ d_model=D_MODEL,
16
+ num_layers=12,
17
+ layer_drop=0.05,
18
+ pos_embedding=dict(
19
+ d_model=D_MODEL,
20
+ kernel_size=3,
21
+ groups=2,
22
+ dropout=0.1,
23
+ ),
24
+ layer=dict(
25
+ d_model=D_MODEL,
26
+ num_heads=8,
27
+ layer_norm_first=False,
28
+ feed_forward_dim=2048,
29
+ dropout=0.1,
30
+ ),
31
+ )
32
+ )
33
+
34
+ feature_extractor = dict(
35
+ num_channels=7 * (HIDDEN_SIZE,),
36
+ kernel_sizes=(10,) + 4 * (3,) + 2 * (2,),
37
+ strides=(5,) + 6 * (2,),
38
+ )
39
+
40
+ quantizer = dict(
41
+ in_features=HIDDEN_SIZE,
42
+ num_codebooks=2,
43
+ num_codewords=320,
44
+ d_model=D_MODEL,
45
+ )
46
+
47
+ wav2vec2_pretraining = dict(
48
+ context_encoder=context_encoder,
49
+ feature_extractor=feature_extractor,
50
+ quantizer=quantizer,
51
+ mask_prob=0.65,
52
+ mask_length=10,
53
+ min_masks=2,
54
+ num_negatives=100,
55
+ contrastive_logits_temperature=0.1,
56
+ diversity_loss_weight=0.2,
57
+ )
src/datamodule/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .vlsp2020 import (
2
+ VLSP2020TarDataset,
3
+ VLSP2020Dataset,
4
+ )
src/datamodule/vlsp2020.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Optional, Union
2
+ from tqdm import tqdm
3
+ import os
4
+ import torch
5
+ import torchaudio
6
+ import torchaudio.functional as F
7
+ from torch.utils.data import Dataset, DataLoader, IterableDataset, random_split
8
+ from pytorch_lightning import LightningDataModule
9
+ import webdataset
10
+
11
+
12
+ class VLSP2020Dataset(Dataset):
13
+ def __init__(self, root: str, sample_rate: int = 16000):
14
+ super().__init__()
15
+
16
+ self.sample_rate = sample_rate
17
+ self.memory = self._prepare_data(root)
18
+ self._memory = tuple(
19
+ (v["transcript"], v["audio"]) for v in self.memory.values()
20
+ )
21
+
22
+ @staticmethod
23
+ def _prepare_data(root: str):
24
+ memory = {}
25
+
26
+ for f in os.scandir(root):
27
+ file_name, file_ext = os.path.splitext(f.name)
28
+
29
+ if file_ext == ".txt":
30
+ if file_name not in memory:
31
+ memory[file_name] = {"transcript": f.path}
32
+ elif "transcript" not in memory[file_name]:
33
+ memory[file_name]["transcript"] = f.path
34
+ else:
35
+ raise ValueError(f"Duplicate transcript for {f.path}")
36
+ else:
37
+ if file_name not in memory:
38
+ memory[file_name] = {"audio": f.path}
39
+ elif "audio" not in memory[file_name]:
40
+ memory[file_name]["audio"] = f.path
41
+ else:
42
+ raise ValueError(f"Duplicate audio for {f.path}")
43
+
44
+ for key, value in memory.items():
45
+ if "audio" not in value:
46
+ raise ValueError(f"Missing audio for {key}")
47
+ elif "transcript" not in value:
48
+ raise ValueError(f"Missing transcript for {key}")
49
+
50
+ return memory
51
+
52
+ def __len__(self):
53
+ return len(self.memory)
54
+
55
+ def __getitem__(self, index: int):
56
+ transcript, audio = self._memory[index]
57
+
58
+ with open(transcript, "r") as f:
59
+ transcript = f.read()
60
+
61
+ audio, sample_rate = torchaudio.load(audio)
62
+ audio = F.resample(audio, sample_rate, self.sample_rate)
63
+
64
+ return transcript, audio
65
+
66
+
67
+ class VLSP2020TarDataset:
68
+ def __init__(self, outpath: str):
69
+ self.outpath = outpath
70
+
71
+ def convert(self, dataset: VLSP2020Dataset):
72
+ writer = webdataset.TarWriter(self.outpath)
73
+
74
+ for idx, (transcript, audio) in enumerate(tqdm(dataset, colour="green")):
75
+ writer.write(
76
+ {
77
+ "__key__": f"{idx:08d}",
78
+ "txt": transcript,
79
+ "pth": audio,
80
+ }
81
+ )
82
+
83
+ writer.close()
84
+
85
+ def load(self) -> webdataset.WebDataset:
86
+ self.data = (
87
+ webdataset.WebDataset(self.outpath)
88
+ .decode(
89
+ webdataset.handle_extension("txt", lambda x: x.decode("utf-8")),
90
+ webdataset.torch_audio,
91
+ )
92
+ .to_tuple("txt", "pth")
93
+ )
94
+
95
+ return self.data
96
+
97
+
98
+ def get_dataloader(
99
+ dataset: Union[VLSP2020Dataset, webdataset.WebDataset],
100
+ return_transcript: bool = False,
101
+ target_transform: Optional[Callable] = None,
102
+ batch_size: int = 32,
103
+ num_workers: int = 2,
104
+ ):
105
+ def collate_fn(batch):
106
+ def get_audio(item):
107
+ audio = item[1]
108
+
109
+ assert (
110
+ isinstance(audio, torch.Tensor)
111
+ and audio.ndim == 2
112
+ and audio.size(0) == 1
113
+ )
114
+
115
+ return audio.squeeze(0)
116
+
117
+ audio = tuple(get_audio(item) for item in batch)
118
+
119
+ if return_transcript:
120
+ if target_transform is not None:
121
+ transcript = tuple(target_transform(item[0]) for item in batch)
122
+ else:
123
+ transcript = tuple(item[0] for item in batch)
124
+
125
+ return transcript, audio
126
+ else:
127
+ return audio
128
+
129
+ return DataLoader(
130
+ dataset, batch_size=batch_size, num_workers=num_workers, collate_fn=collate_fn
131
+ )
src/model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .wav2vec2 import Wav2Vec2PretrainingModule
src/model/modules/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .context_encoder import ContextEncoder
2
+ from .feature_extractor import FeatureExtractor
3
+ from .quantization import QuantizationModule
4
+ from .processor import Wav2Vec2Processor
src/model/modules/context_encoder.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ import torch
3
+ from torch import nn
4
+ import torch.nn.functional as F
5
+ from .transformers import EncoderLayer
6
+
7
+
8
+ class FeatureProjection(nn.Module):
9
+ def __init__(self, in_features: int, out_features: int, dropout: float = 0.1):
10
+ """
11
+ Projects the extracted features to the encoder dimension.
12
+
13
+ Args:
14
+ x (Tensor): The input features. Shape: (batch, num_frames, in_features)
15
+
16
+ Returns:
17
+ hiddens (Tensor): The latent features. Shape: (batch, num_frames, out_features)
18
+ """
19
+ super().__init__()
20
+
21
+ self.projection = nn.Linear(in_features, out_features)
22
+ self.layernorm = nn.LayerNorm(in_features)
23
+ self.dropout = nn.Dropout(dropout)
24
+
25
+ def forward(self, x: torch.Tensor):
26
+
27
+ hiddens = self.layernorm(x)
28
+ hiddens = self.projection(x)
29
+ hiddens = self.dropout(hiddens)
30
+ return hiddens
31
+
32
+
33
+ class RelativePositionalEmbedding(nn.Module):
34
+ def __init__(
35
+ self, d_model: int, kernel_size: int, groups: int, dropout: float = 0.1
36
+ ):
37
+ """
38
+ Args:
39
+ x (Tensor): The extracted features. Shape: (batch, num_frames, d_model)
40
+
41
+ Returns:
42
+ out (Tensor): The output which encoded the relative positional information. Shape: (batch, num_frames, d_model)
43
+ """
44
+ super().__init__()
45
+
46
+ self.conv = nn.Conv1d(
47
+ in_channels=d_model,
48
+ out_channels=d_model,
49
+ kernel_size=kernel_size,
50
+ padding=kernel_size // 2,
51
+ groups=groups,
52
+ )
53
+ self.dropout = nn.Dropout(dropout)
54
+ self.num_remove = 1 if kernel_size % 2 == 0 else 0
55
+
56
+ def forward(self, x: torch.Tensor):
57
+ # (batch, channels=d_model, num_frames)
58
+ out = x.transpose(1, 2)
59
+
60
+ out = self.conv(out)
61
+
62
+ if self.num_remove > 0:
63
+ out = out[..., : -self.num_remove]
64
+
65
+ out = F.gelu(out)
66
+
67
+ # (batch, num_frames, channels=d_model)
68
+ out = out.transpose_(1, 2)
69
+ out = out + x
70
+ out = self.dropout(out)
71
+
72
+ return out
73
+
74
+
75
+ class TranformerEncoder(nn.Module):
76
+ def __init__(self, config):
77
+ """
78
+ Args:
79
+ x (Tensor): The extracted features. Shape: (batch, num_frames, d_model)
80
+ mask (Tensor): The mask for the valid frames. Shape: (batch, num_frames)
81
+
82
+ Returns:
83
+ out (Tensor): The output of the transformer encoder. Shape: (batch, num_frames, d_model)
84
+ """
85
+ super().__init__()
86
+
87
+ self.pos_embedding = RelativePositionalEmbedding(**config.pos_embedding)
88
+ self.layernorm = nn.LayerNorm(config.d_model)
89
+ self.layer_drop = config.layer_drop
90
+
91
+ self.layers = nn.ModuleList(
92
+ EncoderLayer(**config.layer) for _ in range(config.num_layers)
93
+ )
94
+
95
+ def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None):
96
+ out = self.pos_embedding(x)
97
+
98
+ for layer in self.layers:
99
+ skip_layer = self.training and torch.rand(1).item() < self.layer_drop
100
+
101
+ if skip_layer:
102
+ continue
103
+ else:
104
+ out, _ = layer(out, attention_mask=mask)
105
+
106
+ out = self.layernorm(out)
107
+
108
+ return out
109
+
110
+
111
+ class ContextEncoder(nn.Module):
112
+ def __init__(self, config):
113
+ """
114
+ Args:
115
+ x (Tensor): The extracted features. Shape: (batch, num_frames, in_features)
116
+ attention_mask (BoolTensor): The mask for the valid frames. `True` is invalid. Shape: (batch, num_frames)
117
+ """
118
+ super().__init__()
119
+
120
+ self.feature_projection = FeatureProjection(**config.feature_projection)
121
+ self.encoder = TranformerEncoder(config.encoder)
122
+ self.masked_spec_embed = nn.Parameter(
123
+ torch.FloatTensor(config.feature_projection.out_features).uniform_()
124
+ )
125
+
126
+ def forward(
127
+ self,
128
+ x: torch.Tensor,
129
+ attention_mask: torch.Tensor = None,
130
+ mask_time_indices: torch.Tensor = None,
131
+ ):
132
+ x = self.feature_projection(x)
133
+
134
+ if mask_time_indices is not None:
135
+ x[mask_time_indices] = self.masked_spec_embed.to(x.dtype)
136
+
137
+ if attention_mask is not None:
138
+ x[attention_mask] = 0.0 # turn invalid frames to zero
139
+
140
+ attention_mask = attention_mask[:, None, None, :]
141
+ # (batch, 1, num_frames, num_frames)
142
+ # mask = mask[:, None, None, :].repeat(1, 1, mask.size(1), 1) # TODO: check this
143
+ attention_mask = (
144
+ torch.maximum(attention_mask, attention_mask.transpose(2, 3)) * -1e6
145
+ )
146
+
147
+ x = self.encoder(x, mask=attention_mask)
148
+
149
+ return x
src/model/modules/feature_extractor.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+ import torch
3
+ from torch import nn
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class _Conv1DLayer(nn.Module):
8
+ def __init__(
9
+ self,
10
+ in_channels: int,
11
+ out_channels: int,
12
+ kernel_size: int,
13
+ stride: int,
14
+ ):
15
+ """
16
+ Args:
17
+ x (Tensor): The ouput. Shape: (batch, in_channels, in_frames)
18
+ length (Tensor): The valid length of each sample. Shape: (batch)
19
+
20
+ Returns:
21
+ x (Tensor): The output. Shape: (batch, out_channels, out_frames)
22
+ length (Tensor): The valid length of each sample. Shape: (batch)
23
+ """
24
+ super().__init__()
25
+
26
+ self.kernel_size = kernel_size
27
+ self.stride = stride
28
+
29
+ self.conv = nn.Conv1d(
30
+ in_channels=in_channels,
31
+ out_channels=out_channels,
32
+ stride=stride,
33
+ kernel_size=kernel_size,
34
+ bias=False,
35
+ )
36
+
37
+ self.layernorm = nn.LayerNorm(out_channels)
38
+
39
+ def forward(self, x: torch.Tensor, length: torch.Tensor):
40
+ x = self.conv(x)
41
+ x = x.transpose_(1, 2)
42
+ x = self.layernorm(x)
43
+ x = x.transpose_(1, 2)
44
+ x = F.gelu(x)
45
+
46
+ length = (length - self.kernel_size) // self.stride + 1
47
+ length = length.clamp_min_(min=0) # prevent negative lengths
48
+ return x, length
49
+
50
+
51
+ class FeatureExtractor(nn.Module):
52
+ def __init__(self, config):
53
+ """
54
+ Extracts features from the waveform.
55
+
56
+ Args:
57
+ waveforms (Tensor): The waveform to extract features from. Shape: (batch, wavelength)
58
+ wavelength (Tensor): The valid length of each waveform. Shape: (batch)
59
+
60
+ Returns:
61
+ features (Tensor): The extracted features. Shape: (batch, num_frames, num_channels)
62
+ num_frames (Tensor): The valid length of each feature. Shape: (batch)
63
+ """
64
+ super().__init__()
65
+
66
+ num_channels = config.num_channels
67
+ kernel_sizes = config.kernel_sizes
68
+ strides = config.strides
69
+
70
+ assert (
71
+ len(num_channels) == len(kernel_sizes) == len(strides)
72
+ ), "The number of layers must be the same for all parameters"
73
+
74
+ self.conv_layers = nn.ModuleList(
75
+ (
76
+ _Conv1DLayer(
77
+ in_channels=1,
78
+ out_channels=num_channels[0],
79
+ kernel_size=kernel_sizes[0],
80
+ stride=strides[0],
81
+ ),
82
+ )
83
+ )
84
+
85
+ for i in range(1, len(num_channels)):
86
+ self.conv_layers.append(
87
+ _Conv1DLayer(
88
+ in_channels=num_channels[i - 1],
89
+ out_channels=num_channels[i],
90
+ kernel_size=kernel_sizes[i],
91
+ stride=strides[i],
92
+ )
93
+ )
94
+
95
+ def forward(self, waveforms: torch.Tensor, wavelength: torch.Tensor):
96
+ features = waveforms.unsqueeze(1)
97
+
98
+ for conv_layer in self.conv_layers:
99
+ features, wavelength = conv_layer(features, wavelength)
100
+
101
+ # (batch, num_channels, num_frames) -> (batch, num_frames, num_channels)
102
+ features = features.transpose(1, 2)
103
+ return features, wavelength
src/model/modules/processor.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+ import torch
3
+ from torch import nn
4
+
5
+
6
+ class Wav2Vec2Processor(nn.Module):
7
+ def __init__(self):
8
+ """
9
+ Convert tuple of waveforms whose length is different to a batch.
10
+
11
+ Args:
12
+ waveforms (Tuple[torch.Tensor]): The waveforms. Shape: (batch_size, wave_length).
13
+
14
+ Returns:
15
+ waveforms (torch.Tensor): The batched waveforms. Shape: (batch_size, max_wave_length).
16
+ wave_lengths (torch.Tensor): The wave length of each waveform. Shape: (batch_size,).
17
+ """
18
+ super().__init__()
19
+
20
+ def forward(self, waveforms: Tuple[torch.Tensor, ...]):
21
+ device = waveforms[0].device
22
+ wave_lengths = torch.tensor(
23
+ tuple(waveform.size(0) for waveform in waveforms), device=device
24
+ )
25
+
26
+ max_length = wave_lengths.max().item()
27
+
28
+ padded = []
29
+
30
+ for waveform in waveforms:
31
+ padded.append(
32
+ nn.functional.pad(
33
+ waveform,
34
+ (0, max_length - waveform.size(0)),
35
+ mode="constant",
36
+ value=0.0,
37
+ )
38
+ )
39
+
40
+ batched_waveforms = torch.stack(padded, dim=0)
41
+
42
+ return batched_waveforms, wave_lengths
src/model/modules/quantization.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ import torch
3
+ from torch import nn
4
+ import torch.nn.functional as F
5
+ import einops
6
+
7
+
8
+ class QuantizationModule(nn.Module):
9
+ def __init__(
10
+ self, config
11
+ ):
12
+ """
13
+ Args:
14
+ x (Tensor): The extracted features from waveforms. Shape: (batch, num_frames, in_features)
15
+ mask (BoolTensor): The mask for the valid frames. `True` is invalid. Shape: (batch, num_frames)
16
+
17
+ Returns:
18
+ out (Tensor): The quantized features. Shape: (batch, num_frames, d_model)
19
+ perplexity (Tensor): The perplexity of the quantized features. Shape: (1)
20
+ """
21
+ super().__init__()
22
+
23
+ assert (
24
+ config.d_model % config.num_codebooks == 0
25
+ ), "d_model must be divisible by num_codebooks"
26
+
27
+ self.num_codebooks = config.num_codebooks
28
+ self.num_codewords = config.num_codewords
29
+ self.d_model = config.d_model
30
+ self.codeword_dim = config.d_model // config.num_codebooks
31
+
32
+ self.codebooks = self._init_codebooks()
33
+
34
+ self.projection = nn.Linear(
35
+ config.in_features, self.num_codebooks * self.num_codewords
36
+ )
37
+
38
+ self.tau = 1 # temperature factor
39
+
40
+ def _init_codebooks(self):
41
+ codebooks = torch.randn(
42
+ 1, 1, self.num_codebooks, self.num_codewords, self.codeword_dim
43
+ )
44
+ nn.init.xavier_uniform_(codebooks)
45
+
46
+ return nn.Parameter(codebooks)
47
+
48
+ @property
49
+ def total_codewords(self):
50
+ return self.num_codebooks * self.num_codewords
51
+
52
+ @staticmethod
53
+ def _compute_perplexity(probs: torch.Tensor, mask: Optional[torch.Tensor] = None):
54
+ """
55
+ Computes the perplexity of the quantized features. (Diversity loss)
56
+
57
+ Args:
58
+ probs (Tensor): The probability distribution of words in codebooks. Shape: (batch, num_frames, num_codebooks, num_codewords)
59
+ mask (BoolTensor): The mask for the valid frames. `True` is invalid. Shape: (batch, num_frames)
60
+ """
61
+ if mask is not None:
62
+ probs = (
63
+ probs * ~mask[..., None, None]
64
+ ) # Turn invalid frames' probability to 0
65
+ marginal_probs = (
66
+ einops.reduce(probs, "b nf nb nw -> nb nw", "sum") / mask.sum()
67
+ )
68
+ else:
69
+ marginal_probs = einops.reduce(probs, "b nf nb nw -> nb nw", "mean")
70
+
71
+ perplexity = torch.exp(
72
+ -torch.sum(marginal_probs * torch.log(marginal_probs + 1e-7), dim=-1)
73
+ ).sum()
74
+ return perplexity
75
+
76
+ def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None):
77
+ batch_size, num_frames, _ = x.shape
78
+
79
+ logits = self.projection(x)
80
+ logits = logits.view(
81
+ batch_size, num_frames, self.num_codebooks, self.num_codewords
82
+ )
83
+
84
+ if self.training:
85
+ word_probs = F.gumbel_softmax(logits, tau=self.tau, hard=True, dim=-1)
86
+ word_soft_probs = F.softmax(logits, dim=-1)
87
+
88
+ perplexity = self._compute_perplexity(word_soft_probs, mask=mask)
89
+ else:
90
+ word_ids = torch.argmax(logits, dim=-1, keepdim=True)
91
+ word_probs = torch.zeros_like(logits).scatter_(-1, word_ids, 1.0) # One-hot
92
+
93
+ perplexity = self._compute_perplexity(word_probs, mask=mask)
94
+
95
+ # (batch, num_frames, num_codebooks, num_codewords, 1) x (1, 1, num_codebooks, num_codewords, codeword_dim)
96
+ # -> (batch, num_frames, num_codebooks x codeword_dim)
97
+ quantized = einops.reduce(
98
+ word_probs.unsqueeze_(-1) * self.codebooks,
99
+ "b nf nb nw d -> b nf (nb d)",
100
+ reduction="sum",
101
+ )
102
+
103
+ return quantized, perplexity
src/model/modules/transformers.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file contains the implementation of the Transformer Encoder layer.
3
+ Source: https://github.com/pytorch/audio/blob/main/torchaudio/models/wav2vec2/components.py
4
+ """
5
+ from typing import Optional, Tuple
6
+ import torch
7
+ from torch import nn, Tensor
8
+ from torch.nn import Module
9
+
10
+
11
+ class SelfAttention(Module):
12
+ """Multihead Self Attention module
13
+ Args:
14
+ embed_dim (int): Total dimension of the model.
15
+ num_heads (int): The number of heads.
16
+ dropout (float, optional):
17
+ Dropout probability on attn_output_weights. Default: ``0.0``
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ embed_dim: int,
23
+ num_heads: int,
24
+ dropout: float = 0.0,
25
+ ):
26
+ super().__init__()
27
+ head_dim = embed_dim // num_heads
28
+ if head_dim * num_heads != embed_dim:
29
+ raise ValueError(
30
+ f"`embed_dim ({embed_dim})` is not divisible by `num_heads ({num_heads})`"
31
+ )
32
+
33
+ self.embed_dim = embed_dim
34
+ self.num_heads = num_heads
35
+ self.dropout = torch.nn.Dropout(dropout)
36
+ self.head_dim = head_dim
37
+
38
+ self.scaling = self.head_dim**-0.5
39
+
40
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True)
41
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
42
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
43
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)
44
+
45
+ def forward(
46
+ self,
47
+ x: Tensor,
48
+ attention_mask: Optional[Tensor] = None,
49
+ position_bias: Optional[Tensor] = None,
50
+ key_padding_mask: Optional[Tensor] = None,
51
+ ) -> Tuple[Tensor, Optional[Tensor]]:
52
+ """
53
+ Args:
54
+ x (Tensor): shape: ``[batch_size, sequence_length, embed_dim]``.
55
+ attention_mask (Tensor or ``None``, optional):
56
+ shape: ``[batch_size, 1, sequence_length, sequence_length]``
57
+ position_bias: Not used. Only for the compatibility with :py:class:`WavLMSelfAttention`.
58
+ key_padding_mask (Tensor or ``None``): Not used. Only for the compatibility with
59
+ :py:class:`WavLMSelfAttention`.
60
+ Returns:
61
+ (Tensor, ``None``): The resulting attention output and ``None`` (necessary for compatibility
62
+ with :py:class:`WavLMSelAttention`).
63
+ Attention output shape: ``[batch, sequence_length, embed_dim]``.
64
+ """
65
+ if x.ndim != 3 or x.shape[2] != self.embed_dim:
66
+ raise ValueError(
67
+ f"The expected input shape is (batch, sequence, embed_dim=={self.embed_dim}). "
68
+ f"Found {x.shape}."
69
+ )
70
+ batch_size, length, embed_dim = x.size()
71
+ if attention_mask is not None:
72
+ shape_ = (batch_size, 1, length, length)
73
+ if attention_mask.size() != shape_:
74
+ raise ValueError(
75
+ f"The expected attention mask shape is {shape_}. "
76
+ f"Found {attention_mask.size()}."
77
+ )
78
+
79
+ shape = (batch_size, length, self.num_heads, self.head_dim)
80
+ q = self.q_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd
81
+ k = self.k_proj(x).view(*shape).permute(0, 2, 3, 1) # B, nH, Hd, L
82
+ v = self.v_proj(x).view(*shape).transpose(2, 1) # B, nH, L, Hd
83
+
84
+ # scale down q to avoid value overflow.
85
+ weights = (self.scaling * q) @ k # B, nH, L, L
86
+ if attention_mask is not None:
87
+ weights += attention_mask
88
+ # subtracting a constant value from the tensor won't change the output of softmax.
89
+ # apply the subtraction to avoid value overflow in torch.nn.functional.softmax.
90
+ # for more details, please see Equation 7 in https://arxiv.org/abs/2112.08778
91
+ weights = weights - weights.max(dim=-1, keepdim=True)[0]
92
+
93
+ weights = torch.nn.functional.softmax(weights, dim=-1)
94
+ weights = self.dropout(weights)
95
+
96
+ output = weights @ v # B, nH, L, Hd
97
+ output = output.transpose(2, 1).reshape(batch_size, length, embed_dim)
98
+
99
+ output = self.out_proj(output)
100
+ return output, None # Necessary for compatibility with WavLMSelAttention
101
+
102
+
103
+ class FeedForward(Module):
104
+ """Layer that follows attention layer in encoder layer."""
105
+
106
+ def __init__(
107
+ self,
108
+ io_features: int,
109
+ intermediate_features: int,
110
+ intermediate_dropout: float,
111
+ output_dropout: float,
112
+ ):
113
+ super().__init__()
114
+ self.intermediate_dense = nn.Linear(io_features, intermediate_features)
115
+ self.intermediate_dropout = nn.Dropout(intermediate_dropout)
116
+ self.output_dense = nn.Linear(intermediate_features, io_features)
117
+ self.output_dropout = nn.Dropout(output_dropout)
118
+
119
+ def forward(self, x):
120
+ """
121
+ Args:
122
+ x (Tensor): shape: `(batch, sequence_length, io_features)`
123
+ Returns:
124
+ x (Tensor): shape: `(batch, sequence_length, io_features)`
125
+ """
126
+ x = self.intermediate_dense(x)
127
+ x = torch.nn.functional.gelu(x)
128
+ x = self.intermediate_dropout(x)
129
+
130
+ x = self.output_dense(x)
131
+ x = self.output_dropout(x)
132
+ return x
133
+
134
+
135
+ class EncoderLayer(Module):
136
+ """A layer unit in encoder. Combines multihead self attention and feed forward."""
137
+
138
+ def __init__(
139
+ self,
140
+ d_model: int,
141
+ num_heads: int,
142
+ layer_norm_first: bool,
143
+ feed_forward_dim: int,
144
+ dropout: float = 0.1,
145
+ ):
146
+ super().__init__()
147
+ self.attention = SelfAttention(
148
+ embed_dim=d_model,
149
+ num_heads=num_heads,
150
+ dropout=dropout,
151
+ )
152
+
153
+ self.dropout = nn.Dropout(dropout)
154
+ self.layer_norm = nn.LayerNorm(d_model)
155
+ self.layer_norm_first = layer_norm_first
156
+ self.feed_forward = FeedForward(d_model, feed_forward_dim, dropout, dropout)
157
+ self.final_layer_norm = nn.LayerNorm(d_model)
158
+
159
+ def forward(
160
+ self,
161
+ x: Tensor,
162
+ attention_mask: Optional[Tensor] = None,
163
+ position_bias: Optional[Tensor] = None,
164
+ key_padding_mask: Optional[Tensor] = None,
165
+ ) -> Tuple[Tensor, Optional[Tensor]]:
166
+ """
167
+ Args:
168
+ x (Tensor): Input of shape ``(batch, sequence_length, embed_dim)``.
169
+ attention_mask (Tensor or ``None``, optional): attention mask
170
+ of shape ``(batch, 1, sequence_length, sequence_length)``. (Default: ``None``)
171
+ position_bias (Tensor or ``None``, optional): position bias of shape
172
+ ``(batch_size * num_heads, src_len, src_len)``.
173
+ Only necessary for WavLM model, ``None`` otherwise. (Default: ``None``)
174
+ key_padding_mask (Tensor or ``None``, optional): key padding mask of shape ``(batch_size, src_len)``.
175
+ Only used for WavLM model, ignored otherwise. (Default: ``None``)
176
+ Returns:
177
+ (x, position_bias): Shapes are the same as in the input. Position bias is only relevant for WaLM model,
178
+ ``None`` otherwise.
179
+ """
180
+ residual = x
181
+
182
+ if self.layer_norm_first:
183
+ x = self.layer_norm(x)
184
+
185
+ x, position_bias = self.attention(
186
+ x,
187
+ attention_mask=attention_mask,
188
+ position_bias=position_bias,
189
+ key_padding_mask=key_padding_mask,
190
+ )
191
+
192
+ x = self.dropout(x)
193
+ x = residual + x
194
+
195
+ if self.layer_norm_first:
196
+ x = x + self.feed_forward(self.final_layer_norm(x))
197
+ else:
198
+ x = self.layer_norm(x)
199
+ x = self.final_layer_norm(x + self.feed_forward(x))
200
+ return x, position_bias
src/model/wav2vec2.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A wrapper of Wav2Vec2 for training phase.
3
+ """
4
+ from typing import Tuple, Optional
5
+ import torch
6
+ from pytorch_lightning import LightningModule
7
+ import einops
8
+ from torchmetrics import MeanMetric
9
+
10
+ from .modules import (
11
+ ContextEncoder,
12
+ FeatureExtractor,
13
+ QuantizationModule,
14
+ Wav2Vec2Processor,
15
+ )
16
+ from src.utils import init_module_weights
17
+
18
+
19
+ class Wav2Vec2PretrainingModule(LightningModule):
20
+ def __init__(self, config):
21
+ super().__init__()
22
+
23
+ self.save_hyperparameters(config)
24
+
25
+ self.processor = Wav2Vec2Processor()
26
+ self.context_encoder = ContextEncoder(config.context_encoder)
27
+ self.feature_extractor = FeatureExtractor(config.feature_extractor)
28
+ self.quantizer = QuantizationModule(config.quantizer)
29
+
30
+ self.train_loss = MeanMetric()
31
+
32
+ def forward(self, waveforms: Tuple[torch.Tensor, ...]):
33
+ """
34
+ Args:
35
+ waveforms (Tuple[torch.Tensor]): The waveforms. Shape: (batch_size, wave_length).
36
+
37
+ Returns:
38
+ loss: The loss of the model. Contrastive loss + Diversity loss.
39
+ """
40
+ waveforms, wave_lengths = self.processor(waveforms)
41
+
42
+ # features.shape == (batch_size, num_frames, hidden_size)
43
+ features, num_frames = self.feature_extractor(waveforms, wave_lengths)
44
+
45
+ attention_mask = self._compute_attention_mask(num_frames)
46
+ mask_time_indices = self._compute_mask_span(
47
+ shape=features.shape[:-1],
48
+ mask_prob=self.hparams.mask_prob,
49
+ mask_length=self.hparams.mask_length,
50
+ attention_mask=attention_mask,
51
+ device=features.device,
52
+ min_masks=self.hparams.min_masks,
53
+ )
54
+
55
+ context_features = self.context_encoder(
56
+ features, attention_mask=attention_mask, mask_time_indices=mask_time_indices
57
+ )
58
+
59
+ quantized_features, perplexity = self.quantizer(features, attention_mask)
60
+
61
+ negative_quantized_features = self._sample_negatives(
62
+ quantized_features,
63
+ num_negatives=self.hparams.num_negatives,
64
+ attention_mask=attention_mask,
65
+ )
66
+
67
+ # (batch_size, num_frames, num_negatives + 1)
68
+ contrastive_logits = self._compute_contrastive_logits(
69
+ context_features,
70
+ quantized_features,
71
+ negative_quantized_features,
72
+ self.hparams.contrastive_logits_temperature,
73
+ ).flatten(0, -2)
74
+
75
+ # compute contrastive loss
76
+ # positive indices are always the first one
77
+ targets = (1 - mask_time_indices.long().flatten()) * -100
78
+
79
+ contrastive_loss = torch.nn.functional.cross_entropy(
80
+ contrastive_logits, targets, reduction="sum"
81
+ )
82
+
83
+ # compute diversity loss
84
+ diversity_loss = 1 - perplexity / self.quantizer.total_codewords
85
+
86
+ loss = contrastive_loss + diversity_loss * self.hparams.diversity_loss_weight
87
+
88
+ return loss
89
+
90
+ @staticmethod
91
+ def _sample_negatives(
92
+ features: torch.Tensor,
93
+ num_negatives: int,
94
+ attention_mask: Optional[torch.Tensor] = None,
95
+ ):
96
+ """
97
+ Sampling negative features from quantized features to compute the contrastive loss.
98
+
99
+ Args:
100
+ features (torch.Tensor): The quantized features. Shape: (batch_size, num_frames, d_model).
101
+ num_negatives (int): The number of negative samples.
102
+ attention_mask (Optional[torch.Tensor]): The mask for valid frames. `True` is invalid. Shape: (batch_size, num_frames).
103
+
104
+ Returns:
105
+ sampled_negatives (torch.Tensor): The sampled negative features. Shape: (batch_size, num_frames, num_negatives, d_model).
106
+ """
107
+
108
+ batch_size, num_frames, d_model = features.shape
109
+
110
+ features = features.view(-1, d_model) # (batch_size * num_frames, d_model)
111
+
112
+ with torch.no_grad():
113
+ sampled_ids = []
114
+
115
+ for batch_idx in range(batch_size):
116
+ num_valid_frames = (
117
+ features.size(1)
118
+ if attention_mask is None
119
+ else (1 - attention_mask[batch_idx].long()).sum()
120
+ ).item()
121
+
122
+ sampled_ids.append(
123
+ torch.randint(
124
+ 0,
125
+ num_valid_frames - 1,
126
+ (num_frames * num_negatives,),
127
+ device=features.device,
128
+ )
129
+ )
130
+
131
+ sampled_ids = torch.stack(
132
+ sampled_ids, dim=0
133
+ ) # (batch_size, num_frames * num_negatives)
134
+
135
+ feature_ids = einops.repeat(
136
+ torch.arange(num_frames, device=features.device),
137
+ "f -> (f n)",
138
+ n=num_negatives,
139
+ )
140
+
141
+ # avoid sampling the same positive vector, but keep the distribution uniform
142
+ sampled_ids[sampled_ids >= feature_ids] += 1
143
+
144
+ # correct for batch size
145
+ # E.g [[0, 1, 2], [0, 1, 2]] -> [0, 1, 2, 3, 4, 5]
146
+ sampled_ids += torch.arange(
147
+ 0, batch_size * num_frames, num_frames, device=features.device
148
+ ).unsqueeze_(-1)
149
+
150
+ sampled_negatives = features[sampled_ids.view(-1)]
151
+ sampled_negatives = einops.rearrange(
152
+ sampled_negatives,
153
+ "(b f n) d -> b f n d",
154
+ b=batch_size,
155
+ f=num_frames,
156
+ n=num_negatives,
157
+ )
158
+
159
+ return sampled_negatives
160
+
161
+ @staticmethod
162
+ def _compute_contrastive_logits(
163
+ predicted_features: torch.Tensor,
164
+ target_features: torch.Tensor,
165
+ negative_features: torch.Tensor,
166
+ temperature: int = 1,
167
+ ):
168
+ """
169
+ Compute the logits for contrastive loss.
170
+
171
+ Args:
172
+ predicted_features (torch.Tensor): The predicted features. Shape: (batch_size, num_frames, d_model).
173
+ target_features (torch.Tensor): The target features. Shape: (batch_size, num_frames, d_model).
174
+ negative_features (torch.Tensor): The negative features. Shape: (batch_size, num_frames, num_negatives, d_model).
175
+ temperature (int): The temperature for contrastive loss.
176
+
177
+ Returns:
178
+ logits (torch.Tensor): The logits for contrastive loss. Shape: (batch_size, num_frames, num_negatives + 1).
179
+ """
180
+
181
+ # (batch_size, num_frames, num_negatives + 1, d_model)
182
+ target_features = torch.cat(
183
+ (target_features.unsqueeze_(2), negative_features), dim=2
184
+ )
185
+
186
+ # (batch_size, num_frames, 1, d_model)
187
+ predicted_features = predicted_features.unsqueeze_(2)
188
+
189
+ # (batch_size, num_frames, num_negatives + 1)
190
+ logits = torch.cosine_similarity(predicted_features, target_features, dim=-1)
191
+ logits /= temperature
192
+
193
+ return logits
194
+
195
+ @staticmethod
196
+ def _compute_mask_span(
197
+ shape: Tuple[int, int],
198
+ mask_prob: float = 0.065,
199
+ mask_length: int = 10,
200
+ attention_mask: Optional[torch.Tensor] = None,
201
+ device: torch.device = torch.device("cpu"),
202
+ min_masks: int = 0,
203
+ ):
204
+ """
205
+ Compute the mask span for contrastive task.
206
+
207
+ Args:
208
+ shape (Tuple[int, int]): The shape of the mask span. Shape: (batch_size, num_frames).
209
+ mask_prob (float): The probability of choosing a frame to be the start of masking position.
210
+ mask_length (int): The length of the mask span.
211
+ attention_mask (Optional[torch.Tensor]): The mask for valid frames. `True` is invalid. Shape: (batch_size, num_frames).
212
+ device (torch.device): The device of the mask span.
213
+ min_masks (int): The minimum number of masks.
214
+
215
+ Returns:
216
+ mask_span (torch.Tensor): The mask span. Shape: (batch_size, num_frames).
217
+ """
218
+
219
+ batch_size, num_frames = shape
220
+
221
+ # NOTE: num_frames / mask_length: the number of spans in one waveform
222
+ num_masked_spans = int(
223
+ mask_prob * num_frames / mask_length + torch.rand(1).item()
224
+ )
225
+ num_masked_spans = max(num_masked_spans, min_masks)
226
+
227
+ # make sure num masked indices <= num frames
228
+ if num_masked_spans * mask_length > num_frames:
229
+ num_masked_spans = num_frames // mask_length
230
+
231
+ # uniform distribution to sample from
232
+ # NOTE: num_frames - (mask_length - 1): the number of start positions of the span
233
+ uniform_dist = torch.ones(
234
+ (batch_size, num_frames - (mask_length - 1)), device=device
235
+ )
236
+
237
+ # (batch_size, num_masked_spans)
238
+ mask_span_ids = torch.multinomial(uniform_dist, num_masked_spans)
239
+
240
+ # (batch_size, num_masked_spans * mask_length)
241
+ mask_span_ids = einops.repeat(mask_span_ids, "b n -> b (n l)", l=mask_length)
242
+
243
+ offsets = einops.repeat(
244
+ torch.arange(mask_length, device=device),
245
+ "l -> b (n l)",
246
+ b=batch_size,
247
+ n=num_masked_spans,
248
+ )
249
+
250
+ mask_span_ids = mask_span_ids + offsets
251
+
252
+ mask_span = torch.zeros(shape, device=device, dtype=torch.bool)
253
+ mask_span = mask_span.scatter_(1, mask_span_ids, True)
254
+
255
+ if attention_mask is not None:
256
+ # Make sure the invalid frames are not masked
257
+ mask_span = torch.where(attention_mask.bool(), mask_span, False)
258
+
259
+ return mask_span
260
+
261
+ @staticmethod
262
+ def _compute_attention_mask(length: torch.Tensor):
263
+ """
264
+ Args:
265
+ length (Tensor): The length of valid frames. Shape: (batch)
266
+ max_length (int): The maximum length of the frames.
267
+
268
+ Returns:
269
+ attention_mask (BoolTensor): The mask for the valid frames. `True` is invalid. Shape: (batch, num_frames)
270
+ """
271
+ max_length = length.max().item()
272
+
273
+ mask = (
274
+ torch.arange(max_length, device=length.device).expand(
275
+ length.size(0), max_length
276
+ )
277
+ >= length[:, None]
278
+ )
279
+
280
+ return mask
281
+
282
+ def training_step(self, batch, batch_idx):
283
+ loss = self(batch)
284
+
285
+ self.train_loss(loss)
286
+
287
+ if batch_idx % 100 == 0:
288
+ self.log("train/loss", self.train_loss, on_step=True, on_epoch=True)
289
+
290
+ return loss
291
+
292
+ def configure_optimizers(self):
293
+ return torch.optim.AdamW(self.parameters(), lr=1e-4)
src/train.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append(".")
3
+
4
+ from src.config import model as conf
5
+ from src.model import Wav2Vec2PretrainingModule
6
+ from src.datamodule import WebDatasetConverter, VLSP2020ForPretrainingDataModule
7
+ from pytorch_lightning import Trainer
8
+ from pytorch_lightning.callbacks import ModelCheckpoint
9
+
10
+
11
+ if __name__ == "__main__":
12
+
13
+ model = Wav2Vec2PretrainingModule(conf.wav2vec2_pretraining)
14
+ dts = WebDatasetConverter(conf.dataset.path).get_dataset()
15
+ dtm = VLSP2020ForPretrainingDataModule(dts, **conf.dataset)
16
+ trainer = Trainer(
17
+ callbacks=[
18
+ ModelCheckpoint(
19
+ monitor="val/loss",
20
+ dirpath=conf["checkpoint_dir"],
21
+ )
22
+ ],
23
+ gradient_clip_val=1.0,
24
+ accelerator="gpu"
25
+ )
26
+
27
+ trainer.fit(model, dtm)
src/utils/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .functional import init_module_weights
src/utils/functional.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def init_module_weights(module):
5
+ """Initialize the weights"""
6
+
7
+ from src.model.modules import QuantizationModule
8
+
9
+ # gumbel softmax requires special init
10
+ if isinstance(module, QuantizationModule):
11
+ module.weight_proj.weight.data.normal_(mean=0.0, std=1)
12
+ module.weight_proj.bias.data.zero_()
13
+ torch.nn.init.uniform_(module.codebooks)
14
+ elif isinstance(module, torch.nn.Linear):
15
+ # Slightly different from the TF version which uses truncated_normal for initialization
16
+ # cf https://github.com/pytorch/pytorch/pull/5617
17
+ module.weight.data.normal_(mean=0.0, std=0.5)
18
+ elif isinstance(module, (torch.nn.LayerNorm, torch.nn.GroupNorm)):
19
+ module.bias.data.zero_()
20
+ module.weight.data.fill_(1.0)
21
+ elif isinstance(module, torch.nn.Conv1d):
22
+ torch.nn.init.kaiming_normal_(module.weight.data)
23
+
24
+ if (
25
+ isinstance(module, (torch.nn.Linear, torch.nn.Conv1d))
26
+ and module.bias is not None
27
+ ):
28
+ module.bias.data.zero_()
src/utils/metrics.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Union
2
+ import re
3
+
4
+
5
+ def levenshtein_distance(source: Tuple[str], target: Tuple[str]):
6
+ """
7
+ Compute the Levenshtein distance between two sequences.
8
+ """
9
+
10
+ n, m = len(source), len(target)
11
+ if n > m:
12
+ # Make sure n <= m, to use O(min(n,m)) space
13
+ source, target = target, source
14
+ n, m = m, n
15
+
16
+ current_row = range(n + 1) # Keep current and previous row, not entire matrix
17
+ for i in range(1, m + 1):
18
+ previous_row, current_row = current_row, [i] + [0] * n
19
+ for j in range(1, n + 1):
20
+ add, delete, change = (
21
+ previous_row[j] + 1,
22
+ current_row[j - 1] + 1,
23
+ previous_row[j - 1],
24
+ )
25
+ if source[j - 1] != target[i - 1]:
26
+ change += 1
27
+ current_row[j] = min(add, delete, change)
28
+
29
+ distance = current_row[n]
30
+
31
+ del current_row
32
+ del previous_row
33
+
34
+ return distance
35
+
36
+
37
+ def word_error_rate(
38
+ predicted: Union[str, Tuple[str]], transcript: Union[str, Tuple[str]]
39
+ ):
40
+ if isinstance(predicted, str):
41
+ predicted = (predicted,)
42
+ if isinstance(transcript, str):
43
+ transcript = (transcript,)
44
+
45
+ pattern = r"\W+"
46
+
47
+ err, total = 0, 0
48
+
49
+ for pred, tgt in zip(predicted, transcript):
50
+ pred_tokens = re.split(pattern, pred)
51
+ tgt_tokens = re.split(pattern, tgt)
52
+ err += levenshtein_distance(pred_tokens, tgt_tokens)
53
+ total += len(tgt_tokens)
54
+
55
+ return err / total
56
+
57
+
58
+ def character_error_rate(
59
+ predicted: Union[str, Tuple[str]], transcript: Union[str, Tuple[str]]
60
+ ):
61
+ if isinstance(predicted, str):
62
+ predicted = (predicted,)
63
+ if isinstance(transcript, str):
64
+ transcript = (transcript,)
65
+
66
+ err, total = 0, 0
67
+
68
+ for pred, tgt in zip(predicted, transcript):
69
+ err += levenshtein_distance(pred, tgt)
70
+ total += len(tgt)
71
+
72
+ return err / total
src/utils/scheduler.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from torch.optim.lr_scheduler import _LRScheduler
3
+
4
+
5
+ class WarmUpScheduler(_LRScheduler):
6
+ def __init__(
7
+ self,
8
+ optimizer,
9
+ warmup_steps: int,
10
+ feature_size: int,
11
+ factor: float = 1.0,
12
+ last_epoch=-1,
13
+ ):
14
+ self.warmup_steps = warmup_steps
15
+ self.feature_size = feature_size
16
+ self.factor = factor
17
+ super().__init__(optimizer, last_epoch)
18
+
19
+ def get_lr(self):
20
+ lr = self._compute_lr()
21
+ return [lr] * len(self.base_lrs)
22
+
23
+ def _compute_lr(self):
24
+ if self.last_epoch == 0:
25
+ return 0.0
26
+
27
+ lr = (self.feature_size ** (-0.5)) * min(
28
+ self.last_epoch ** (-0.5), self.last_epoch * self.warmup_steps ** (-1.5)
29
+ )
30
+
31
+ return lr * self.factor
32
+
33
+
34
+ class TriStateScheduler(_LRScheduler):
35
+ def __init__(
36
+ self,
37
+ optimizer,
38
+ total_steps: int,
39
+ warmup_steps: int,
40
+ constant_steps: int,
41
+ factor: float = 0.3,
42
+ last_epoch: int = -1,
43
+ ):
44
+ self.warmup_steps = warmup_steps
45
+ self.constant_steps = constant_steps
46
+ self.total_steps = total_steps
47
+ self.factor = factor
48
+
49
+ super().__init__(optimizer, last_epoch)
50
+
51
+ def get_lr(self):
52
+ if not hasattr(self, "eta_min"):
53
+ self.eta_max = self.base_lrs.copy()
54
+ self.eta_min = [eta_max * self.factor for eta_max in self.eta_max]
55
+
56
+ return [
57
+ self._compute_lr(group["lr"], eta_min, eta_max)
58
+ for group, eta_min, eta_max in zip(
59
+ self.optimizer.param_groups, self.eta_min, self.eta_max
60
+ )
61
+ ]
62
+
63
+ def _compute_lr(self, prev_lr: float, eta_min: float, eta_max: float):
64
+ # first stage
65
+ if self.last_epoch <= self.warmup_steps:
66
+ lr = eta_max - 0.5 * (eta_max - eta_min) * (
67
+ 1 + math.cos(math.pi * self.last_epoch / self.warmup_steps)
68
+ )
69
+ # second stage
70
+ elif self.last_epoch <= self.warmup_steps + self.constant_steps:
71
+ lr = prev_lr
72
+ else:
73
+ # third stage
74
+ decay_steps = self.total_steps - self.warmup_steps - self.constant_steps
75
+ k = self.last_epoch - self.warmup_steps - self.constant_steps
76
+ lr = eta_min + 0.5 * (eta_max - eta_min) * (
77
+ 1 + math.cos(math.pi * k / decay_steps)
78
+ )
79
+
80
+ return lr
81
+
82
+ def state_dict(self) -> dict:
83
+ return super().state_dict()