File size: 3,754 Bytes
609e6b7
 
 
 
 
 
87b1360
609e6b7
5980dab
 
87b1360
609e6b7
 
 
d25058a
609e6b7
87b1360
 
 
 
48b8ed3
87b1360
 
609e6b7
87b1360
 
83fd560
87b1360
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48b8ed3
 
87b1360
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37162e4
 
48b8ed3
 
 
 
 
 
7acb992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
609e6b7
 
87b1360
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
---
license: cc-by-sa-3.0
language:
- de
---

# xLSTM Model trained on German Wikipedia

![xLSTM](brat-logo.png)

Research & development of an xLSTM model trained on German Wikipedia.

The Flair team is currently working on the integration of xLSTM (both LM training and fine-tuning models for downstream tasks).

For pretraining this xLSTM model, we this [fork](https://github.com/HallerPatrick/helibrunna) (from [Patrick Haller](https://huggingface.co/PatrickHaller)) of the awesome [Helibrunna](https://github.com/AI-Guru/helibrunna) library from [Tristan](https://huggingface.co/TristanBehrens).

Initially, we integrated xLSTM model training into Flair - for more information about this, please refer to the archived [flair-old](https://huggingface.co/stefan-it/xlstm-german-wikipedia/blob/flair-old/README.md) branch of this repository.

# Changelog

- 29.08.2024: Uploaded re-trained model for 1 epoch over complete German Wikipedia corpus. Training was done with gradient clipping (0.25).
- 28.08.2024: Model training is now done with [Helibrunna](https://github.com/AI-Guru/helibrunna) fork - find it [here](https://github.com/HallerPatrick/helibrunna).
- 10.06.2024: Initial version. xLSTM was trained with Flair library, see this [old](https://huggingface.co/stefan-it/xlstm-german-wikipedia/blob/flair-old/README.md) branch.

# Training

The current model was trained with commit `a1b3772` from the [`main` branch](https://github.com/HallerPatrick/helibrunna) of the forked Helibrunna repo.

The `xlstm` [library](https://github.com/NX-AI/xlstm) needs to be installed manually - also check that `pip3 install Ninja` is installed.

The German Wikipedia dump from [this repository](https://huggingface.co/datasets/gwlms/dewiki-20230701-flair-corpus) is used.

The following training configuration is used:

```yaml
description: "Train a wikipedia xLSTM"

training:
  model_name: "german_wikipedia"
  batch_size: 10
  lr: 6e-4
  lr_warmup_steps: 4584
  lr_decay_until_steps: "auto"
  lr_decay_factor: 0.001
  weight_decay: 0.1
  amp_precision: bfloat16
  weight_precision: float32
  enable_mixed_precision: true
  num_epochs: 1
  output_dir: "./output"
  save_every_step: 2000
  log_every_step: 10
  generate_every_step: 5000
  wandb_project: "xlstm"
  max_grad_norm: 0.25
  # wandb_project: "lovecraftxlstm"

model:
  num_blocks: 24
  embedding_dim: 768
  mlstm_block:
    mlstm:
      num_heads: 4
  slstm_block: {}
  slstm_at: []
  context_length: 512

dataset:
  output_path: "./output/german-wikipedia-dataset"
  hugging_face_id: ["stefan-it/dewiki-20230701"]
  split: "train" # Also subsetting is possible: "train[:100000]"
  shuffle: False
  seed: 42

tokenizer:
  type: "pretrained"
  pretrained_class: "LlamaTokenizer"
  pretrained_id: "meta-llama/Llama-2-7b-hf"
```

The training loss curve can be seen here:

![Training Loss](training-loss.png)

The uploaded model checkpoint is from 458,431 steps (1 epoch over corpus). Training took 1d 3h 17m 58s on a single RTX 4090.

# Usage

It is possible to use the model to generate some text:

```python
from transformers import AutoModelForCausalLM, AutoTokenizer

model_name_or_path = "stefan-it/xlstm-german-wikipedia"

model = AutoModelForCausalLM.from_pretrained(model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

input_ids = tokenizer.encode("Heute ist schönes Wetter in", return_tensors="pt")
output = model.generate(input_ids, max_length=100, temperature=0.7, do_sample=True)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)

print(generated_text)
```

# Caveats

Notice: this model integration is heavily under development. And in the process of finding good hyper-parameters.
Also downstream experiments are coming very soon.