import re from itertools import count, islice from typing import Any, Iterable, Literal, Optional, TypedDict, TypeVar, Union, overload from datasets import Features, Value, get_dataset_config_info from datasets.features.features import FeatureType, _visit from presidio_analyzer import AnalyzerEngine, BatchAnalyzerEngine, RecognizerResult Row = dict[str, Any] T = TypeVar("T") BATCH_SIZE = 1 batch_analyzer: Optional[BatchAnalyzerEngine] = None class PresidioEntity(TypedDict): text: str type: str row_idx: int column_name: str @overload def batched(it: Iterable[T], n: int) -> Iterable[list[T]]: ... @overload def batched(it: Iterable[T], n: int, with_indices: Literal[False]) -> Iterable[list[T]]: ... @overload def batched(it: Iterable[T], n: int, with_indices: Literal[True]) -> Iterable[tuple[list[int], list[T]]]: ... def batched( it: Iterable[T], n: int, with_indices: bool = False ) -> Union[Iterable[list[T]], Iterable[tuple[list[int], list[T]]]]: it, indices = iter(it), count() while batch := list(islice(it, n)): yield (list(islice(indices, len(batch))), batch) if with_indices else batch def mask(text: str) -> str: return " ".join( word[: min(2, len(word) - 1)] + re.sub("[A-Za-z0-9]", "*", word[min(2, len(word) - 1) :]) for word in text.split(" ") ) def get_strings(row_content: Any) -> str: if isinstance(row_content, str): return row_content if isinstance(row_content, dict): if "src" in row_content: return "" # could be image or audio row_content = list(row_content.values()) if isinstance(row_content, list): str_items = (get_strings(row_content_item) for row_content_item in row_content) return "\n".join(str_item for str_item in str_items if str_item) return "" def _simple_analyze_iterator_cache( batch_analyzer: BatchAnalyzerEngine, texts: Iterable[str], language: str, score_threshold: float, cache: dict[str, list[RecognizerResult]], ) -> list[list[RecognizerResult]]: not_cached_results = iter( batch_analyzer.analyze_iterator( (text for text in texts if text not in cache), language=language, score_threshold=score_threshold ) ) results = [cache[text] if text in cache else next(not_cached_results) for text in texts] # cache the last results cache.clear() cache.update(dict(zip(texts, results))) return results def analyze( batch_analyzer: BatchAnalyzerEngine, batch: list[dict[str, str]], indices: Iterable[int], scanned_columns: list[str], columns_descriptions: list[str], cache: Optional[dict[str, list[RecognizerResult]]] = None, ) -> list[PresidioEntity]: cache = {} if cache is None else cache texts = [ f"The following is {columns_description} data:\n\n{example[column_name] or ''}" for example in batch for column_name, columns_description in zip(scanned_columns, columns_descriptions) ] return [ PresidioEntity( text=mask(texts[i][recognizer_result.start : recognizer_result.end]), type=recognizer_result.entity_type, row_idx=row_idx, column_name=column_name, ) for i, row_idx, recognizer_results in zip( count(), indices, _simple_analyze_iterator_cache(batch_analyzer, texts, language="en", score_threshold=0.8, cache=cache), ) for column_name, columns_description, recognizer_result in zip( scanned_columns, columns_descriptions, recognizer_results ) if recognizer_result.start >= len(f"The following is {columns_description} data:\n\n") ] def presidio_scan_entities( rows: Iterable[Row], scanned_columns: list[str], columns_descriptions: list[str] ) -> Iterable[PresidioEntity]: global batch_analyzer cache: dict[str, list[RecognizerResult]] = {} if batch_analyzer is None: batch_analyser = BatchAnalyzerEngine(AnalyzerEngine()) rows_with_scanned_columns_only = ( {column_name: get_strings(row[column_name]) for column_name in scanned_columns} for row in rows ) for indices, batch in batched(rows_with_scanned_columns_only, BATCH_SIZE, with_indices=True): yield from analyze( batch_analyzer=batch_analyser, batch=batch, indices=indices, scanned_columns=scanned_columns, columns_descriptions=columns_descriptions, cache=cache, ) def get_columns_with_strings(features: Features) -> list[str]: columns_with_strings: list[str] = [] for column, feature in features.items(): str_column = str(column) with_string = False def classify(feature: FeatureType) -> None: nonlocal with_string if isinstance(feature, Value) and feature.dtype == "string": with_string = True _visit(feature, classify) if with_string: columns_with_strings.append(str_column) return columns_with_strings def get_column_description(column_name: str, feature: FeatureType) -> str: nested_fields: list[str] = [] def get_nested_field_names(feature: FeatureType) -> None: nonlocal nested_fields if isinstance(feature, dict): nested_fields += list(feature) _visit(feature, get_nested_field_names) return f"{column_name} (with {', '.join(nested_fields)})" if nested_fields else column_name