File size: 11,785 Bytes
48847bb
 
9cad313
 
 
48847bb
 
 
 
ed03b25
9cad313
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed03b25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9cad313
 
ed03b25
 
 
 
 
48847bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed03b25
 
 
 
48847bb
 
ed03b25
 
48847bb
 
ed03b25
48847bb
 
 
 
 
 
ed03b25
48847bb
 
 
 
 
 
 
 
9cad313
 
 
 
 
 
 
 
 
 
 
 
 
48847bb
 
 
 
 
 
 
 
 
ed03b25
 
48847bb
 
 
 
ed03b25
 
 
 
 
 
 
 
 
 
 
 
48847bb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
import gradio as gr
import requests
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers import AutoTokenizer, T5ForConditionalGeneration, AutoModelForSeq2SeqLM, T5Config
import torch

MAX_SOURCE_LENGTH = 512


class ReviewerModel(T5ForConditionalGeneration):

    def __init__(self, config):
        super().__init__(config)
        self.cls_head = nn.Linear(self.config.d_model, 2, bias=True)
        self.init()

    def init(self):
        nn.init.xavier_uniform_(self.lm_head.weight)
        factor = self.config.initializer_factor
        self.cls_head.weight.data.normal_(mean=0.0, \
                                          std=factor * ((self.config.d_model) ** -0.5))
        self.cls_head.bias.data.zero_()

    def forward(
            self, *argv, **kwargs
    ):
        r"""
        Doc from Huggingface transformers:
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
            Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[-100, 0, ...,
            config.vocab_size - 1]`. All labels set to ``-100`` are ignored (masked), the loss is only computed for
            labels in ``[0, ..., config.vocab_size]``
        Returns:
        Examples::
            >>> from transformers import T5Tokenizer, T5ForConditionalGeneration
            >>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
            >>> model = T5ForConditionalGeneration.from_pretrained('t5-small')
            >>> # training
            >>> input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids
            >>> labels = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2>', return_tensors='pt').input_ids
            >>> outputs = model(input_ids=input_ids, labels=labels)
            >>> loss = outputs.loss
            >>> logits = outputs.logits
            >>> # inference
            >>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you", return_tensors="pt").input_ids  # Batch size 1
            >>> outputs = model.generate(input_ids)
            >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
            >>> # studies have shown that owning a dog is good for you.
        """
        if "cls" in kwargs:
            assert (
                    "input_ids" in kwargs and \
                    "labels" in kwargs and \
                    "attention_mask" in kwargs
            )
            return self.cls(
                input_ids=kwargs["input_ids"],
                labels=kwargs["labels"],
                attention_mask=kwargs["attention_mask"],
            )
        if "input_labels" in kwargs:
            assert (
                    "input_ids" in kwargs and \
                    "input_labels" in kwargs and \
                    "decoder_input_ids" in kwargs and \
                    "attention_mask" in kwargs and \
                    "decoder_attention_mask" in kwargs
            ), "Please give these arg keys."
            input_ids = kwargs["input_ids"]
            input_labels = kwargs["input_labels"]
            decoder_input_ids = kwargs["decoder_input_ids"]
            attention_mask = kwargs["attention_mask"]
            decoder_attention_mask = kwargs["decoder_attention_mask"]
            if "encoder_loss" not in kwargs:
                encoder_loss = True
            else:
                encoder_loss = kwargs["encoder_loss"]
            return self.review_forward(input_ids, input_labels, decoder_input_ids, attention_mask,
                                       decoder_attention_mask, encoder_loss)
        return super().forward(*argv, **kwargs)

    def cls(
            self,
            input_ids,
            labels,
            attention_mask,
    ):
        encoder_outputs = self.encoder( \
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_attentions=False,
            return_dict=False
        )
        hidden_states = encoder_outputs[0]
        first_hidden = hidden_states[:, 0, :]
        first_hidden = nn.Dropout(0.3)(first_hidden)
        logits = self.cls_head(first_hidden)
        loss_fct = CrossEntropyLoss()
        if labels != None:
            loss = loss_fct(logits, labels)
            return loss
        return logits

    def review_forward(
            self,
            input_ids,
            input_labels,
            decoder_input_ids,
            attention_mask,
            decoder_attention_mask,
            encoder_loss=True
    ):
        encoder_outputs = self.encoder( \
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_attentions=False,
            return_dict=False
        )
        hidden_states = encoder_outputs[0]
        decoder_inputs = self._shift_right(decoder_input_ids)
        # Decode
        decoder_outputs = self.decoder(
            input_ids=decoder_inputs,
            attention_mask=decoder_attention_mask,
            encoder_hidden_states=hidden_states,
            encoder_attention_mask=attention_mask,
            output_attentions=False,
            return_dict=False
        )
        sequence_output = decoder_outputs[0]
        if self.config.tie_word_embeddings:  # this is True default
            sequence_output = sequence_output * (self.model_dim ** -0.5)
        if encoder_loss:
            # print(self.encoder.get_input_embeddings().weight.shape)
            cls_logits = nn.functional.linear(hidden_states, self.encoder.get_input_embeddings().weight)
            # cls_logits = self.cls_head(hidden_states)
        lm_logits = self.lm_head(sequence_output)
        if decoder_input_ids is not None:
            lm_loss_fct = CrossEntropyLoss(ignore_index=0)  # Warning: PAD_ID should be 0
            loss = lm_loss_fct(lm_logits.view(-1, lm_logits.size(-1)), decoder_input_ids.view(-1))
            if encoder_loss and input_labels is not None:
                cls_loss_fct = CrossEntropyLoss(ignore_index=-100)
                loss += cls_loss_fct(cls_logits.view(-1, cls_logits.size(-1)), input_labels.view(-1))
            return loss
        return cls_logits, lm_logits


