eliphatfs commited on
Commit
654bd81
1 Parent(s): c064d59

Improve retrieval performance.

Browse files
Files changed (1) hide show
  1. openshape/demo/retrieval.py +3 -2
openshape/demo/retrieval.py CHANGED
@@ -37,6 +37,9 @@ def retrieve(embedding, top, sim_th=0.0, filter_fn=None):
37
  sims.append(embedding @ F.normalize(chunk.float(), dim=-1).T)
38
  sims = torch.cat(sims)
39
  sims, idx = torch.sort(sims, descending=True)
 
 
 
40
  results = []
41
  for i, sim in zip(idx, sims):
42
  if us[i] in meta:
@@ -44,6 +47,4 @@ def retrieve(embedding, top, sim_th=0.0, filter_fn=None):
44
  results.append(dict(meta[us[i]], sim=sim))
45
  if len(results) >= top:
46
  break
47
- if sim < sim_th:
48
- break
49
  return results
 
37
  sims.append(embedding @ F.normalize(chunk.float(), dim=-1).T)
38
  sims = torch.cat(sims)
39
  sims, idx = torch.sort(sims, descending=True)
40
+ sim_mask = sims > sim_th
41
+ sims = sims[sim_mask]
42
+ idx = idx[sim_mask]
43
  results = []
44
  for i, sim in zip(idx, sims):
45
  if us[i] in meta:
 
47
  results.append(dict(meta[us[i]], sim=sim))
48
  if len(results) >= top:
49
  break
 
 
50
  return results