pvanand commited on
Commit
6e45ae9
·
1 Parent(s): 5b552f8

Upload search_content.py

Browse files

Upload search_content.py to enable vector search

Files changed (1) hide show
  1. actions/search_content.py +55 -0
actions/search_content.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # search_content.py
2
+
3
+ import faiss
4
+ import pandas as pd
5
+ from sentence_transformers import SentenceTransformer
6
+
7
+ # Define paths for model, Faiss index, and data file
8
+ MODEL_SAVE_PATH = "all-distilroberta-v1-model.pkl"
9
+ FAISS_INDEX_FILE_PATH = "index.faiss"
10
+ DATA_FILE_PATH = "/content/omdena_faq_training_data.csv"
11
+
12
+ def load_transformer_model(model_file):
13
+ """Load a sentence transformer model from a file."""
14
+ return SentenceTransformer.load(model_file)
15
+
16
+ def load_faiss_index(filename):
17
+ """Load a Faiss index from a file."""
18
+ return faiss.read_index(filename)
19
+
20
+ def load_data(file_path):
21
+ """Load data from a CSV file and preprocess it."""
22
+ data_frame = pd.read_csv(file_path)
23
+ data_frame["id"] = data_frame.index
24
+ # Create a 'QNA' column that combines 'Questions' and 'Answers'
25
+ data_frame['QNA'] = data_frame.apply(lambda row: f"Question: {row['Questions']}, Answer: {row['Answers']}", axis=1)
26
+ return data_frame.set_index(["id"], drop=False)
27
+
28
+ def search_content(query, data_frame_indexed, transformer_model, faiss_index, k=5):
29
+ """Search the content using a query and return the top k results."""
30
+ # Encode the query using the model
31
+ query_vector = transformer_model.encode([query])
32
+ # Normalize the query vector
33
+ faiss.normalize_L2(query_vector)
34
+ # Search the Faiss index using the query vector
35
+ top_k = faiss_index.search(query_vector, k)
36
+ # Extract the IDs and similarities of the top k results
37
+ ids = top_k[1][0].tolist()
38
+ similarities = top_k[0][0].tolist()
39
+ # Get the corresponding results from the data frame
40
+ results = data_frame_indexed.loc[ids]
41
+ # Add a column for the similarities
42
+ results["similarities"] = similarities
43
+ return results
44
+
45
+ def main_search(query):
46
+ """Main function to execute the search."""
47
+ transformer_model = load_transformer_model(MODEL_SAVE_PATH)
48
+ faiss_index = load_faiss_index(FAISS_INDEX_FILE_PATH)
49
+ data_frame_indexed = load_data(DATA_FILE_PATH)
50
+ results = search_content(query, data_frame_indexed, transformer_model, faiss_index)
51
+ return results['QNA'] # return the results
52
+
53
+ if __name__ == "__main__":
54
+ query = "school courses"
55
+ print(main_search(query)) # print the results if this script is run directly