lhoestq's picture
lhoestq HF staff
add max text length
f12f776
raw
history blame
5.63 kB
import re
from itertools import count, islice
from typing import Any, Iterable, Literal, Optional, TypedDict, TypeVar, Union, overload
from datasets import Features, Value, get_dataset_config_info
from datasets.features.features import FeatureType, _visit
from presidio_analyzer import AnalyzerEngine, BatchAnalyzerEngine, RecognizerResult
Row = dict[str, Any]
T = TypeVar("T")
BATCH_SIZE = 1
MAX_TEXT_LENGTH = 3000
batch_analyzer: Optional[BatchAnalyzerEngine] = None
class PresidioEntity(TypedDict):
text: str
type: str
row_idx: int
column_name: str
@overload
def batched(it: Iterable[T], n: int) -> Iterable[list[T]]:
...
@overload
def batched(it: Iterable[T], n: int, with_indices: Literal[False]) -> Iterable[list[T]]:
...
@overload
def batched(it: Iterable[T], n: int, with_indices: Literal[True]) -> Iterable[tuple[list[int], list[T]]]:
...
def batched(
it: Iterable[T], n: int, with_indices: bool = False
) -> Union[Iterable[list[T]], Iterable[tuple[list[int], list[T]]]]:
it, indices = iter(it), count()
while batch := list(islice(it, n)):
yield (list(islice(indices, len(batch))), batch) if with_indices else batch
def mask(text: str) -> str:
return text # don't apply mask for demo
# return " ".join(
# word[: min(2, len(word) - 1)] + re.sub("[A-Za-z0-9]", "*", word[min(2, len(word) - 1) :])
# for word in text.split(" ")
# )
def get_strings(row_content: Any) -> str:
if isinstance(row_content, str):
return row_content
if isinstance(row_content, dict):
if "src" in row_content:
return "" # could be image or audio
row_content = list(row_content.values())
if isinstance(row_content, list):
str_items = (get_strings(row_content_item) for row_content_item in row_content)
return "\n".join(str_item for str_item in str_items if str_item)
return ""
def _simple_analyze_iterator_cache(
batch_analyzer: BatchAnalyzerEngine,
texts: Iterable[str],
language: str,
score_threshold: float,
cache: dict[str, list[RecognizerResult]],
) -> list[list[RecognizerResult]]:
not_cached_results = iter(
batch_analyzer.analyze_iterator(
(text for text in texts if text not in cache), language=language, score_threshold=score_threshold
)
)
results = [cache[text] if text in cache else next(not_cached_results) for text in texts]
# cache the last results
cache.clear()
cache.update(dict(zip(texts, results)))
return results
def analyze(
batch_analyzer: BatchAnalyzerEngine,
batch: list[dict[str, str]],
indices: Iterable[int],
scanned_columns: list[str],
columns_descriptions: list[str],
cache: Optional[dict[str, list[RecognizerResult]]] = None,
) -> list[PresidioEntity]:
cache = {} if cache is None else cache
texts = [
f"The following is {columns_description} data:\n\n{example[column_name] or ''}"
for example in batch
for column_name, columns_description in zip(scanned_columns, columns_descriptions)
]
return [
PresidioEntity(
text=mask(texts[i][recognizer_result.start : recognizer_result.end]),
type=recognizer_result.entity_type,
row_idx=row_idx,
column_name=column_name,
)
for i, row_idx, recognizer_results in zip(
count(),
indices,
_simple_analyze_iterator_cache(batch_analyzer, texts, language="en", score_threshold=0.8, cache=cache),
)
for column_name, columns_description, recognizer_result in zip(
scanned_columns, columns_descriptions, recognizer_results
)
if recognizer_result.start >= len(f"The following is {columns_description} data:\n\n")
]
def presidio_scan_entities(
rows: Iterable[Row], scanned_columns: list[str], columns_descriptions: list[str]
) -> Iterable[PresidioEntity]:
global batch_analyzer
cache: dict[str, list[RecognizerResult]] = {}
if batch_analyzer is None:
batch_analyser = BatchAnalyzerEngine(AnalyzerEngine())
rows_with_scanned_columns_only = (
{column_name: get_strings(row[column_name])[:MAX_TEXT_LENGTH] for column_name in scanned_columns} for row in rows
)
for indices, batch in batched(rows_with_scanned_columns_only, BATCH_SIZE, with_indices=True):
yield from analyze(
batch_analyzer=batch_analyser,
batch=batch,
indices=indices,
scanned_columns=scanned_columns,
columns_descriptions=columns_descriptions,
cache=cache,
)
def get_columns_with_strings(features: Features) -> list[str]:
columns_with_strings: list[str] = []
for column, feature in features.items():
str_column = str(column)
with_string = False
def classify(feature: FeatureType) -> None:
nonlocal with_string
if isinstance(feature, Value) and feature.dtype == "string":
with_string = True
_visit(feature, classify)
if with_string:
columns_with_strings.append(str_column)
return columns_with_strings
def get_column_description(column_name: str, feature: FeatureType) -> str:
nested_fields: list[str] = []
def get_nested_field_names(feature: FeatureType) -> None:
nonlocal nested_fields
if isinstance(feature, dict):
nested_fields += list(feature)
_visit(feature, get_nested_field_names)
return f"{column_name} (with {', '.join(nested_fields)})" if nested_fields else column_name