File size: 2,239 Bytes
6e45ae9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# search_content.py

import faiss
import pandas as pd
from sentence_transformers import SentenceTransformer

# Define paths for model, Faiss index, and data file
MODEL_SAVE_PATH = "all-distilroberta-v1-model.pkl"
FAISS_INDEX_FILE_PATH = "index.faiss"
DATA_FILE_PATH = "/content/omdena_faq_training_data.csv"

def load_transformer_model(model_file):
    """Load a sentence transformer model from a file."""
    return SentenceTransformer.load(model_file)

def load_faiss_index(filename):
    """Load a Faiss index from a file."""
    return faiss.read_index(filename)

def load_data(file_path):
    """Load data from a CSV file and preprocess it."""
    data_frame = pd.read_csv(file_path)
    data_frame["id"] = data_frame.index
    # Create a 'QNA' column that combines 'Questions' and 'Answers'
    data_frame['QNA'] = data_frame.apply(lambda row: f"Question: {row['Questions']}, Answer: {row['Answers']}", axis=1)
    return data_frame.set_index(["id"], drop=False)

def search_content(query, data_frame_indexed, transformer_model, faiss_index, k=5):
    """Search the content using a query and return the top k results."""
    # Encode the query using the model
    query_vector = transformer_model.encode([query])
    # Normalize the query vector
    faiss.normalize_L2(query_vector)
    # Search the Faiss index using the query vector
    top_k = faiss_index.search(query_vector, k)
    # Extract the IDs and similarities of the top k results
    ids = top_k[1][0].tolist()
    similarities = top_k[0][0].tolist()
    # Get the corresponding results from the data frame
    results = data_frame_indexed.loc[ids]
    # Add a column for the similarities
    results["similarities"] = similarities
    return results

def main_search(query):
    """Main function to execute the search."""
    transformer_model = load_transformer_model(MODEL_SAVE_PATH)
    faiss_index = load_faiss_index(FAISS_INDEX_FILE_PATH)
    data_frame_indexed = load_data(DATA_FILE_PATH)
    results = search_content(query, data_frame_indexed, transformer_model, faiss_index)
    return results['QNA']  # return the results

if __name__ == "__main__":
    query = "school courses"
    print(main_search(query))  # print the results if this script is run directly