File size: 13,704 Bytes
1ef6bf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
import textattack
import transformers
import pandas as pd
import csv
import string
import pickle
# Construct our four components for `Attack`
from textattack.constraints.pre_transformation import (
    RepeatModification,
    StopwordModification,
)
from textattack.constraints.semantics import WordEmbeddingDistance
from textattack.transformations import WordSwapEmbedding
from textattack.search_methods import GreedyWordSwapWIR

import numpy as np
import json
import random
import re
import textattack.shared.attacked_text as atk
import torch.nn.functional as F
import torch


class InvertedText:

    def __init__(

        self,

        swapped_indexes,

        score,

        attacked_text,

        new_class,

    ):
        self.attacked_text = attacked_text
        self.swapped_indexes = (
            swapped_indexes  # dict of swapped indexes with their synonym
        )
        self.score = score  # value of original class
        self.new_class = new_class  # class after inversion

    def __repr__(self):
        return f"InvertedText:\n attacked_text='{self.attacked_text}', \n swapped_indexes={self.swapped_indexes},\n score={self.score}"


def count_matching_classes(original, corrected, perturbed_texts=None):
    if len(original) != len(corrected):
        raise ValueError("Arrays must have the same length")
    hard_samples = []
    easy_samples = []

    matching_count = 0

    for i in range(len(corrected)):
        if original[i] == corrected[i]:
            matching_count += 1
            easy_samples.append(perturbed_texts[i])
        elif perturbed_texts != None:
            hard_samples.append(perturbed_texts[i])

    return matching_count, hard_samples, easy_samples


