File size: 9,425 Bytes
8fbb714
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
import click
import json
import os
import sqlite3
import sys

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from config import DEFAULT_TABLES_DIR, DEFAULT_MODEL_ID, DEFAULT_INTERFACE_MODEL_ID
from src.processing.generate import get_sentences, generate_prediction
from src.utils.utils import load_model_and_tokenizer


class ArxivDatabase:
    def __init__(self, db_path, model_id=None):
        self.conn = None
        self.cursor = None
        self.db_path = db_path
        self.model_id = model_id if model_id else DEFAULT_INTERFACE_MODEL_ID
        self.model = None
        self.tokenizer = None
        self.is_db_empty = True
        self.paper_table = """CREATE TABLE IF NOT EXISTS papers
                        (paper_id TEXT PRIMARY KEY, abstract TEXT, authors TEXT, 
                        primary_category TEXT, url TEXT, updated_on TEXT, sentence_count INTEGER)"""
        self.pred_table = """CREATE TABLE IF NOT EXISTS predictions
                        (id INTEGER PRIMARY KEY AUTOINCREMENT, paper_id TEXT, sentence_index INTEGER, 
                        tag_type TEXT, concept TEXT,
                        FOREIGN KEY (paper_id) REFERENCES papers(paper_id))"""

    # def init_db(self):
    #     self.cursor.execute(self.paper_table)
    #     self.cursor.execute(self.pred_table)

    #     print("Database and tables created successfully.")
    #     self.is_db_empty = self.is_empty()

    def init_db(self):
        self.conn = sqlite3.connect(self.db_path)
        self.cursor = self.conn.cursor()
        self.cursor.execute(self.paper_table)
        self.cursor.execute(self.pred_table)
        self.conn.commit()
        self.is_db_empty = self.is_empty()
        if not self.is_db_empty:
            print("Database already contains data.")
        else:
            print("Database and tables created successfully.")

    def is_empty(self):
        try:
            self.cursor.execute("SELECT COUNT(*) FROM papers")
            count = self.cursor.fetchone()[0]
            return count == 0
        except sqlite3.OperationalError:
            return True

    def get_connection(self):
        return sqlite3.connect(self.conn.path)

    def populate_db(self, data_path, pred_path):
        papers_info = self._insert_papers(data_path)
        self._insert_predictions(pred_path, papers_info)
        print("Database population completed.")

    def _insert_papers(self, data_path):
        papers_info = []
        seen_papers = set()
        with open(data_path, "r") as f:
            for line in f:
                paper = json.loads(line)
                if paper["id"] in seen_papers:
                    continue
                seen_papers.add(paper["id"])
                sentence_count = len(get_sentences(paper["id"])) + len(
                    get_sentences(paper["abstract"])
                )
                papers_info.append((paper["id"], sentence_count))
                self.cursor.execute(
                    """INSERT OR REPLACE INTO papers VALUES (?, ?, ?, ?, ?, ?, ?)""",
                    (
                        paper["id"],
                        paper["abstract"],
                        json.dumps(paper["authors"]),
                        json.dumps(paper["primary_category"]),
                        json.dumps(paper["url"]),
                        json.dumps(paper["updated"]),
                        sentence_count,
                    ),
                )
        print(f"Inserted {len(papers_info)} papers.")
        return papers_info

    def _insert_predictions(self, pred_path, papers_info):
        with open(pred_path, "r") as f:
            predictions = json.load(f)
            predicted_tags = predictions["predicted_tags"]

        k = 0
        papers_with_predictions = set()
        papers_without_predictions = []
        for paper_id, sentence_count in papers_info:
            paper_predictions = predicted_tags[k : k + sentence_count]

            has_predictions = False
            for sentence_index, pred in enumerate(paper_predictions):
                if pred:  # If the prediction is not an empty dictionary
                    has_predictions = True
                    for tag_type, concepts in pred.items():
                        for concept in concepts:
                            self.cursor.execute(
                                """INSERT INTO predictions (paper_id, sentence_index, tag_type, concept) 
                                VALUES (?, ?, ?, ?)""",
                                (paper_id, sentence_index, tag_type, concept),
                            )
                else:
                    # Insert a null prediction to ensure the paper is counted
                    self.cursor.execute(
                        """INSERT INTO predictions (paper_id, sentence_index, tag_type, concept) 
                        VALUES (?, ?, ?, ?)""",
                        (paper_id, sentence_index, "null", "null"),
                    )

            if has_predictions:
                papers_with_predictions.add(paper_id)
            else:
                papers_without_predictions.append(paper_id)

            k += sentence_count

        print(f"Inserted predictions for {len(papers_with_predictions)} papers.")
        print(f"Papers without any predictions: {len(papers_without_predictions)}")

        if k < len(predicted_tags):
            print(f"Warning: {len(predicted_tags) - k} predictions were not inserted.")

    def load_model(self):
        if self.model is None:
            try:
                self.model, self.tokenizer = load_model_and_tokenizer(self.model_id)
                return f"Model {self.model_id} loaded successfully."
            except Exception as e:
                return f"Error loading model: {str(e)}"
        else:
            return "Model is already loaded."

    def natural_language_to_sql(self, question):
        system_prompt = "You are an assistant who converts natural language questions to SQL queries to query a database of scientific papers."
        table = self.paper_table + "; " + self.pred_table
        prefix = (
            f"[INST] Write SQLite query to answer the following question given the database schema. Please wrap your code answer using "
            f"```: Schema: {table} Question: {question}[/INST] Here is the SQLite query to answer to the question: {question}: ``` "
        )

        sql_query = generate_prediction(
            self.model, self.tokenizer, prefix, question, "sql", system_prompt
        )

        sql_query = sql_query.split("```")[1]

        return sql_query

    def execute_query(self, sql_query):
        try:
            self.cursor.execute(sql_query)
            results = self.cursor.fetchall()
            return results if results else []
        except sqlite3.Error as e:
            return [(f"An error occurred: {e}",)]

    def query_db(self, question, is_sql):
        if self.is_db_empty:
            return "The database is empty. Please populate it with data first."

        try:
            if is_sql:
                sql_query = question.strip()
            else:
                nl_to_sql = self.natural_language_to_sql(question)
                sql_query = nl_to_sql.replace("```sql", "").replace("```", "").strip()

            results = self.execute_query(sql_query)

            output = f"SQL Query: {sql_query}\n\nResults:\n"
            if isinstance(results, list):
                if len(results) > 0:
                    for row in results:
                        output += str(row) + "\n"
                else:
                    output += "No results found."
            else:
                output += str(results)  # In case of an error message

            return output
        except Exception as e:
            return f"An error occurred: {str(e)}"

    def close(self):
        self.conn.commit()
        self.conn.close()


