from typing import List import pandas as pd import torch from tqdm import trange from transformers import AutoModel, AutoTokenizer from src.feature_extractors.base_extractor import BaseExtractor class BertPretrainFeatureExtractor(BaseExtractor): """Extract [CLS] embedding feature from any untrained bert-like models""" device = 'cuda' if torch.cuda.is_available() else 'cpu' def __init__(self, model_name: str, max_length: int = 512, batch_size=64): self.model_name = model_name self.max_length = max_length self.batch_size = batch_size self.model = AutoModel.from_pretrained(self.model_name) self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) @torch.no_grad() def generate_features(self, data: pd.Series) -> pd.DataFrame: """ Generates features in batch-mode, obtained from untrained bert model. :param data: Series with full_text column :return: Dataframe, that have index - id's from data, and columns - bert features """ torch.cuda.empty_cache() texts = data.tolist() self.model = self.model.to(self.device) classification_outputs = [] for ii in trange( 0, len(data), self.batch_size, total=len(data) // self.batch_size + 1, desc="Generating bert features..." ): text_batch = texts[ii: ii + self.batch_size] batch_encoded = self.tokenizer( text_batch, padding=True, truncation=True, max_length=self.max_length, return_tensors='pt' ).to(self.device) output = self.model(**batch_encoded) cls_output = output['last_hidden_state'][:, 0].cpu() classification_outputs.append(cls_output) self.model = self.model.to("cpu") classification_outputs_tensor = torch.cat(classification_outputs, dim=0) torch.cuda.empty_cache() column_names = [f"{self.model_name}_feat_{ii}" for ii in range(len(classification_outputs_tensor[0]))] return pd.DataFrame( data=classification_outputs_tensor.tolist(), index=data.index, columns=column_names ) class ManyBertPretrainFeatureExtractor(BaseExtractor): def __init__(self, model_names: List[str], max_length: int = 512, batch_size=64): super(ManyBertPretrainFeatureExtractor, self).__init__() self.model_names = model_names self.max_length = max_length self.batch_size = batch_size def generate_features(self, X: pd.Series) -> pd.DataFrame: extractors = [ BertPretrainFeatureExtractor(model_name, self.max_length, self.batch_size) for model_name in self.model_names ] dataframes = [ extractor.generate_features(X) for extractor in extractors ] return pd.concat(dataframes, axis='columns')