Vsevolod commited on
Commit
fa438f6
·
1 Parent(s): 2c135d2

added core functional

Browse files
Files changed (3) hide show
  1. .gitignore +1 -0
  2. app.py +25 -5
  3. requirements.txt +4 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ __pycache__
app.py CHANGED
@@ -1,4 +1,8 @@
 
 
 
1
  import gradio as gr
 
2
 
3
 
4
  K = 5
@@ -20,16 +24,32 @@ def create_demo(callback):
20
 
21
 
22
  class Callback:
23
- def __init__(self):
24
- pass
 
 
 
 
25
 
26
  def __call__(self, input_name):
27
- res = ["These are completely different companies."] * K
28
- return res
 
 
 
 
 
 
 
 
29
 
30
 
31
  def main():
32
- callback = Callback()
 
 
 
 
33
  demo = create_demo(callback)
34
  demo.launch()
35
 
 
1
+ import torch
2
+ import pickle
3
+ import nmslib
4
  import gradio as gr
5
+ from sentence_transformers import SentenceTransformer
6
 
7
 
8
  K = 5
 
24
 
25
 
26
  class Callback:
27
+ def __init__(self, model, data):
28
+ self.index = nmslib.init(method='hnsw', space='cosinesimil')
29
+ self.index.addDataPointBatch(data["emb"])
30
+ self.index.createIndex({'post': 2}, print_progress=True)
31
+ self.model = model
32
+ self.data = data
33
 
34
  def __call__(self, input_name):
35
+ emb = self.model.encode(input_name)
36
+ ids, _ = self.index.knnQuery(emb, k=K)
37
+ names = [self.data["names"][id] for id in ids]
38
+ return names
39
+
40
+
41
+ def load_data(filename):
42
+ with open(filename, "rb") as file:
43
+ data = pickle.load(file)
44
+ return data
45
 
46
 
47
  def main():
48
+ data = load_data("data.pickle")
49
+ device = "cuda:0" if torch.cuda.is_available() else "cpu"
50
+ model = SentenceTransformer("Vsevolod/company-names-similarity-sentence-transformer").to(device)
51
+ callback = Callback(model, data)
52
+
53
  demo = create_demo(callback)
54
  demo.launch()
55
 
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch
2
+ gradio
3
+ nmslib
4
+ sentence_transformers