class Flow_Corrector:
    def __init__(

        self,

        attack,

        word_rank_file="en_full_ranked.json",

        word_freq_file="en_full_freq.json",

        wir_threshold=0.3,

    ):
        self.attack = attack
        self.attack.cuda_()
        self.wir_threshold = wir_threshold
        with open(word_rank_file, "r") as f:
            self.word_ranked_frequence = json.load(f)
        with open(word_freq_file, "r") as f:
            self.word_frequence = json.load(f)
        self.victim_model = attack.goal_function.model

    def wir_gradient(

        self,

        attack,

        victim_model,

        detected_text,

    ):
        _, indices_to_order = attack.get_indices_to_order(detected_text)

        index_scores = np.zeros(len(indices_to_order))
        grad_output = victim_model.get_grad(detected_text.tokenizer_input)
        gradient = grad_output["gradient"]
        word2token_mapping = detected_text.align_with_model_tokens(victim_model)
        for i, index in enumerate(indices_to_order):
            matched_tokens = word2token_mapping[index]
            if not matched_tokens:
                index_scores[i] = 0.0
            else:
                agg_grad = np.mean(gradient[matched_tokens], axis=0)
                index_scores[i] = np.linalg.norm(agg_grad, ord=1)
        index_order = np.array(indices_to_order)[(-index_scores).argsort()]
        return index_order

    def get_syn_freq_dict(

        self,

        index_order,

        detected_text,

    ):
        most_frequent_syn_dict = {}

        no_syn = []
        freq_thershold = len(self.word_ranked_frequence) / 10

        for idx in index_order:
            # get the synonyms of a specific index

            try:
                synonyms = [
                    attacked_text.words[idx]
                    for attacked_text in self.attack.get_transformations(
                        detected_text, detected_text, indices_to_modify=[idx]
                    )
                ]
                # getting synonyms that exists in dataset with thiere frequency rank
                ranked_synonyms = {
                    syn: self.word_ranked_frequence[syn]
                    for syn in synonyms
                    if syn in self.word_ranked_frequence.keys()
                    and self.word_ranked_frequence[syn] < freq_thershold
                    and self.word_ranked_frequence[detected_text.words[idx]]
                    > self.word_ranked_frequence[syn]
                }
                # selecting the M most frequent synonym

                if list(ranked_synonyms.keys()) != []:
                    most_frequent_syn_dict[idx] = list(ranked_synonyms.keys())
            except:
                # no synonyms avaialble in the dataset
                no_syn.append(idx)

        return most_frequent_syn_dict

    def build_candidates(

        self, detected_text, most_frequent_syn_dict: dict, max_attempt: int

    ):
        candidates = {}
        for _ in range(max_attempt):
            syn_dict = {}
            current_text = detected_text
            for index in most_frequent_syn_dict.keys():
                syn = random.choice(most_frequent_syn_dict[index])
                syn_dict[index] = syn
                current_text = current_text.replace_word_at_index(index, syn)

            candidates[current_text] = syn_dict
        return candidates

    def find_dominant_class(self, inverted_texts):
        class_counts = {}  # Dictionary to store the count of each new class

        for text in inverted_texts:
            new_class = text.new_class
            class_counts[new_class] = class_counts.get(new_class, 0) + 1

        # Find the most dominant class
        most_dominant_class = max(class_counts, key=class_counts.get)

        return most_dominant_class

    def correct(self, detected_texts):
        corrected_classes = []
        for detected_text in detected_texts:

            # convert to Attacked texts
            detected_text = atk.AttackedText(detected_text)

            # getting 30% most important indexes
            index_order = self.wir_gradient(
                self.attack, self.victim_model, detected_text
            )
            index_order = index_order[: int(len(index_order) * self.wir_threshold)]

            # getting synonyms according to frequency conditiontions
            most_frequent_syn_dict = self.get_syn_freq_dict(index_order, detected_text)

            # generate M candidates
            candidates = self.build_candidates(
                detected_text, most_frequent_syn_dict, max_attempt=100
            )

            original_probs = F.softmax(self.victim_model(detected_text.text), dim=1)
            original_class = torch.argmax(original_probs).item()
            original_golden_prob = float(original_probs[0][original_class])

            nbr_inverted = 0
            inverted_texts = []  # a dictionary of inverted texts with
            bad, impr = 0, 0
            dict_deltas = {}

            batch_inputs = [candidate.text for candidate in candidates.keys()]

            batch_outputs = self.victim_model(batch_inputs)

            probabilities = F.softmax(batch_outputs, dim=1)
            for i, (candidate, syn_dict) in enumerate(candidates.items()):

                corrected_class = torch.argmax(probabilities[i]).item()
                new_golden_probability = float(probabilities[i][corrected_class])
                if corrected_class != original_class:
                    nbr_inverted += 1
                    inverted_texts.append(
                        InvertedText(
                            syn_dict, new_golden_probability, candidate, corrected_class
                        )
                    )
                else:
                    delta = new_golden_probability - original_golden_prob
                    if delta <= 0:
                        bad += 1
                    else:
                        impr += 1
                        dict_deltas[candidate] = delta

            if len(original_probs[0]) > 2 and len(inverted_texts) >= len(candidates) / (
                len(original_probs[0])
            ):
                # selecting the most dominant class
                dominant_class = self.find_dominant_class(inverted_texts)
            elif len(inverted_texts) >= len(candidates) / 2:
                dominant_class = corrected_class
            else:
                dominant_class = original_class

            corrected_classes.append(dominant_class)

        return corrected_classes


def remove_brackets(text):
    text = text.replace("[[", "")
    text = text.replace("]]", "")
    return text


def clean_text(text):
    pattern = "[" + re.escape(string.punctuation) + "]"
    cleaned_text = re.sub(pattern, " ", text)

    return cleaned_text


# Load model, tokenizer, and model_wrapper
model = transformers.AutoModelForSequenceClassification.from_pretrained(
    "textattack/bert-base-uncased-ag-news"
)
tokenizer = transformers.AutoTokenizer.from_pretrained(
    "textattack/bert-base-uncased-ag-news"
)
model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer)


