zhniu commited on
Commit
2e005a5
·
verified ·
1 Parent(s): 31a0f29

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +82 -1
README.md CHANGED
@@ -4,4 +4,85 @@ language:
4
  - en
5
  base_model:
6
  - facebook/wav2vec2-base
7
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  - en
5
  base_model:
6
  - facebook/wav2vec2-base
7
+ ---
8
+
9
+ SCD(Speaker Change Detection,讲者变化检测):是指在音频或视频内容中识别出讲话者发生变化的技术。它通常被应用于多讲者的对话或演讲场景中,以此来检测何时从一个讲者切换到另一个讲者。
10
+
11
+ 如何使用
12
+ # Note: at the time this code was originally written, transformers.Wav2Vec2ForAudioFrameClassification was incomplete
13
+ # -> this adds the then-missing parts
14
+ class Wav2Vec2ForAudioFrameClassification_custom(transformers.Wav2Vec2ForAudioFrameClassification,
15
+ PyTorchModelHubMixin,
16
+ repo_url="your-repo-url",
17
+ pipeline_tag="text-to-image",
18
+ license="mit",):
19
+ def __init__(self, config):
20
+ super().__init__(config)
21
+ self.num_labels = config.num_labels
22
+
23
+ if hasattr(config, "add_adapter") and config.add_adapter:
24
+ raise ValueError(
25
+ "Audio frame classification does not support the use of Wav2Vec2 adapters (config.add_adapter=True)"
26
+ )
27
+ self.wav2vec2 = Wav2Vec2Model(config)
28
+ num_layers = config.num_hidden_layers + 1 # transformer layers + input embeddings
29
+ if config.use_weighted_layer_sum:
30
+ self.layer_weights = nn.Parameter(torch.ones(num_layers) / num_layers)
31
+ self.classifier = nn.Linear(config.hidden_size, config.num_labels)
32
+
33
+ self.init_weights()
34
+
35
+ def forward(
36
+ self,
37
+ input_values,
38
+ attention_mask=None,
39
+ output_attentions=None,
40
+ output_hidden_states=None,
41
+ return_dict=None,
42
+ labels=None, # ADDED
43
+ ):
44
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
45
+ output_hidden_states = True if self.config.use_weighted_layer_sum else output_hidden_states
46
+
47
+ outputs = self.wav2vec2(
48
+ input_values,
49
+ attention_mask=attention_mask,
50
+ output_attentions=output_attentions,
51
+ output_hidden_states=output_hidden_states,
52
+ return_dict=return_dict,
53
+ )
54
+
55
+ if self.config.use_weighted_layer_sum:
56
+ hidden_states = outputs[_HIDDEN_STATES_START_POSITION]
57
+ hidden_states = torch.stack(hidden_states, dim=1)
58
+ norm_weights = nn.functional.softmax(self.layer_weights, dim=-1)
59
+ hidden_states = (hidden_states * norm_weights.view(-1, 1, 1)).sum(dim=1)
60
+ else:
61
+ hidden_states = outputs[0]
62
+
63
+ logits = self.classifier(hidden_states)
64
+ labels = labels.reshape(-1,1) # 1xN -> Nx1
65
+
66
+ # ADDED
67
+ loss = None
68
+ if labels is not None:
69
+ if self.num_labels == 1:
70
+ loss_fct = MSELoss()
71
+ #loss = loss_fct(logits.squeeze(), labels.squeeze())
72
+ loss = loss_fct(logits.view(-1, self.num_labels), labels)
73
+ else:
74
+ loss_fct = CrossEntropyLoss()
75
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
76
+
77
+
78
+ if not return_dict:
79
+ output = (logits,) + outputs[_HIDDEN_STATES_START_POSITION:]
80
+ return ((loss,) + output) if loss is not None else output
81
+
82
+ return TokenClassifierOutput(
83
+ loss=loss,
84
+ logits=logits,
85
+ hidden_states=outputs.hidden_states,
86
+ attentions=outputs.attentions,
87
+ )
88
+