|
import torch |
|
import pickle |
|
from transformers import AutoTokenizer , DistilBertForSequenceClassification , CamembertForSequenceClassification |
|
from transformers import BatchEncoding, PreTrainedTokenizerBase |
|
from typing import Optional |
|
from torch import Tensor |
|
import numpy as np |
|
from random import shuffle |
|
from Model import BERT |
|
from FrModel import FR_BERT |
|
from Model import tokenizer , mult_token_id , cls_token_id , pad_token_id , max_pred , maxlen , sep_token_id |
|
from FrModel import fr_tokenizer , fr_mult_token_id , fr_cls_token_id , fr_pad_token_id , fr_sep_token_id |
|
from transformers import pipeline |
|
from transformers import AutoModelForCTC, Wav2Vec2Processor |
|
import torchaudio |
|
import logging |
|
import soundfile as sf |
|
|
|
device = "cpu" |
|
|
|
def load_models(): |
|
print("Loading DistilBERT model...") |
|
model = DistilBertForSequenceClassification.from_pretrained("DistillMDPI1/DistillMDPI1/saved_model") |
|
|
|
print("Loading BERT model...") |
|
neptune = BERT() |
|
device = "cpu" |
|
model_save_path = "neptune_270_papers/neptune_270_papers/model.pt" |
|
neptune.load_state_dict(torch.load(model_save_path, map_location=torch.device('cpu'))) |
|
neptune.to(device) |
|
|
|
print("Loading speech recognition pipeline...") |
|
pipe = pipeline( |
|
"automatic-speech-recognition", |
|
model="openai/whisper-tiny.en", |
|
chunk_length_s=30, |
|
device=device, |
|
) |
|
print(pipe) |
|
|
|
with open("DistillMDPI1/DistillMDPI1/label_encoder.pkl", "rb") as f: |
|
label_encoder = pickle.load(f) |
|
|
|
return model, neptune, pipe |
|
|
|
def load_fr_models(): |
|
print("Loading Camembert model") |
|
fr_model = CamembertForSequenceClassification.from_pretrained("Camembert/Camembert/saved_model") |
|
print("Loading BERT model...") |
|
fr_neptune = FR_BERT() |
|
device = "cpu" |
|
model_save_path = "fr_neptune/fr_neptune/model.pt" |
|
fr_neptune.load_state_dict(torch.load(model_save_path, map_location=torch.device('cpu'))) |
|
fr_neptune.to(device) |
|
print("Loading Wav2Vec2 model for French...") |
|
wav2vec2_processor = Wav2Vec2Processor.from_pretrained("bhuang/asr-wav2vec2-french") |
|
wav2vec2_model = AutoModelForCTC.from_pretrained("bhuang/asr-wav2vec2-french").to(device) |
|
return fr_model, fr_neptune, wav2vec2_processor, wav2vec2_model |
|
|
|
fr_class_labels = {0: ('Physics', 'primary', '#5e7cc8'), 1: ('AI','cyan', '#0dcaf0'), |
|
2: ('economies', 'warning' , '#f7c32e'), 3: ('environments','success' , '#0cbc87'), |
|
4: ('sports', 'orange', '#fd7e14')} |
|
class_labels = { |
|
16: ('vehicles','info' , '#4f9ef8'), |
|
10: ('environments','success' , '#0cbc87'), |
|
9: ('energies', 'danger', '#d6293e'), |
|
0: ('Physics', 'primary', '#0f6fec'), |
|
13: ('robotics', 'moss','#B1E5F2'), |
|
3: ('agriculture','agri' , '#a8c686'), |
|
11: ('ML', 'yellow', '#ffc107'), |
|
8: ('economies', 'warning' , '#f7c32e'), |
|
15: ('technologies','vanila' ,'#FDF0D5' ), |
|
12: ('mathematics','coffe' ,'#7f5539' ), |
|
14: ('sports', 'orange', '#fd7e14'), |
|
4: ('AI','cyan', '#0dcaf0'), |
|
6: ('Innovation','rosy' ,'#BF98A0'), |
|
5: ('Science','picton' ,'#5fa8d3' ), |
|
1: ('Societies','purple' , '#6f42c1'), |
|
2: ('administration','pink', '#d63384'), |
|
7: ('biology' ,'cambridge' , '#88aa99')} |
|
|
|
def predict_class(text,model): |
|
|
|
inputs = transform_list_of_texts([text], tokenizer, 510, 510, 1, 2550) |
|
|
|
|
|
|
|
all_probabilities = [] |
|
|
|
|
|
model.eval() |
|
with torch.no_grad(): |
|
for i, sample in enumerate(inputs['input_ids']): |
|
for j in range(len(sample)): |
|
input_ids_tensor = torch.tensor(sample[j], device=device).unsqueeze(0) |
|
attention_mask_tensor = torch.tensor(inputs['attention_mask'][i][j], device=device).unsqueeze(0) |
|
outputs = model(input_ids=input_ids_tensor, attention_mask=attention_mask_tensor) |
|
|
|
|
|
probabilities = torch.softmax(outputs.logits, dim=1)[0] |
|
all_probabilities.append(probabilities) |
|
|
|
|
|
if len(all_probabilities) > 1: |
|
mean_probabilities = torch.stack(all_probabilities).mean(dim=0) |
|
else: |
|
mean_probabilities = all_probabilities[0] |
|
|
|
|
|
predicted_class_index = torch.argmax(mean_probabilities).item() |
|
predicted_class = class_labels[predicted_class_index] |
|
|
|
|
|
sorted_percentages = {class_labels[idx]: mean_probabilities[idx].item() * 100 for idx in range(len(class_labels))} |
|
print(sorted_percentages) |
|
sorted_percentages = dict(sorted(sorted_percentages.items(), key=lambda item: item[1], reverse=True)) |
|
|
|
return predicted_class, sorted_percentages |
|
|
|
def predict_class_for_Neptune(text,model): |
|
|
|
encoded_text = transform_for_inference_text(text, tokenizer, 125, 125, 1, 2550) |
|
batch, sentences = prepare_text(encoded_text) |
|
|
|
|
|
model.eval() |
|
all_probabilities = [] |
|
with torch.no_grad(): |
|
for sample in batch: |
|
input_ids = torch.tensor(sample[0], device=device, dtype=torch.long).unsqueeze(0) |
|
segment_ids = torch.tensor(sample[1], device=device, dtype=torch.long).unsqueeze(0) |
|
masked_pos = torch.tensor(sample[2], device=device, dtype=torch.long).unsqueeze(0) |
|
|
|
_, _, logits_mclsf1, logits_mclsf2 = model(input_ids, segment_ids, masked_pos) |
|
probabilities1 = torch.softmax(logits_mclsf1, dim=1)[0] |
|
probabilities2 = torch.softmax(logits_mclsf2, dim=1)[0] |
|
all_probabilities.extend([probabilities1, probabilities2]) |
|
|
|
|
|
aggregated_probabilities = torch.stack(all_probabilities).mean(dim=0) |
|
|
|
|
|
predicted_class_index = torch.argmax(aggregated_probabilities).item() |
|
predicted_class = class_labels[predicted_class_index] |
|
|
|
|
|
sorted_percentages = {class_labels[idx]: aggregated_probabilities[idx].item() * 100 for idx in range(len(class_labels))} |
|
sorted_percentages = dict(sorted(sorted_percentages.items(), key=lambda item: item[1], reverse=True)) |
|
|
|
return predicted_class, sorted_percentages |
|
|
|
def predict_sentences_class(text,model): |
|
|
|
inputs = transform_list_of_texts([text], tokenizer, 510, 510, 1, 2550) |
|
aligned_predictions = {} |
|
|
|
|
|
model.eval() |
|
with torch.no_grad(): |
|
for i, sample in enumerate(inputs['input_ids']): |
|
for j in range(len(sample)): |
|
input_ids_tensor = sample[j].clone().detach().to(device).unsqueeze(0) |
|
attention_mask_tensor = inputs['attention_mask'][i][j].clone().detach().to(device).unsqueeze(0) |
|
|
|
|
|
sentence = tokenizer.decode(input_ids_tensor[0], skip_special_tokens=True) |
|
|
|
|
|
outputs = model(input_ids=input_ids_tensor, attention_mask=attention_mask_tensor) |
|
|
|
|
|
predicted_class_index = torch.argmax(outputs.logits, dim=1).item() |
|
predicted_class = class_labels[predicted_class_index] |
|
|
|
|
|
if sentence not in aligned_predictions: |
|
aligned_predictions[sentence] = predicted_class |
|
|
|
return aligned_predictions |
|
|
|
|
|
def transform_list_of_texts( |
|
texts: list[str], |
|
tokenizer: PreTrainedTokenizerBase, |
|
chunk_size: int, |
|
stride: int, |
|
minimal_chunk_length: int, |
|
maximal_text_length: Optional[int] = None, |
|
) -> BatchEncoding: |
|
model_inputs = [ |
|
transform_single_text(text, tokenizer, chunk_size, stride, minimal_chunk_length, maximal_text_length) |
|
for text in texts |
|
] |
|
input_ids = [model_input[0] for model_input in model_inputs] |
|
attention_mask = [model_input[1] for model_input in model_inputs] |
|
tokens = {"input_ids": input_ids, "attention_mask": attention_mask} |
|
return BatchEncoding(tokens) |
|
|
|
|
|
def transform_single_text( |
|
text: str, |
|
tokenizer: PreTrainedTokenizerBase, |
|
chunk_size: int, |
|
stride: int, |
|
minimal_chunk_length: int, |
|
maximal_text_length: Optional[int], |
|
) -> tuple[Tensor, Tensor]: |
|
"""Transforms (the entire) text to model input of BERT model.""" |
|
if maximal_text_length: |
|
tokens = tokenize_text_with_truncation(text, tokenizer, maximal_text_length) |
|
else: |
|
tokens = tokenize_whole_text(text, tokenizer) |
|
input_id_chunks, mask_chunks = split_tokens_into_smaller_chunks(tokens, chunk_size, stride, minimal_chunk_length) |
|
add_special_tokens_at_beginning_and_end(input_id_chunks, mask_chunks) |
|
add_padding_tokens(input_id_chunks, mask_chunks , chunk_size) |
|
input_ids, attention_mask = stack_tokens_from_all_chunks(input_id_chunks, mask_chunks) |
|
return input_ids, attention_mask |
|
|
|
|
|
def tokenize_whole_text(text: str, tokenizer: PreTrainedTokenizerBase) -> BatchEncoding: |
|
"""Tokenizes the entire text without truncation and without special tokens.""" |
|
tokens = tokenizer(text, add_special_tokens=False, truncation=False, return_tensors="pt") |
|
return tokens |
|
|
|
|
|
def tokenize_text_with_truncation( |
|
text: str, tokenizer: PreTrainedTokenizerBase, maximal_text_length: int |
|
) -> BatchEncoding: |
|
"""Tokenizes the text with truncation to maximal_text_length and without special tokens.""" |
|
tokens = tokenizer( |
|
text, add_special_tokens=False, max_length=maximal_text_length, truncation=True, return_tensors="pt" |
|
) |
|
return tokens |
|
|
|
|
|
def split_tokens_into_smaller_chunks( |
|
tokens: BatchEncoding, |
|
chunk_size: int, |
|
stride: int, |
|
minimal_chunk_length: int, |
|
) -> tuple[list[Tensor], list[Tensor]]: |
|
"""Splits tokens into overlapping chunks with given size and stride.""" |
|
input_id_chunks = split_overlapping(tokens["input_ids"][0], chunk_size, stride, minimal_chunk_length) |
|
mask_chunks = split_overlapping(tokens["attention_mask"][0], chunk_size, stride, minimal_chunk_length) |
|
return input_id_chunks, mask_chunks |
|
|
|
|
|
def add_special_tokens_at_beginning_and_end(input_id_chunks: list[Tensor], mask_chunks: list[Tensor]) -> None: |
|
""" |
|
Adds special CLS token (token id = 101) at the beginning. |
|
Adds SEP token (token id = 102) at the end of each chunk. |
|
Adds corresponding attention masks equal to 1 (attention mask is boolean). |
|
""" |
|
for i in range(len(input_id_chunks)): |
|
|
|
input_id_chunks[i] = torch.cat([Tensor([101]), input_id_chunks[i], Tensor([102])]) |
|
|
|
mask_chunks[i] = torch.cat([Tensor([1]), mask_chunks[i], Tensor([1])]) |
|
|
|
|
|
def add_padding_tokens(input_id_chunks: list[Tensor], mask_chunks: list[Tensor] , chunk_size) -> None: |
|
"""Adds padding tokens (token id = 0) at the end to make sure that all chunks have exactly 512 tokens.""" |
|
for i in range(len(input_id_chunks)): |
|
|
|
pad_len = chunk_size + 2 - input_id_chunks[i].shape[0] |
|
|
|
if pad_len > 0: |
|
|
|
input_id_chunks[i] = torch.cat([input_id_chunks[i], Tensor([0] * pad_len)]) |
|
mask_chunks[i] = torch.cat([mask_chunks[i], Tensor([0] * pad_len)]) |
|
|
|
|
|
def stack_tokens_from_all_chunks(input_id_chunks: list[Tensor], mask_chunks: list[Tensor]) -> tuple[Tensor, Tensor]: |
|
"""Reshapes data to a form compatible with BERT model input.""" |
|
input_ids = torch.stack(input_id_chunks) |
|
attention_mask = torch.stack(mask_chunks) |
|
|
|
return input_ids.long(), attention_mask.int() |
|
|
|
|
|
def split_overlapping(tensor: Tensor, chunk_size: int, stride: int, minimal_chunk_length: int) -> list[Tensor]: |
|
"""Helper function for dividing 1-dimensional tensors into overlapping chunks.""" |
|
result = [tensor[i : i + chunk_size] for i in range(0, len(tensor), stride)] |
|
if len(result) > 1: |
|
|
|
result = [x for x in result if len(x) >= minimal_chunk_length] |
|
return result |
|
|
|
|
|
|
|
def stack_tokens_from_all_chunks_for_inference(input_id_chunks: list[Tensor], mask_chunks: list[Tensor]) -> tuple[Tensor, Tensor]: |
|
"""Reshapes data to a form compatible with BERT model input.""" |
|
input_ids = torch.stack(input_id_chunks) |
|
attention_mask = torch.stack(mask_chunks) |
|
|
|
return input_ids.long(), attention_mask.int() |
|
|
|
def transform_for_inference_text(text: str, |
|
tokenizer: PreTrainedTokenizerBase, |
|
chunk_size: int, |
|
stride: int, |
|
minimal_chunk_length: int, |
|
maximal_text_length: Optional[int],) -> BatchEncoding: |
|
if maximal_text_length: |
|
tokens = tokenize_text_with_truncation(text, tokenizer, maximal_text_length) |
|
else: |
|
tokens = tokenize_whole_text(text, tokenizer) |
|
input_id_chunks, mask_chunks = split_tokens_into_smaller_chunks(tokens, chunk_size, stride, minimal_chunk_length) |
|
add_special_tokens_at_beginning_and_end_inference(input_id_chunks, mask_chunks) |
|
add_padding_tokens_inference(input_id_chunks, mask_chunks, chunk_size) |
|
input_ids, attention_mask = stack_tokens_from_all_chunks_for_inference(input_id_chunks, mask_chunks) |
|
return {"input_ids": input_ids, "attention_mask": attention_mask} |
|
|
|
def add_special_tokens_at_beginning_and_end_inference(input_id_chunks: list[Tensor], mask_chunks: list[Tensor]) -> None: |
|
""" |
|
Adds special MULT token, CLS token at the beginning. |
|
Adds SEP token at the end of each chunk. |
|
Adds corresponding attention masks equal to 1 (attention mask is boolean). |
|
""" |
|
for i in range(len(input_id_chunks)): |
|
|
|
input_id_chunks[i] = torch.cat([input_id_chunks[i]]) |
|
|
|
mask_chunks[i] = torch.cat([mask_chunks[i]]) |
|
|
|
def add_padding_tokens_inference(input_id_chunks: list[Tensor], mask_chunks: list[Tensor], chunk_size: int) -> None: |
|
"""Adds padding tokens at the end to make sure that all chunks have exactly chunk_size tokens.""" |
|
pad_token_id = 0 |
|
for i in range(len(input_id_chunks)): |
|
|
|
pad_len = chunk_size - input_id_chunks[i].shape[0] |
|
|
|
if pad_len > 0: |
|
|
|
input_id_chunks[i] = torch.cat([input_id_chunks[i], torch.tensor([pad_token_id] * pad_len)]) |
|
mask_chunks[i] = torch.cat([mask_chunks[i], torch.tensor([0] * pad_len)]) |
|
|
|
def prepare_text(tokens_splitted: BatchEncoding): |
|
batch = [] |
|
sentences = [] |
|
input_ids_list = tokens_splitted['input_ids'] |
|
|
|
for i in range(0, len(input_ids_list), 2): |
|
k = i + 1 |
|
if k == len(input_ids_list): |
|
input_ids_a = input_ids_list[i] |
|
input_ids_a = [token for token in input_ids_a.view(-1).tolist() if token != pad_token_id] |
|
input_ids_b = [] |
|
input_ids = [cls_token_id] + [mult_token_id] + input_ids_a + [sep_token_id] + [mult_token_id] + input_ids_b + [sep_token_id] |
|
text_input_a = tokenizer.decode(input_ids_a) |
|
sentences.append(text_input_a) |
|
segment_ids = [0] * (1 + 1 + len(input_ids_a) + 1) + [1] * (1 + len(input_ids_b) + 1) |
|
|
|
|
|
n_pred = min(max_pred, max(1, int(round(len(input_ids) * 0.15)))) |
|
cand_masked_pos = [idx for idx, token in enumerate(input_ids) if token not in [cls_token_id, sep_token_id, mult_token_id]] |
|
shuffle(cand_masked_pos) |
|
masked_tokens, masked_pos = [], [] |
|
for pos in cand_masked_pos[:n_pred]: |
|
masked_pos.append(pos) |
|
masked_tokens.append(input_ids[pos]) |
|
input_ids[pos] = tokenizer.mask_token_id |
|
|
|
|
|
n_pad = maxlen - len(input_ids) |
|
input_ids.extend([pad_token_id] * n_pad) |
|
segment_ids.extend([0] * n_pad) |
|
|
|
|
|
if max_pred > n_pred: |
|
n_pad = max_pred - n_pred |
|
masked_tokens.extend([0] * n_pad) |
|
masked_pos.extend([0] * n_pad) |
|
else: |
|
input_ids_a = input_ids_list[i] |
|
input_ids_b = input_ids_list[k] |
|
input_ids_a = [token for token in input_ids_a.view(-1).tolist() if token != pad_token_id] |
|
input_ids_b = [token for token in input_ids_b.view(-1).tolist() if token != pad_token_id] |
|
input_ids = [cls_token_id] + [mult_token_id] + input_ids_a + [sep_token_id] + [mult_token_id] + input_ids_b + [sep_token_id] |
|
segment_ids = [0] * (1 + 1 + len(input_ids_a) + 1) + [1] * (1 + len(input_ids_b) + 1) |
|
text_input_a = tokenizer.decode(input_ids_a) |
|
text_input_b = tokenizer.decode(input_ids_b) |
|
sentences.append(text_input_a) |
|
sentences.append(text_input_b) |
|
|
|
|
|
n_pred = min(max_pred, max(1, int(round(len(input_ids) * 0.15)))) |
|
cand_masked_pos = [idx for idx, token in enumerate(input_ids) if token not in [cls_token_id, sep_token_id, mult_token_id]] |
|
shuffle(cand_masked_pos) |
|
masked_tokens, masked_pos = [], [] |
|
for pos in cand_masked_pos[:n_pred]: |
|
masked_pos.append(pos) |
|
masked_tokens.append(input_ids[pos]) |
|
input_ids[pos] = tokenizer.mask_token_id |
|
|
|
|
|
n_pad = maxlen - len(input_ids) |
|
input_ids.extend([pad_token_id] * n_pad) |
|
segment_ids.extend([0] * n_pad) |
|
|
|
|
|
if max_pred > n_pred: |
|
n_pad = max_pred - n_pred |
|
masked_tokens.extend([0] * n_pad) |
|
masked_pos.extend([0] * n_pad) |
|
|
|
batch.append([input_ids, segment_ids, masked_pos]) |
|
return batch, sentences |
|
|
|
def inference(text: str): |
|
encoded_text = transform_for_inference_text(text, tokenizer, 125, 125, 1, 2550) |
|
batch, sentences = prepare_text(encoded_text) |
|
return batch, sentences |
|
|
|
def predict(inference_batch,neptune , device = device): |
|
all_preds_mult1 = [] |
|
neptune.eval() |
|
with torch.no_grad(): |
|
for batch in inference_batch: |
|
input_ids = torch.tensor(batch[0], device=device, dtype=torch.long).unsqueeze(0) |
|
segment_ids = torch.tensor(batch[1], device=device, dtype=torch.long).unsqueeze(0) |
|
masked_pos = torch.tensor(batch[2], device=device, dtype=torch.long).unsqueeze(0) |
|
_, _, logits_mclsf1, logits_mclsf2 = neptune(input_ids, segment_ids, masked_pos) |
|
preds_mult1 = torch.argmax(logits_mclsf1, dim=1).cpu().detach().numpy() |
|
preds_mult2 = torch.argmax(logits_mclsf2, dim=1).cpu().detach().numpy() |
|
|
|
all_preds_mult1.extend(preds_mult1) |
|
all_preds_mult1.extend(preds_mult2) |
|
|
|
return all_preds_mult1 |
|
|
|
def align_predictions_with_sentences(sentences, preds): |
|
dc = {} |
|
for sentence, pred in zip(sentences, preds): |
|
dc[sentence] = class_labels.get(pred, "Unknown") |
|
return dc |
|
|
|
|
|
def predict_fr_class(text , model): |
|
|
|
inputs = transform_list_of_fr_texts(text, fr_tokenizer, 126, 30, 1, 2550) |
|
|
|
input_ids_tensor = inputs["input_ids"][0] |
|
attention_mask_tensor = inputs["attention_mask"][0] |
|
|
|
with torch.no_grad(): |
|
outputs = model(input_ids=input_ids_tensor, attention_mask=attention_mask_tensor) |
|
|
|
|
|
probabilities = torch.softmax(outputs.logits, dim=1)[0] |
|
|
|
|
|
predicted_class_index = torch.argmax(probabilities).item() |
|
predicted_class = fr_class_labels[predicted_class_index] |
|
|
|
|
|
sorted_percentages = {fr_class_labels[idx]: probabilities[idx].item() * 100 for idx in range(len(fr_class_labels))} |
|
sorted_percentages = dict(sorted(sorted_percentages.items(), key=lambda item: item[1], reverse=True)) |
|
|
|
return predicted_class, sorted_percentages |
|
|
|
def prepare_fr_text(tokens_splitted: BatchEncoding): |
|
batch = [] |
|
sentences = [] |
|
input_ids_list = tokens_splitted['input_ids'] |
|
|
|
for i in range(0, len(input_ids_list), 2): |
|
k = i + 1 |
|
if k == len(input_ids_list): |
|
input_ids_a = input_ids_list[i] |
|
input_ids_a = [token for token in input_ids_a.view(-1).tolist() if token != pad_token_id] |
|
input_ids_b = [] |
|
input_ids = [fr_cls_token_id] + [fr_mult_token_id] + input_ids_a + [fr_sep_token_id] + [fr_mult_token_id] + input_ids_b + [fr_sep_token_id] |
|
text_input_a = fr_tokenizer.decode(input_ids_a , skip_special_tokens=True) |
|
sentences.append(text_input_a) |
|
segment_ids = [0] * (1 + 1 + len(input_ids_a) + 1) + [1] * (1 + len(input_ids_b) + 1) |
|
|
|
|
|
n_pred = min(max_pred, max(1, int(round(len(input_ids) * 0.15)))) |
|
cand_masked_pos = [idx for idx, token in enumerate(input_ids) if token not in [fr_cls_token_id, fr_sep_token_id, fr_mult_token_id]] |
|
shuffle(cand_masked_pos) |
|
masked_tokens, masked_pos = [], [] |
|
for pos in cand_masked_pos[:n_pred]: |
|
masked_pos.append(pos) |
|
masked_tokens.append(input_ids[pos]) |
|
input_ids[pos] = fr_tokenizer.mask_token_id |
|
|
|
|
|
n_pad = maxlen - len(input_ids) |
|
input_ids.extend([fr_pad_token_id] * n_pad) |
|
segment_ids.extend([0] * n_pad) |
|
|
|
|
|
if max_pred > n_pred: |
|
n_pad = max_pred - n_pred |
|
masked_tokens.extend([0] * n_pad) |
|
masked_pos.extend([0] * n_pad) |
|
else: |
|
input_ids_a = input_ids_list[i] |
|
input_ids_b = input_ids_list[k] |
|
input_ids_a = [token for token in input_ids_a.view(-1).tolist() if token != pad_token_id] |
|
input_ids_b = [token for token in input_ids_b.view(-1).tolist() if token != pad_token_id] |
|
input_ids = [fr_cls_token_id] + [fr_mult_token_id] + input_ids_a + [fr_sep_token_id] + [fr_mult_token_id] + input_ids_b + [fr_sep_token_id] |
|
segment_ids = [0] * (1 + 1 + len(input_ids_a) + 1) + [1] * (1 + len(input_ids_b) + 1) |
|
text_input_a = fr_tokenizer.decode(input_ids_a , skip_special_tokens=True) |
|
text_input_b = fr_tokenizer.decode(input_ids_b, skip_special_tokens=True) |
|
sentences.append(text_input_a) |
|
sentences.append(text_input_b) |
|
|
|
|
|
n_pred = min(max_pred, max(1, int(round(len(input_ids) * 0.15)))) |
|
cand_masked_pos = [idx for idx, token in enumerate(input_ids) if token not in [fr_cls_token_id, fr_sep_token_id, fr_mult_token_id]] |
|
shuffle(cand_masked_pos) |
|
masked_tokens, masked_pos = [], [] |
|
for pos in cand_masked_pos[:n_pred]: |
|
masked_pos.append(pos) |
|
masked_tokens.append(input_ids[pos]) |
|
input_ids[pos] = fr_tokenizer.mask_token_id |
|
|
|
|
|
n_pad = maxlen - len(input_ids) |
|
input_ids.extend([fr_pad_token_id] * n_pad) |
|
segment_ids.extend([0] * n_pad) |
|
|
|
|
|
if max_pred > n_pred: |
|
n_pad = max_pred - n_pred |
|
masked_tokens.extend([0] * n_pad) |
|
masked_pos.extend([0] * n_pad) |
|
|
|
batch.append([input_ids, segment_ids, masked_pos]) |
|
return batch, sentences |
|
|
|
def fr_inference(text: str): |
|
encoded_text = transform_for_inference_fr_text(text, fr_tokenizer, 125, 125, 1, 2550) |
|
batch, sentences = prepare_fr_text(encoded_text) |
|
return batch, sentences |
|
|
|
def align_fr_predictions_with_sentences(sentences, preds): |
|
dc = {} |
|
for sentence, pred in zip(sentences, preds): |
|
dc[sentence] = fr_class_labels.get(pred, "Unknown") |
|
return dc |
|
|
|
def transform_for_inference_fr_text(text: str, |
|
tokenizer: PreTrainedTokenizerBase, |
|
chunk_size: int, |
|
stride: int, |
|
minimal_chunk_length: int, |
|
maximal_text_length: Optional[int],) -> BatchEncoding: |
|
if maximal_text_length: |
|
tokens = tokenize_text_with_truncation(text, tokenizer, maximal_text_length) |
|
else: |
|
tokens = tokenize_whole_text(text, tokenizer) |
|
input_id_chunks, mask_chunks = split_tokens_into_smaller_chunks(tokens, chunk_size, stride, minimal_chunk_length) |
|
add_special_tokens_at_beginning_and_end_inference(input_id_chunks, mask_chunks) |
|
add_padding_fr_tokens_inference(input_id_chunks, mask_chunks, chunk_size) |
|
input_ids, attention_mask = stack_tokens_from_all_chunks_for_inference(input_id_chunks, mask_chunks) |
|
return {"input_ids": input_ids, "attention_mask": attention_mask} |
|
|
|
def add_padding_fr_tokens_inference(input_id_chunks: list[Tensor], mask_chunks: list[Tensor], chunk_size: int) -> None: |
|
"""Adds padding tokens at the end to make sure that all chunks have exactly chunk_size tokens.""" |
|
pad_token_id = 1 |
|
for i in range(len(input_id_chunks)): |
|
|
|
pad_len = chunk_size - input_id_chunks[i].shape[0] |
|
|
|
if pad_len > 0: |
|
|
|
input_id_chunks[i] = torch.cat([input_id_chunks[i], torch.tensor([pad_token_id] * pad_len)]) |
|
mask_chunks[i] = torch.cat([mask_chunks[i], torch.tensor([0] * pad_len)]) |
|
|
|
|
|
def transform_list_of_fr_texts( |
|
texts: list[str], |
|
tokenizer: PreTrainedTokenizerBase, |
|
chunk_size: int, |
|
stride: int, |
|
minimal_chunk_length: int, |
|
maximal_text_length: Optional[int] = None, |
|
) -> BatchEncoding: |
|
model_inputs = [ |
|
transform_single_fr_text(text, tokenizer, chunk_size, stride, minimal_chunk_length, maximal_text_length) |
|
for text in texts |
|
] |
|
input_ids = [model_input[0] for model_input in model_inputs] |
|
attention_mask = [model_input[1] for model_input in model_inputs] |
|
tokens = {"input_ids": input_ids, "attention_mask": attention_mask} |
|
return BatchEncoding(tokens) |
|
|
|
|
|
def transform_single_fr_text( |
|
text: str, |
|
tokenizer: PreTrainedTokenizerBase, |
|
chunk_size: int, |
|
stride: int, |
|
minimal_chunk_length: int, |
|
maximal_text_length: Optional[int], |
|
) -> tuple[Tensor, Tensor]: |
|
"""Transforms (the entire) text to model input of BERT model.""" |
|
if maximal_text_length: |
|
tokens = tokenize_text_with_truncation(text, tokenizer, maximal_text_length) |
|
else: |
|
tokens = tokenize_whole_text(text, tokenizer) |
|
input_id_chunks, mask_chunks = split_tokens_into_smaller_chunks(tokens, chunk_size, stride, minimal_chunk_length) |
|
add_fr_special_tokens_at_beginning_and_end(input_id_chunks, mask_chunks) |
|
add_padding_tokens(input_id_chunks, mask_chunks , chunk_size) |
|
input_ids, attention_mask = stack_tokens_from_all_chunks(input_id_chunks, mask_chunks) |
|
return input_ids, attention_mask |
|
|
|
def add_fr_special_tokens_at_beginning_and_end(input_id_chunks: list[Tensor], mask_chunks: list[Tensor]) -> None: |
|
""" |
|
Adds special CLS token (token id = 101) at the beginning. |
|
Adds SEP token (token id = 102) at the end of each chunk. |
|
Adds corresponding attention masks equal to 1 (attention mask is boolean). |
|
""" |
|
for i in range(len(input_id_chunks)): |
|
|
|
input_id_chunks[i] = torch.cat([Tensor([5]), input_id_chunks[i], Tensor([6])]) |
|
|
|
mask_chunks[i] = torch.cat([Tensor([1]), mask_chunks[i], Tensor([1])]) |
|
|
|
def transcribe_speech(audio_path, wav2vec2_processor, wav2vec2_model): |
|
logging.info(f"Starting transcription of {audio_path}") |
|
|
|
try: |
|
|
|
waveform, sample_rate = torchaudio.load(audio_path) |
|
waveform = waveform.squeeze().numpy() |
|
logging.info(f"Audio loaded with torchaudio. Shape: {waveform.shape}, Sample rate: {sample_rate}") |
|
except Exception as e: |
|
logging.warning(f"torchaudio failed to load the audio. Trying with soundfile. Error: {str(e)}") |
|
try: |
|
|
|
waveform, sample_rate = sf.read(audio_path) |
|
waveform = torch.from_numpy(waveform).float() |
|
logging.info(f"Audio loaded with soundfile. Shape: {waveform.shape}, Sample rate: {sample_rate}") |
|
except Exception as e: |
|
logging.error(f"Both torchaudio and soundfile failed to load the audio. Error: {str(e)}") |
|
raise ValueError("Unable to load the audio file.") |
|
|
|
|
|
if waveform.ndim > 1: |
|
waveform = np.mean(waveform, axis=0) |
|
logging.info(f"Waveform reduced to 1D. New shape: {waveform.shape}") |
|
|
|
|
|
if sample_rate != wav2vec2_processor.feature_extractor.sampling_rate: |
|
resampler = torchaudio.transforms.Resample(sample_rate, wav2vec2_processor.feature_extractor.sampling_rate) |
|
waveform = resampler(torch.from_numpy(waveform).float()) |
|
logging.info(f"Audio resampled to {wav2vec2_processor.feature_extractor.sampling_rate}Hz") |
|
|
|
|
|
try: |
|
input_values = wav2vec2_processor(waveform, sampling_rate=wav2vec2_processor.feature_extractor.sampling_rate, return_tensors="pt").input_values |
|
logging.info(f"Input values shape after processing: {input_values.shape}") |
|
except Exception as e: |
|
logging.error(f"Error during audio processing: {str(e)}") |
|
raise |
|
|
|
|
|
input_values = input_values.squeeze() |
|
if input_values.dim() == 0: |
|
input_values = input_values.unsqueeze(0).unsqueeze(0) |
|
elif input_values.dim() == 1: |
|
input_values = input_values.unsqueeze(0) |
|
logging.info(f"Final input values shape: {input_values.shape}") |
|
|
|
try: |
|
with torch.inference_mode(): |
|
logits = wav2vec2_model(input_values.to(device)).logits |
|
logging.info(f"Model inference successful. Logits shape: {logits.shape}") |
|
except Exception as e: |
|
logging.error(f"Error during model inference: {str(e)}") |
|
raise |
|
|
|
predicted_ids = torch.argmax(logits, dim=-1) |
|
predicted_sentence = wav2vec2_processor.batch_decode(predicted_ids) |
|
logging.info(f"Transcription complete. Result: {predicted_sentence[0]}") |
|
return predicted_sentence[0] |