Update wav2vec2speechclassification.py
Browse files
wav2vec2speechclassification.py
CHANGED
@@ -2,6 +2,7 @@ from dataclasses import dataclass
|
|
2 |
from typing import Optional, Tuple
|
3 |
import torch
|
4 |
from transformers.file_utils import ModelOutput
|
|
|
5 |
|
6 |
|
7 |
@dataclass
|
@@ -24,6 +25,7 @@ from transformers.models.wav2vec2.modeling_wav2vec2 import (
|
|
24 |
|
25 |
class Wav2Vec2ClassificationHead(nn.Module):
|
26 |
"""Head for wav2vec classification task."""
|
|
|
27 |
|
28 |
def __init__(self, config):
|
29 |
super().__init__()
|
@@ -42,6 +44,8 @@ class Wav2Vec2ClassificationHead(nn.Module):
|
|
42 |
|
43 |
|
44 |
class Wav2Vec2ForSpeechClassification(Wav2Vec2PreTrainedModel):
|
|
|
|
|
45 |
def __init__(self, config):
|
46 |
super().__init__(config)
|
47 |
self.num_labels = config.num_labels
|
|
|
2 |
from typing import Optional, Tuple
|
3 |
import torch
|
4 |
from transformers.file_utils import ModelOutput
|
5 |
+
from transformers import AutoConfig
|
6 |
|
7 |
|
8 |
@dataclass
|
|
|
25 |
|
26 |
class Wav2Vec2ClassificationHead(nn.Module):
|
27 |
"""Head for wav2vec classification task."""
|
28 |
+
config_class = AutoConfig
|
29 |
|
30 |
def __init__(self, config):
|
31 |
super().__init__()
|
|
|
44 |
|
45 |
|
46 |
class Wav2Vec2ForSpeechClassification(Wav2Vec2PreTrainedModel):
|
47 |
+
config_class = AutoConfig
|
48 |
+
|
49 |
def __init__(self, config):
|
50 |
super().__init__(config)
|
51 |
self.num_labels = config.num_labels
|