Spaces:
Runtime error
Runtime error
/* | |
* Copyright (c) Facebook, Inc. and its affiliates. | |
* | |
* This source code is licensed under the MIT license found in the | |
* LICENSE file in the root directory of this source tree. | |
*/ | |
/* | |
* This program is to modify a FST without self-loop by: | |
* for each incoming arc with non-eps input symbol, add a self-loop arc | |
* with that non-eps symbol as input and eps as output. | |
* | |
* This is to make sure the resultant FST can do deduplication for repeated | |
* symbols, which is very common in acoustic model | |
* | |
*/ | |
namespace { | |
int32 AddSelfLoopsSimple(fst::StdVectorFst* fst) { | |
typedef fst::MutableArcIterator<fst::StdVectorFst> IterType; | |
int32 num_states_before = fst->NumStates(); | |
fst::MakePrecedingInputSymbolsSame(false, fst); | |
int32 num_states_after = fst->NumStates(); | |
KALDI_LOG << "There are " << num_states_before | |
<< " states in the original FST; " | |
<< " after MakePrecedingInputSymbolsSame, there are " | |
<< num_states_after << " states " << std::endl; | |
auto weight_one = fst::StdArc::Weight::One(); | |
int32 num_arc_added = 0; | |
fst::StdArc self_loop_arc; | |
self_loop_arc.weight = weight_one; | |
int32 num_states = fst->NumStates(); | |
std::vector<std::set<int32>> incoming_non_eps_label_per_state(num_states); | |
for (int32 state = 0; state < num_states; state++) { | |
for (IterType aiter(fst, state); !aiter.Done(); aiter.Next()) { | |
fst::StdArc arc(aiter.Value()); | |
if (arc.ilabel != 0) { | |
incoming_non_eps_label_per_state[arc.nextstate].insert(arc.ilabel); | |
} | |
} | |
} | |
for (int32 state = 0; state < num_states; state++) { | |
if (!incoming_non_eps_label_per_state[state].empty()) { | |
auto& ilabel_set = incoming_non_eps_label_per_state[state]; | |
for (auto it = ilabel_set.begin(); it != ilabel_set.end(); it++) { | |
self_loop_arc.ilabel = *it; | |
self_loop_arc.olabel = 0; | |
self_loop_arc.nextstate = state; | |
fst->AddArc(state, self_loop_arc); | |
num_arc_added++; | |
} | |
} | |
} | |
return num_arc_added; | |
} | |
void print_usage() { | |
std::cout << "add-self-loop-simple usage:\n" | |
"\tadd-self-loop-simple <in-fst> <out-fst> \n"; | |
} | |
} // namespace | |
int main(int argc, char** argv) { | |
if (argc != 3) { | |
print_usage(); | |
exit(1); | |
} | |
auto input = argv[1]; | |
auto output = argv[2]; | |
auto fst = fst::ReadFstKaldi(input); | |
auto num_states = fst->NumStates(); | |
KALDI_LOG << "Loading FST from " << input << " with " << num_states | |
<< " states." << std::endl; | |
int32 num_arc_added = AddSelfLoopsSimple(fst); | |
KALDI_LOG << "Adding " << num_arc_added << " self-loop arcs " << std::endl; | |
fst::WriteFstKaldi(*fst, std::string(output)); | |
KALDI_LOG << "Writing FST to " << output << std::endl; | |
delete fst; | |
} | |