cocoirun commited on
Commit
9fcb9ee
โ€ข
1 Parent(s): 2f60e7d

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +129 -0
README.md CHANGED
@@ -38,4 +38,133 @@ output ="""
38
  - ์ƒ๋‹ด์›์ด ์นด๋“œ ๋ฒˆํ˜ธ์™€ ์ž”์•ก ํ™•์ธ ํ›„ ์ถ”๊ฐ€ ์ด์šฉ ํ˜œํƒ ์•ˆ๋‚ด
39
  - ๊ณ ๊ฐ์ด ์—ฌํ–‰ ํ• ์ธ, ๋งˆ์ผ๋ฆฌ์ง€, ํ˜ธํ…” ํ• ์ธ ๋“ฑ ๋‹ค์–‘ํ•œ ํ˜œํƒ์— ๊ด€์‹ฌ ํ‘œํ˜„
40
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  ```
 
38
  - ์ƒ๋‹ด์›์ด ์นด๋“œ ๋ฒˆํ˜ธ์™€ ์ž”์•ก ํ™•์ธ ํ›„ ์ถ”๊ฐ€ ์ด์šฉ ํ˜œํƒ ์•ˆ๋‚ด
39
  - ๊ณ ๊ฐ์ด ์—ฌํ–‰ ํ• ์ธ, ๋งˆ์ผ๋ฆฌ์ง€, ํ˜ธํ…” ํ• ์ธ ๋“ฑ ๋‹ค์–‘ํ•œ ํ˜œํƒ์— ๊ด€์‹ฌ ํ‘œํ˜„
40
  """
