lhoestq HF staff commited on
Commit
0d3784b
·
1 Parent(s): 09a5205

fix analyze bug

Browse files
Files changed (1) hide show
  1. analyze.py +6 -8
analyze.py CHANGED
@@ -72,15 +72,12 @@ def _simple_analyze_iterator_cache(
72
  score_threshold: float,
73
  cache: dict[str, list[RecognizerResult]],
74
  ) -> list[list[RecognizerResult]]:
75
- print(cache)
76
- print(texts)
77
  not_cached_results = iter(
78
  batch_analyzer.analyze_iterator(
79
  (text for text in texts if text not in cache), language=language, score_threshold=score_threshold
80
  )
81
  )
82
  results = [cache[text] if text in cache else next(not_cached_results) for text in texts]
83
- print(results)
84
  # cache the last results
85
  cache.clear()
86
  cache.update(dict(zip(texts, results)))
@@ -103,19 +100,20 @@ def analyze(
103
  ]
104
  return [
105
  PresidioEntity(
106
- text=mask(texts[i][recognizer_result.start : recognizer_result.end]),
107
  type=recognizer_result.entity_type,
108
  row_idx=row_idx,
109
  column_name=column_name,
110
  )
111
- for i, row_idx, recognizer_results in zip(
112
  count(),
113
  indices,
114
- _simple_analyze_iterator_cache(batch_analyzer, texts, language="en", score_threshold=0.8, cache=cache),
115
  )
116
- for column_name, columns_description, recognizer_result in zip(
117
- scanned_columns, columns_descriptions, recognizer_results
118
  )
 
119
  if recognizer_result.start >= len(f"The following is {columns_description} data:\n\n")
120
  ]
121
 
 
72
  score_threshold: float,
73
  cache: dict[str, list[RecognizerResult]],
74
  ) -> list[list[RecognizerResult]]:
 
 
75
  not_cached_results = iter(
76
  batch_analyzer.analyze_iterator(
77
  (text for text in texts if text not in cache), language=language, score_threshold=score_threshold
78
  )
79
  )
80
  results = [cache[text] if text in cache else next(not_cached_results) for text in texts]
 
81
  # cache the last results
82
  cache.clear()
83
  cache.update(dict(zip(texts, results)))
 
100
  ]
101
  return [
102
  PresidioEntity(
103
+ text=mask(texts[i * len(scanned_columns) + j][recognizer_result.start : recognizer_result.end]),
104
  type=recognizer_result.entity_type,
105
  row_idx=row_idx,
106
  column_name=column_name,
107
  )
108
+ for i, row_idx, recognizer_row_results in zip(
109
  count(),
110
  indices,
111
+ batched(_simple_analyze_iterator_cache(batch_analyzer, texts, language="en", score_threshold=0.8, cache=cache), len(scanned_columns)),
112
  )
113
+ for j, column_name, columns_description, recognizer_results in zip(
114
+ count(), scanned_columns, columns_descriptions, recognizer_row_results
115
  )
116
+ for recognizer_result in recognizer_results
117
  if recognizer_result.start >= len(f"The following is {columns_description} data:\n\n")
118
  ]
119