|
"""Utility metrics.""" |
|
import sqlglot |
|
from rich.console import Console |
|
from sqlglot import parse_one |
|
|
|
console = Console(soft_wrap=True) |
|
|
|
|
|
def correct_casing(sql: str) -> str: |
|
"""Correct casing of SQL.""" |
|
parse: sqlglot.expressions.Expression = parse_one(sql, read="sqlite") |
|
return parse.sql() |
|
|
|
|
|
def prec_recall_f1(gold: set, pred: set) -> dict[str, float]: |
|
"""Compute precision, recall and F1 score.""" |
|
prec = len(gold.intersection(pred)) / len(pred) if pred else 0.0 |
|
recall = len(gold.intersection(pred)) / len(gold) if gold else 0.0 |
|
f1 = 2 * prec * recall / (prec + recall) if prec + recall else 0.0 |
|
return {"prec": prec, "recall": recall, "f1": f1} |
|
|
|
|
|
def edit_distance(s1: str, s2: str) -> int: |
|
"""Compute edit distance between two strings.""" |
|
|
|
if len(s1) > len(s2): |
|
s1, s2 = s2, s1 |
|
|
|
distances: list[int] = list(range(len(s1) + 1)) |
|
for i2, c2 in enumerate(s2): |
|
distances_ = [i2 + 1] |
|
for i1, c1 in enumerate(s1): |
|
if c1 == c2: |
|
distances_.append(distances[i1]) |
|
else: |
|
distances_.append( |
|
1 + min((distances[i1], distances[i1 + 1], distances_[-1])) |
|
) |
|
distances = distances_ |
|
return distances[-1] |
|
|