orionweller commited on
Commit
d89580e
1 Parent(s): d231d5c
Files changed (2) hide show
  1. app.py +41 -10
  2. scifact/faiss_index.bin +3 -0
app.py CHANGED
@@ -15,6 +15,7 @@ import spaces
15
  import ir_datasets
16
  import pytrec_eval
17
  from huggingface_hub import login
 
18
 
19
  # Set up logging
20
  logging.basicConfig(level=logging.INFO)
@@ -72,21 +73,51 @@ def load_model():
72
  model = model.merge_and_unload()
73
  model.eval()
74
 
 
 
 
 
 
 
 
 
 
 
 
 
75
  def load_corpus_embeddings(dataset_name):
76
  global retrievers, corpus_lookups
77
  corpus_path = f"{dataset_name}/corpus_emb.*.pkl"
78
  index_files = glob.glob(corpus_path)
79
  logger.info(f'Loading {len(index_files)} files into index for {dataset_name}.')
80
 
81
- p_reps_0, p_lookup_0 = pickle_load(index_files[0])
82
- retrievers[dataset_name] = FaissFlatSearcher(p_reps_0)
83
-
84
- shards = [(p_reps_0, p_lookup_0)] + [pickle_load(f) for f in index_files[1:]]
85
- corpus_lookups[dataset_name] = []
86
 
87
- for p_reps, p_lookup in tqdm.tqdm(shards, desc=f'Loading shards into index for {dataset_name}', total=len(index_files)):
88
- retrievers[dataset_name].add(p_reps)
89
- corpus_lookups[dataset_name] += p_lookup
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  def pickle_load(path):
92
  with open(path, 'rb') as f:
@@ -187,7 +218,6 @@ def gradio_interface(dataset, postfix):
187
  return run_evaluation(dataset, postfix)
188
 
189
 
190
-
191
  # Create Gradio interface
192
  iface = gr.Interface(
193
  fn=gradio_interface,
@@ -201,7 +231,8 @@ iface = gr.Interface(
201
  examples=[
202
  ["scifact", ""],
203
  ["scifact", "When judging the relevance of a document, focus on the pragmatics of the query and consider irrelevant any documents for which the user would have used a different query."]
204
- ]
 
205
  )
206
 
207
  # Launch the interface
 
15
  import ir_datasets
16
  import pytrec_eval
17
  from huggingface_hub import login
18
+ import faiss
19
 
20
  # Set up logging
21
  logging.basicConfig(level=logging.INFO)
 
73
  model = model.merge_and_unload()
74
  model.eval()
75
 
76
+ def save_faiss_index(index, dataset_name):
77
+ index_path = f"{dataset_name}/faiss_index.bin"
78
+ faiss.write_index(index, index_path)
79
+ logger.info(f"Saved FAISS index for {dataset_name} to {index_path}")
80
+
81
+ def load_faiss_index(dataset_name):
82
+ index_path = f"{dataset_name}/faiss_index.bin"
83
+ if os.path.exists(index_path):
84
+ logger.info(f"Loading existing FAISS index for {dataset_name} from {index_path}")
85
+ return faiss.read_index(index_path, faiss.IO_FLAG_MMAP)
86
+ return None
87
+
88
  def load_corpus_embeddings(dataset_name):
89
  global retrievers, corpus_lookups
90
  corpus_path = f"{dataset_name}/corpus_emb.*.pkl"
91
  index_files = glob.glob(corpus_path)
92
  logger.info(f'Loading {len(index_files)} files into index for {dataset_name}.')
93
 
94
+ # Try to load existing FAISS index
95
+ faiss_index = load_faiss_index(dataset_name)
 
 
 
96
 
97
+ if faiss_index is None:
98
+ # If no existing index, create a new one
99
+ p_reps_0, p_lookup_0 = pickle_load(index_files[0])
100
+ retrievers[dataset_name] = FaissFlatSearcher(p_reps_0)
101
+
102
+ shards = [(p_reps_0, p_lookup_0)] + [pickle_load(f) for f in index_files[1:]]
103
+ corpus_lookups[dataset_name] = []
104
+
105
+ for p_reps, p_lookup in tqdm.tqdm(shards, desc=f'Loading shards into index for {dataset_name}', total=len(index_files)):
106
+ retrievers[dataset_name].add(p_reps)
107
+ corpus_lookups[dataset_name] += p_lookup
108
+
109
+ # Save the newly created index
110
+ save_faiss_index(retrievers[dataset_name].index, dataset_name)
111
+ else:
112
+ # Use the loaded index
113
+ retrievers[dataset_name] = FaissFlatSearcher(faiss_index)
114
+
115
+ # Load corpus lookups
116
+ corpus_lookups[dataset_name] = []
117
+ for file in index_files:
118
+ _, p_lookup = pickle_load(file)
119
+ corpus_lookups[dataset_name] += p_lookup
120
+
121
 
122
  def pickle_load(path):
123
  with open(path, 'rb') as f:
 
218
  return run_evaluation(dataset, postfix)
219
 
220
 
 
221
  # Create Gradio interface
222
  iface = gr.Interface(
223
  fn=gradio_interface,
 
231
  examples=[
232
  ["scifact", ""],
233
  ["scifact", "When judging the relevance of a document, focus on the pragmatics of the query and consider irrelevant any documents for which the user would have used a different query."]
234
+ ],
235
+ cache_examples=True,
236
  )
237
 
238
  # Launch the interface
scifact/faiss_index.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d04b686b0c2f04a4fdeabb58c840eacf9471ac3f4625395f7664419b3c51cf57
3
+ size 84918317