File size: 2,822 Bytes
7d09b5b
 
4b2df1b
 
 
7d09b5b
 
 
 
 
 
 
 
 
 
 
99bb109
3359dd0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4b2df1b
7d09b5b
 
 
 
 
 
99bb109
4b2df1b
7d09b5b
4b2df1b
 
7d09b5b
4b2df1b
 
7d09b5b
 
 
 
 
 
 
 
 
4d1b45b
99bb109
5cafae6
 
99bb109
5cafae6
 
99bb109
7d09b5b
 
99bb109
7d09b5b
 
 
51f0842
5cafae6
7d09b5b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import gradio as gr
import pandas as pd
import faiss
import numpy as np
import os
from FlagEmbedding import BGEM3FlagModel

# Load the pre-trained embedding model
model = BGEM3FlagModel('BAAI/bge-m3', use_fp16=True)

# Load the JSON data into a DataFrame
df = pd.read_json('White-Stride-Red-68.json')
df['embeding_context'] = df['embeding_context'].astype(str).fillna('')

# Filter out any rows where 'embeding_context' might be empty or invalid
df = df[df['embeding_context'] != '']

# # Encode the 'embeding_context' column
# embedding_contexts = df['embeding_context'].tolist()
# embeddings_csv = model.encode(embedding_contexts, batch_size=12, max_length=1024)['dense_vecs']

# # Convert embeddings to numpy array
# embeddings_np = np.array(embeddings_csv).astype('float32')

# # FAISS index file path
# index_file_path = 'vector_store_bge_m3.index'

# # Check if FAISS index file already exists
# if os.path.exists(index_file_path):
#     # Load the existing FAISS index from file
#     index = faiss.read_index(index_file_path)
#     print("FAISS index loaded from file.")
# else:
#     # Initialize FAISS index (for L2 similarity)
#     dim = embeddings_np.shape[1]
#     index = faiss.IndexFlatL2(dim)

#     # Add embeddings to the FAISS index
#     index.add(embeddings_np)

#     # Save the FAISS index to a file for future use
#     faiss.write_index(index, index_file_path)
#     print("FAISS index created and saved to file.")

index = faiss.read_index('vector_store_bge_m3.index')


# Function to perform search and return all columns
def search_query(query_text):
    num_records = 50

    # Encode the input query text
    embeddings_query = model.encode([query_text], batch_size=12, max_length=1024)['dense_vecs']
    embeddings_query_np = np.array(embeddings_query).astype('float32')

    # Search in FAISS index for nearest neighbors
    distances, indices = index.search(embeddings_query_np, num_records)

    # Get the top results based on FAISS indices
    result_df = df.iloc[indices[0]].drop(columns=['embeding_context']).drop_duplicates().reset_index(drop=True)

    return result_df

# Gradio interface function
def gradio_interface(query_text):
    search_results = search_query(query_text)
    return search_results

with gr.Blocks() as app:
    gr.Markdown("<h1>White Stride Red Search (BEG-M3)</h1>")

    # Input text box for the search query
    search_input = gr.Textbox(label="Search Query", placeholder="Enter search text", interactive=True)

    # Search button below the text box
    search_button = gr.Button("Search")

    # Output table for displaying results
    search_output = gr.DataFrame(label="Search Results")

    # Link button click to action
    search_button.click(fn=gradio_interface, inputs=search_input, outputs=search_output)



# Launch the Gradio app
app.launch()