eliphatfs
commited on
Commit
•
654bd81
1
Parent(s):
c064d59
Improve retrieval performance.
Browse files
openshape/demo/retrieval.py
CHANGED
@@ -37,6 +37,9 @@ def retrieve(embedding, top, sim_th=0.0, filter_fn=None):
|
|
37 |
sims.append(embedding @ F.normalize(chunk.float(), dim=-1).T)
|
38 |
sims = torch.cat(sims)
|
39 |
sims, idx = torch.sort(sims, descending=True)
|
|
|
|
|
|
|
40 |
results = []
|
41 |
for i, sim in zip(idx, sims):
|
42 |
if us[i] in meta:
|
@@ -44,6 +47,4 @@ def retrieve(embedding, top, sim_th=0.0, filter_fn=None):
|
|
44 |
results.append(dict(meta[us[i]], sim=sim))
|
45 |
if len(results) >= top:
|
46 |
break
|
47 |
-
if sim < sim_th:
|
48 |
-
break
|
49 |
return results
|
|
|
37 |
sims.append(embedding @ F.normalize(chunk.float(), dim=-1).T)
|
38 |
sims = torch.cat(sims)
|
39 |
sims, idx = torch.sort(sims, descending=True)
|
40 |
+
sim_mask = sims > sim_th
|
41 |
+
sims = sims[sim_mask]
|
42 |
+
idx = idx[sim_mask]
|
43 |
results = []
|
44 |
for i, sim in zip(idx, sims):
|
45 |
if us[i] in meta:
|
|
|
47 |
results.append(dict(meta[us[i]], sim=sim))
|
48 |
if len(results) >= top:
|
49 |
break
|
|
|
|
|
50 |
return results
|