Load Model file
Browse files<<20220904 Commit>>
ํด๋น ํ์ผ์ ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ก ํ์ฉํด์ผ๋ง ๋ชจ๋ธ์ ๋ก๋ํ ์ ์์ต๋๋ค.
์ง์์ ์ผ๋ก ์
๋ฐ์ดํธํ๋ฉด์ ๋ฐ๊ฟ๋๊ฒ ์ต๋๋ค.
- load_model.py +181 -0
load_model.py
ADDED
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import math
|
4 |
+
import pandas as pd
|
5 |
+
import numpy as np
|
6 |
+
import argparse
|
7 |
+
import json
|
8 |
+
import gc
|
9 |
+
import re
|
10 |
+
import copy
|
11 |
+
import random
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
|
15 |
+
from typing import Dict, List, Optional, Tuple
|
16 |
+
|
17 |
+
from sklearn.preprocessing import LabelEncoder
|
18 |
+
from sklearn.model_selection import train_test_split
|
19 |
+
from transformers import PreTrainedTokenizerFast, LEDForConditionalGeneration, AutoModel
|
20 |
+
from transformers import BartForConditionalGeneration, BartConfig
|
21 |
+
from transformers.models.bart.modeling_bart import BartLearnedPositionalEmbedding
|
22 |
+
from transformers.models.longformer.modeling_longformer import LongformerSelfAttention
|
23 |
+
from transformers import get_linear_schedule_with_warmup, AdamW, TrainingArguments
|
24 |
+
|
25 |
+
import torch
|
26 |
+
import torch.nn as nn
|
27 |
+
import torch.optim as optim
|
28 |
+
from torch.nn import CrossEntropyLoss, MSELoss
|
29 |
+
from torch.nn import functional as F
|
30 |
+
from torch.utils.data import DataLoader, Dataset
|
31 |
+
|
32 |
+
|
33 |
+
# Kobart์ attention layer๋ฅผ ๋์ฒด
|
34 |
+
class LongformerSelfAttentionForBart(nn.Module):
|
35 |
+
def __init__(self, config : dict , layer_id : int):
|
36 |
+
super().__init__()
|
37 |
+
self.embed_dim = config.d_model
|
38 |
+
self.longformer_self_attn = LongformerSelfAttention(config, layer_id=layer_id)
|
39 |
+
self.output = nn.Linear(self.embed_dim, self.embed_dim)
|
40 |
+
|
41 |
+
# kobart์ ๊ธฐ์กด layer์ ๋์ผํ ํํ์ ์
๋ ฅ์ ๋ฐ๊ณ , ๋์ผํ ํํ์ ์ถ๋ ฅ์ ํ ์ ์๋๋ก ํด์ค์ผํจ.
|
42 |
+
def forward(self,
|
43 |
+
hidden_states: torch.Tensor,
|
44 |
+
key_value_states: Optional[torch.Tensor] = None,
|
45 |
+
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
46 |
+
attention_mask: Optional[torch.Tensor] = None,
|
47 |
+
layer_head_mask: Optional[torch.Tensor] = None,
|
48 |
+
output_attentions: bool = False,
|
49 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
50 |
+
|
51 |
+
is_cross_attention = key_value_states is not None
|
52 |
+
bsz, tgt_len, embed_dim = hidden_states.size()
|
53 |
+
|
54 |
+
# bs x seq_len x seq_len -> bs x seq_len ์ผ๋ก ๋ณ๊ฒฝ
|
55 |
+
attention_mask = attention_mask.squeeze(dim=1)
|
56 |
+
attention_mask = attention_mask[:,0]
|
57 |
+
|
58 |
+
is_index_masked = attention_mask < 0
|
59 |
+
is_index_global_attn = attention_mask > 0
|
60 |
+
is_global_attn = is_index_global_attn.flatten().any().item()
|
61 |
+
|
62 |
+
outputs = self.longformer_self_attn(
|
63 |
+
hidden_states,
|
64 |
+
attention_mask=attention_mask,
|
65 |
+
layer_head_mask=None,
|
66 |
+
is_index_masked=is_index_masked,
|
67 |
+
is_index_global_attn=is_index_global_attn,
|
68 |
+
is_global_attn=is_global_attn,
|
69 |
+
output_attentions=output_attentions,
|
70 |
+
)
|
71 |
+
|
72 |
+
attn_output = self.output(outputs[0])
|
73 |
+
|
74 |
+
return (attn_output,) + outputs[1:] if len(outputs) == 2 else (attn_output, None, None)
|
75 |
+
|
76 |
+
class LongformerBartForConditionalGeneration(BartForConditionalGeneration):
|
77 |
+
def __init__(self, config):
|
78 |
+
super().__init__(config)
|
79 |
+
if config.attention_mode == 'n2':
|
80 |
+
pass # do nothing, use BertSelfAttention instead
|
81 |
+
else:
|
82 |
+
self.model.encoder.embed_positions = BartLearnedPositionalEmbedding(
|
83 |
+
config.max_encoder_position_embeddings,
|
84 |
+
config.d_model,
|
85 |
+
config.pad_token_id)
|
86 |
+
|
87 |
+
self.model.decoder.embed_positions = BartLearnedPositionalEmbedding(
|
88 |
+
config.max_decoder_position_embeddings,
|
89 |
+
config.d_model,
|
90 |
+
config.pad_token_id)
|
91 |
+
|
92 |
+
for i, layer in enumerate(self.model.encoder.layers):
|
93 |
+
layer.self_attn = LongformerSelfAttentionForBart(config, layer_id=i)
|
94 |
+
|
95 |
+
#longformer bart๋ชจ๋ธ์ config ์์ฑ class
|
96 |
+
class LongformerBartConfig(BartConfig):
|
97 |
+
def __init__(self, attention_window: List[int] = [512], attention_dilation: List[int] = [1],
|
98 |
+
autoregressive: bool = False, attention_mode: str = 'sliding_chunks',
|
99 |
+
gradient_checkpointing: bool = False, max_seq_len: int = 4096, max_pos: int = 4104, **kwargs):
|
100 |
+
"""
|
101 |
+
Args:
|
102 |
+
attention_window: list of attention window sizes of length = number of layers.
|
103 |
+
window size = number of attention locations on each side.
|
104 |
+
For an affective window size of 512, use `attention_window=[256]*num_layers`
|
105 |
+
which is 256 on each side.
|
106 |
+
attention_dilation: list of attention dilation of length = number of layers.
|
107 |
+
attention dilation of `1` means no dilation.
|
108 |
+
autoregressive: do autoregressive attention or have attention of both sides
|
109 |
+
attention_mode: 'n2' for regular n^2 self-attention, 'tvm' for TVM implemenation of Longformer
|
110 |
+
selfattention, 'sliding_chunks' for another implementation of Longformer selfattention
|
111 |
+
"""
|
112 |
+
|
113 |
+
super().__init__(**kwargs)
|
114 |
+
self.attention_window = attention_window
|
115 |
+
self.attention_dilation = attention_dilation
|
116 |
+
self.autoregressive = autoregressive
|
117 |
+
self.attention_mode = attention_mode
|
118 |
+
self.gradient_checkpointing = gradient_checkpointing
|
119 |
+
assert self.attention_mode in ['tvm', 'sliding_chunks', 'n2']
|
120 |
+
|
121 |
+
if __name__ == '__main__':
|
122 |
+
# Longformer weight ๋ง๋๋ ์ฝ๋
|
123 |
+
max_pos = 4104
|
124 |
+
max_seq_len = 4096
|
125 |
+
attention_window = 512
|
126 |
+
save_path = '../LED_KoBART/model'
|
127 |
+
|
128 |
+
# ๊ธฐ์กด pretrained ๋ kobart tokenizer & model load
|
129 |
+
tokenizer = PreTrainedTokenizerFast.from_pretrained('ainize/kobart-news', model_max_length=max_pos)
|
130 |
+
kobart_longformer = BartForConditionalGeneration.from_pretrained('ainize/kobart-news')
|
131 |
+
config = LongformerBartConfig.from_pretrained('ainize/kobart-news')
|
132 |
+
|
133 |
+
kobart_longformer.config = config
|
134 |
+
|
135 |
+
config.attention_probs_dropout_prob = config.attention_dropout
|
136 |
+
config.architectures = ['LongformerEncoderDecoderForConditionalGeneration', ]
|
137 |
+
|
138 |
+
# Tokenizer์ max_positional_embedding_size ํ์ฅ
|
139 |
+
# extend position embeddings
|
140 |
+
tokenizer.model_max_length = max_pos
|
141 |
+
tokenizer.init_kwargs['model_max_length'] = max_pos
|
142 |
+
current_max_pos, embed_size = kobart_longformer.model.encoder.embed_positions.weight.shape
|
143 |
+
assert current_max_pos == config.max_position_embeddings + 2
|
144 |
+
|
145 |
+
config.max_encoder_position_embeddings = max_pos
|
146 |
+
config.max_decoder_position_embeddings = config.max_position_embeddings
|
147 |
+
del config.max_position_embeddings
|
148 |
+
max_pos += 2 # NOTE: BART has positions 0,1 reserved, so embedding size is max position + 2
|
149 |
+
assert max_pos >= current_max_pos
|
150 |
+
|
151 |
+
new_encoder_pos_embed = kobart_longformer.model.encoder.embed_positions.weight.new_empty(max_pos, embed_size)
|
152 |
+
|
153 |
+
# Positional Embedding ํ์ฅ
|
154 |
+
k = 2
|
155 |
+
step = 1028 - 2
|
156 |
+
while k < max_pos - 1:
|
157 |
+
new_encoder_pos_embed[k:(k + step)] = kobart_longformer.model.encoder.embed_positions.weight[2:]
|
158 |
+
k += step
|
159 |
+
kobart_longformer.model.encoder.embed_positions.weight.data = new_encoder_pos_embed
|
160 |
+
|
161 |
+
config.attention_window = [attention_window] * config.num_hidden_layers
|
162 |
+
config.attention_dilation = [1] * config.num_hidden_layers
|
163 |
+
|
164 |
+
# Kobart Self attention > Longformer Self Attention
|
165 |
+
for i, layer in enumerate(kobart_longformer.model.encoder.layers):
|
166 |
+
longformer_self_attn_for_bart = LongformerSelfAttentionForBart(kobart_longformer.config, layer_id=i)
|
167 |
+
|
168 |
+
longformer_self_attn_for_bart.longformer_self_attn.query = layer.self_attn.q_proj
|
169 |
+
longformer_self_attn_for_bart.longformer_self_attn.key = layer.self_attn.k_proj
|
170 |
+
longformer_self_attn_for_bart.longformer_self_attn.value = layer.self_attn.v_proj
|
171 |
+
|
172 |
+
longformer_self_attn_for_bart.longformer_self_attn.query_global = copy.deepcopy(layer.self_attn.q_proj)
|
173 |
+
longformer_self_attn_for_bart.longformer_self_attn.key_global = copy.deepcopy(layer.self_attn.k_proj)
|
174 |
+
longformer_self_attn_for_bart.longformer_self_attn.value_global = copy.deepcopy(layer.self_attn.v_proj)
|
175 |
+
|
176 |
+
longformer_self_attn_for_bart.output = layer.self_attn.out_proj
|
177 |
+
|
178 |
+
layer.self_attn = longformer_self_attn_for_bart
|
179 |
+
|
180 |
+
kobart_longformer.save_pretrained(save_path)
|
181 |
+
tokenizer.save_pretrained(save_path, None)
|