File size: 5,413 Bytes
b551628
e6ef189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
from typing import Any, Iterable, Literal, Optional, TypedDict, TypeVar, Union, overload

from datasets import Features, Value, get_dataset_config_info
from datasets.features.features import FeatureType, _visit
from presidio_analyzer import AnalyzerEngine, BatchAnalyzerEngine, RecognizerResult


Row = dict[str, Any]
T = TypeVar("T")
BATCH_SIZE = 10
batch_analyzer: Optional[BatchAnalyzerEngine] = None


class PresidioEntity(TypedDict):
    text: str
    type: str
    row_idx: int
    column_name: str


@overload
def batched(it: Iterable[T], n: int) -> Iterable[list[T]]:
    ...


@overload
def batched(it: Iterable[T], n: int, with_indices: Literal[False]) -> Iterable[list[T]]:
    ...


@overload
def batched(it: Iterable[T], n: int, with_indices: Literal[True]) -> Iterable[tuple[list[int], list[T]]]:
    ...


def batched(
    it: Iterable[T], n: int, with_indices: bool = False
) -> Union[Iterable[list[T]], Iterable[tuple[list[int], list[T]]]]:
    it, indices = iter(it), count()
    while batch := list(islice(it, n)):
        yield (list(islice(indices, len(batch))), batch) if with_indices else batch


def mask(text: str) -> str:
    return " ".join(
        word[: min(2, len(word) - 1)] + re.sub("[A-Za-z0-9]", "*", word[min(2, len(word) - 1) :])
        for word in text.split(" ")
    )


def get_strings(row_content: Any) -> str:
    if isinstance(row_content, str):
        return row_content
    if isinstance(row_content, dict):
        row_content = list(row_content.values())
    if isinstance(row_content, list):
        str_items = (get_strings(row_content_item) for row_content_item in row_content)
        return "\n".join(str_item for str_item in str_items if str_item)
    return ""


def _simple_analyze_iterator_cache(
    batch_analyzer: BatchAnalyzerEngine,
    texts: Iterable[str],
    language: str,
    score_threshold: float,
    cache: dict[str, list[RecognizerResult]],
) -> list[list[RecognizerResult]]:
    not_cached_results = iter(
        batch_analyzer.analyze_iterator(
            (text for text in texts if text not in cache), language=language, score_threshold=score_threshold
        )
    )
    results = [cache[text] if text in cache else next(not_cached_results) for text in texts]
    # cache the last results
    cache.clear()
    cache.update(dict(zip(texts, results)))
    return results


def analyze(
    batch_analyzer: BatchAnalyzerEngine,
    batch: list[dict[str, str]],
    indices: Iterable[int],
    scanned_columns: list[str],
    columns_descriptions: list[str],
    cache: Optional[dict[str, list[RecognizerResult]]] = None,
) -> list[PresidioEntity]:
    cache = {} if cache is None else cache
    texts = [
        f"The following is {columns_description} data:\n\n{example[column_name] or ''}"
        for example in batch
        for column_name, columns_description in zip(scanned_columns, columns_descriptions)
    ]
    return [
        PresidioEntity(
            text=mask(texts[i][recognizer_result.start : recognizer_result.end]),
            type=recognizer_result.entity_type,
            row_idx=row_idx,
            column_name=column_name,
        )
        for i, row_idx, recognizer_results in zip(
            count(),
            indices,
            _simple_analyze_iterator_cache(batch_analyzer, texts, language="en", score_threshold=0.8, cache=cache),
        )
        for column_name, columns_description, recognizer_result in zip(
            scanned_columns, columns_descriptions, recognizer_results
        )
        if recognizer_result.start >= len(f"The following is {columns_description} data:\n\n")
    ]


def presidio_scan_entities(
    rows: Iterable[Row], scanned_columns: list[str], columns_descriptions: list[str]
) -> Iterable[PresidioEntity]:
    global batch_analyzer
    cache: dict[str, list[RecognizerResult]] = {}
    if batch_analyzer is None:
        batch_analyser = BatchAnalyzerEngine(AnalyzerEngine())
    rows_with_scanned_columns_only = (
        {column_name: get_strings(row[column_name]) for column_name in scanned_columns} for row in rows
    )
    for indices, batch in batched(rows_with_scanned_columns_only, BATCH_SIZE, with_indices=True):
        yield from analyze(
            batch_analyzer=batch_analyser,
            batch=batch,
            indices=indices,
            scanned_columns=scanned_columns,
            columns_descriptions=columns_descriptions,
            cache=cache,
        )


def get_columns_with_strings(features: Features) -> list[str]:
    columns_with_strings: list[str] = []

    for column, feature in features.items():
        str_column = str(column)
        with_string = False

        def classify(feature: FeatureType) -> None:
            nonlocal with_string
            if isinstance(feature, Value) and feature.dtype == "string":
                with_string = True

        _visit(feature, classify)
        if with_string:
            columns_with_strings.append(str_column)
    return columns_with_strings


def get_column_description(column_name: str, feature: FeatureType) -> str:
    nested_fields: list[str] = []

    def get_nested_field_names(feature: FeatureType) -> None:
        nonlocal nested_fields
        if isinstance(feature, dict):
            nested_fields += list(feature)

    _visit(feature, get_nested_field_names)
    return f"{column_name} (with {', '.join(nested_fields)})" if nested_fields else column_name