Dekode commited on
Commit
9125950
1 Parent(s): bae8925

Upload 10 files

Browse files
Files changed (10) hide show
  1. app.py +166 -0
  2. config.py +33 -0
  3. dataset.py +90 -0
  4. model.py +267 -0
  5. predict.py +16 -0
  6. requirements.txt +16 -0
  7. tokenizer_en.json +0 -0
  8. tokenizer_it.json +0 -0
  9. train.py +283 -0
  10. translate.py +79 -0
app.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModel, utils
3
+ from bertviz import model_view
4
+ import streamlit.components.v1 as components
5
+ from train import get_or_build_tokenizer, greedy_decode
6
+ from config import get_config, latest_weights_file_path
7
+ from model import build_transformer
8
+ import torch
9
+ from bertviz import model_view
10
+ import torch
11
+ import altair as alt
12
+ import pandas as pd
13
+ import warnings
14
+ warnings.filterwarnings("ignore")
15
+
16
+ utils.logging.set_verbosity_error() # Suppress standard warnings
17
+
18
+ st.set_page_config(page_title='Attention Visualizer', layout='wide')
19
+
20
+ def mtx2df(m, max_row, max_col, row_tokens, col_tokens):
21
+ return pd.DataFrame(
22
+ [
23
+ (
24
+ r,
25
+ c,
26
+ float(m[r, c]),
27
+ "%.2d - %s" % (r, row_tokens[r] if len(row_tokens) > r else "<blank>"),
28
+ "%.2d - %s" % (c, col_tokens[c] if len(col_tokens) > c else "<blank>"),
29
+ )
30
+ for r in range(m.shape[0])
31
+ for c in range(m.shape[1])
32
+ if r < max_row and c < max_col
33
+ ],
34
+ columns=["row", "column", "value", "row_token", "col_token"],
35
+ )
36
+
37
+
38
+ def get_attn_map(attn_type: str, layer: int, head: int, model):
39
+ if attn_type == "encoder":
40
+ attn = model.encoder.layers[layer].self_attention_block.attention_scores
41
+ elif attn_type == "decoder":
42
+ attn = model.decoder.layers[layer].self_attention_block.attention_scores
43
+ elif attn_type == "encoder-decoder":
44
+ attn = model.decoder.layers[layer].cross_attention_block.attention_scores
45
+ return attn[0, head].data
46
+
47
+ def attn_map(attn_type, layer, head, row_tokens, col_tokens, max_sentence_len, model):
48
+ df = mtx2df(
49
+ get_attn_map(attn_type, layer, head, model),
50
+ max_sentence_len,
51
+ max_sentence_len,
52
+ row_tokens,
53
+ col_tokens,
54
+ )
55
+ return (
56
+ alt.Chart(data=df)
57
+ .mark_rect()
58
+ .encode(
59
+ x=alt.X("col_token", axis=alt.Axis(title="")),
60
+ y=alt.Y("row_token", axis=alt.Axis(title="")),
61
+ color=alt.Color("value", scale=alt.Scale(scheme="blues")),
62
+ tooltip=["row", "column", "value", "row_token", "col_token"],
63
+ )
64
+ #.title(f"Layer {layer} Head {head}")
65
+ .properties(height=200, width=200, title=f"Layer {layer} Head {head}")
66
+ .interactive()
67
+ )
68
+
69
+ def get_all_attention_maps(attn_type: str, layers: list[int], heads: list[int], row_tokens: list, col_tokens, max_sentence_len: int, model):
70
+ charts = []
71
+ for layer in layers:
72
+ rowCharts = []
73
+ for head in heads:
74
+ rowCharts.append(attn_map(attn_type, layer, head, row_tokens, col_tokens, max_sentence_len, model))
75
+ charts.append(alt.hconcat(*rowCharts))
76
+ return alt.vconcat(*charts)
77
+
78
+ def initiate_model(config, device):
79
+ tokenizer_src = get_or_build_tokenizer(config, None, config["lang_src"])
80
+ tokenizer_tgt = get_or_build_tokenizer(config, None, config["lang_tgt"])
81
+
82
+ model = build_transformer(tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size(), config["seq_len"], config['seq_len'], d_model=config['d_model']).to(device)
83
+
84
+ model_filename = latest_weights_file_path(config)
85
+ state = torch.load(model_filename)
86
+ model.load_state_dict(state['model_state_dict'])
87
+ return model, tokenizer_src, tokenizer_tgt
88
+
89
+ def process_input(input_text, tokenizer_src, tokenizer_tgt, model, config, device):
90
+ src = tokenizer_src.encode(input_text)
91
+ src = torch.cat([
92
+ torch.tensor([tokenizer_src.token_to_id('[SOS]')], dtype=torch.int64),
93
+ torch.tensor(src.ids, dtype=torch.int64),
94
+ torch.tensor([tokenizer_src.token_to_id('[EOS]')], dtype=torch.int64),
95
+ torch.tensor([tokenizer_src.token_to_id('[PAD]')] * (config['seq_len'] - len(src.ids) - 2), dtype=torch.int64)
96
+ ], dim=0).to(device)
97
+ source_mask = (src != tokenizer_src.token_to_id('[PAD]')).unsqueeze(0).unsqueeze(0).int().to(device)
98
+
99
+ encoder_input_tokens = [tokenizer_src.id_to_token(i) for i in src.cpu().numpy()]
100
+ encoder_input_tokens = [i for i in encoder_input_tokens if i != '[PAD]']
101
+
102
+ model_out = greedy_decode(model, src, source_mask, tokenizer_src, tokenizer_tgt, config['seq_len'], device)
103
+
104
+ decoder_input_tokens = [tokenizer_tgt.id_to_token(i) for i in model_out.cpu().numpy()]
105
+
106
+ output = tokenizer_tgt.decode(model_out.detach().cpu().numpy())
107
+
108
+ return encoder_input_tokens, decoder_input_tokens, output
109
+
110
+
111
+ # def get_html_data(model_name, input_text):
112
+ # model_name ="microsoft/xtremedistil-l12-h384-uncased"
113
+ # model = AutoModel.from_pretrained(model_name, output_attentions=True, cache_dir='__pycache__') # Configure model to return attention values
114
+ # tokenizer = AutoTokenizer.from_pretrained(model_name)
115
+ # inputs = tokenizer.encode(input_text, return_tensors='pt') # Tokenize input text
116
+ # outputs = model(inputs) # Run model
117
+ # attention = outputs[-1] # Retrieve attention from model outputs
118
+ # tokens = tokenizer.convert_ids_to_tokens(inputs[0]) # Convert input ids to token strings
119
+ # model_html = model_view(attention, tokens, html_action="return") # Display model view
120
+ # with open("static/model_view.html", 'w') as file:
121
+ # file.write(model_html.data)
122
+
123
+ def main():
124
+ st.title('Transformer Visualizer')
125
+ # st.info('Enter a sentence to visualize the attention of the model')
126
+ st.write('This app visualizes the attention of a transformer model on a given sentence.')
127
+ # add a side bar with model options and a prompt
128
+ config = get_config()
129
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
130
+ model, tokenizer_src, tokenizer_tgt = initiate_model(config, device)
131
+ with st.sidebar:
132
+ input_text = st.text_input('Enter a sentence')
133
+ # put two buttons side by side in the sidebar
134
+ # translate_button = st.button('Translate', key='translate_button')
135
+ viz_button = st.button('Visualize Attention', key='viz_button')
136
+ attn_type = st.selectbox('Select attention type', ['encoder', 'decoder', 'encoder-decoder'])
137
+ layers = st.multiselect('Select layers', list(range(3)))
138
+ heads = st.multiselect('Select heads', list(range(7)))
139
+ # allow the user to select the all the layers and heads at once to visualize
140
+ if st.checkbox('Select all layers'):
141
+ layers = list(range(3))
142
+ if st.checkbox('Select all heads'):
143
+ heads = list(range(7))
144
+
145
+ if viz_button and input_text != '':
146
+ encoder_input_tokens, decoder_input_tokens, output = process_input(input_text, tokenizer_src, tokenizer_tgt, model, config, device)
147
+ max_sentence_len = len(encoder_input_tokens)
148
+ row_tokens = encoder_input_tokens
149
+ col_tokens = decoder_input_tokens
150
+ st.write('Input:', ' '.join(encoder_input_tokens))
151
+ st.write('Output:', ' '.join(decoder_input_tokens))
152
+ st.write('Translated:', output)
153
+ st.write('Attention Visualization')
154
+ st.write(get_all_attention_maps(attn_type, layers, heads, row_tokens, col_tokens, max_sentence_len, model))
155
+ else:
156
+ st.write('Enter a sentence to visualize the attention of the model')
157
+
158
+ # add a footer with the github repo link and dataset link
159
+ st.markdown('---')
160
+ st.write('Made by [Pratik Dwivedi](https://github.com/Dekode1859)')
161
+ st.write('Check out the Scratch Implementation and Visualizer Code on [GitHub](https://github.com/Dekode1859/transformer-visualizer)')
162
+ st.write('Dataset: [Opus-books: english-Italian](https://huggingface.co/datasets/Helsinki-NLP/opus_books)')
163
+ # st.write('This app is a Streamlit implementation of the [BERTViz](
164
+
165
+ if __name__ == '__main__':
166
+ main()
config.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ def get_config():
4
+ return {
5
+ "batch_size": 8,
6
+ "num_epochs": 20,
7
+ "lr": 10**-4,
8
+ "seq_len": 350,
9
+ "d_model": 512,
10
+ "datasource": 'opus_books',
11
+ "lang_src": "en",
12
+ "lang_tgt": "it",
13
+ "model_folder": "weights",
14
+ "model_basename": "tmodel_",
15
+ "preload": "latest",
16
+ "tokenizer_file": "tokenizer_{0}.json",
17
+ "experiment_name": "runs/tmodel"
18
+ }
19
+
20
+ def get_weights_file_path(config, epoch: str):
21
+ model_folder = f"{config['datasource']}_{config['model_folder']}"
22
+ model_filename = f"{config['model_basename']}{epoch}.pt"
23
+ return str(Path('.') / model_folder / model_filename)
24
+
25
+ # Find the latest weights file in the weights folder
26
+ def latest_weights_file_path(config):
27
+ model_folder = f"{config['datasource']}_{config['model_folder']}"
28
+ model_filename = f"{config['model_basename']}*"
29
+ weights_files = list(Path(model_folder).glob(model_filename))
30
+ if len(weights_files) == 0:
31
+ return None
32
+ weights_files.sort()
33
+ return str(weights_files[-1])
dataset.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.data import Dataset
4
+
5
+ class BilingualDataset(Dataset):
6
+
7
+ def __init__(self, ds, tokenizer_src, tokenizer_tgt, src_lang, tgt_lang, seq_len):
8
+ super().__init__()
9
+ self.seq_len = seq_len
10
+
11
+ self.ds = ds
12
+ self.tokenizer_src = tokenizer_src
13
+ self.tokenizer_tgt = tokenizer_tgt
14
+ self.src_lang = src_lang
15
+ self.tgt_lang = tgt_lang
16
+
17
+ self.sos_token = torch.tensor([tokenizer_tgt.token_to_id("[SOS]")], dtype=torch.int64)
18
+ self.eos_token = torch.tensor([tokenizer_tgt.token_to_id("[EOS]")], dtype=torch.int64)
19
+ self.pad_token = torch.tensor([tokenizer_tgt.token_to_id("[PAD]")], dtype=torch.int64)
20
+
21
+ def __len__(self):
22
+ return len(self.ds)
23
+
24
+ def __getitem__(self, idx):
25
+ src_target_pair = self.ds[idx]
26
+ src_text = src_target_pair['translation'][self.src_lang]
27
+ tgt_text = src_target_pair['translation'][self.tgt_lang]
28
+
29
+ # Transform the text into tokens
30
+ enc_input_tokens = self.tokenizer_src.encode(src_text).ids
31
+ dec_input_tokens = self.tokenizer_tgt.encode(tgt_text).ids
32
+
33
+ # Add sos, eos and padding to each sentence
34
+ enc_num_padding_tokens = self.seq_len - len(enc_input_tokens) - 2 # We will add <s> and </s>
35
+ # We will only add <s>, and </s> only on the label
36
+ dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) - 1
37
+
38
+ # Make sure the number of padding tokens is not negative. If it is, the sentence is too long
39
+ if enc_num_padding_tokens < 0 or dec_num_padding_tokens < 0:
40
+ raise ValueError("Sentence is too long")
41
+
42
+ # Add <s> and </s> token
43
+ encoder_input = torch.cat(
44
+ [
45
+ self.sos_token,
46
+ torch.tensor(enc_input_tokens, dtype=torch.int64),
47
+ self.eos_token,
48
+ torch.tensor([self.pad_token] * enc_num_padding_tokens, dtype=torch.int64),
49
+ ],
50
+ dim=0,
51
+ )
52
+
53
+ # Add only <s> token
54
+ decoder_input = torch.cat(
55
+ [
56
+ self.sos_token,
57
+ torch.tensor(dec_input_tokens, dtype=torch.int64),
58
+ torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
59
+ ],
60
+ dim=0,
61
+ )
62
+
63
+ # Add only </s> token
64
+ label = torch.cat(
65
+ [
66
+ torch.tensor(dec_input_tokens, dtype=torch.int64),
67
+ self.eos_token,
68
+ torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
69
+ ],
70
+ dim=0,
71
+ )
72
+
73
+ # Double check the size of the tensors to make sure they are all seq_len long
74
+ assert encoder_input.size(0) == self.seq_len
75
+ assert decoder_input.size(0) == self.seq_len
76
+ assert label.size(0) == self.seq_len
77
+
78
+ return {
79
+ "encoder_input": encoder_input, # (seq_len)
80
+ "decoder_input": decoder_input, # (seq_len)
81
+ "encoder_mask": (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(), # (1, 1, seq_len)
82
+ "decoder_mask": (decoder_input != self.pad_token).unsqueeze(0).int() & causal_mask(decoder_input.size(0)), # (1, seq_len) & (1, seq_len, seq_len),
83
+ "label": label, # (seq_len)
84
+ "src_text": src_text,
85
+ "tgt_text": tgt_text,
86
+ }
87
+
88
+ def causal_mask(size):
89
+ mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
90
+ return mask == 0
model.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+
5
+ class LayerNormalization(nn.Module):
6
+
7
+ def __init__(self, features: int, eps:float=10**-6) -> None:
8
+ super().__init__()
9
+ self.eps = eps
10
+ self.alpha = nn.Parameter(torch.ones(features)) # alpha is a learnable parameter
11
+ self.bias = nn.Parameter(torch.zeros(features)) # bias is a learnable parameter
12
+
13
+ def forward(self, x):
14
+ # x: (batch, seq_len, hidden_size)
15
+ # Keep the dimension for broadcasting
16
+ mean = x.mean(dim = -1, keepdim = True) # (batch, seq_len, 1)
17
+ # Keep the dimension for broadcasting
18
+ std = x.std(dim = -1, keepdim = True) # (batch, seq_len, 1)
19
+ # eps is to prevent dividing by zero or when std is very small
20
+ return self.alpha * (x - mean) / (std + self.eps) + self.bias
21
+
22
+ class FeedForwardBlock(nn.Module):
23
+
24
+ def __init__(self, d_model: int, d_ff: int, dropout: float) -> None:
25
+ super().__init__()
26
+ self.linear_1 = nn.Linear(d_model, d_ff) # w1 and b1
27
+ self.dropout = nn.Dropout(dropout)
28
+ self.linear_2 = nn.Linear(d_ff, d_model) # w2 and b2
29
+
30
+ def forward(self, x):
31
+ # (batch, seq_len, d_model) --> (batch, seq_len, d_ff) --> (batch, seq_len, d_model)
32
+ return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))
33
+
34
+ class InputEmbeddings(nn.Module):
35
+
36
+ def __init__(self, d_model: int, vocab_size: int) -> None:
37
+ super().__init__()
38
+ self.d_model = d_model
39
+ self.vocab_size = vocab_size
40
+ self.embedding = nn.Embedding(vocab_size, d_model)
41
+
42
+ def forward(self, x):
43
+ # (batch, seq_len) --> (batch, seq_len, d_model)
44
+ # Multiply by sqrt(d_model) to scale the embeddings according to the paper
45
+ return self.embedding(x) * math.sqrt(self.d_model)
46
+
47
+ class PositionalEncoding(nn.Module):
48
+
49
+ def __init__(self, d_model: int, seq_len: int, dropout: float) -> None:
50
+ super().__init__()
51
+ self.d_model = d_model
52
+ self.seq_len = seq_len
53
+ self.dropout = nn.Dropout(dropout)
54
+ # Create a matrix of shape (seq_len, d_model)
55
+ pe = torch.zeros(seq_len, d_model)
56
+ # Create a vector of shape (seq_len)
57
+ position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1) # (seq_len, 1)
58
+ # Create a vector of shape (d_model)
59
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) # (d_model / 2)
60
+ # Apply sine to even indices
61
+ pe[:, 0::2] = torch.sin(position * div_term) # sin(position * (10000 ** (2i / d_model))
62
+ # Apply cosine to odd indices
63
+ pe[:, 1::2] = torch.cos(position * div_term) # cos(position * (10000 ** (2i / d_model))
64
+ # Add a batch dimension to the positional encoding
65
+ pe = pe.unsqueeze(0) # (1, seq_len, d_model)
66
+ # Register the positional encoding as a buffer
67
+ self.register_buffer('pe', pe)
68
+
69
+ def forward(self, x):
70
+ x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False) # (batch, seq_len, d_model)
71
+ return self.dropout(x)
72
+
73
+ class ResidualConnection(nn.Module):
74
+
75
+ def __init__(self, features: int, dropout: float) -> None:
76
+ super().__init__()
77
+ self.dropout = nn.Dropout(dropout)
78
+ self.norm = LayerNormalization(features)
79
+
80
+ def forward(self, x, sublayer):
81
+ return x + self.dropout(sublayer(self.norm(x)))
82
+
83
+ class MultiHeadAttentionBlock(nn.Module):
84
+
85
+ def __init__(self, d_model: int, h: int, dropout: float) -> None:
86
+ super().__init__()
87
+ self.d_model = d_model # Embedding vector size
88
+ self.h = h # Number of heads
89
+ # Make sure d_model is divisible by h
90
+ assert d_model % h == 0, "d_model is not divisible by h"
91
+
92
+ self.d_k = d_model // h # Dimension of vector seen by each head
93
+ self.w_q = nn.Linear(d_model, d_model, bias=False) # Wq
94
+ self.w_k = nn.Linear(d_model, d_model, bias=False) # Wk
95
+ self.w_v = nn.Linear(d_model, d_model, bias=False) # Wv
96
+ self.w_o = nn.Linear(d_model, d_model, bias=False) # Wo
97
+ self.dropout = nn.Dropout(dropout)
98
+
99
+ @staticmethod
100
+ def attention(query, key, value, mask, dropout: nn.Dropout):
101
+ d_k = query.shape[-1]
102
+ # Just apply the formula from the paper
103
+ # (batch, h, seq_len, d_k) --> (batch, h, seq_len, seq_len)
104
+ attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)
105
+ if mask is not None:
106
+ # Write a very low value (indicating -inf) to the positions where mask == 0
107
+ attention_scores.masked_fill_(mask == 0, -1e9)
108
+ attention_scores = attention_scores.softmax(dim=-1) # (batch, h, seq_len, seq_len) # Apply softmax
109
+ if dropout is not None:
110
+ attention_scores = dropout(attention_scores)
111
+ # (batch, h, seq_len, seq_len) --> (batch, h, seq_len, d_k)
112
+ # return attention scores which can be used for visualization
113
+ return (attention_scores @ value), attention_scores
114
+
115
+ def forward(self, q, k, v, mask):
116
+ query = self.w_q(q) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
117
+ key = self.w_k(k) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
118
+ value = self.w_v(v) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
119
+
120
+ # (batch, seq_len, d_model) --> (batch, seq_len, h, d_k) --> (batch, h, seq_len, d_k)
121
+ query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1, 2)
122
+ key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1, 2)
123
+ value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1, 2)
124
+
125
+ # Calculate attention
126
+ x, self.attention_scores = MultiHeadAttentionBlock.attention(query, key, value, mask, self.dropout)
127
+
128
+ # Combine all the heads together
129
+ # (batch, h, seq_len, d_k) --> (batch, seq_len, h, d_k) --> (batch, seq_len, d_model)
130
+ x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k)
131
+
132
+ # Multiply by Wo
133
+ # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
134
+ return self.w_o(x)
135
+
136
+ class EncoderBlock(nn.Module):
137
+
138
+ def __init__(self, features: int, self_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float) -> None:
139
+ super().__init__()
140
+ self.self_attention_block = self_attention_block
141
+ self.feed_forward_block = feed_forward_block
142
+ self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(2)])
143
+
144
+ def forward(self, x, src_mask):
145
+ x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, src_mask))
146
+ x = self.residual_connections[1](x, self.feed_forward_block)
147
+ return x
148
+
149
+ class Encoder(nn.Module):
150
+
151
+ def __init__(self, features: int, layers: nn.ModuleList) -> None:
152
+ super().__init__()
153
+ self.layers = layers
154
+ self.norm = LayerNormalization(features)
155
+
156
+ def forward(self, x, mask):
157
+ for layer in self.layers:
158
+ x = layer(x, mask)
159
+ return self.norm(x)
160
+
161
+ class DecoderBlock(nn.Module):
162
+
163
+ def __init__(self, features: int, self_attention_block: MultiHeadAttentionBlock, cross_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float) -> None:
164
+ super().__init__()
165
+ self.self_attention_block = self_attention_block
166
+ self.cross_attention_block = cross_attention_block
167
+ self.feed_forward_block = feed_forward_block
168
+ self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(3)])
169
+
170
+ def forward(self, x, encoder_output, src_mask, tgt_mask):
171
+ x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, tgt_mask))
172
+ x = self.residual_connections[1](x, lambda x: self.cross_attention_block(x, encoder_output, encoder_output, src_mask))
173
+ x = self.residual_connections[2](x, self.feed_forward_block)
174
+ return x
175
+
176
+ class Decoder(nn.Module):
177
+
178
+ def __init__(self, features: int, layers: nn.ModuleList) -> None:
179
+ super().__init__()
180
+ self.layers = layers
181
+ self.norm = LayerNormalization(features)
182
+
183
+ def forward(self, x, encoder_output, src_mask, tgt_mask):
184
+ for layer in self.layers:
185
+ x = layer(x, encoder_output, src_mask, tgt_mask)
186
+ return self.norm(x)
187
+
188
+ class ProjectionLayer(nn.Module):
189
+
190
+ def __init__(self, d_model, vocab_size) -> None:
191
+ super().__init__()
192
+ self.proj = nn.Linear(d_model, vocab_size)
193
+
194
+ def forward(self, x) -> None:
195
+ # (batch, seq_len, d_model) --> (batch, seq_len, vocab_size)
196
+ return self.proj(x)
197
+
198
+ class Transformer(nn.Module):
199
+
200
+ def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: InputEmbeddings, tgt_embed: InputEmbeddings, src_pos: PositionalEncoding, tgt_pos: PositionalEncoding, projection_layer: ProjectionLayer) -> None:
201
+ super().__init__()
202
+ self.encoder = encoder
203
+ self.decoder = decoder
204
+ self.src_embed = src_embed
205
+ self.tgt_embed = tgt_embed
206
+ self.src_pos = src_pos
207
+ self.tgt_pos = tgt_pos
208
+ self.projection_layer = projection_layer
209
+
210
+ def encode(self, src, src_mask):
211
+ # (batch, seq_len, d_model)
212
+ src = self.src_embed(src)
213
+ src = self.src_pos(src)
214
+ return self.encoder(src, src_mask)
215
+
216
+ def decode(self, encoder_output: torch.Tensor, src_mask: torch.Tensor, tgt: torch.Tensor, tgt_mask: torch.Tensor):
217
+ # (batch, seq_len, d_model)
218
+ tgt = self.tgt_embed(tgt)
219
+ tgt = self.tgt_pos(tgt)
220
+ return self.decoder(tgt, encoder_output, src_mask, tgt_mask)
221
+
222
+ def project(self, x):
223
+ # (batch, seq_len, vocab_size)
224
+ return self.projection_layer(x)
225
+
226
+ def build_transformer(src_vocab_size: int, tgt_vocab_size: int, src_seq_len: int, tgt_seq_len: int, d_model: int=512, N: int=6, h: int=8, dropout: float=0.1, d_ff: int=2048) -> Transformer:
227
+ # Create the embedding layers
228
+ src_embed = InputEmbeddings(d_model, src_vocab_size)
229
+ tgt_embed = InputEmbeddings(d_model, tgt_vocab_size)
230
+
231
+ # Create the positional encoding layers
232
+ src_pos = PositionalEncoding(d_model, src_seq_len, dropout)
233
+ tgt_pos = PositionalEncoding(d_model, tgt_seq_len, dropout)
234
+
235
+ # Create the encoder blocks
236
+ encoder_blocks = []
237
+ for _ in range(N):
238
+ encoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
239
+ feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
240
+ encoder_block = EncoderBlock(d_model, encoder_self_attention_block, feed_forward_block, dropout)
241
+ encoder_blocks.append(encoder_block)
242
+
243
+ # Create the decoder blocks
244
+ decoder_blocks = []
245
+ for _ in range(N):
246
+ decoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
247
+ decoder_cross_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
248
+ feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
249
+ decoder_block = DecoderBlock(d_model, decoder_self_attention_block, decoder_cross_attention_block, feed_forward_block, dropout)
250
+ decoder_blocks.append(decoder_block)
251
+
252
+ # Create the encoder and decoder
253
+ encoder = Encoder(d_model, nn.ModuleList(encoder_blocks))
254
+ decoder = Decoder(d_model, nn.ModuleList(decoder_blocks))
255
+
256
+ # Create the projection layer
257
+ projection_layer = ProjectionLayer(d_model, tgt_vocab_size)
258
+
259
+ # Create the transformer
260
+ transformer = Transformer(encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection_layer)
261
+
262
+ # Initialize the parameters
263
+ for p in transformer.parameters():
264
+ if p.dim() > 1:
265
+ nn.init.xavier_uniform_(p)
266
+
267
+ return transformer
predict.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModel, utils
2
+ from bertviz import model_view
3
+ utils.logging.set_verbosity_error() # Suppress standard warnings
4
+
5
+ def get_predictions(input_text):
6
+ model_name = "microsoft/xtremedistil-l12-h384-uncased"
7
+ model = AutoModel.from_pretrained(model_name, output_attentions=True)
8
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
9
+ inputs = tokenizer.encode(input_text, return_tensors='pt')
10
+ outputs = model(inputs)
11
+ attention = outputs[-1]
12
+ tokens = tokenizer.convert_ids_to_tokens(inputs[0])
13
+ model_html = model_view(attention, tokens, html_action="return")
14
+ with open("static/model_view.html", 'w') as file:
15
+ file.write(model_html.data)
16
+
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Use python 3.9
2
+
3
+ torch
4
+ torchvision
5
+ torchaudio
6
+ torchtext
7
+ datasets
8
+ tokenizers
9
+ torchmetrics
10
+ tensorboard
11
+ altair
12
+ wandb
13
+ transformers
14
+ bertviz
15
+ IPython
16
+ streamlit
tokenizer_en.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer_it.json ADDED
The diff for this file is too large to render. See raw diff
 
