kenken999 commited on
Commit
a1e6264
·
1 Parent(s): d6a2797

add sentence-transformer

Browse files
Files changed (2) hide show
  1. babyagi/classesa/diamond.py +19 -7
  2. requirements.txt +2 -1
babyagi/classesa/diamond.py CHANGED
@@ -31,7 +31,7 @@ class ProductDatabase:
31
  def insert_vector(self, product_id, text):
32
  vector = self.get_embedding(text).tolist() # ndarray をリストに変換
33
  with self.conn.cursor() as cursor:
34
- cursor.execute("UPDATE products SET vector_col = %s WHERE id = %s", (vector, product_id))
35
  self.conn.commit()
36
 
37
  def search_similar_vectors(self, query_text, top_k=5):
@@ -39,13 +39,23 @@ class ProductDatabase:
39
  with self.conn.cursor() as cursor:
40
  cursor.execute("""
41
  SELECT id, vector_col <=> %s::vector AS distance
42
- FROM products
43
  ORDER BY distance
44
  LIMIT %s;
45
  """, (query_vector, top_k))
46
  results = cursor.fetchall()
47
  return results
48
 
 
 
 
 
 
 
 
 
 
 
49
  def main():
50
  # データベース接続情報
51
  DATABASE_URL = "postgresql://miyataken999:yz1wPf4KrWTm@ep-odd-mode-93794521.us-east-2.aws.neon.tech/neondb?sslmode=require"
@@ -60,12 +70,14 @@ def main():
60
  # pgvector拡張機能のインストールとカラムの追加
61
  db.setup_vector_extension_and_column()
62
  print("Vector extension installed and column added successfully.")
63
-
 
 
 
 
 
64
  # サンプルデータの挿入
65
- sample_text = """検査にはどのぐらい時間かかりますか?⇒当日に分かります。
66
- 法人取引やってますか?⇒大丈夫ですよ。成約時に必要な書類の説明
67
- LINEで金粉送って、査定はできますか?⇒できますが、今お話した内容と同様で、検査が必要な旨を返すだけなので、金粉ではなく、他のお品物でLINE査定くださいと。
68
- 分かりました、またどうするか検討して連絡しますと"""
69
  sample_product_id = 1 # 実際の製品IDを使用
70
  db.insert_vector(sample_product_id, sample_text)
71
  db.insert_vector(2, sample_text)
 
31
  def insert_vector(self, product_id, text):
32
  vector = self.get_embedding(text).tolist() # ndarray をリストに変換
33
  with self.conn.cursor() as cursor:
34
+ cursor.execute("UPDATE diamondprice SET vector_col = %s WHERE id = %s", (vector, product_id))
35
  self.conn.commit()
36
 
37
  def search_similar_vectors(self, query_text, top_k=5):
 
39
  with self.conn.cursor() as cursor:
40
  cursor.execute("""
41
  SELECT id, vector_col <=> %s::vector AS distance
42
+ FROM diamondprice
43
  ORDER BY distance
44
  LIMIT %s;
45
  """, (query_vector, top_k))
46
  results = cursor.fetchall()
47
  return results
48
 
49
+ def search_similar_all(self, query_text, top_k=5):
50
+ query_vector = self.get_embedding(query_text).tolist() # ndarray をリストに変換
51
+ with self.conn.cursor() as cursor:
52
+ cursor.execute("""
53
+ SELECT id,'carat', 'cut', 'color', 'clarity', 'depth', 'diamondprice.table', 'x', 'y', 'z'
54
+ FROM diamondprice
55
+ """, (query_vector, top_k))
56
+ results = cursor.fetchall()
57
+ return results
58
+
59
  def main():
60
  # データベース接続情報
61
  DATABASE_URL = "postgresql://miyataken999:yz1wPf4KrWTm@ep-odd-mode-93794521.us-east-2.aws.neon.tech/neondb?sslmode=require"
 
70
  # pgvector拡張機能のインストールとカラムの追加
71
  db.setup_vector_extension_and_column()
72
  print("Vector extension installed and column added successfully.")
73
+ query_text="1"
74
+ results = db.search_similar_all(query_text)
75
+ print("Search results:")
76
+ for result in results:
77
+ print(result)
78
+ return
79
  # サンプルデータの挿入
80
+ sample_text = """"""
 
 
 
81
  sample_product_id = 1 # 実際の製品IDを使用
82
  db.insert_vector(sample_product_id, sample_text)
83
  db.insert_vector(2, sample_text)
requirements.txt CHANGED
@@ -56,4 +56,5 @@ torchvision
56
  transformers
57
  langchain
58
  langchain_groq
59
- sqlalchemy
 
 
56
  transformers
57
  langchain
58
  langchain_groq
59
+ sqlalchemy
60
+ sentence-transformers