tdoehmen commited on
Commit
0e01bbd
·
1 Parent(s): 93f5976
Files changed (1) hide show
  1. duckdb-nsql/eval/evaluate.py +6 -7
duckdb-nsql/eval/evaluate.py CHANGED
@@ -114,9 +114,11 @@ def compute_exact_match_metric(
114
  return exact_match
115
 
116
 
117
- def evaluate_with_timeout(evaluator, *args, timeout):
 
118
  with ThreadPoolExecutor(max_workers=1) as executor:
119
- future = executor.submit(evaluator.evaluate_one, *args)
 
120
  try:
121
  result = future.result(timeout=timeout)
122
  except TimeoutError:
@@ -150,15 +152,12 @@ def compute_test_suite_metric(
150
  zip(predictions, references, gold_dbs, setup_sqls, validate_sqls, categories),
151
  total=len(predictions),
152
  ):
153
- turn_idx = 0
154
- # skip final utterance-query pairs
155
- if turn_idx < 0:
156
- continue
157
 
158
  # Use the new function to evaluate with timeout
159
  ex_metrics = evaluate_with_timeout(
160
  evaluator, gold_db, reference, prediction, setup_sql, validate_sql,
161
- turn_scores, timeout=TIMEOUT_SECONDS
162
  )
163
 
164
  if ex_metrics:
 
114
  return exact_match
115
 
116
 
117
+ def evaluate_with_timeout(evaluator, gold_db, reference, prediction,
118
+ setup_sql, validate_sql, turn_scores, idx, category, timeout):
119
  with ThreadPoolExecutor(max_workers=1) as executor:
120
+ future = executor.submit(evaluator.evaluate_one, gold_db, reference, prediction,
121
+ setup_sql, validate_sql, turn_scores, idx=idx, category=category)
122
  try:
123
  result = future.result(timeout=timeout)
124
  except TimeoutError:
 
152
  zip(predictions, references, gold_dbs, setup_sqls, validate_sqls, categories),
153
  total=len(predictions),
154
  ):
155
+ turn_idx = 0 # or any value that represents the current index if this is incorrect
 
 
 
156
 
157
  # Use the new function to evaluate with timeout
158
  ex_metrics = evaluate_with_timeout(
159
  evaluator, gold_db, reference, prediction, setup_sql, validate_sql,
160
+ turn_scores, idx=turn_idx, category=category, timeout=TIMEOUT_SECONDS
161
  )
162
 
163
  if ex_metrics: