Spaces:
Running
Running
File size: 9,425 Bytes
8fbb714 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 |
import click
import json
import os
import sqlite3
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from config import DEFAULT_TABLES_DIR, DEFAULT_MODEL_ID, DEFAULT_INTERFACE_MODEL_ID
from src.processing.generate import get_sentences, generate_prediction
from src.utils.utils import load_model_and_tokenizer
class ArxivDatabase:
def __init__(self, db_path, model_id=None):
self.conn = None
self.cursor = None
self.db_path = db_path
self.model_id = model_id if model_id else DEFAULT_INTERFACE_MODEL_ID
self.model = None
self.tokenizer = None
self.is_db_empty = True
self.paper_table = """CREATE TABLE IF NOT EXISTS papers
(paper_id TEXT PRIMARY KEY, abstract TEXT, authors TEXT,
primary_category TEXT, url TEXT, updated_on TEXT, sentence_count INTEGER)"""
self.pred_table = """CREATE TABLE IF NOT EXISTS predictions
(id INTEGER PRIMARY KEY AUTOINCREMENT, paper_id TEXT, sentence_index INTEGER,
tag_type TEXT, concept TEXT,
FOREIGN KEY (paper_id) REFERENCES papers(paper_id))"""
# def init_db(self):
# self.cursor.execute(self.paper_table)
# self.cursor.execute(self.pred_table)
# print("Database and tables created successfully.")
# self.is_db_empty = self.is_empty()
def init_db(self):
self.conn = sqlite3.connect(self.db_path)
self.cursor = self.conn.cursor()
self.cursor.execute(self.paper_table)
self.cursor.execute(self.pred_table)
self.conn.commit()
self.is_db_empty = self.is_empty()
if not self.is_db_empty:
print("Database already contains data.")
else:
print("Database and tables created successfully.")
def is_empty(self):
try:
self.cursor.execute("SELECT COUNT(*) FROM papers")
count = self.cursor.fetchone()[0]
return count == 0
except sqlite3.OperationalError:
return True
def get_connection(self):
return sqlite3.connect(self.conn.path)
def populate_db(self, data_path, pred_path):
papers_info = self._insert_papers(data_path)
self._insert_predictions(pred_path, papers_info)
print("Database population completed.")
def _insert_papers(self, data_path):
papers_info = []
seen_papers = set()
with open(data_path, "r") as f:
for line in f:
paper = json.loads(line)
if paper["id"] in seen_papers:
continue
seen_papers.add(paper["id"])
sentence_count = len(get_sentences(paper["id"])) + len(
get_sentences(paper["abstract"])
)
papers_info.append((paper["id"], sentence_count))
self.cursor.execute(
"""INSERT OR REPLACE INTO papers VALUES (?, ?, ?, ?, ?, ?, ?)""",
(
paper["id"],
paper["abstract"],
json.dumps(paper["authors"]),
json.dumps(paper["primary_category"]),
json.dumps(paper["url"]),
json.dumps(paper["updated"]),
sentence_count,
),
)
print(f"Inserted {len(papers_info)} papers.")
return papers_info
def _insert_predictions(self, pred_path, papers_info):
with open(pred_path, "r") as f:
predictions = json.load(f)
predicted_tags = predictions["predicted_tags"]
k = 0
papers_with_predictions = set()
papers_without_predictions = []
for paper_id, sentence_count in papers_info:
paper_predictions = predicted_tags[k : k + sentence_count]
has_predictions = False
for sentence_index, pred in enumerate(paper_predictions):
if pred: # If the prediction is not an empty dictionary
has_predictions = True
for tag_type, concepts in pred.items():
for concept in concepts:
self.cursor.execute(
"""INSERT INTO predictions (paper_id, sentence_index, tag_type, concept)
VALUES (?, ?, ?, ?)""",
(paper_id, sentence_index, tag_type, concept),
)
else:
# Insert a null prediction to ensure the paper is counted
self.cursor.execute(
"""INSERT INTO predictions (paper_id, sentence_index, tag_type, concept)
VALUES (?, ?, ?, ?)""",
(paper_id, sentence_index, "null", "null"),
)
if has_predictions:
papers_with_predictions.add(paper_id)
else:
papers_without_predictions.append(paper_id)
k += sentence_count
print(f"Inserted predictions for {len(papers_with_predictions)} papers.")
print(f"Papers without any predictions: {len(papers_without_predictions)}")
if k < len(predicted_tags):
print(f"Warning: {len(predicted_tags) - k} predictions were not inserted.")
def load_model(self):
if self.model is None:
try:
self.model, self.tokenizer = load_model_and_tokenizer(self.model_id)
return f"Model {self.model_id} loaded successfully."
except Exception as e:
return f"Error loading model: {str(e)}"
else:
return "Model is already loaded."
def natural_language_to_sql(self, question):
system_prompt = "You are an assistant who converts natural language questions to SQL queries to query a database of scientific papers."
table = self.paper_table + "; " + self.pred_table
prefix = (
f"[INST] Write SQLite query to answer the following question given the database schema. Please wrap your code answer using "
f"```: Schema: {table} Question: {question}[/INST] Here is the SQLite query to answer to the question: {question}: ``` "
)
sql_query = generate_prediction(
self.model, self.tokenizer, prefix, question, "sql", system_prompt
)
sql_query = sql_query.split("```")[1]
return sql_query
def execute_query(self, sql_query):
try:
self.cursor.execute(sql_query)
results = self.cursor.fetchall()
return results if results else []
except sqlite3.Error as e:
return [(f"An error occurred: {e}",)]
def query_db(self, question, is_sql):
if self.is_db_empty:
return "The database is empty. Please populate it with data first."
try:
if is_sql:
sql_query = question.strip()
else:
nl_to_sql = self.natural_language_to_sql(question)
sql_query = nl_to_sql.replace("```sql", "").replace("```", "").strip()
results = self.execute_query(sql_query)
output = f"SQL Query: {sql_query}\n\nResults:\n"
if isinstance(results, list):
if len(results) > 0:
for row in results:
output += str(row) + "\n"
else:
output += "No results found."
else:
output += str(results) # In case of an error message
return output
except Exception as e:
return f"An error occurred: {str(e)}"
def close(self):
self.conn.commit()
self.conn.close()
def check_db_exists(db_path):
return os.path.exists(db_path) and os.path.getsize(db_path) > 0
@click.command()
@click.option(
"--data_path", help="Path to the data file containing the papers information."
)
@click.option("--pred_path", help="Path to the predictions file.")
@click.option("--db_name", default="arxiv.db", help="Name of the database to create.")
@click.option(
"--force", is_flag=True, help="Force overwrite if database already exists"
)
def main(data_path, pred_path, db_name, force):
ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
tables_dir = os.path.join(ROOT, DEFAULT_TABLES_DIR)
os.makedirs(tables_dir, exist_ok=True)
db_path = os.path.join(tables_dir, db_name)
db_exists = check_db_exists(db_path)
db = ArxivDatabase(db_path)
db.init_db()
if db_exists and not db.is_db_empty:
if not force:
print(f"Warning: The database '{db_name}' already exists and is not empty.")
overwrite = input("Do you want to overwrite it? (y/N): ").lower().strip()
if overwrite != "y":
print("Operation cancelled.")
db.close()
return
else:
print(
f"Warning: Overwriting existing database '{db_name}' due to --force flag."
)
db.populate_db(data_path, pred_path)
db.close()
print(f"Database created and populated at: {db_path}")
if __name__ == "__main__":
main()
|