geekyrakshit commited on
Commit
77a97ce
·
1 Parent(s): 70d9de4

add: utility save_vector_index

Browse files
medrag_multi_modal/retrieval/common.py CHANGED
@@ -1,5 +1,11 @@
1
  from enum import Enum
2
 
 
 
 
 
 
 
3
 
4
  class SimilarityMetric(Enum):
5
  COSINE = "cosine"
@@ -19,3 +25,20 @@ def argsort_scores(scores: list[float], descending: bool = False):
19
  list(enumerate(scores)), key=lambda x: x[1], reverse=descending
20
  )
21
  ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from enum import Enum
2
 
3
+ import safetensors
4
+ import safetensors.torch
5
+ import torch
6
+
7
+ import wandb
8
+
9
 
10
  class SimilarityMetric(Enum):
11
  COSINE = "cosine"
 
25
  list(enumerate(scores)), key=lambda x: x[1], reverse=descending
26
  )
27
  ]
28
+
29
+
30
+ def save_vector_index(
31
+ vector_index: torch.Tensor,
32
+ index_name: str,
33
+ metadata: dict,
34
+ filename: str = "vector_index.safetensors",
35
+ ):
36
+ safetensors.torch.save_file({"vector_index": vector_index.cpu()}, filename)
37
+ if wandb.run:
38
+ artifact = wandb.Artifact(
39
+ name=index_name,
40
+ type="contriever-index",
41
+ metadata=metadata,
42
+ )
43
+ artifact.add_file(filename)
44
+ artifact.save()