stefan-it's picture
readme: minor tweaks
70454bb verified
|
raw
history blame
2.95 kB
metadata
license: cc-by-sa-3.0
language:
  - de
library_name: flair

Flair xLSTM Embeddings (German Wikipedia, Forward)

Research & development of Flair xLSTM Embeddings (Forward) trained on German Wikipedia dump.

The Flair team is currently working on the integration of xLSTM (both LM training and fine-tuning models for downstream tasks). Check out the xlstm branch in the Flair repository - many thanks to Patrick Haller for the work on it.

Training

The current model was trained with commit 18ef331 from the xlstm branch. The xlstm library needs to be installed manually - also check that pip3 install Ninja is installed.

The German Wikipedia dump from this repository is used, including sharding the corpus into a Flair-compatible format:

  • valid.txt -> Validation corpus
  • test.txt -> Test corpus
  • train -> Folder with text files as training corpus

The model was trained with the following parameters for 2 epochs:

import flair
import torch

from flair.data import SubTokenDictionary
from flair.models import xLSTMLanguageModel
from flair.trainers.language_model_trainer import LanguageModelTrainer, TextCorpus

from transformers import AutoTokenizer
     
flair.device = torch.device('cuda:0')
      
is_forward_lm = True
                      
dictionary = SubTokenDictionary.load("gwlms/bert-base-dewiki-v1")

corpus = TextCorpus("/home/ubuntu/splitted_corpus",
                    dictionary,
                    is_forward_lm,
                    character_level=False,
                    random_case_flip=True,
                    )

xlstm_ablation_1 = """
mlstm_block:
  mlstm:
    conv1d_kernel_size: 2
    qkv_proj_blocksize: 2
    num_heads: 2
slstm_block:
  slstm:
    backend: cuda
    num_heads: 2
    conv1d_kernel_size: 2
    bias_init: powerlaw_blockdependent
  feedforward:
    proj_factor: 1.3
    act_fn: gelu
context_length: 256
num_blocks: 7
embedding_dim: 128
slstm_at: [1]
"""

language_model = xLSTMLanguageModel(dictionary, xlstm_cfg=xlstm_ablation_1,
                                    is_forward_lm=True)
print(language_model)

trainer = LanguageModelTrainer(language_model, corpus)

trainer.train("xflair-german-wikipedia-xlstm_ablation_1-bs64-lr5-e2",
              sequence_length=256,
              mini_batch_size=64,
              learning_rate=5,
              patience=50,
              max_epochs=2,
              checkpoint=False,
              num_workers=4,
              )

Caveats

Notice: this model integration is heavily under development. And in the process of finding good hyper-parameters. Also downstream experiments are coming very soon.