41
+ ```
42
+
43
+
44
+ ํ•ด๋‹น ๋ชจ๋ธ์„ ํ™œ์šฉํ•˜๊ธฐ ์œ„ํ•ด์„œ ๋‹ค์Œ๊ณผ ๊ฐ™์€ class ํ•„์š”
45
+ ```
46
+ class LongformerSelfAttentionForBart(nn.Module):
47
+ def __init__(self, config, layer_id):
48
+ super().__init__()
49
+ self.embed_dim = config.d_model
50
+ self.longformer_self_attn = LongformerSelfAttention(config, layer_id=layer_id)
51
+ self.output = nn.Linear(self.embed_dim, self.embed_dim)
52
+
53
+ def forward(
54
+ self,
55
+ hidden_states: torch.Tensor,
56
+ key_value_states: Optional[torch.Tensor] = None,
57
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
58
+ attention_mask: Optional[torch.Tensor] = None,
59
+ layer_head_mask: Optional[torch.Tensor] = None,
60
+ output_attentions: bool = False,
61
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
62
+
63
+ is_cross_attention = key_value_states is not None
64
+ bsz, tgt_len, embed_dim = hidden_states.size()
65
+
66
+ # bs x seq_len x seq_len -> bs x seq_len ์œผ๋กœ ๋ณ€๊ฒฝ
67
+ attention_mask = attention_mask.squeeze(dim=1)
68
+ attention_mask = attention_mask[:,0]
69
+
70
+ is_index_masked = attention_mask < 0
71
+ is_index_global_attn = attention_mask > 0
72
+ is_global_attn = is_index_global_attn.flatten().any().item()
73
+
74
+ outputs = self.longformer_self_attn(
75
+ hidden_states,
76
+ attention_mask=attention_mask,
77
+ layer_head_mask=None,
78
+ is_index_masked=is_index_masked,
79
+ is_index_global_attn=is_index_global_attn,
80
+ is_global_attn=is_global_attn,
81
+ output_attentions=output_attentions,
82
+ )
83
+
84
+ attn_output = self.output(outputs[0])
85
+
86
+ return (attn_output,) + outputs[1:] if len(outputs) == 2 else (attn_output, None, None)
87
+ ```
88
+
89
+ ```
90
+ class LongformerEncoderDecoderForConditionalGeneration(BartForConditionalGeneration):
91
+ def __init__(self, config):
92
+ super().__init__(config)
93
+
94
+ if config.attention_mode == 'n2':
95
+ pass # do nothing, use BertSelfAttention instead
96
+ else:
97
+
98
+ self.model.encoder.embed_positions = BartLearnedPositionalEmbedding(
99
+ config.max_encoder_position_embeddings,
100
+ config.d_model)
101
+
102
+ self.model.decoder.embed_positions = BartLearnedPositionalEmbedding(
103
+ config.max_decoder_position_embeddings,
104
+ config.d_model)
105
+
106
+ for i, layer in enumerate(self.model.encoder.layers):
107
+ layer.self_attn = LongformerSelfAttentionForBart(config, layer_id=i)
108
+ ```
109
+
110
+ ```
111
+ class LongformerEncoderDecoderConfig(BartConfig):
112
+ def __init__(self, attention_window: List[int] = None, attention_dilation: List[int] = None,
113
+ autoregressive: bool = False, attention_mode: str = 'sliding_chunks',
114
+ gradient_checkpointing: bool = False, **kwargs):
115
+ """
116
+ Args:
117
+ attention_window: list of attention window sizes of length = number of layers.
118
+ window size = number of attention locations on each side.
119
+ For an affective window size of 512, use `attention_window=[256]*num_layers`
120
+ which is 256 on each side.
121
+ attention_dilation: list of attention dilation of length = number of layers.
122
+ attention dilation of `1` means no dilation.
123
+ autoregressive: do autoregressive attention or have attention of both sides
124
+ attention_mode: 'n2' for regular n^2 self-attention, 'tvm' for TVM implemenation of Longformer
125
+ selfattention, 'sliding_chunks' for another implementation of Longformer selfattention
126
+ """
127
+ super().__init__(**kwargs)
128
+ self.attention_window = attention_window
129
+ self.attention_dilation = attention_dilation
130
+ self.autoregressive = autoregressive
131
+ self.attention_mode = attention_mode
132
+ self.gradient_checkpointing = gradient_checkpointing
133
+ assert self.attention_mode in ['tvm', 'sliding_chunks', 'n2']
134
+ ```
135
+ ๋ชจ๋ธ ์˜ค๋ธŒ์ ํŠธ ๋กœ๋“œ ํ›„
136
+ weightํŒŒ์ผ์„ ๋ณ„๋„๋กœ ๋‹ค์šด๋ฐ›์•„์„œ load_state_dict๋กœ ์›จ์ดํŠธ๋ฅผ ๋ถˆ๋Ÿฌ์•ผ ํ•ฉ๋‹ˆ๋‹ค.
137
+ ```
138
+ tokenizer = AutoTokenizer.from_pretrained("cocoirun/longforemr-kobart-summary-v1")
139
+ model = LongformerEncoderDecoderForConditionalGeneration.from_pretrained("cocoirun/longforemr-kobart-summary-v1")
140
+ device = torch.device('cuda')
141
+ model.load_state_dict(torch.load("summary weight.ckpt"))
142
+ model.to(device)
143
+ ```
144
+
145
+ ๋ชจ๋ธ ์š”์•ฝ ํ•จ์ˆ˜
146
+ ```
147
+ def summarize(text, max_len):
148
+ max_seq_len = 4096
149
+ context_tokens = ['<s>'] + tokenizer.tokenize(text) + ['</s>']
150
+ input_ids = tokenizer.convert_tokens_to_ids(context_tokens)
151
+
152
+ if len(input_ids) < max_seq_len:
153
+ while len(input_ids) < max_seq_len:
154
+ input_ids += [tokenizer.pad_token_id]
155
+
156
+ else:
157
+ input_ids = input_ids[:max_seq_len - 1] + [
158
+ tokenizer.eos_token_id]
159
+
160
+ res_ids = model.generate(torch.tensor([input_ids]).to(device),
161
+ max_length=max_len,
162
+ num_beams=5,
163
+ no_repeat_ngram_size = 3,
164
+ eos_token_id=tokenizer.eos_token_id,
165
+ bad_words_ids=[[tokenizer.unk_token_id]])
166
+
167
+ res = tokenizer.batch_decode(res_ids.tolist(), skip_special_tokens=True)[0]
168
+ res = res.replace("\n\n","\n")
169
+ return res
170
  ```