surveyor-0 / scripts /create_db.py
Abhipsha Das
initial spaces deploy
8fbb714 unverified
import click
import json
import os
import sqlite3
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from config import DEFAULT_TABLES_DIR, DEFAULT_MODEL_ID, DEFAULT_INTERFACE_MODEL_ID
from src.processing.generate import get_sentences, generate_prediction
from src.utils.utils import load_model_and_tokenizer
class ArxivDatabase:
def __init__(self, db_path, model_id=None):
self.conn = None
self.cursor = None
self.db_path = db_path
self.model_id = model_id if model_id else DEFAULT_INTERFACE_MODEL_ID
self.model = None
self.tokenizer = None
self.is_db_empty = True
self.paper_table = """CREATE TABLE IF NOT EXISTS papers
(paper_id TEXT PRIMARY KEY, abstract TEXT, authors TEXT,
primary_category TEXT, url TEXT, updated_on TEXT, sentence_count INTEGER)"""
self.pred_table = """CREATE TABLE IF NOT EXISTS predictions
(id INTEGER PRIMARY KEY AUTOINCREMENT, paper_id TEXT, sentence_index INTEGER,
tag_type TEXT, concept TEXT,
FOREIGN KEY (paper_id) REFERENCES papers(paper_id))"""
# def init_db(self):
# self.cursor.execute(self.paper_table)
# self.cursor.execute(self.pred_table)
# print("Database and tables created successfully.")
# self.is_db_empty = self.is_empty()
def init_db(self):
self.conn = sqlite3.connect(self.db_path)
self.cursor = self.conn.cursor()
self.cursor.execute(self.paper_table)
self.cursor.execute(self.pred_table)
self.conn.commit()
self.is_db_empty = self.is_empty()
if not self.is_db_empty:
print("Database already contains data.")
else:
print("Database and tables created successfully.")
def is_empty(self):
try:
self.cursor.execute("SELECT COUNT(*) FROM papers")
count = self.cursor.fetchone()[0]
return count == 0
except sqlite3.OperationalError:
return True
def get_connection(self):
return sqlite3.connect(self.conn.path)
def populate_db(self, data_path, pred_path):
papers_info = self._insert_papers(data_path)
self._insert_predictions(pred_path, papers_info)
print("Database population completed.")
def _insert_papers(self, data_path):
papers_info = []
seen_papers = set()
with open(data_path, "r") as f:
for line in f:
paper = json.loads(line)
if paper["id"] in seen_papers:
continue
seen_papers.add(paper["id"])
sentence_count = len(get_sentences(paper["id"])) + len(
get_sentences(paper["abstract"])
)
papers_info.append((paper["id"], sentence_count))
self.cursor.execute(
"""INSERT OR REPLACE INTO papers VALUES (?, ?, ?, ?, ?, ?, ?)""",
(
paper["id"],
paper["abstract"],
json.dumps(paper["authors"]),
json.dumps(paper["primary_category"]),
json.dumps(paper["url"]),
json.dumps(paper["updated"]),
sentence_count,
),
)
print(f"Inserted {len(papers_info)} papers.")
return papers_info
def _insert_predictions(self, pred_path, papers_info):
with open(pred_path, "r") as f:
predictions = json.load(f)
predicted_tags = predictions["predicted_tags"]
k = 0
papers_with_predictions = set()
papers_without_predictions = []
for paper_id, sentence_count in papers_info:
paper_predictions = predicted_tags[k : k + sentence_count]
has_predictions = False
for sentence_index, pred in enumerate(paper_predictions):
if pred: # If the prediction is not an empty dictionary
has_predictions = True
for tag_type, concepts in pred.items():
for concept in concepts:
self.cursor.execute(
"""INSERT INTO predictions (paper_id, sentence_index, tag_type, concept)
VALUES (?, ?, ?, ?)""",
(paper_id, sentence_index, tag_type, concept),
)
else:
# Insert a null prediction to ensure the paper is counted
self.cursor.execute(
"""INSERT INTO predictions (paper_id, sentence_index, tag_type, concept)
VALUES (?, ?, ?, ?)""",
(paper_id, sentence_index, "null", "null"),
)
if has_predictions:
papers_with_predictions.add(paper_id)
else:
papers_without_predictions.append(paper_id)
k += sentence_count
print(f"Inserted predictions for {len(papers_with_predictions)} papers.")
print(f"Papers without any predictions: {len(papers_without_predictions)}")
if k < len(predicted_tags):
print(f"Warning: {len(predicted_tags) - k} predictions were not inserted.")
def load_model(self):
if self.model is None:
try:
self.model, self.tokenizer = load_model_and_tokenizer(self.model_id)
return f"Model {self.model_id} loaded successfully."
except Exception as e:
return f"Error loading model: {str(e)}"
else:
return "Model is already loaded."
def natural_language_to_sql(self, question):
system_prompt = "You are an assistant who converts natural language questions to SQL queries to query a database of scientific papers."
table = self.paper_table + "; " + self.pred_table
prefix = (
f"[INST] Write SQLite query to answer the following question given the database schema. Please wrap your code answer using "
f"```: Schema: {table} Question: {question}[/INST] Here is the SQLite query to answer to the question: {question}: ``` "
)
sql_query = generate_prediction(
self.model, self.tokenizer, prefix, question, "sql", system_prompt
)
sql_query = sql_query.split("```")[1]
return sql_query
def execute_query(self, sql_query):
try:
self.cursor.execute(sql_query)
results = self.cursor.fetchall()
return results if results else []
except sqlite3.Error as e:
return [(f"An error occurred: {e}",)]
def query_db(self, question, is_sql):
if self.is_db_empty:
return "The database is empty. Please populate it with data first."
try:
if is_sql:
sql_query = question.strip()
else:
nl_to_sql = self.natural_language_to_sql(question)
sql_query = nl_to_sql.replace("```sql", "").replace("```", "").strip()
results = self.execute_query(sql_query)
output = f"SQL Query: {sql_query}\n\nResults:\n"
if isinstance(results, list):
if len(results) > 0:
for row in results:
output += str(row) + "\n"
else:
output += "No results found."
else:
output += str(results) # In case of an error message
return output
except Exception as e:
return f"An error occurred: {str(e)}"
def close(self):
self.conn.commit()
self.conn.close()
def check_db_exists(db_path):
return os.path.exists(db_path) and os.path.getsize(db_path) > 0
@click.command()
@click.option(
"--data_path", help="Path to the data file containing the papers information."
)
@click.option("--pred_path", help="Path to the predictions file.")
@click.option("--db_name", default="arxiv.db", help="Name of the database to create.")
@click.option(
"--force", is_flag=True, help="Force overwrite if database already exists"
)
def main(data_path, pred_path, db_name, force):
ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
tables_dir = os.path.join(ROOT, DEFAULT_TABLES_DIR)
os.makedirs(tables_dir, exist_ok=True)
db_path = os.path.join(tables_dir, db_name)
db_exists = check_db_exists(db_path)
db = ArxivDatabase(db_path)
db.init_db()
if db_exists and not db.is_db_empty:
if not force:
print(f"Warning: The database '{db_name}' already exists and is not empty.")
overwrite = input("Do you want to overwrite it? (y/N): ").lower().strip()
if overwrite != "y":
print("Operation cancelled.")
db.close()
return
else:
print(
f"Warning: Overwriting existing database '{db_name}' due to --force flag."
)
db.populate_db(data_path, pred_path)
db.close()
print(f"Database created and populated at: {db_path}")
if __name__ == "__main__":
main()