File size: 5,870 Bytes
63775f2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
"""
Managing Attack Logs.
========================
"""

from typing import Dict, Optional

from textattack.metrics.attack_metrics import (
    AttackQueries,
    AttackSuccessRate,
    WordsPerturbed,
)
from textattack.metrics.quality_metrics import Perplexity, USEMetric

from . import (
    CSVLogger,
    FileLogger,
    JsonSummaryLogger,
    VisdomLogger,
    WeightsAndBiasesLogger,
)


class AttackLogManager:
    """Logs the results of an attack to all attached loggers."""

    # metrics maps strings (metric names) to textattack.metric.Metric objects
    metrics: Dict

    def __init__(self, metrics: Optional[Dict]):
        self.loggers = []
        self.results = []
        self.enable_advance_metrics = False
        if metrics is None:
            self.metrics = {}
        else:
            self.metrics = metrics

    def enable_stdout(self):
        self.loggers.append(FileLogger(stdout=True))

    def enable_visdom(self):
        self.loggers.append(VisdomLogger())

    def enable_wandb(self, **kwargs):
        self.loggers.append(WeightsAndBiasesLogger(**kwargs))

    def disable_color(self):
        self.loggers.append(FileLogger(stdout=True, color_method="file"))

    def add_output_file(self, filename, color_method):
        self.loggers.append(FileLogger(filename=filename, color_method=color_method))

    def add_output_csv(self, filename, color_method):
        self.loggers.append(CSVLogger(filename=filename, color_method=color_method))

    def add_output_summary_json(self, filename):
        self.loggers.append(JsonSummaryLogger(filename=filename))

    def log_result(self, result):
        """Logs an ``AttackResult`` on each of `self.loggers`."""
        self.results.append(result)
        for logger in self.loggers:
            logger.log_attack_result(result)

    def log_results(self, results):
        """Logs an iterable of ``AttackResult`` objects on each of
        `self.loggers`."""
        for result in results:
            self.log_result(result)
        self.log_summary()

    def log_summary_rows(self, rows, title, window_id):
        for logger in self.loggers:
            logger.log_summary_rows(rows, title, window_id)

    def log_sep(self):
        for logger in self.loggers:
            logger.log_sep()

    def flush(self):
        for logger in self.loggers:
            logger.flush()

    def log_attack_details(self, attack_name, model_name):
        # @TODO log a more complete set of attack details
        attack_detail_rows = [
            ["Attack algorithm:", attack_name],
            ["Model:", model_name],
        ]
        self.log_summary_rows(attack_detail_rows, "Attack Details", "attack_details")

    def log_summary(self):
        total_attacks = len(self.results)
        if total_attacks == 0:
            return

        # Default metrics - calculated on every attack
        attack_success_stats = AttackSuccessRate().calculate(self.results)
        words_perturbed_stats = WordsPerturbed().calculate(self.results)
        attack_query_stats = AttackQueries().calculate(self.results)

        # @TODO generate this table based on user input - each column in specific class
        # Example to demonstrate:
        # summary_table_rows = attack_success_stats.display_row() + words_perturbed_stats.display_row() + ...
        summary_table_rows = [
            [
                "Number of successful attacks:",
                attack_success_stats["successful_attacks"],
            ],
            ["Number of failed attacks:", attack_success_stats["failed_attacks"]],
            ["Number of skipped attacks:", attack_success_stats["skipped_attacks"]],
            [
                "Original accuracy:",
                str(attack_success_stats["original_accuracy"]) + "%",
            ],
            [
                "Accuracy under attack:",
                str(attack_success_stats["attack_accuracy_perc"]) + "%",
            ],
            [
                "Attack success rate:",
                str(attack_success_stats["attack_success_rate"]) + "%",
            ],
            [
                "Average perturbed word %:",
                str(words_perturbed_stats["avg_word_perturbed_perc"]) + "%",
            ],
            [
                "Average num. words per input:",
                words_perturbed_stats["avg_word_perturbed"],
            ],
        ]

        summary_table_rows.append(
            ["Avg num queries:", attack_query_stats["avg_num_queries"]]
        )

        for metric_name, metric in self.metrics.items():
            summary_table_rows.append([metric_name, metric.calculate(self.results)])

        if self.enable_advance_metrics:
            perplexity_stats = Perplexity().calculate(self.results)
            use_stats = USEMetric().calculate(self.results)

            summary_table_rows.append(
                [
                    "Average Original Perplexity:",
                    perplexity_stats["avg_original_perplexity"],
                ]
            )

            summary_table_rows.append(
                [
                    "Average Attack Perplexity:",
                    perplexity_stats["avg_attack_perplexity"],
                ]
            )
            summary_table_rows.append(
                ["Average Attack USE Score:", use_stats["avg_attack_use_score"]]
            )

        self.log_summary_rows(
            summary_table_rows, "Attack Results", "attack_results_summary"
        )
        # Show histogram of words changed.
        numbins = max(words_perturbed_stats["max_words_changed"], 10)
        for logger in self.loggers:
            logger.log_hist(
                words_perturbed_stats["num_words_changed_until_success"][:numbins],
                numbins=numbins,
                title="Num Words Perturbed",
                window_id="num_words_perturbed",
            )