Transformers
PyTorch
Russian
bert
Inference Endpoints

RuTextSegModel

Model for Russian text segmentation, trained on wiki and news corpora

Model description

This model is a top-level part of HierBERT model and solves the problem of text segmentation as a token classification at the sentence level. The ai-forever/sbert_large_nlu_ru with max pooling is used as a low-level model (sentence embedding generator). It's recommended to use this model only with specified low-level model with defined pooling for embeddings.

Intended uses & limitations

How to use

Here is how to use this model in PyTorch:

import torch
import torch.nn as nn
from transformers import BertForTokenClassification, AutoModel, AutoTokenizer
from razdel import sentenize

class BertForTextSegmentationEmbeddings(nn.Module):
    def __init__(self, config, embeddings_dim=768):
        super(BertForTextSegmentationEmbeddings, self).__init__()

        self.config = config
        self.position_embeddings = torch.nn.Embedding(config.max_position_embeddings, config.hidden_size)

        self.LayerNorm = torch.nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = torch.nn.Dropout(config.hidden_dropout_prob)

    def forward(self, inputs_embeds, position_ids=None, input_ids=None, token_type_ids=None, past_key_values_length=None):
        input_shape = inputs_embeds.size()[:-1]
        seq_length = input_shape[1]
        device = inputs_embeds.device

        assert seq_length <= self.config.max_position_embeddings, \
            f"Too long sequence is passed {seq_length}. Maximum allowed sequence length is {self.config.max_position_embeddings}"

        if position_ids is None:
            position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
            position_ids = position_ids.unsqueeze(0).expand(input_shape)

        position_embeddings = self.position_embeddings(position_ids)

        embeddings = inputs_embeds + position_embeddings
        embeddings = self.LayerNorm(embeddings)
        embeddings = self.dropout(embeddings)
        return embeddings

class BertForTextSegmentation(BertForTokenClassification):
    def __init__(self, config):
        super(BertForTextSegmentation, self).__init__(config)

        self.bert.base_model.embeddings = BertForTextSegmentationEmbeddings(config)

        self.init_weights()
        
def max_pooling(model_output, attention_mask):
    token_embeddings = model_output[0] #First element of model_output contains all token embeddings
    input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
    token_embeddings[input_mask_expanded == 0] = -1e9  # Set padding tokens to large negative value
    return torch.max(token_embeddings, 1)[0]

def create_embeddings(sentences, tokenizer, model):
    # Tokenize sentences
    encoded_input = tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
    # Compute token embeddings
    with torch.no_grad():
        model_output = model(**encoded_input.to(device))
    # Perform pooling. In this case, max pooling.
    sentence_embeddings = max_pooling(model_output, encoded_input['attention_mask'])
    
    return sentence_embeddings

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

emb_tokenizer = AutoTokenizer.from_pretrained("ai-forever/sbert_large_nlu_ru")
emb_model = AutoModel.from_pretrained("ai-forever/sbert_large_nlu_ru")
model = BertForTextSegmentation.from_pretrained("mlenjoyneer/RuTextSegModel")

emb_model.to(device)
model.to(device)

text = """В Норильске за годы работы телефона доверия консультанты приняли в общей сложности порядка 75 тысяч обращений, сообщает «Заполярная Правда». Служба психологической помощи появилась в 2000 году. Руководитель службы профилактики наркомании Елена Слатвицкая рассказала журналистам, что в Заполярье настал период, когда ухудшается психо– эмоциональное состояние населения. Это происходит на входе в полярную ночь и на выходе из нее. Осень является кризисным моментом. Сейчас на телефоне доверия работают 15 специалистов. Каждый — под своим псевдонимом. Тему беседы определяет звонящий. Это могут быть наркомания и алкоголизм, ВИЧ–инфекция и прочие заболевания и зависимости, кризисы семейных отношений и многое другое. Сотрудники службы отметили, что больше стало звонков по поводу суицидальных намерений. Наибольшее количество обращений по суицидам пришлось на октябрь — ноябрь. Много звонков как от мужчин, так и от женщин с вопросами об одиночестве. Лидерами по количеству обращений пока остаются женщины. В сентябре в Норильске обнаружили тело девятиклассницы. По версии следствия, девочка сбросилась с крыши. В январе подросток нанес себе порезы стеклом от разбитой бутылки, пытаясь покончить с собой. Мальчик поссорился с матерью и в ходе ссоры нанес себе несколько порезов. Проводится расследование."""

input_embeds = create_embeddings([s.text for s in sentenize(text)], emb_tokenizer, emb_model).unsqueeze(0)
outputs = model(inputs_embeds=input_embeds)

logits = outputs.logits.cpu()
preds = logits.argmax(axis=2).tolist()[0]  # [0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0]

# true_labels = [0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0]

Training data

Model trained on mlenjoyneer/RuTextSegNews and mlenjoyneer/RuTextSegWiki datasets.

Evaluation results

Train Dataset Test Dataset F1_total F1_1 Pk Pk_5 WinDiff WinDiff_5
News+Wiki News 0.88 0.80 0.16 0.11 0.20 0.35
News+Wiki Wiki 0.89 0.80 0.18 0.16 0.09 0.19

Citation info

In progress
Downloads last month
33
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no pipeline_tag.

Datasets used to train mlenjoyneer/RuTextSegModel