lhoestq's picture
lhoestq HF staff
Update analyze.py
fe1b9ba verified
raw history blame
No virus
5.67 kB
import re
from itertools import count, islice
from typing import Any, Iterable, Literal, Optional, TypedDict, TypeVar, Union, overload
from datasets import Features, Value, get_dataset_config_info
from datasets.features.features import FeatureType, _visit
from presidio_analyzer import AnalyzerEngine, BatchAnalyzerEngine, RecognizerResult
Row = dict[str, Any]
T = TypeVar("T")
BATCH_SIZE = 1
MAX_TEXT_LENGTH = 500
analyzer = AnalyzerEngine()
batch_analyzer = BatchAnalyzerEngine(analyzer)
class PresidioEntity(TypedDict):
text: str
type: str
row_idx: int
column_name: str
@overload
def batched(it: Iterable[T], n: int) -> Iterable[list[T]]:
...
@overload
def batched(it: Iterable[T], n: int, with_indices: Literal[False]) -> Iterable[list[T]]:
...
@overload
def batched(it: Iterable[T], n: int, with_indices: Literal[True]) -> Iterable[tuple[list[int], list[T]]]:
...
def batched(
it: Iterable[T], n: int, with_indices: bool = False
) -> Union[Iterable[list[T]], Iterable[tuple[list[int], list[T]]]]:
it, indices = iter(it), count()
while batch := list(islice(it, n)):
yield (list(islice(indices, len(batch))), batch) if with_indices else batch
def mask(text: str) -> str:
return text # don't apply mask for demo
# return " ".join(
# word[: min(2, len(word) - 1)] + re.sub("[A-Za-z0-9]", "*", word[min(2, len(word) - 1) :])
# for word in text.split(" ")
# )
def get_strings(row_content: Any) -> str:
if isinstance(row_content, str):
return row_content
if isinstance(row_content, dict):
if "src" in row_content:
return "" # could be image or audio
row_content = list(row_content.values())
if isinstance(row_content, list):
str_items = (get_strings(row_content_item) for row_content_item in row_content)
return "\n".join(str_item for str_item in str_items if str_item)
return ""
def _simple_analyze_iterator_cache(
batch_analyzer: BatchAnalyzerEngine,
texts: Iterable[str],
language: str,
score_threshold: float,
cache: dict[str, list[RecognizerResult]],
) -> list[list[RecognizerResult]]:
not_cached_results = iter(
batch_analyzer.analyze_iterator(
(text for text in texts if text not in cache), language=language, score_threshold=score_threshold
)
)
results = [cache[text] if text in cache else next(not_cached_results) for text in texts]
# cache the last results
cache.clear()
cache.update(dict(zip(texts, results)))
return results
def analyze(
batch_analyzer: BatchAnalyzerEngine,
batch: list[dict[str, str]],
indices: Iterable[int],
scanned_columns: list[str],
columns_descriptions: list[str],
cache: Optional[dict[str, list[RecognizerResult]]] = None,
) -> list[PresidioEntity]:
cache = {} if cache is None else cache
texts = [
f"The following is {columns_description} data:\n\n{example[column_name] or ''}"
for example in batch
for column_name, columns_description in zip(scanned_columns, columns_descriptions)
]
return [
PresidioEntity(
text=mask(texts[i * len(scanned_columns) + j][recognizer_result.start : recognizer_result.end]),
type=recognizer_result.entity_type,
row_idx=row_idx,
column_name=column_name,
)
for i, row_idx, recognizer_row_results in zip(
count(),
indices,
batched(_simple_analyze_iterator_cache(batch_analyzer, texts, language="en", score_threshold=0.8, cache=cache), len(scanned_columns)),
)
for j, column_name, columns_description, recognizer_results in zip(
count(), scanned_columns, columns_descriptions, recognizer_row_results
)
for recognizer_result in recognizer_results
if recognizer_result.start >= len(f"The following is {columns_description} data:\n\n")
]
def presidio_scan_entities(
rows: Iterable[Row], scanned_columns: list[str], columns_descriptions: list[str]
) -> Iterable[PresidioEntity]:
cache: dict[str, list[RecognizerResult]] = {}
rows_with_scanned_columns_only = (
{column_name: get_strings(row[column_name])[:MAX_TEXT_LENGTH] for column_name in scanned_columns} for row in rows
)
for indices, batch in batched(rows_with_scanned_columns_only, BATCH_SIZE, with_indices=True):
yield from analyze(
batch_analyzer=batch_analyzer,
batch=batch,
indices=indices,
scanned_columns=scanned_columns,
columns_descriptions=columns_descriptions,
cache=cache,
)
def get_columns_with_strings(features: Features) -> list[str]:
columns_with_strings: list[str] = []
for column, feature in features.items():
str_column = str(column)
with_string = False
def classify(feature: FeatureType) -> None:
nonlocal with_string
if isinstance(feature, Value) and feature.dtype == "string":
with_string = True
_visit(feature, classify)
if with_string:
columns_with_strings.append(str_column)
return columns_with_strings
def get_column_description(column_name: str, feature: FeatureType) -> str:
nested_fields: list[str] = []
def get_nested_field_names(feature: FeatureType) -> None:
nonlocal nested_fields
if isinstance(feature, dict):
nested_fields += list(feature)
_visit(feature, get_nested_field_names)
return f"{column_name} (with {', '.join(nested_fields)})" if nested_fields else column_name