noahkim commited on
Commit
5780d70
ยท
1 Parent(s): 55c4f32

Load Model file

Browse files

<<20220904 Commit>>
ํ•ด๋‹น ํŒŒ์ผ์„ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ๋กœ ํ™œ์šฉํ•ด์•ผ๋งŒ ๋ชจ๋ธ์„ ๋กœ๋“œํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
์ง€์†์ ์œผ๋กœ ์—…๋ฐ์ดํŠธํ•˜๋ฉด์„œ ๋ฐ”๊ฟ”๋†“๊ฒ ์Šต๋‹ˆ๋‹ค.

Files changed (1) hide show
  1. load_model.py +181 -0
load_model.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import math
4
+ import pandas as pd
5
+ import numpy as np
6
+ import argparse
7
+ import json
8
+ import gc
9
+ import re
10
+ import copy
11
+ import random
12
+ from tqdm import tqdm
13
+
14
+
15
+ from typing import Dict, List, Optional, Tuple
16
+
17
+ from sklearn.preprocessing import LabelEncoder
18
+ from sklearn.model_selection import train_test_split
19
+ from transformers import PreTrainedTokenizerFast, LEDForConditionalGeneration, AutoModel
20
+ from transformers import BartForConditionalGeneration, BartConfig
21
+ from transformers.models.bart.modeling_bart import BartLearnedPositionalEmbedding
22
+ from transformers.models.longformer.modeling_longformer import LongformerSelfAttention
23
+ from transformers import get_linear_schedule_with_warmup, AdamW, TrainingArguments
24
+
25
+ import torch
26
+ import torch.nn as nn
27
+ import torch.optim as optim
28
+ from torch.nn import CrossEntropyLoss, MSELoss
29
+ from torch.nn import functional as F
30
+ from torch.utils.data import DataLoader, Dataset
31
+
32
+
33
+ # Kobart์˜ attention layer๋ฅผ ๋Œ€์ฒด
34
+ class LongformerSelfAttentionForBart(nn.Module):
35
+ def __init__(self, config : dict , layer_id : int):
36
+ super().__init__()
37
+ self.embed_dim = config.d_model
38
+ self.longformer_self_attn = LongformerSelfAttention(config, layer_id=layer_id)
39
+ self.output = nn.Linear(self.embed_dim, self.embed_dim)
40
+
41
+ # kobart์˜ ๊ธฐ์กด layer์™€ ๋™์ผํ•œ ํ˜•ํƒœ์˜ ์ž…๋ ฅ์„ ๋ฐ›๊ณ , ๋™์ผํ•œ ํ˜•ํƒœ์˜ ์ถœ๋ ฅ์„ ํ•  ์ˆ˜ ์žˆ๋„๋ก ํ•ด์ค˜์•ผํ•จ.
42
+ def forward(self,
43
+ hidden_states: torch.Tensor,
44
+ key_value_states: Optional[torch.Tensor] = None,
45
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
46
+ attention_mask: Optional[torch.Tensor] = None,
47
+ layer_head_mask: Optional[torch.Tensor] = None,
48
+ output_attentions: bool = False,
49
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
50
+
51
+ is_cross_attention = key_value_states is not None
52
+ bsz, tgt_len, embed_dim = hidden_states.size()
53
+
54
+ # bs x seq_len x seq_len -> bs x seq_len ์œผ๋กœ ๋ณ€๊ฒฝ
55
+ attention_mask = attention_mask.squeeze(dim=1)
56
+ attention_mask = attention_mask[:,0]
57
+
58
+ is_index_masked = attention_mask < 0
59
+ is_index_global_attn = attention_mask > 0
60
+ is_global_attn = is_index_global_attn.flatten().any().item()
61
+
62
+ outputs = self.longformer_self_attn(
63
+ hidden_states,
64
+ attention_mask=attention_mask,
65
+ layer_head_mask=None,
66
+ is_index_masked=is_index_masked,
67
+ is_index_global_attn=is_index_global_attn,
68
+ is_global_attn=is_global_attn,
69
+ output_attentions=output_attentions,
70
+ )
71
+
72
+ attn_output = self.output(outputs[0])
73
+
74
+ return (attn_output,) + outputs[1:] if len(outputs) == 2 else (attn_output, None, None)
75
+
76
+ class LongformerBartForConditionalGeneration(BartForConditionalGeneration):
77
+ def __init__(self, config):
78
+ super().__init__(config)
79
+ if config.attention_mode == 'n2':
80
+ pass # do nothing, use BertSelfAttention instead
81
+ else:
82
+ self.model.encoder.embed_positions = BartLearnedPositionalEmbedding(
83
+ config.max_encoder_position_embeddings,
84
+ config.d_model,
85
+ config.pad_token_id)
86
+
87
+ self.model.decoder.embed_positions = BartLearnedPositionalEmbedding(
88
+ config.max_decoder_position_embeddings,
89
+ config.d_model,
90
+ config.pad_token_id)
91
+
92
+ for i, layer in enumerate(self.model.encoder.layers):
93
+ layer.self_attn = LongformerSelfAttentionForBart(config, layer_id=i)
94
+
95
+ #longformer bart๋ชจ๋ธ์˜ config ์ƒ์„ฑ class
96
+ class LongformerBartConfig(BartConfig):
97
+ def __init__(self, attention_window: List[int] = [512], attention_dilation: List[int] = [1],
98
+ autoregressive: bool = False, attention_mode: str = 'sliding_chunks',
99
+ gradient_checkpointing: bool = False, max_seq_len: int = 4096, max_pos: int = 4104, **kwargs):
100
+ """
101
+ Args:
102
+ attention_window: list of attention window sizes of length = number of layers.
103
+ window size = number of attention locations on each side.
104
+ For an affective window size of 512, use `attention_window=[256]*num_layers`
105
+ which is 256 on each side.
106
+ attention_dilation: list of attention dilation of length = number of layers.
107
+ attention dilation of `1` means no dilation.
108
+ autoregressive: do autoregressive attention or have attention of both sides
109
+ attention_mode: 'n2' for regular n^2 self-attention, 'tvm' for TVM implemenation of Longformer
110
+ selfattention, 'sliding_chunks' for another implementation of Longformer selfattention
111
+ """
112
+
113
+ super().__init__(**kwargs)
114
+ self.attention_window = attention_window
115
+ self.attention_dilation = attention_dilation
116
+ self.autoregressive = autoregressive
117
+ self.attention_mode = attention_mode
118
+ self.gradient_checkpointing = gradient_checkpointing
119
+ assert self.attention_mode in ['tvm', 'sliding_chunks', 'n2']
120
+
121
+ if __name__ == '__main__':
122
+ # Longformer weight ๋งŒ๋“œ๋Š” ์ฝ”๋“œ
123
+ max_pos = 4104
124
+ max_seq_len = 4096
125
+ attention_window = 512
126
+ save_path = '../LED_KoBART/model'
127
+
128
+ # ๊ธฐ์กด pretrained ๋œ kobart tokenizer & model load
129
+ tokenizer = PreTrainedTokenizerFast.from_pretrained('ainize/kobart-news', model_max_length=max_pos)
130
+ kobart_longformer = BartForConditionalGeneration.from_pretrained('ainize/kobart-news')
131
+ config = LongformerBartConfig.from_pretrained('ainize/kobart-news')
132
+
133
+ kobart_longformer.config = config
134
+
135
+ config.attention_probs_dropout_prob = config.attention_dropout
136
+ config.architectures = ['LongformerEncoderDecoderForConditionalGeneration', ]
137
+
138
+ # Tokenizer์˜ max_positional_embedding_size ํ™•์žฅ
139
+ # extend position embeddings
140
+ tokenizer.model_max_length = max_pos
141
+ tokenizer.init_kwargs['model_max_length'] = max_pos
142
+ current_max_pos, embed_size = kobart_longformer.model.encoder.embed_positions.weight.shape
143
+ assert current_max_pos == config.max_position_embeddings + 2
144
+
145
+ config.max_encoder_position_embeddings = max_pos
146
+ config.max_decoder_position_embeddings = config.max_position_embeddings
147
+ del config.max_position_embeddings
148
+ max_pos += 2 # NOTE: BART has positions 0,1 reserved, so embedding size is max position + 2
149
+ assert max_pos >= current_max_pos
150
+
151
+ new_encoder_pos_embed = kobart_longformer.model.encoder.embed_positions.weight.new_empty(max_pos, embed_size)
152
+
153
+ # Positional Embedding ํ™•์žฅ
154
+ k = 2
155
+ step = 1028 - 2
156
+ while k < max_pos - 1:
157
+ new_encoder_pos_embed[k:(k + step)] = kobart_longformer.model.encoder.embed_positions.weight[2:]
158
+ k += step
159
+ kobart_longformer.model.encoder.embed_positions.weight.data = new_encoder_pos_embed
160
+
161
+ config.attention_window = [attention_window] * config.num_hidden_layers
162
+ config.attention_dilation = [1] * config.num_hidden_layers
163
+
164
+ # Kobart Self attention > Longformer Self Attention
165
+ for i, layer in enumerate(kobart_longformer.model.encoder.layers):
166
+ longformer_self_attn_for_bart = LongformerSelfAttentionForBart(kobart_longformer.config, layer_id=i)
167
+
168
+ longformer_self_attn_for_bart.longformer_self_attn.query = layer.self_attn.q_proj
169
+ longformer_self_attn_for_bart.longformer_self_attn.key = layer.self_attn.k_proj
170
+ longformer_self_attn_for_bart.longformer_self_attn.value = layer.self_attn.v_proj
171
+
172
+ longformer_self_attn_for_bart.longformer_self_attn.query_global = copy.deepcopy(layer.self_attn.q_proj)
173
+ longformer_self_attn_for_bart.longformer_self_attn.key_global = copy.deepcopy(layer.self_attn.k_proj)
174
+ longformer_self_attn_for_bart.longformer_self_attn.value_global = copy.deepcopy(layer.self_attn.v_proj)
175
+
176
+ longformer_self_attn_for_bart.output = layer.self_attn.out_proj
177
+
178
+ layer.self_attn = longformer_self_attn_for_bart
179
+
180
+ kobart_longformer.save_pretrained(save_path)
181
+ tokenizer.save_pretrained(save_path, None)