Ar4ikov commited on
Commit
8b9f2c3
1 Parent(s): 71c53d9

Update wav2vec2speechclassification.py

Browse files
Files changed (1) hide show
  1. wav2vec2speechclassification.py +3 -3
wav2vec2speechclassification.py CHANGED
@@ -2,7 +2,7 @@ from dataclasses import dataclass
2
  from typing import Optional, Tuple
3
  import torch
4
  from transformers.file_utils import ModelOutput
5
- from transformers import AutoConfig
6
 
7
 
8
  @dataclass
@@ -25,7 +25,7 @@ from transformers.models.wav2vec2.modeling_wav2vec2 import (
25
 
26
  class Wav2Vec2ClassificationHead(nn.Module):
27
  """Head for wav2vec classification task."""
28
- config_class = AutoConfig
29
 
30
  def __init__(self, config):
31
  super().__init__()
@@ -44,7 +44,7 @@ class Wav2Vec2ClassificationHead(nn.Module):
44
 
45
 
46
  class Wav2Vec2ForSpeechClassification(Wav2Vec2PreTrainedModel):
47
- config_class = AutoConfig
48
 
49
  def __init__(self, config):
50
  super().__init__(config)
 
2
  from typing import Optional, Tuple
3
  import torch
4
  from transformers.file_utils import ModelOutput
5
+ from transformers import Wav2Vec2Config
6
 
7
 
8
  @dataclass
 
25
 
26
  class Wav2Vec2ClassificationHead(nn.Module):
27
  """Head for wav2vec classification task."""
28
+ config_class = Wav2Vec2Config
29
 
30
  def __init__(self, config):
31
  super().__init__()
 
44
 
45
 
46
  class Wav2Vec2ForSpeechClassification(Wav2Vec2PreTrainedModel):
47
+ config_class = Wav2Vec2Config
48
 
49
  def __init__(self, config):
50
  super().__init__(config)