|
"""Training data prep utils.""" |
|
import json |
|
import re |
|
from collections import defaultdict |
|
from schema import ForeignKey, Table, TableColumn |
|
|
|
|
|
def read_tables_json( |
|
schema_file: str, |
|
lowercase: bool = False, |
|
) -> dict[str, dict[str, Table]]: |
|
"""Read tables json.""" |
|
data = json.load(open(schema_file)) |
|
db_to_tables = {} |
|
for db in data: |
|
db_name = db["db_id"] |
|
table_names = db["table_names_original"] |
|
db["column_names_original"] = [ |
|
[x[0], x[1]] for x in db["column_names_original"] |
|
] |
|
db["column_types"] = db["column_types"] |
|
if lowercase: |
|
table_names = [tn.lower() for tn in table_names] |
|
pks = db["primary_keys"] |
|
fks = db["foreign_keys"] |
|
tables = defaultdict(list) |
|
tables_pks = defaultdict(list) |
|
tables_fks = defaultdict(list) |
|
for idx, ((ti, col_name), col_type) in enumerate( |
|
zip(db["column_names_original"], db["column_types"]) |
|
): |
|
if ti == -1: |
|
continue |
|
if lowercase: |
|
col_name = col_name.lower() |
|
col_type = col_type.lower() |
|
if idx in pks: |
|
tables_pks[table_names[ti]].append( |
|
TableColumn(name=col_name, dtype=col_type) |
|
) |
|
for fk in fks: |
|
if idx == fk[0]: |
|
other_column = db["column_names_original"][fk[1]] |
|
other_column_type = db["column_types"][fk[1]] |
|
other_table = table_names[other_column[0]] |
|
tables_fks[table_names[ti]].append( |
|
ForeignKey( |
|
column=TableColumn(name=col_name, dtype=col_type), |
|
references_name=other_table, |
|
references_column=TableColumn( |
|
name=other_column[1], dtype=other_column_type |
|
), |
|
) |
|
) |
|
tables[table_names[ti]].append(TableColumn(name=col_name, dtype=col_type)) |
|
db_to_tables[db_name] = { |
|
table_name: Table( |
|
name=table_name, |
|
columns=tables[table_name], |
|
pks=tables_pks[table_name], |
|
fks=tables_fks[table_name], |
|
examples=None, |
|
) |
|
for table_name in tables |
|
} |
|
return db_to_tables |
|
|
|
|
|
def clean_str(target: str) -> str: |
|
"""Clean string for question.""" |
|
if not target: |
|
return target |
|
|
|
target = re.sub(r"[^\x00-\x7f]", r" ", target) |
|
line = re.sub(r"''", r" ", target) |
|
line = re.sub(r"``", r" ", line) |
|
line = re.sub(r"\"", r"'", line) |
|
line = re.sub(r"[\t ]+", " ", line) |
|
return line.strip() |
|
|