Spaces:
Running
Running
import click | |
import json | |
import os | |
import sqlite3 | |
import sys | |
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) | |
from config import DEFAULT_TABLES_DIR, DEFAULT_MODEL_ID, DEFAULT_INTERFACE_MODEL_ID | |
from src.processing.generate import get_sentences, generate_prediction | |
from src.utils.utils import load_model_and_tokenizer | |
class ArxivDatabase: | |
def __init__(self, db_path, model_id=None): | |
self.conn = None | |
self.cursor = None | |
self.db_path = db_path | |
self.model_id = model_id if model_id else DEFAULT_INTERFACE_MODEL_ID | |
self.model = None | |
self.tokenizer = None | |
self.is_db_empty = True | |
self.paper_table = """CREATE TABLE IF NOT EXISTS papers | |
(paper_id TEXT PRIMARY KEY, abstract TEXT, authors TEXT, | |
primary_category TEXT, url TEXT, updated_on TEXT, sentence_count INTEGER)""" | |
self.pred_table = """CREATE TABLE IF NOT EXISTS predictions | |
(id INTEGER PRIMARY KEY AUTOINCREMENT, paper_id TEXT, sentence_index INTEGER, | |
tag_type TEXT, concept TEXT, | |
FOREIGN KEY (paper_id) REFERENCES papers(paper_id))""" | |
# def init_db(self): | |
# self.cursor.execute(self.paper_table) | |
# self.cursor.execute(self.pred_table) | |
# print("Database and tables created successfully.") | |
# self.is_db_empty = self.is_empty() | |
def init_db(self): | |
self.conn = sqlite3.connect(self.db_path) | |
self.cursor = self.conn.cursor() | |
self.cursor.execute(self.paper_table) | |
self.cursor.execute(self.pred_table) | |
self.conn.commit() | |
self.is_db_empty = self.is_empty() | |
if not self.is_db_empty: | |
print("Database already contains data.") | |
else: | |
print("Database and tables created successfully.") | |
def is_empty(self): | |
try: | |
self.cursor.execute("SELECT COUNT(*) FROM papers") | |
count = self.cursor.fetchone()[0] | |
return count == 0 | |
except sqlite3.OperationalError: | |
return True | |
def get_connection(self): | |
return sqlite3.connect(self.conn.path) | |
def populate_db(self, data_path, pred_path): | |
papers_info = self._insert_papers(data_path) | |
self._insert_predictions(pred_path, papers_info) | |
print("Database population completed.") | |
def _insert_papers(self, data_path): | |
papers_info = [] | |
seen_papers = set() | |
with open(data_path, "r") as f: | |
for line in f: | |
paper = json.loads(line) | |
if paper["id"] in seen_papers: | |
continue | |
seen_papers.add(paper["id"]) | |
sentence_count = len(get_sentences(paper["id"])) + len( | |
get_sentences(paper["abstract"]) | |
) | |
papers_info.append((paper["id"], sentence_count)) | |
self.cursor.execute( | |
"""INSERT OR REPLACE INTO papers VALUES (?, ?, ?, ?, ?, ?, ?)""", | |
( | |
paper["id"], | |
paper["abstract"], | |
json.dumps(paper["authors"]), | |
json.dumps(paper["primary_category"]), | |
json.dumps(paper["url"]), | |
json.dumps(paper["updated"]), | |
sentence_count, | |
), | |
) | |
print(f"Inserted {len(papers_info)} papers.") | |
return papers_info | |
def _insert_predictions(self, pred_path, papers_info): | |
with open(pred_path, "r") as f: | |
predictions = json.load(f) | |
predicted_tags = predictions["predicted_tags"] | |
k = 0 | |
papers_with_predictions = set() | |
papers_without_predictions = [] | |
for paper_id, sentence_count in papers_info: | |
paper_predictions = predicted_tags[k : k + sentence_count] | |
has_predictions = False | |
for sentence_index, pred in enumerate(paper_predictions): | |
if pred: # If the prediction is not an empty dictionary | |
has_predictions = True | |
for tag_type, concepts in pred.items(): | |
for concept in concepts: | |
self.cursor.execute( | |
"""INSERT INTO predictions (paper_id, sentence_index, tag_type, concept) | |
VALUES (?, ?, ?, ?)""", | |
(paper_id, sentence_index, tag_type, concept), | |
) | |
else: | |
# Insert a null prediction to ensure the paper is counted | |
self.cursor.execute( | |
"""INSERT INTO predictions (paper_id, sentence_index, tag_type, concept) | |
VALUES (?, ?, ?, ?)""", | |
(paper_id, sentence_index, "null", "null"), | |
) | |
if has_predictions: | |
papers_with_predictions.add(paper_id) | |
else: | |
papers_without_predictions.append(paper_id) | |
k += sentence_count | |
print(f"Inserted predictions for {len(papers_with_predictions)} papers.") | |
print(f"Papers without any predictions: {len(papers_without_predictions)}") | |
if k < len(predicted_tags): | |
print(f"Warning: {len(predicted_tags) - k} predictions were not inserted.") | |
def load_model(self): | |
if self.model is None: | |
try: | |
self.model, self.tokenizer = load_model_and_tokenizer(self.model_id) | |
return f"Model {self.model_id} loaded successfully." | |
except Exception as e: | |
return f"Error loading model: {str(e)}" | |
else: | |
return "Model is already loaded." | |
def natural_language_to_sql(self, question): | |
system_prompt = "You are an assistant who converts natural language questions to SQL queries to query a database of scientific papers." | |
table = self.paper_table + "; " + self.pred_table | |
prefix = ( | |
f"[INST] Write SQLite query to answer the following question given the database schema. Please wrap your code answer using " | |
f"```: Schema: {table} Question: {question}[/INST] Here is the SQLite query to answer to the question: {question}: ``` " | |
) | |
sql_query = generate_prediction( | |
self.model, self.tokenizer, prefix, question, "sql", system_prompt | |
) | |
sql_query = sql_query.split("```")[1] | |
return sql_query | |
def execute_query(self, sql_query): | |
try: | |
self.cursor.execute(sql_query) | |
results = self.cursor.fetchall() | |
return results if results else [] | |
except sqlite3.Error as e: | |
return [(f"An error occurred: {e}",)] | |
def query_db(self, question, is_sql): | |
if self.is_db_empty: | |
return "The database is empty. Please populate it with data first." | |
try: | |
if is_sql: | |
sql_query = question.strip() | |
else: | |
nl_to_sql = self.natural_language_to_sql(question) | |
sql_query = nl_to_sql.replace("```sql", "").replace("```", "").strip() | |
results = self.execute_query(sql_query) | |
output = f"SQL Query: {sql_query}\n\nResults:\n" | |
if isinstance(results, list): | |
if len(results) > 0: | |
for row in results: | |
output += str(row) + "\n" | |
else: | |
output += "No results found." | |
else: | |
output += str(results) # In case of an error message | |
return output | |
except Exception as e: | |
return f"An error occurred: {str(e)}" | |
def close(self): | |
self.conn.commit() | |
self.conn.close() | |
def check_db_exists(db_path): | |
return os.path.exists(db_path) and os.path.getsize(db_path) > 0 | |
def main(data_path, pred_path, db_name, force): | |
ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) | |
tables_dir = os.path.join(ROOT, DEFAULT_TABLES_DIR) | |
os.makedirs(tables_dir, exist_ok=True) | |
db_path = os.path.join(tables_dir, db_name) | |
db_exists = check_db_exists(db_path) | |
db = ArxivDatabase(db_path) | |
db.init_db() | |
if db_exists and not db.is_db_empty: | |
if not force: | |
print(f"Warning: The database '{db_name}' already exists and is not empty.") | |
overwrite = input("Do you want to overwrite it? (y/N): ").lower().strip() | |
if overwrite != "y": | |
print("Operation cancelled.") | |
db.close() | |
return | |
else: | |
print( | |
f"Warning: Overwriting existing database '{db_name}' due to --force flag." | |
) | |
db.populate_db(data_path, pred_path) | |
db.close() | |
print(f"Database created and populated at: {db_path}") | |
if __name__ == "__main__": | |
main() | |