train.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model import build_transformer
2
+ from dataset import BilingualDataset, causal_mask
3
+ from config import get_config, get_weights_file_path, latest_weights_file_path
4
+
5
+ # import torchtext.datasets as datasets
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.utils.data import Dataset, DataLoader, random_split
9
+ from torch.optim.lr_scheduler import LambdaLR
10
+
11
+ import warnings
12
+ from tqdm import tqdm
13
+ import os
14
+ from pathlib import Path
15
+
16
+ # Huggingface datasets and tokenizers
17
+ from datasets import load_dataset
18
+ from tokenizers import Tokenizer
19
+ from tokenizers.models import WordLevel
20
+ from tokenizers.trainers import WordLevelTrainer
21
+ from tokenizers.pre_tokenizers import Whitespace
22
+
23
+ # import torchmetrics
24
+ # from torch.utils.tensorboard import SummaryWriter
25
+
26
+ def greedy_decode(model, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device):
27
+ sos_idx = tokenizer_tgt.token_to_id('[SOS]')
28
+ eos_idx = tokenizer_tgt.token_to_id('[EOS]')
29
+
30
+ # Precompute the encoder output and reuse it for every step
31
+ encoder_output = model.encode(source, source_mask)
32
+ # Initialize the decoder input with the sos token
33
+ decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)
34
+ while True:
35
+ if decoder_input.size(1) == max_len:
36
+ break
37
+
38
+ # build mask for target
39
+ decoder_mask = causal_mask(decoder_input.size(1)).type_as(source_mask).to(device)
40
+
41
+ # calculate output
42
+ out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)
43
+
44
+ # get next token
45
+ prob = model.project(out[:, -1])
46
+ _, next_word = torch.max(prob, dim=1)
47
+ decoder_input = torch.cat(
48
+ [decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)], dim=1
49
+ )
50
+
51
+ if next_word == eos_idx:
52
+ break
53
+
54
+ return decoder_input.squeeze(0)
55
+
56
+
57
+ def run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len, device, print_msg, global_step, writer=None, num_examples=2):
58
+ model.eval()
59
+ count = 0
60
+
61
+ source_texts = []
62
+ expected = []
63
+ predicted = []
64
+
65
+ try:
66
+ # get the console window width
67
+ with os.popen('stty size', 'r') as console:
68
+ _, console_width = console.read().split()
69
+ console_width = int(console_width)
70
+ except:
71
+ # If we can't get the console width, use 80 as default
72
+ console_width = 80
73
+
74
+ with torch.no_grad():
75
+ for batch in validation_ds:
76
+ count += 1
77
+ encoder_input = batch["encoder_input"].to(device) # (b, seq_len)
78
+ encoder_mask = batch["encoder_mask"].to(device) # (b, 1, 1, seq_len)
79
+
80
+ # check that the batch size is 1
81
+ assert encoder_input.size(
82
+ 0) == 1, "Batch size must be 1 for validation"
83
+
84
+ model_out = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device)
85
+
86
+ source_text = batch["src_text"][0]
87
+ target_text = batch["tgt_text"][0]
88
+ model_out_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy())
89
+
90
+ source_texts.append(source_text)
91
+ expected.append(target_text)
92
+ predicted.append(model_out_text)
93
+
94
+ # Print the source, target and model output
95
+ print_msg('-'*console_width)
96
+ print_msg(f"{f'SOURCE: ':>12}{source_text}")
97
+ print_msg(f"{f'TARGET: ':>12}{target_text}")
98
+ print_msg(f"{f'PREDICTED: ':>12}{model_out_text}")
99
+
100
+ if count == num_examples:
101
+ print_msg('-'*console_width)
102
+ break
103
+
104
+ # if writer:
105
+ # # Evaluate the character error rate
106
+ # # Compute the char error rate
107
+ # metric = torchmetrics.CharErrorRate()
108
+ # cer = metric(predicted, expected)
109
+ # writer.add_scalar('validation cer', cer, global_step)
110
+ # writer.flush()
111
+
112
+ # # Compute the word error rate
113
+ # metric = torchmetrics.WordErrorRate()
114
+ # wer = metric(predicted, expected)
115
+ # writer.add_scalar('validation wer', wer, global_step)
116
+ # writer.flush()
117
+
118
+ # # Compute the BLEU metric
119
+ # metric = torchmetrics.BLEUScore()
120
+ # bleu = metric(predicted, expected)
121
+ # writer.add_scalar('validation BLEU', bleu, global_step)
122
+ # writer.flush()
123
+
124
+ def get_all_sentences(ds, lang):
125
+ for item in ds:
126
+ yield item['translation'][lang]
127
+
128
+ def get_or_build_tokenizer(config, ds, lang):
129
+ print(f"Checking for existing tokenizer for {lang}")
130
+ tokenizer_path = Path(config['tokenizer_file'].format(lang))
131
+ if not Path.exists(tokenizer_path):
132
+ print(f"Building tokenizer for {lang}")
133
+ # Most code taken from: https://huggingface.co/docs/tokenizers/quicktour
134
+ tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
135
+ tokenizer.pre_tokenizer = Whitespace()
136
+ trainer = WordLevelTrainer(special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"], min_frequency=2)
137
+ print(f"Training tokenizer for {lang}")
138
+ tokenizer.train_from_iterator(get_all_sentences(ds, lang), trainer=trainer)
139
+ print(f"Saving tokenizer for {lang}")
140
+ tokenizer.save(str(tokenizer_path))
141
+ else:
142
+ print(f"Found existing tokenizer for {lang}")
143
+ tokenizer = Tokenizer.from_file(str(tokenizer_path))
144
+ return tokenizer
145
+
146
+ def get_ds(config):
147
+ # It only has the train split, so we divide it overselves
148
+ print(f"Loading dataset {config['datasource']}")
149
+ ds_raw = load_dataset(f"{config['datasource']}", f"{config['lang_src']}-{config['lang_tgt']}", split='train')
150
+
151
+ # Build tokenizers
152
+ print(f"Building tokenizers for {config['lang_src']} and {config['lang_tgt']}")
153
+ tokenizer_src = get_or_build_tokenizer(config, ds_raw, config['lang_src'])
154
+ tokenizer_tgt = get_or_build_tokenizer(config, ds_raw, config['lang_tgt'])
155
+
156
+ # Keep 90% for training, 10% for validation
157
+ print("Splitting dataset into training and validation")
158
+ train_ds_size = int(0.9 * len(ds_raw))
159
+ val_ds_size = len(ds_raw) - train_ds_size
160
+ train_ds_raw, val_ds_raw = random_split(ds_raw, [train_ds_size, val_ds_size])
161
+
162
+ train_ds = BilingualDataset(train_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
163
+ val_ds = BilingualDataset(val_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
164
+
165
+ # Find the maximum length of each sentence in the source and target sentence
166
+ print("Finding the maximum length of the source and target sentences")
167
+ max_len_src = 0
168
+ max_len_tgt = 0
169
+
170
+ for item in ds_raw:
171
+ src_ids = tokenizer_src.encode(item['translation'][config['lang_src']]).ids
172
+ tgt_ids = tokenizer_tgt.encode(item['translation'][config['lang_tgt']]).ids
173
+ max_len_src = max(max_len_src, len(src_ids))
174
+ max_len_tgt = max(max_len_tgt, len(tgt_ids))
175
+
176
+ print(f'Max length of source sentence: {max_len_src}')
177
+ print(f'Max length of target sentence: {max_len_tgt}')
178
+
179
+
180
+ train_dataloader = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True)
181
+ val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True)
182
+
183
+ return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt
184
+
185
+ def get_model(config, vocab_src_len, vocab_tgt_len):
186
+ model = build_transformer(vocab_src_len, vocab_tgt_len, config["seq_len"], config['seq_len'], d_model=config['d_model'])
187
+ return model
188
+
189
+ def train_model(config):
190
+ # Define the device
191
+ device = "cuda" if torch.cuda.is_available() else "mps" if torch.has_mps or torch.backends.mps.is_available() else "cpu"
192
+ print("Using device:", device)
193
+ if (device == 'cuda'):
194
+ print(f"Device name: {torch.cuda.get_device_name(device.index)}")
195
+ print(f"Device memory: {torch.cuda.get_device_properties(device.index).total_memory / 1024 ** 3} GB")
196
+ elif (device == 'mps'):
197
+ print(f"Device name: <mps>")
198
+ else:
199
+ print("NOTE: If you have a GPU, consider using it for training.")
200
+ print(" On a Windows machine with NVidia GPU, check this video: https://www.youtube.com/watch?v=GMSjDTU8Zlc")
201
+ print(" On a Mac machine, run: pip3 install --pre torch torchvision torchaudio torchtext --index-url https://download.pytorch.org/whl/nightly/cpu")
202
+ device = torch.device(device)
203
+
204
+ # Make sure the weights folder exists
205
+ Path(f"{config['datasource']}_{config['model_folder']}").mkdir(parents=True, exist_ok=True)
206
+
207
+ train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
208
+ model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)
209
+ # Tensorboard
210
+ # writer = SummaryWriter(config['experiment_name'])
211
+
212
+ optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps=1e-9)
213
+
214
+ # If the user specified a model to preload before training, load it
215
+ initial_epoch = 0
216
+ global_step = 0
217
+ preload = config['preload']
218
+ model_filename = latest_weights_file_path(config) if preload == 'latest' else get_weights_file_path(config, preload) if preload else None
219
+ if model_filename:
220
+ print(f'Preloading model {model_filename}')
221
+ state = torch.load(model_filename)
222
+ model.load_state_dict(state['model_state_dict'])
223
+ initial_epoch = state['epoch'] + 1
224
+ optimizer.load_state_dict(state['optimizer_state_dict'])
225
+ global_step = state['global_step']
226
+ else:
227
+ print('No model to preload, starting from scratch')
228
+
229
+ loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer_src.token_to_id('[PAD]'), label_smoothing=0.1).to(device)
230
+
231
+ for epoch in range(initial_epoch, config['num_epochs']):
232
+ torch.cuda.empty_cache()
233
+ model.train()
234
+ batch_iterator = tqdm(train_dataloader, desc=f"Processing Epoch {epoch:02d}")
235
+ for batch in batch_iterator:
236
+
237
+ encoder_input = batch['encoder_input'].to(device) # (b, seq_len)
238
+ decoder_input = batch['decoder_input'].to(device) # (B, seq_len)
239
+ encoder_mask = batch['encoder_mask'].to(device) # (B, 1, 1, seq_len)
240
+ decoder_mask = batch['decoder_mask'].to(device) # (B, 1, seq_len, seq_len)
241
+
242
+ # Run the tensors through the encoder, decoder and the projection layer
243
+ encoder_output = model.encode(encoder_input, encoder_mask) # (B, seq_len, d_model)
244
+ decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask) # (B, seq_len, d_model)
245
+ proj_output = model.project(decoder_output) # (B, seq_len, vocab_size)
246
+
247
+ # Compare the output with the label
248
+ label = batch['label'].to(device) # (B, seq_len)
249
+
250
+ # Compute the loss using a simple cross entropy
251
+ loss = loss_fn(proj_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1))
252
+ batch_iterator.set_postfix({"loss": f"{loss.item():6.3f}"})
253
+
254
+ # Log the loss
255
+ # writer.add_scalar('train loss', loss.item(), global_step)
256
+ # writer.flush()
257
+
258
+ # Backpropagate the loss
259
+ loss.backward()
260
+
261
+ # Update the weights
262
+ optimizer.step()
263
+ optimizer.zero_grad(set_to_none=True)
264
+
265
+ global_step += 1
266
+
267
+ # Run validation at the end of every epoch
268
+ run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, config['seq_len'], device, lambda msg: batch_iterator.write(msg), global_step, writer=None)
269
+
270
+ # Save the model at the end of every epoch
271
+ model_filename = get_weights_file_path(config, f"{epoch:02d}")
272
+ torch.save({
273
+ 'epoch': epoch,
274
+ 'model_state_dict': model.state_dict(),
275
+ 'optimizer_state_dict': optimizer.state_dict(),
276
+ 'global_step': global_step
277
+ }, model_filename)
278
+
279
+
280
+ if __name__ == '__main__':
281
+ warnings.filterwarnings("ignore")
282
+ config = get_config()
283
+ train_model(config)
translate.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from config import get_config, latest_weights_file_path
3
+ from model import build_transformer
4
+ from tokenizers import Tokenizer
5
+ from datasets import load_dataset
6
+ from dataset import BilingualDataset
7
+ import torch
8
+ import sys
9
+
10
+ def translate(sentence: str):
11
+ # Define the device, tokenizers, and model
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ print("Using device:", device)
14
+ config = get_config()
15
+ tokenizer_src = Tokenizer.from_file(str(Path(config['tokenizer_file'].format(config['lang_src']))))
16
+ tokenizer_tgt = Tokenizer.from_file(str(Path(config['tokenizer_file'].format(config['lang_tgt']))))
17
+ model = build_transformer(tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size(), config["seq_len"], config['seq_len'], d_model=config['d_model']).to(device)
18
+
19
+ # Load the pretrained weights
20
+ model_filename = latest_weights_file_path(config)
21
+ state = torch.load(model_filename)
22
+ model.load_state_dict(state['model_state_dict'])
23
+
24
+ # if the sentence is a number use it as an index to the test set
25
+ label = ""
26
+ if type(sentence) == int or sentence.isdigit():
27
+ id = int(sentence)
28
+ ds = load_dataset(f"{config['datasource']}", f"{config['lang_src']}-{config['lang_tgt']}", split='all')
29
+ ds = BilingualDataset(ds, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
30
+ sentence = ds[id]['src_text']
31
+ label = ds[id]["tgt_text"]
32
+ seq_len = config['seq_len']
33
+
34
+ # translate the sentence
35
+ model.eval()
36
+ with torch.no_grad():
37
+ # Precompute the encoder output and reuse it for every generation step
38
+ source = tokenizer_src.encode(sentence)
39
+ source = torch.cat([
40
+ torch.tensor([tokenizer_src.token_to_id('[SOS]')], dtype=torch.int64),
41
+ torch.tensor(source.ids, dtype=torch.int64),
42
+ torch.tensor([tokenizer_src.token_to_id('[EOS]')], dtype=torch.int64),
43
+ torch.tensor([tokenizer_src.token_to_id('[PAD]')] * (seq_len - len(source.ids) - 2), dtype=torch.int64)
44
+ ], dim=0).to(device)
45
+ source_mask = (source != tokenizer_src.token_to_id('[PAD]')).unsqueeze(0).unsqueeze(0).int().to(device)
46
+ encoder_output = model.encode(source, source_mask)
47
+
48
+ # Initialize the decoder input with the sos token
49
+ decoder_input = torch.empty(1, 1).fill_(tokenizer_tgt.token_to_id('[SOS]')).type_as(source).to(device)
50
+
51
+ # Print the source sentence and target start prompt
52
+ if label != "": print(f"{f'ID: ':>12}{id}")
53
+ print(f"{f'SOURCE: ':>12}{sentence}")
54
+ if label != "": print(f"{f'TARGET: ':>12}{label}")
55
+ print(f"{f'PREDICTED: ':>12}", end='')
56
+
57
+ # Generate the translation word by word
58
+ while decoder_input.size(1) < seq_len:
59
+ # build mask for target and calculate output
60
+ decoder_mask = torch.triu(torch.ones((1, decoder_input.size(1), decoder_input.size(1))), diagonal=1).type(torch.int).type_as(source_mask).to(device)
61
+ out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)
62
+
63
+ # project next token
64
+ prob = model.project(out[:, -1])
65
+ _, next_word = torch.max(prob, dim=1)
66
+ decoder_input = torch.cat([decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)], dim=1)
67
+
68
+ # print the translated word
69
+ print(f"{tokenizer_tgt.decode([next_word.item()])}", end=' ')
70
+
71
+ # break if we predict the end of sentence token
72
+ if next_word == tokenizer_tgt.token_to_id('[EOS]'):
73
+ break
74
+
75
+ # convert ids to tokens
76
+ return tokenizer_tgt.decode(decoder_input[0].tolist())
77
+
78
+ #read sentence from argument
79
+ translate(sys.argv[1] if len(sys.argv) > 1 else "I am not a very good a student.")