goal_function = textattack.goal_functions.UntargetedClassification(model_wrapper)
constraints = [
    RepeatModification(),
    StopwordModification(),
    WordEmbeddingDistance(min_cos_sim=0.9),
]
transformation = WordSwapEmbedding(max_candidates=50)
search_method = GreedyWordSwapWIR(wir_method="gradient")

# Construct the actual attack
attack = textattack.Attack(goal_function, constraints, transformation, search_method)
attack.cuda_()


results = pd.read_csv("ag_news_results.csv")
perturbed_texts = [
    results["perturbed_text"][i]
    for i in range(len(results))
    if results["result_type"][i] == "Successful"
]
original_texts = [
    results["original_text"][i]
    for i in range(len(results))
    if results["result_type"][i] == "Successful"
]

perturbed_texts = [remove_brackets(text) for text in perturbed_texts]
original_texts = [remove_brackets(text) for text in original_texts]

perturbed_texts = [clean_text(text) for text in perturbed_texts]
original_texts = [clean_text(text) for text in original_texts]


victim_model = attack.goal_function.model

print("Getting corrected classes")
print("This may take a while ...")
# we can use directly resultds in csv file
original_classes = [
    torch.argmax(F.softmax(victim_model(original_text), dim=1)).item()
    for original_text in original_texts
]

batch_size = 1000
num_batches = (len(perturbed_texts) + batch_size - 1) // batch_size
batched_perturbed_texts = []
batched_original_texts = []
batched_original_classes = []

for i in range(num_batches):
    start = i * batch_size
    end = min(start + batch_size, len(perturbed_texts))
    batched_perturbed_texts.append(perturbed_texts[start:end])
    batched_original_texts.append(original_texts[start:end])
    batched_original_classes.append(original_classes[start:end])
print(batched_original_classes)
hard_samples_list = []
easy_samples_list = []


# Open a CSV file for writing
csv_filename = "flow_correction_results_ag_news.csv"
with open(csv_filename, "w", newline="") as csvfile:
    fieldnames = ["freq_threshold", "batch_num", "match_perturbed", "match_original"]
    writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

    # Write the header row
    writer.writeheader()

    # Iterate over batched lists
    batch_num = 0
    for perturbed, original, classes in zip(
        batched_perturbed_texts, batched_original_texts, batched_original_classes
    ):
        batch_num += 1
        print(f"Processing batch number: {batch_num}")

        for i in range(2):
            wir_threshold = 0.1 * (i + 1)
            print(f"Setting Word threshold to: {wir_threshold}")

            corrector = Flow_Corrector(
                attack,
                word_rank_file="en_full_ranked.json",
                word_freq_file="en_full_freq.json",
                wir_threshold=wir_threshold,
            )

            # Correct perturbed texts
            print("Correcting perturbed texts...")
            corrected_perturbed_classes = corrector.correct(perturbed)

            match_perturbed, hard_samples, easy_samples = count_matching_classes(
                classes, corrected_perturbed_classes, perturbed
            )
            hard_samples_list.extend(hard_samples)
            easy_samples_list.extend(easy_samples)


            print(f"Number of matching classes (perturbed): {match_perturbed}")

            # Correct original texts
            print("Correcting original texts...")
            corrected_original_classes = corrector.correct(original)
            match_original, hard_samples, easy_samples = count_matching_classes(
                classes, corrected_original_classes, perturbed
            )
            print(f"Number of matching classes (original): {match_original}")

            # Write results to CSV file
            print("Writing results to CSV file...")
            writer.writerow(
                {
                    "freq_threshold": wir_threshold,
                    "batch_num": batch_num,
                    "match_perturbed": match_perturbed/len(perturbed),
                    "match_original": match_original/len(perturbed),
                }
            )
            print("-" * 20)

print("savig samples for more statistics studies")

# Save hard_samples_list and easy_samples_list to files
with open('hard_samples.pkl', 'wb') as f:
    pickle.dump(hard_samples_list, f)

with open('easy_samples.pkl', 'wb') as f:
    pickle.dump(easy_samples_list, f)