def prepare_models():
    tokenizer = AutoTokenizer.from_pretrained("microsoft/codereviewer")

    tokenizer.special_dict = {
        f"<e{i}>": tokenizer.get_vocab()[f"<e{i}>"] for i in range(99, -1, -1)
    }
    tokenizer.mask_id = tokenizer.get_vocab()["<mask>"]
    tokenizer.bos_id = tokenizer.get_vocab()["<s>"]
    tokenizer.pad_id = tokenizer.get_vocab()["<pad>"]
    tokenizer.eos_id = tokenizer.get_vocab()["</s>"]
    tokenizer.msg_id = tokenizer.get_vocab()["<msg>"]
    tokenizer.keep_id = tokenizer.get_vocab()["<keep>"]
    tokenizer.add_id = tokenizer.get_vocab()["<add>"]
    tokenizer.del_id = tokenizer.get_vocab()["<del>"]
    tokenizer.start_id = tokenizer.get_vocab()["<start>"]
    tokenizer.end_id = tokenizer.get_vocab()["<end>"]

    config = T5Config.from_pretrained("microsoft/codereviewer")
    model = ReviewerModel.from_pretrained("microsoft/codereviewer", config=config)

    model.eval()
    return tokenizer, model


def pad_assert(tokenizer, source_ids):
    source_ids = source_ids[:MAX_SOURCE_LENGTH - 2]
    source_ids = [tokenizer.bos_id] + source_ids + [tokenizer.eos_id]
    pad_len = MAX_SOURCE_LENGTH - len(source_ids)
    source_ids += [tokenizer.pad_id] * pad_len
    assert len(source_ids) == MAX_SOURCE_LENGTH, "Not equal length."
    return source_ids


