KennethTM commited on
Commit
f0f2a62
1 Parent(s): b07eb8a

update db connection handling

Browse files
Files changed (1) hide show
  1. app.py +10 -4
app.py CHANGED
@@ -2,20 +2,26 @@ import gradio as gr
2
  from sentence_transformers import SentenceTransformer
3
  import psycopg2
4
  import os
 
5
 
6
  conn_string = os.environ.get("DATABASE_URL")
7
- conn = psycopg2.connect(conn_string)
8
- cur = conn.cursor()
9
 
10
  model = SentenceTransformer("multilingual-e5-small", device="cpu")
 
11
 
12
  def search(query, top_k):
13
- query_embedding = model.encode("query: " + query)
14
-
 
 
 
 
 
15
  query_sql = f"SELECT source_file, chunk FROM items ORDER BY embedding <=> '{str(query_embedding.tolist())}' LIMIT {int(top_k)};"
16
 
17
  cur.execute(query_sql)
18
  results = cur.fetchall()
 
19
 
20
  results_format = "\n".join([f"{i+1}. {text} __({file})__" for i, (file, text) in enumerate(results)])
21
 
 
2
  from sentence_transformers import SentenceTransformer
3
  import psycopg2
4
  import os
5
+ import torch
6
 
7
  conn_string = os.environ.get("DATABASE_URL")
 
 
8
 
9
  model = SentenceTransformer("multilingual-e5-small", device="cpu")
10
+ model.eval()
11
 
12
  def search(query, top_k):
13
+
14
+ with torch.no_grad():
15
+ query_embedding = model.encode("query: " + query)
16
+
17
+ conn = psycopg2.connect(conn_string)
18
+ cur = conn.cursor()
19
+
20
  query_sql = f"SELECT source_file, chunk FROM items ORDER BY embedding <=> '{str(query_embedding.tolist())}' LIMIT {int(top_k)};"
21
 
22
  cur.execute(query_sql)
23
  results = cur.fetchall()
24
+ conn.close()
25
 
26
  results_format = "\n".join([f"{i+1}. {text} __({file})__" for i, (file, text) in enumerate(results)])
27