|
import re |
|
from itertools import count, islice |
|
from typing import Any, Iterable, Literal, Optional, TypedDict, TypeVar, Union, overload |
|
|
|
from datasets import Features, Value, get_dataset_config_info |
|
from datasets.features.features import FeatureType, _visit |
|
from presidio_analyzer import AnalyzerEngine, BatchAnalyzerEngine, RecognizerResult |
|
|
|
|
|
Row = dict[str, Any] |
|
T = TypeVar("T") |
|
BATCH_SIZE = 1 |
|
batch_analyzer: Optional[BatchAnalyzerEngine] = None |
|
|
|
|
|
class PresidioEntity(TypedDict): |
|
text: str |
|
type: str |
|
row_idx: int |
|
column_name: str |
|
|
|
|
|
@overload |
|
def batched(it: Iterable[T], n: int) -> Iterable[list[T]]: |
|
... |
|
|
|
|
|
@overload |
|
def batched(it: Iterable[T], n: int, with_indices: Literal[False]) -> Iterable[list[T]]: |
|
... |
|
|
|
|
|
@overload |
|
def batched(it: Iterable[T], n: int, with_indices: Literal[True]) -> Iterable[tuple[list[int], list[T]]]: |
|
... |
|
|
|
|
|
def batched( |
|
it: Iterable[T], n: int, with_indices: bool = False |
|
) -> Union[Iterable[list[T]], Iterable[tuple[list[int], list[T]]]]: |
|
it, indices = iter(it), count() |
|
while batch := list(islice(it, n)): |
|
yield (list(islice(indices, len(batch))), batch) if with_indices else batch |
|
|
|
|
|
def mask(text: str) -> str: |
|
return text |
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_strings(row_content: Any) -> str: |
|
if isinstance(row_content, str): |
|
return row_content |
|
if isinstance(row_content, dict): |
|
if "src" in row_content: |
|
return "" |
|
row_content = list(row_content.values()) |
|
if isinstance(row_content, list): |
|
str_items = (get_strings(row_content_item) for row_content_item in row_content) |
|
return "\n".join(str_item for str_item in str_items if str_item) |
|
return "" |
|
|
|
|
|
def _simple_analyze_iterator_cache( |
|
batch_analyzer: BatchAnalyzerEngine, |
|
texts: Iterable[str], |
|
language: str, |
|
score_threshold: float, |
|
cache: dict[str, list[RecognizerResult]], |
|
) -> list[list[RecognizerResult]]: |
|
not_cached_results = iter( |
|
batch_analyzer.analyze_iterator( |
|
(text for text in texts if text not in cache), language=language, score_threshold=score_threshold |
|
) |
|
) |
|
results = [cache[text] if text in cache else next(not_cached_results) for text in texts] |
|
|
|
cache.clear() |
|
cache.update(dict(zip(texts, results))) |
|
return results |
|
|
|
|
|
def analyze( |
|
batch_analyzer: BatchAnalyzerEngine, |
|
batch: list[dict[str, str]], |
|
indices: Iterable[int], |
|
scanned_columns: list[str], |
|
columns_descriptions: list[str], |
|
cache: Optional[dict[str, list[RecognizerResult]]] = None, |
|
) -> list[PresidioEntity]: |
|
cache = {} if cache is None else cache |
|
texts = [ |
|
f"The following is {columns_description} data:\n\n{example[column_name] or ''}" |
|
for example in batch |
|
for column_name, columns_description in zip(scanned_columns, columns_descriptions) |
|
] |
|
return [ |
|
PresidioEntity( |
|
text=mask(texts[i][recognizer_result.start : recognizer_result.end]), |
|
type=recognizer_result.entity_type, |
|
row_idx=row_idx, |
|
column_name=column_name, |
|
) |
|
for i, row_idx, recognizer_results in zip( |
|
count(), |
|
indices, |
|
_simple_analyze_iterator_cache(batch_analyzer, texts, language="en", score_threshold=0.8, cache=cache), |
|
) |
|
for column_name, columns_description, recognizer_result in zip( |
|
scanned_columns, columns_descriptions, recognizer_results |
|
) |
|
if recognizer_result.start >= len(f"The following is {columns_description} data:\n\n") |
|
] |
|
|
|
|
|
def presidio_scan_entities( |
|
rows: Iterable[Row], scanned_columns: list[str], columns_descriptions: list[str] |
|
) -> Iterable[PresidioEntity]: |
|
global batch_analyzer |
|
cache: dict[str, list[RecognizerResult]] = {} |
|
if batch_analyzer is None: |
|
batch_analyser = BatchAnalyzerEngine(AnalyzerEngine()) |
|
rows_with_scanned_columns_only = ( |
|
{column_name: get_strings(row[column_name]) for column_name in scanned_columns} for row in rows |
|
) |
|
for indices, batch in batched(rows_with_scanned_columns_only, BATCH_SIZE, with_indices=True): |
|
yield from analyze( |
|
batch_analyzer=batch_analyser, |
|
batch=batch, |
|
indices=indices, |
|
scanned_columns=scanned_columns, |
|
columns_descriptions=columns_descriptions, |
|
cache=cache, |
|
) |
|
|
|
|
|
def get_columns_with_strings(features: Features) -> list[str]: |
|
columns_with_strings: list[str] = [] |
|
|
|
for column, feature in features.items(): |
|
str_column = str(column) |
|
with_string = False |
|
|
|
def classify(feature: FeatureType) -> None: |
|
nonlocal with_string |
|
if isinstance(feature, Value) and feature.dtype == "string": |
|
with_string = True |
|
|
|
_visit(feature, classify) |
|
if with_string: |
|
columns_with_strings.append(str_column) |
|
return columns_with_strings |
|
|
|
|
|
def get_column_description(column_name: str, feature: FeatureType) -> str: |
|
nested_fields: list[str] = [] |
|
|
|
def get_nested_field_names(feature: FeatureType) -> None: |
|
nonlocal nested_fields |
|
if isinstance(feature, dict): |
|
nested_fields += list(feature) |
|
|
|
_visit(feature, get_nested_field_names) |
|
return f"{column_name} (with {', '.join(nested_fields)})" if nested_fields else column_name |
|
|