def encode_diff(tokenizer, diff, msg, source):
    difflines = diff.split("\n")[1:]  # remove start @@
    difflines = [line for line in difflines if len(line.strip()) > 0]
    map_dic = {"-": 0, "+": 1, " ": 2}

    def f(s):
        if s in map_dic:
            return map_dic[s]
        else:
            return 2

    labels = [f(line[0]) for line in difflines]
    difflines = [line[1:].strip() for line in difflines]
    inputstr = "<s>" + source + "</s>"
    inputstr += "<msg>" + msg
    for label, line in zip(labels, difflines):
        if label == 1:
            inputstr += "<add>" + line
        elif label == 0:
            inputstr += "<del>" + line
        else:
            inputstr += "<keep>" + line
    source_ids = tokenizer.encode(inputstr, max_length=MAX_SOURCE_LENGTH, truncation=True)[1:-1]
    source_ids = pad_assert(tokenizer, source_ids)
    return source_ids


class FileDiffs(object):
    def __init__(self, diff_string):
        diff_array = diff_string.split("\n")
        self.file_name = diff_array[0]
        self.file_path = self.file_name.split("a/", 1)[1].rsplit("b/", 1)[0]
        self.diffs = list()
        for line in diff_array[4:]:
            if line.startswith("@@"):
                self.diffs.append(str())
            self.diffs[-1] += "\n" + line


def review_commit(user="p4vv37", repository="ueflow", commit="610a8c7b02b946bc9e5e26e6dacbba0e2abba259"):
    tokenizer, model = prepare_models()

    # Get diff and commit metadata from GitHub API
    commit_metadata = requests.get(F"https://api.github.com/repos/{user}/{repository}/commits/{commit}").json()
    msg = commit_metadata["commit"]["message"]
    diff_data = requests.get(F"https://api.github.com/repos/{user}/{repository}/commits/{commit}",
                             headers={"Accept": "application/vnd.github.diff"})
    code_diff = diff_data.text

    # Parse diff into FileDiffs objects
    files_diffs = list()
    for file in code_diff.split("diff --git"):
        if len(file) > 0:
            fd = FileDiffs(file)
            files_diffs.append(fd)

    # Generate comments for each diff
    output = ""
    for fd in files_diffs:
        output += F"File:{fd.file_path}\n"
        source = requests.get(F"https://raw.githubusercontent.com/{user}/{repository}/^{commit}/{fd.file_path}").text

        for diff in fd.diffs:
            inputs = torch.tensor([encode_diff(tokenizer, diff, msg, source)], dtype=torch.long).to("cpu")
            inputs_mask = inputs.ne(tokenizer.pad_id)
            logits = model(
                input_ids=inputs,
                cls=True,
                attention_mask=inputs_mask,
                labels=None,
                use_cache=True,
                num_beams=5,
                early_stopping=True,
                max_length=100
            )
            needs_review = torch.argmax(logits, dim=-1).cpu().numpy()[0]
            if not needs_review:
                continue
            preds = model.generate(inputs,
                                   attention_mask=inputs_mask,
                                   use_cache=True,
                                   num_beams=5,
                                   early_stopping=True,
                                   max_length=100,
                                   num_return_sequences=2
                                   )
            preds = list(preds.cpu().numpy())
            pred_nls = [tokenizer.decode(_id[2:], skip_special_tokens=True, clean_up_tokenization_spaces=False)
                        for _id in preds]
            output += diff + "\n#######\nComment:\n#######\n" + pred_nls[0] + "\n#######\n"
    return output


description = "An interface for running " \
              "\"Microsoft CodeBERT CodeReviewer: Pre-Training for Automating Code Review Activities.\" " \
              "(microsoft/codereviewer) on GitHub commits."
examples = [
    ["p4vv37", "ueflow", "610a8c7b02b946bc9e5e26e6dacbba0e2abba259"],
    ["microsoft", "vscode", "378b0d711f6b82ac59b47fb246906043a6fb995a"],
]
iface = gr.Interface(fn=review_commit,
                     description=description,
                     inputs=["text", "text", "text"],
                     outputs="text",
                     examples=examples)
iface.launch()