def check_db_exists(db_path):
    return os.path.exists(db_path) and os.path.getsize(db_path) > 0


@click.command()
@click.option(
    "--data_path", help="Path to the data file containing the papers information."
)
@click.option("--pred_path", help="Path to the predictions file.")
@click.option("--db_name", default="arxiv.db", help="Name of the database to create.")
@click.option(
    "--force", is_flag=True, help="Force overwrite if database already exists"
)
def main(data_path, pred_path, db_name, force):
    ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
    tables_dir = os.path.join(ROOT, DEFAULT_TABLES_DIR)
    os.makedirs(tables_dir, exist_ok=True)
    db_path = os.path.join(tables_dir, db_name)

    db_exists = check_db_exists(db_path)

    db = ArxivDatabase(db_path)
    db.init_db()

    if db_exists and not db.is_db_empty:
        if not force:
            print(f"Warning: The database '{db_name}' already exists and is not empty.")
            overwrite = input("Do you want to overwrite it? (y/N): ").lower().strip()
            if overwrite != "y":
                print("Operation cancelled.")
                db.close()
                return
        else:
            print(
                f"Warning: Overwriting existing database '{db_name}' due to --force flag."
            )

    db.populate_db(data_path, pred_path)
    db.close()

    print(f"Database created and populated at: {db_path}")


if __name__ == "__main__":
    main()