pvanand commited on
Commit
131998f
·
verified ·
1 Parent(s): 6bfafac

use sqlite to replace tables.json

Browse files
Files changed (1) hide show
  1. rag_routerv2.py +141 -136
rag_routerv2.py CHANGED
@@ -9,6 +9,7 @@ import uuid
9
  import io
10
  from io import BytesIO
11
  import csv
 
12
 
13
  # LlamaIndex imports
14
  from llama_index.core import Settings, SimpleDirectoryReader, VectorStoreIndex
@@ -27,13 +28,38 @@ router = APIRouter(
27
 
28
  # Configure global LlamaIndex settings
29
  Settings.embed_model = FastEmbedEmbedding(model_name="BAAI/bge-small-en-v1.5")
30
- tables_file_path = './data/tables.json'
31
 
32
  # Database connection dependency
33
  @lru_cache()
34
  def get_db_connection(db_path: str = "./lancedb/dev"):
35
  return lancedb.connect(db_path)
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  # Pydantic models
38
  class CreateTableResponse(BaseModel):
39
  table_id: str
@@ -48,94 +74,65 @@ class QueryTableResponse(BaseModel):
48
 
49
  @router.post("/create_table", response_model=CreateTableResponse)
50
  async def create_embedding_table(
51
- user_id: str,
52
- files: List[UploadFile] = File(...),
53
- table_id: Optional[str] = None,
54
- table_name: Optional[str] = None
55
  ) -> CreateTableResponse:
56
- """Create a table and load embeddings from uploaded files using LlamaIndex."""
57
- allowed_extensions = {".pdf", ".docx", ".csv", ".txt", ".md"}
58
- for file in files:
59
- if file.filename is None:
60
- raise HTTPException(status_code=400, detail="File must have a valid name.")
61
- file_extension = os.path.splitext(file.filename)[1].lower()
62
- if file_extension not in allowed_extensions:
63
- raise HTTPException(
64
- status_code=400,
65
- detail=f"File type {file_extension} is not allowed. Supported file types are: {', '.join(allowed_extensions)}."
66
- )
67
-
68
- if table_id is None:
69
- table_id = str(uuid.uuid4())
70
- table_name = f"knowledge-base-{str(uuid.uuid4())[:4]}" if not table_name else table_name
71
-
72
- #table_name = table_id #f"{user_id}__table__{table_id}"
73
-
74
- # Create a directory for the uploaded files
75
- directory_path = f"./data/{table_id}"
76
- os.makedirs(directory_path, exist_ok=True)
77
-
78
- # Save each uploaded file to the data directory
79
- for file in files:
80
- file_path = os.path.join(directory_path, file.filename)
81
- with open(file_path, "wb") as buffer:
82
- shutil.copyfileobj(file.file, buffer)
83
-
84
- try:
85
- # Setup LanceDB vector store
86
- vector_store = LanceDBVectorStore(
87
- uri="./lancedb/dev",
88
- table_name=table_id,
89
- mode="overwrite",
90
- query_type="hybrid"
91
- )
92
-
93
- # Load documents using SimpleDirectoryReader
94
- documents = SimpleDirectoryReader(directory_path).load_data()
95
-
96
- # Create the index
97
- index = VectorStoreIndex.from_documents(
98
- documents,
99
- vector_store=vector_store
100
- )
101
- index.storage_context.persist(persist_dir=f"./lancedb/index/{table_id}")
102
-
103
-
104
- # Store user_id and table_name in a JSON file
105
- try:
106
- tables_file_path = './data/tables.json'
107
- os.makedirs(os.path.dirname(tables_file_path), exist_ok=True)
108
- # Load existing tables or create a new file if it doesn't exist
109
- try:
110
- with open(tables_file_path, 'r') as f:
111
- tables = json.load(f)
112
- except (FileNotFoundError, json.JSONDecodeError):
113
- tables = {}
114
-
115
- # Update the tables dictionary
116
- if user_id not in tables:
117
- tables[user_id] = []
118
- if table_id not in [table['table_id'] for table in tables[user_id]]:
119
- tables[user_id].append({"table_id": table_id, "table_name": table_name})
120
-
121
- # Write the updated tables back to the JSON file
122
- with open(tables_file_path, 'w') as f:
123
- json.dump(tables, f)
124
-
125
- except Exception as e:
126
- raise HTTPException(status_code=500, detail=f"Failed to update tables file: {str(e)}")
127
-
128
- return CreateTableResponse(
129
- table_id=table_id,
130
- message="Table created and documents indexed successfully",
131
- status="success",
132
- table_name=table_name
133
- )
134
 
135
 
136
- except Exception as e:
137
- raise HTTPException(status_code=500, detail=f"Table creation failed: {str(e)}")
138
-
139
  @router.post("/query_table/{table_id}", response_model=QueryTableResponse)
140
  async def query_table(
141
  table_id: str,
@@ -172,22 +169,28 @@ async def query_table(
172
 
173
  @router.get("/get_tables/{user_id}")
174
  async def get_tables(user_id: str):
175
- """Get all tables for a user."""
176
-
177
- tables_file_path = './data/tables.json'
178
- try:
179
- # Load existing tables from the JSON file
180
- with open(tables_file_path, 'r') as f:
181
- tables = json.load(f)
182
-
183
- # Retrieve tables for the specified user
184
- user_tables = tables.get(user_id, [])
185
- return user_tables
186
-
187
- except (FileNotFoundError, json.JSONDecodeError):
188
- return [] # Return an empty list if the file doesn't exist or is invalid
189
- except Exception as e:
190
- raise HTTPException(status_code=500, detail=f"Failed to retrieve tables: {str(e)}")
 
 
 
 
 
 
191
 
192
  @router.get("/health")
193
  async def health_check():
@@ -195,40 +198,42 @@ async def health_check():
195
 
196
  @router.on_event("startup")
197
  async def startup():
198
- print("RAG Router started")
199
- from llama_index.core.schema import TextNode
200
- table_name = "digiyatra"
201
- nodes = []
202
- vector_store = LanceDBVectorStore(
203
-
204
- uri="./lancedb/dev",
205
- table_name=table_name,
206
- mode="overwrite",
207
- query_type="hybrid"
208
- )
209
- # load digiyatra csv and create node for each row using csv.reader
210
- with open('combined_digi_yatra.csv', newline='') as f:
211
- reader = csv.reader(f)
212
- data = list(reader)
213
- for row in data[1:]:
214
- node = TextNode(text=str(row), id_=str(uuid.uuid4()))
215
- nodes.append(node)
216
-
217
- index = VectorStoreIndex(nodes, vector_store=vector_store)
218
- index.storage_context.persist(persist_dir=f"./lancedb/index/{table_name}")
219
-
220
- # Create tables dictionary
221
- tables = {}
222
- user_id = "digiyatra"
223
-
224
- tables[user_id] = [
225
- {
226
- "table_id": table_name,
227
- "table_name": table_name
228
- }
229
- ]
230
- with open(tables_file_path, 'w') as f:
231
- json.dump(tables, f)
 
 
232
 
233
  @router.on_event("shutdown")
234
  async def shutdown():
 
9
  import io
10
  from io import BytesIO
11
  import csv
12
+ import sqlite3
13
 
14
  # LlamaIndex imports
15
  from llama_index.core import Settings, SimpleDirectoryReader, VectorStoreIndex
 
28
 
29
  # Configure global LlamaIndex settings
30
  Settings.embed_model = FastEmbedEmbedding(model_name="BAAI/bge-small-en-v1.5")
 
31
 
32
  # Database connection dependency
33
  @lru_cache()
34
  def get_db_connection(db_path: str = "./lancedb/dev"):
35
  return lancedb.connect(db_path)
36
 
37
+ def get_db():
38
+ conn = sqlite3.connect('./data/tables.db')
39
+ conn.row_factory = sqlite3.Row
40
+ return conn
41
+
42
+ def init_db():
43
+ db = get_db()
44
+ db.execute('''
45
+ CREATE TABLE IF NOT EXISTS tables (
46
+ id INTEGER PRIMARY KEY,
47
+ user_id TEXT NOT NULL,
48
+ table_id TEXT NOT NULL,
49
+ table_name TEXT NOT NULL
50
+ )
51
+ ''')
52
+ db.execute('''
53
+ CREATE TABLE IF NOT EXISTS table_files (
54
+ id INTEGER PRIMARY KEY,
55
+ table_id TEXT NOT NULL,
56
+ filename TEXT NOT NULL,
57
+ file_path TEXT NOT NULL,
58
+ FOREIGN KEY (table_id) REFERENCES tables (table_id)
59
+ )
60
+ ''')
61
+ db.commit()
62
+
63
  # Pydantic models
64
  class CreateTableResponse(BaseModel):
65
  table_id: str
 
74
 
75
  @router.post("/create_table", response_model=CreateTableResponse)
76
  async def create_embedding_table(
77
+ user_id: str,
78
+ files: List[UploadFile] = File(...),
79
+ table_id: Optional[str] = None,
80
+ table_name: Optional[str] = None
81
  ) -> CreateTableResponse:
82
+ allowed_extensions = {".pdf", ".docx", ".csv", ".txt", ".md"}
83
+ for file in files:
84
+ if not file.filename:
85
+ raise HTTPException(status_code=400, detail="Invalid filename")
86
+ if os.path.splitext(file.filename)[1].lower() not in allowed_extensions:
87
+ raise HTTPException(status_code=400, detail="Unsupported file type")
88
+
89
+ table_id = table_id or str(uuid.uuid4())
90
+ table_name = table_name or f"knowledge-base-{str(uuid.uuid4())[:4]}"
91
+
92
+ directory_path = f"./data/{table_id}"
93
+ os.makedirs(directory_path, exist_ok=True)
94
+
95
+ for file in files:
96
+ file_path = os.path.join(directory_path, file.filename)
97
+ with open(file_path, "wb") as buffer:
98
+ shutil.copyfileobj(file.file, buffer)
99
+
100
+ try:
101
+ vector_store = LanceDBVectorStore(
102
+ uri="./lancedb/dev",
103
+ table_name=table_id,
104
+ mode="overwrite",
105
+ query_type="hybrid"
106
+ )
107
+
108
+ documents = SimpleDirectoryReader(directory_path).load_data()
109
+ index = VectorStoreIndex.from_documents(documents, vector_store=vector_store)
110
+ index.storage_context.persist(persist_dir=f"./lancedb/index/{table_id}")
111
+
112
+ db = get_db()
113
+ db.execute(
114
+ 'INSERT INTO tables (user_id, table_id, table_name) VALUES (?, ?, ?)',
115
+ (user_id, table_id, table_name)
116
+ )
117
+
118
+ for file in files:
119
+ db.execute(
120
+ 'INSERT INTO table_files (table_id, filename, file_path) VALUES (?, ?, ?)',
121
+ (table_id, file.filename, f"./data/{table_id}/{file.filename}")
122
+ )
123
+ db.commit()
124
+
125
+ return CreateTableResponse(
126
+ table_id=table_id,
127
+ message="Success",
128
+ status="success",
129
+ table_name=table_name
130
+ )
131
+
132
+ except Exception as e:
133
+ raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
 
 
 
 
136
  @router.post("/query_table/{table_id}", response_model=QueryTableResponse)
137
  async def query_table(
138
  table_id: str,
 
169
 
170
  @router.get("/get_tables/{user_id}")
171
  async def get_tables(user_id: str):
172
+ db = get_db()
173
+ tables = db.execute('''
174
+ SELECT t.*, GROUP_CONCAT(tf.filename) as filenames, GROUP_CONCAT(tf.file_path) as file_paths
175
+ FROM tables t
176
+ LEFT JOIN table_files tf ON t.table_id = tf.table_id
177
+ WHERE t.user_id = ?
178
+ GROUP BY t.table_id
179
+ ''', (user_id,)).fetchall()
180
+
181
+ result = []
182
+ for table in tables:
183
+ table_dict = dict(table)
184
+ table_dict['files'] = [
185
+ {'filename': f, 'file_path': p}
186
+ for f, p in zip(
187
+ table_dict.pop('filenames').split(',') if table_dict['filenames'] else [],
188
+ table_dict.pop('file_paths').split(',') if table_dict['file_paths'] else []
189
+ )
190
+ ]
191
+ result.append(table_dict)
192
+
193
+ return result
194
 
195
  @router.get("/health")
196
  async def health_check():
 
198
 
199
  @router.on_event("startup")
200
  async def startup():
201
+ init_db()
202
+ print("RAG Router started")
203
+
204
+ table_name = "digiyatra"
205
+ user_id = "digiyatra"
206
+
207
+ # Create vector store and index
208
+ vector_store = LanceDBVectorStore(
209
+ uri="./lancedb/dev",
210
+ table_name=table_name,
211
+ mode="overwrite",
212
+ query_type="hybrid"
213
+ )
214
+
215
+ # Load CSV and create nodes
216
+ with open('combined_digi_yatra.csv', newline='') as f:
217
+ nodes = [
218
+ TextNode(text=str(row), id_=str(uuid.uuid4()))
219
+ for row in list(csv.reader(f))[1:]
220
+ ]
221
+
222
+ # Create and persist index
223
+ index = VectorStoreIndex(nodes, vector_store=vector_store)
224
+ index.storage_context.persist(persist_dir=f"./lancedb/index/{table_name}")
225
+
226
+ # Store in SQLite
227
+ db = get_db()
228
+ db.execute(
229
+ 'INSERT INTO tables (user_id, table_id, table_name) VALUES (?, ?, ?)',
230
+ (user_id, table_name, table_name)
231
+ )
232
+ db.execute(
233
+ 'INSERT INTO table_files (table_id, filename, file_path) VALUES (?, ?, ?)',
234
+ (table_name, 'combined_digi_yatra.csv', 'combined_digi_yatra.csv')
235
+ )
236
+ db.commit()
237
 
238
  @router.on_event("shutdown")
239
  async def shutdown():