lhoestq's picture
lhoestq HF staff
missing import
b551628
raw
history blame
5.41 kB
from typing import Any, Iterable, Literal, Optional, TypedDict, TypeVar, Union, overload
from datasets import Features, Value, get_dataset_config_info
from datasets.features.features import FeatureType, _visit
from presidio_analyzer import AnalyzerEngine, BatchAnalyzerEngine, RecognizerResult
Row = dict[str, Any]
T = TypeVar("T")
BATCH_SIZE = 10
batch_analyzer: Optional[BatchAnalyzerEngine] = None
class PresidioEntity(TypedDict):
text: str
type: str
row_idx: int
column_name: str
@overload
def batched(it: Iterable[T], n: int) -> Iterable[list[T]]:
...
@overload
def batched(it: Iterable[T], n: int, with_indices: Literal[False]) -> Iterable[list[T]]:
...
@overload
def batched(it: Iterable[T], n: int, with_indices: Literal[True]) -> Iterable[tuple[list[int], list[T]]]:
...
def batched(
it: Iterable[T], n: int, with_indices: bool = False
) -> Union[Iterable[list[T]], Iterable[tuple[list[int], list[T]]]]:
it, indices = iter(it), count()
while batch := list(islice(it, n)):
yield (list(islice(indices, len(batch))), batch) if with_indices else batch
def mask(text: str) -> str:
return " ".join(
word[: min(2, len(word) - 1)] + re.sub("[A-Za-z0-9]", "*", word[min(2, len(word) - 1) :])
for word in text.split(" ")
)
def get_strings(row_content: Any) -> str:
if isinstance(row_content, str):
return row_content
if isinstance(row_content, dict):
row_content = list(row_content.values())
if isinstance(row_content, list):
str_items = (get_strings(row_content_item) for row_content_item in row_content)
return "\n".join(str_item for str_item in str_items if str_item)
return ""
def _simple_analyze_iterator_cache(
batch_analyzer: BatchAnalyzerEngine,
texts: Iterable[str],
language: str,
score_threshold: float,
cache: dict[str, list[RecognizerResult]],
) -> list[list[RecognizerResult]]:
not_cached_results = iter(
batch_analyzer.analyze_iterator(
(text for text in texts if text not in cache), language=language, score_threshold=score_threshold
)
)
results = [cache[text] if text in cache else next(not_cached_results) for text in texts]
# cache the last results
cache.clear()
cache.update(dict(zip(texts, results)))
return results
def analyze(
batch_analyzer: BatchAnalyzerEngine,
batch: list[dict[str, str]],
indices: Iterable[int],
scanned_columns: list[str],
columns_descriptions: list[str],
cache: Optional[dict[str, list[RecognizerResult]]] = None,
) -> list[PresidioEntity]:
cache = {} if cache is None else cache
texts = [
f"The following is {columns_description} data:\n\n{example[column_name] or ''}"
for example in batch
for column_name, columns_description in zip(scanned_columns, columns_descriptions)
]
return [
PresidioEntity(
text=mask(texts[i][recognizer_result.start : recognizer_result.end]),
type=recognizer_result.entity_type,
row_idx=row_idx,
column_name=column_name,
)
for i, row_idx, recognizer_results in zip(
count(),
indices,
_simple_analyze_iterator_cache(batch_analyzer, texts, language="en", score_threshold=0.8, cache=cache),
)
for column_name, columns_description, recognizer_result in zip(
scanned_columns, columns_descriptions, recognizer_results
)
if recognizer_result.start >= len(f"The following is {columns_description} data:\n\n")
]
def presidio_scan_entities(
rows: Iterable[Row], scanned_columns: list[str], columns_descriptions: list[str]
) -> Iterable[PresidioEntity]:
global batch_analyzer
cache: dict[str, list[RecognizerResult]] = {}
if batch_analyzer is None:
batch_analyser = BatchAnalyzerEngine(AnalyzerEngine())
rows_with_scanned_columns_only = (
{column_name: get_strings(row[column_name]) for column_name in scanned_columns} for row in rows
)
for indices, batch in batched(rows_with_scanned_columns_only, BATCH_SIZE, with_indices=True):
yield from analyze(
batch_analyzer=batch_analyser,
batch=batch,
indices=indices,
scanned_columns=scanned_columns,
columns_descriptions=columns_descriptions,
cache=cache,
)
def get_columns_with_strings(features: Features) -> list[str]:
columns_with_strings: list[str] = []
for column, feature in features.items():
str_column = str(column)
with_string = False
def classify(feature: FeatureType) -> None:
nonlocal with_string
if isinstance(feature, Value) and feature.dtype == "string":
with_string = True
_visit(feature, classify)
if with_string:
columns_with_strings.append(str_column)
return columns_with_strings
def get_column_description(column_name: str, feature: FeatureType) -> str:
nested_fields: list[str] = []
def get_nested_field_names(feature: FeatureType) -> None:
nonlocal nested_fields
if isinstance(feature, dict):
nested_fields += list(feature)
_visit(feature, get_nested_field_names)
return f"{column_name} (with {', '.join(nested_fields)})" if nested_fields else column_name