File size: 2,822 Bytes
7d09b5b 4b2df1b 7d09b5b 99bb109 3359dd0 4b2df1b 7d09b5b 99bb109 4b2df1b 7d09b5b 4b2df1b 7d09b5b 4b2df1b 7d09b5b 4d1b45b 99bb109 5cafae6 99bb109 5cafae6 99bb109 7d09b5b 99bb109 7d09b5b 51f0842 5cafae6 7d09b5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
import gradio as gr
import pandas as pd
import faiss
import numpy as np
import os
from FlagEmbedding import BGEM3FlagModel
# Load the pre-trained embedding model
model = BGEM3FlagModel('BAAI/bge-m3', use_fp16=True)
# Load the JSON data into a DataFrame
df = pd.read_json('White-Stride-Red-68.json')
df['embeding_context'] = df['embeding_context'].astype(str).fillna('')
# Filter out any rows where 'embeding_context' might be empty or invalid
df = df[df['embeding_context'] != '']
# # Encode the 'embeding_context' column
# embedding_contexts = df['embeding_context'].tolist()
# embeddings_csv = model.encode(embedding_contexts, batch_size=12, max_length=1024)['dense_vecs']
# # Convert embeddings to numpy array
# embeddings_np = np.array(embeddings_csv).astype('float32')
# # FAISS index file path
# index_file_path = 'vector_store_bge_m3.index'
# # Check if FAISS index file already exists
# if os.path.exists(index_file_path):
# # Load the existing FAISS index from file
# index = faiss.read_index(index_file_path)
# print("FAISS index loaded from file.")
# else:
# # Initialize FAISS index (for L2 similarity)
# dim = embeddings_np.shape[1]
# index = faiss.IndexFlatL2(dim)
# # Add embeddings to the FAISS index
# index.add(embeddings_np)
# # Save the FAISS index to a file for future use
# faiss.write_index(index, index_file_path)
# print("FAISS index created and saved to file.")
index = faiss.read_index('vector_store_bge_m3.index')
# Function to perform search and return all columns
def search_query(query_text):
num_records = 50
# Encode the input query text
embeddings_query = model.encode([query_text], batch_size=12, max_length=1024)['dense_vecs']
embeddings_query_np = np.array(embeddings_query).astype('float32')
# Search in FAISS index for nearest neighbors
distances, indices = index.search(embeddings_query_np, num_records)
# Get the top results based on FAISS indices
result_df = df.iloc[indices[0]].drop(columns=['embeding_context']).drop_duplicates().reset_index(drop=True)
return result_df
# Gradio interface function
def gradio_interface(query_text):
search_results = search_query(query_text)
return search_results
with gr.Blocks() as app:
gr.Markdown("<h1>White Stride Red Search (BEG-M3)</h1>")
# Input text box for the search query
search_input = gr.Textbox(label="Search Query", placeholder="Enter search text", interactive=True)
# Search button below the text box
search_button = gr.Button("Search")
# Output table for displaying results
search_output = gr.DataFrame(label="Search Results")
# Link button click to action
search_button.click(fn=gradio_interface, inputs=search_input, outputs=search_output)
# Launch the Gradio app
app.launch()
|