Spaces:
Runtime error
Runtime error
# search_content.py | |
import faiss | |
import pandas as pd | |
from sentence_transformers import SentenceTransformer | |
# Define paths for model, Faiss index, and data file | |
MODEL_SAVE_PATH = "all-distilroberta-v1-model.pkl" | |
FAISS_INDEX_FILE_PATH = "index.faiss" | |
DATA_FILE_PATH = "omdena_qna_dataset/omdena_faq_training_data.csv" | |
def load_transformer_model(model_file): | |
"""Load a sentence transformer model from a file.""" | |
return SentenceTransformer.load(model_file) | |
def load_faiss_index(filename): | |
"""Load a Faiss index from a file.""" | |
return faiss.read_index(filename) | |
def load_data(file_path): | |
"""Load data from a CSV file and preprocess it.""" | |
data_frame = pd.read_csv(file_path) | |
data_frame["id"] = data_frame.index | |
# Create a 'QNA' column that combines 'Questions' and 'Answers' | |
data_frame['QNA'] = data_frame.apply(lambda row: f"Question: {row['Questions']}, Answer: {row['Answers']}", axis=1) | |
return data_frame.set_index(["id"], drop=False) | |
def search_content(query, data_frame_indexed, transformer_model, faiss_index, k=5): | |
"""Search the content using a query and return the top k results.""" | |
# Encode the query using the model | |
query_vector = transformer_model.encode([query]) | |
# Normalize the query vector | |
faiss.normalize_L2(query_vector) | |
# Search the Faiss index using the query vector | |
top_k = faiss_index.search(query_vector, k) | |
# Extract the IDs and similarities of the top k results | |
ids = top_k[1][0].tolist() | |
similarities = top_k[0][0].tolist() | |
# Get the corresponding results from the data frame | |
results = data_frame_indexed.loc[ids] | |
# Add a column for the similarities | |
results["similarities"] = similarities | |
return results | |
def main_search(query): | |
"""Main function to execute the search.""" | |
transformer_model = load_transformer_model(MODEL_SAVE_PATH) | |
faiss_index = load_faiss_index(FAISS_INDEX_FILE_PATH) | |
data_frame_indexed = load_data(DATA_FILE_PATH) | |
results = search_content(query, data_frame_indexed, transformer_model, faiss_index) | |
return results['QNA'] # return the results | |
if __name__ == "__main__": | |
query = "school courses" | |
print(main_search(query)) # print the results if this script is run directly |