Spaces:
Sleeping
Sleeping
bibliotecadebabel
commited on
Commit
•
b1179cf
1
Parent(s):
37c2a8d
mxbai endpoint
Browse files- .gitignore +31 -0
- app.py +29 -22
- requirements.txt +2 -1
- src/constants/config.py +0 -46
- src/constants/credentials.py +2 -1
- src/utils_search.py +18 -10
.gitignore
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
lib
|
2 |
+
dist
|
3 |
+
allWatches.json
|
4 |
+
vectors.json
|
5 |
+
node_modules/
|
6 |
+
/test-results/
|
7 |
+
/playwright-report/
|
8 |
+
/blob-report/
|
9 |
+
/playwright/.cache/
|
10 |
+
|
11 |
+
|
12 |
+
src/all_shopping_scrape_results/*
|
13 |
+
src/all_product_scrape_results/*
|
14 |
+
.envrc
|
15 |
+
tsconfig.tsbuildinfo
|
16 |
+
.tshy
|
17 |
+
.tshy-build
|
18 |
+
response-cache
|
19 |
+
venv
|
20 |
+
*.html
|
21 |
+
response2.json
|
22 |
+
code.py
|
23 |
+
mic-scrape.json
|
24 |
+
*.json
|
25 |
+
*.parquet
|
26 |
+
__pycache__
|
27 |
+
.DS_Store
|
28 |
+
.passwd-s3fs
|
29 |
+
.idea/*
|
30 |
+
myenv/
|
31 |
+
env/
|
app.py
CHANGED
@@ -2,36 +2,32 @@ import torch
|
|
2 |
import src.constants.config as configurations
|
3 |
from sentence_transformers import SentenceTransformer
|
4 |
from sentence_transformers import CrossEncoder
|
5 |
-
from src.constants.credentials import cohere_trial_key
|
6 |
import streamlit as st
|
7 |
from src.reader import Reader
|
8 |
from src.utils_search import UtilsSearch
|
9 |
from copy import deepcopy
|
10 |
import numpy as np
|
11 |
import cohere
|
|
|
|
|
12 |
|
13 |
|
14 |
configurations = configurations.service_mxbai_msc_direct_config
|
15 |
api_key = cohere_trial_key
|
16 |
co = cohere.Client(api_key)
|
17 |
semantic_column_names = configurations["semantic_column_names"]
|
18 |
-
|
19 |
-
|
20 |
-
if torch.cuda.is_available():
|
21 |
-
torch.cuda.set_device(0) # Use the first GPU
|
22 |
-
else:
|
23 |
-
st.write("CUDA is not available. Using CPU instead.")
|
24 |
|
25 |
@st.cache_data
|
26 |
def init():
|
27 |
config = configurations
|
28 |
search_utils = UtilsSearch(config)
|
29 |
reader = Reader(config=config["reader_config"])
|
30 |
-
model = SentenceTransformer(config['sentence_transformer_name'], device='cuda:0')
|
31 |
-
cross_encoder = CrossEncoder(config['cross_encoder_name'], device='cuda:0')
|
32 |
df = reader.read()
|
33 |
index = search_utils.dataframe_to_index(df)
|
34 |
-
return df,
|
35 |
|
36 |
def get_possible_values_for_column(column_name, search_utils, df):
|
37 |
if column_name not in st.session_state:
|
@@ -44,14 +40,15 @@ if 'init_results' not in st.session_state:
|
|
44 |
st.session_state.init_results = init()
|
45 |
|
46 |
# Now you can access your initialized objects directly from the session state
|
47 |
-
df,
|
48 |
|
49 |
# Streamlit app layout
|
50 |
st.title('Search Demo')
|
51 |
|
52 |
# Input fields
|
53 |
query = st.text_input('Enter your search query here')
|
54 |
-
use_cohere = st.checkbox('Use Cohere', value=False) # Default to checked
|
|
|
55 |
|
56 |
programmatic_search_config = deepcopy(configurations['programmatic_search_config'])
|
57 |
|
@@ -87,21 +84,31 @@ programmatic_search_config['discrete_columns'] = dynamic_programmatic_search_con
|
|
87 |
# Search button
|
88 |
if st.button('Search'):
|
89 |
if query: # Checking if a query was entered
|
90 |
-
|
|
|
|
|
|
|
|
|
91 |
if len(df_filtered) == 0:
|
92 |
st.write('No results found')
|
93 |
else:
|
94 |
-
index = search_utils.dataframe_to_index(df_filtered)
|
95 |
if use_cohere == False:
|
96 |
-
|
97 |
-
|
98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
|
100 |
else:
|
101 |
-
|
102 |
-
|
103 |
-
df_retrieved.fillna(value="", inplace=True)
|
104 |
-
docs = df_retrieved.to_dict('records')
|
105 |
column_names = semantic_column_names
|
106 |
docs = [{name: str(doc[name]) for name in column_names} for doc in docs]
|
107 |
rank_fields = list(docs[0].keys())
|
@@ -109,7 +116,7 @@ if st.button('Search'):
|
|
109 |
rank_fields=rank_fields)
|
110 |
top_ids = [hit.index for hit in results.results]
|
111 |
# Create the DataFrame with the rerank results
|
112 |
-
results_df =
|
113 |
results_df['rank'] = (np.arange(len(results_df)) + 1)
|
114 |
|
115 |
st.write(results_df)
|
|
|
2 |
import src.constants.config as configurations
|
3 |
from sentence_transformers import SentenceTransformer
|
4 |
from sentence_transformers import CrossEncoder
|
5 |
+
from src.constants.credentials import cohere_trial_key, mixedbread_key
|
6 |
import streamlit as st
|
7 |
from src.reader import Reader
|
8 |
from src.utils_search import UtilsSearch
|
9 |
from copy import deepcopy
|
10 |
import numpy as np
|
11 |
import cohere
|
12 |
+
from mixedbread_ai.client import MixedbreadAI
|
13 |
+
from src.pytorch_modules.datasets.schema_string_dataset import SchemaStringDataset
|
14 |
|
15 |
|
16 |
configurations = configurations.service_mxbai_msc_direct_config
|
17 |
api_key = cohere_trial_key
|
18 |
co = cohere.Client(api_key)
|
19 |
semantic_column_names = configurations["semantic_column_names"]
|
20 |
+
model = MixedbreadAI(api_key=mixedbread_key)
|
21 |
+
cross_encoder_name = configurations["cross_encoder_name"]
|
|
|
|
|
|
|
|
|
22 |
|
23 |
@st.cache_data
|
24 |
def init():
|
25 |
config = configurations
|
26 |
search_utils = UtilsSearch(config)
|
27 |
reader = Reader(config=config["reader_config"])
|
|
|
|
|
28 |
df = reader.read()
|
29 |
index = search_utils.dataframe_to_index(df)
|
30 |
+
return df, index, search_utils
|
31 |
|
32 |
def get_possible_values_for_column(column_name, search_utils, df):
|
33 |
if column_name not in st.session_state:
|
|
|
40 |
st.session_state.init_results = init()
|
41 |
|
42 |
# Now you can access your initialized objects directly from the session state
|
43 |
+
df, index, search_utils = st.session_state.init_results
|
44 |
|
45 |
# Streamlit app layout
|
46 |
st.title('Search Demo')
|
47 |
|
48 |
# Input fields
|
49 |
query = st.text_input('Enter your search query here')
|
50 |
+
# use_cohere = st.checkbox('Use Cohere', value=False) # Default to checked
|
51 |
+
use_cohere = False
|
52 |
|
53 |
programmatic_search_config = deepcopy(configurations['programmatic_search_config'])
|
54 |
|
|
|
84 |
# Search button
|
85 |
if st.button('Search'):
|
86 |
if query: # Checking if a query was entered
|
87 |
+
df_retrieved = search_utils.retrieve(query, df, model, index, top_k=1000, api=True)
|
88 |
+
df_filtered = search_utils.filter_dataframe(df_retrieved, programmatic_search_config)
|
89 |
+
df_filtered = df_filtered.sort_values(by='similarities', ascending=True)
|
90 |
+
df_filtered = df_filtered[:100].reset_index(drop=True)
|
91 |
+
|
92 |
if len(df_filtered) == 0:
|
93 |
st.write('No results found')
|
94 |
else:
|
|
|
95 |
if use_cohere == False:
|
96 |
+
records = df_filtered.to_dict(orient='records')
|
97 |
+
dataset_str = SchemaStringDataset(records, configurations)
|
98 |
+
documents = [batch["inputs"][:256] for batch in dataset_str]
|
99 |
+
res = model.reranking(
|
100 |
+
model=cross_encoder_name,
|
101 |
+
query=query,
|
102 |
+
input=documents,
|
103 |
+
top_k=10,
|
104 |
+
return_input=False
|
105 |
+
)
|
106 |
+
ids = [item.index for item in res.data]
|
107 |
+
results_df = df_filtered.loc[ids]
|
108 |
|
109 |
else:
|
110 |
+
df_filtered.fillna(value="", inplace=True)
|
111 |
+
docs = df_filtered.to_dict('records')
|
|
|
|
|
112 |
column_names = semantic_column_names
|
113 |
docs = [{name: str(doc[name]) for name in column_names} for doc in docs]
|
114 |
rank_fields = list(docs[0].keys())
|
|
|
116 |
rank_fields=rank_fields)
|
117 |
top_ids = [hit.index for hit in results.results]
|
118 |
# Create the DataFrame with the rerank results
|
119 |
+
results_df = df_filtered.iloc[top_ids].copy()
|
120 |
results_df['rank'] = (np.arange(len(results_df)) + 1)
|
121 |
|
122 |
st.write(results_df)
|
requirements.txt
CHANGED
@@ -8,4 +8,5 @@ s3fs
|
|
8 |
numpy
|
9 |
faiss-gpu
|
10 |
sentence_transformers
|
11 |
-
cohere
|
|
|
|
8 |
numpy
|
9 |
faiss-gpu
|
10 |
sentence_transformers
|
11 |
+
cohere
|
12 |
+
mixedbread_ai
|
src/constants/config.py
CHANGED
@@ -1,52 +1,6 @@
|
|
1 |
import src.constants.credentials as cred
|
2 |
import os
|
3 |
|
4 |
-
service_mxbai_made_in_china_config = {"reader_config": {"input_path": os.environ['made_in_china_s3_path'],
|
5 |
-
"credentials": cred.credentials_backblaze,
|
6 |
-
"format":"parquet"
|
7 |
-
},
|
8 |
-
"sample_size": 32,
|
9 |
-
"sentence_transformer_name": "mixedbread-ai/mxbai-embed-large-v1",
|
10 |
-
"cross_encoder_name": "mixedbread-ai/mxbai-rerank-large-v1",
|
11 |
-
"batch_size": 4,
|
12 |
-
"dataset_size": 32,
|
13 |
-
"seq_len": 256,
|
14 |
-
"top_k": 1000,
|
15 |
-
"programmatic_search_config": {
|
16 |
-
"scalar_columns": [{"column_name": "price", "min_value": 0, "max_value": "10000"}],
|
17 |
-
"discrete_columns": [{"column_name": "supplierName",
|
18 |
-
# "default_values": ['Zhongshan Norye Hardware Co., Ltd.']
|
19 |
-
"default_values": []
|
20 |
-
},
|
21 |
-
{"column_name": "warranty",
|
22 |
-
# "default_values": ['Zhongshan Norye Hardware Co., Ltd.']
|
23 |
-
"default_values": []
|
24 |
-
}
|
25 |
-
],
|
26 |
-
"columns_to_drop": ["similarities", "embeddings"]
|
27 |
-
}
|
28 |
-
}
|
29 |
-
|
30 |
-
|
31 |
-
service_mxbai_msc_direct_sample_config = {"reader_config": {"input_path": os.environ['msc_direct_s3_path'],
|
32 |
-
"credentials": cred.credentials_backblaze,
|
33 |
-
"format":"parquet"
|
34 |
-
},
|
35 |
-
"sample_size": 32,
|
36 |
-
"sentence_transformer_name": "mixedbread-ai/mxbai-embed-large-v1",
|
37 |
-
"cross_encoder_name": "mixedbread-ai/mxbai-rerank-large-v1",
|
38 |
-
"batch_size": 4,
|
39 |
-
"dataset_size": 32,
|
40 |
-
"seq_len": 256,
|
41 |
-
"top_k": 50,
|
42 |
-
"semantic_column_names": ['name', 'price', 'brand', 'keyword', 'description',
|
43 |
-
'specifications'],
|
44 |
-
"programmatic_search_config": {
|
45 |
-
"scalar_columns": [{"column_name": "price", "min_value": 0, "max_value": "10000"}],
|
46 |
-
"discrete_columns": [{"column_name": "brand", "default_values": []}],
|
47 |
-
"columns_to_drop": ["similarities", "embeddings", "index"]
|
48 |
-
}
|
49 |
-
}
|
50 |
|
51 |
service_mxbai_msc_direct_config = {"reader_config": {"input_path": os.environ['msc_direct_s3_path'],
|
52 |
"credentials": cred.credentials_backblaze,
|
|
|
1 |
import src.constants.credentials as cred
|
2 |
import os
|
3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
service_mxbai_msc_direct_config = {"reader_config": {"input_path": os.environ['msc_direct_s3_path'],
|
6 |
"credentials": cred.credentials_backblaze,
|
src/constants/credentials.py
CHANGED
@@ -8,4 +8,5 @@ credentials_backblaze = {"access_key_id": os.environ['credentials_backblaze_acce
|
|
8 |
}
|
9 |
|
10 |
|
11 |
-
cohere_trial_key = os.environ["cohere_trial_key"]
|
|
|
|
8 |
}
|
9 |
|
10 |
|
11 |
+
cohere_trial_key = os.environ["cohere_trial_key"]
|
12 |
+
mixedbread_key = os.environ["mixedbread_key"]
|
src/utils_search.py
CHANGED
@@ -21,8 +21,8 @@ class UtilsSearch:
|
|
21 |
index.add(norm_embeddings)
|
22 |
return index # Ad
|
23 |
|
24 |
-
|
25 |
-
def retrieve(query, df, model, index, top_k=100):
|
26 |
query += "Represent this sentence for searching relevant passages: "
|
27 |
"""
|
28 |
Search the DataFrame for the given query and return a sorted DataFrame based on similarity.
|
@@ -35,14 +35,24 @@ class UtilsSearch:
|
|
35 |
:return: A new DataFrame sorted by similarity to the query, with a 'similarities' column.
|
36 |
"""
|
37 |
# Check if CUDA is available and set the device accordingly
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
-
# Compute the query embedding
|
42 |
-
query_vector = model.encode(query, convert_to_tensor=True, device=device).cpu().numpy()
|
43 |
|
44 |
# Normalize the query vector
|
45 |
-
|
46 |
|
47 |
# Perform the search
|
48 |
distances, indices = index.search(np.array([query_vector]), top_k)
|
@@ -55,7 +65,7 @@ class UtilsSearch:
|
|
55 |
retrieved_df = retrieved_df.assign(similarities=distances[0])
|
56 |
|
57 |
if 'similarities' in retrieved_df.columns:
|
58 |
-
retrieved_df = retrieved_df.sort_values(by='similarities', ascending=
|
59 |
|
60 |
# Optionally, you might want to reset the index if the order matters or if you need to serialize the DataFrame without index issues
|
61 |
retrieved_df = retrieved_df.reset_index(drop=True)
|
@@ -149,5 +159,3 @@ class UtilsSearch:
|
|
149 |
columns_to_drop = config.get('columns_to_drop', [])
|
150 |
df_dropped = df.drop(columns_to_drop, axis=1)
|
151 |
return df_dropped
|
152 |
-
|
153 |
-
|
|
|
21 |
index.add(norm_embeddings)
|
22 |
return index # Ad
|
23 |
|
24 |
+
|
25 |
+
def retrieve(self, query, df, model, index, top_k=100, api=False):
|
26 |
query += "Represent this sentence for searching relevant passages: "
|
27 |
"""
|
28 |
Search the DataFrame for the given query and return a sorted DataFrame based on similarity.
|
|
|
35 |
:return: A new DataFrame sorted by similarity to the query, with a 'similarities' column.
|
36 |
"""
|
37 |
# Check if CUDA is available and set the device accordingly
|
38 |
+
if not api:
|
39 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
40 |
+
model.to(device)
|
41 |
+
|
42 |
+
# Compute the query embedding
|
43 |
+
query_vector = model.encode(query, convert_to_tensor=True, device=device).cpu().numpy()
|
44 |
+
query_vector /= np.linalg.norm(query_vector)
|
45 |
+
else:
|
46 |
+
res = model.embeddings(
|
47 |
+
input=[query],
|
48 |
+
model=self.config["sentence_transformer_name"],
|
49 |
+
prompt=None,
|
50 |
+
)
|
51 |
+
query_vector = np.array([entry.embedding for entry in res.data][0]).astype(np.float32)
|
52 |
|
|
|
|
|
53 |
|
54 |
# Normalize the query vector
|
55 |
+
|
56 |
|
57 |
# Perform the search
|
58 |
distances, indices = index.search(np.array([query_vector]), top_k)
|
|
|
65 |
retrieved_df = retrieved_df.assign(similarities=distances[0])
|
66 |
|
67 |
if 'similarities' in retrieved_df.columns:
|
68 |
+
retrieved_df = retrieved_df.sort_values(by='similarities', ascending=True)
|
69 |
|
70 |
# Optionally, you might want to reset the index if the order matters or if you need to serialize the DataFrame without index issues
|
71 |
retrieved_df = retrieved_df.reset_index(drop=True)
|
|
|
159 |
columns_to_drop = config.get('columns_to_drop', [])
|
160 |
df_dropped = df.drop(columns_to_drop, axis=1)
|
161 |
return df_dropped
|
|
|
|