File size: 5,248 Bytes
37c2a8d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import torch
import src.constants.config as configurations
from sentence_transformers import SentenceTransformer
from sentence_transformers import CrossEncoder
from src.constants.credentials import cohere_trial_key
import streamlit as st
from src.reader import Reader
from src.utils_search import UtilsSearch
from copy import deepcopy
import numpy as np
import cohere


configurations = configurations.service_mxbai_msc_direct_config
api_key = cohere_trial_key
co = cohere.Client(api_key)
semantic_column_names = configurations["semantic_column_names"]

# Check CUDA availability and set device
if torch.cuda.is_available():
    torch.cuda.set_device(0)  # Use the first GPU
else:
    st.write("CUDA is not available. Using CPU instead.")

@st.cache_data
def init():
    config = configurations
    search_utils = UtilsSearch(config)
    reader = Reader(config=config["reader_config"])
    model = SentenceTransformer(config['sentence_transformer_name'], device='cuda:0')
    cross_encoder = CrossEncoder(config['cross_encoder_name'], device='cuda:0')
    df = reader.read()
    index = search_utils.dataframe_to_index(df)
    return df, model, cross_encoder, index, search_utils

def get_possible_values_for_column(column_name, search_utils, df):
    if column_name not in st.session_state:
        setattr(st.session_state, column_name, search_utils.top_10_common_values(df, column_name))
    return getattr(st.session_state, column_name)


# Initialize or retrieve from session state
if 'init_results' not in st.session_state:
    st.session_state.init_results = init()

# Now you can access your initialized objects directly from the session state
df, model, cross_encoder, index, search_utils = st.session_state.init_results

# Streamlit app layout
st.title('Search Demo')

# Input fields
query = st.text_input('Enter your search query here')
use_cohere = st.checkbox('Use Cohere', value=False)  # Default to checked

programmatic_search_config = deepcopy(configurations['programmatic_search_config'])

dynamic_programmatic_search_config = {
    "scalar_columns": [],
    "discrete_columns": []
}


for column in programmatic_search_config['scalar_columns']:
    # Create number input for scalar values
    col_name = column["column_name"]
    min_val = float(column["min_value"])
    max_val = float(column["max_value"])
    user_min = st.number_input(f'Minimum {col_name.capitalize()}', min_value=min_val, max_value=max_val, value=min_val)
    user_max = st.number_input(f'Maximum {col_name.capitalize()}', min_value=min_val, max_value=max_val, value=max_val)
    dynamic_programmatic_search_config['scalar_columns'].append({"column_name": col_name, "min_value": user_min, "max_value": user_max})

for column in programmatic_search_config['discrete_columns']:
    # Create multiselect for discrete values
    col_name = column["column_name"]
    default_values = column["default_values"]
    # Assuming you have a function to fetch possible values for the discrete columns based on the column name
    possible_values = get_possible_values_for_column(col_name, search_utils, df)  # Implement this function based on your application
    selected_values = st.multiselect(f'Select {col_name.capitalize()}', options=possible_values, default=default_values)
    dynamic_programmatic_search_config['discrete_columns'].append({"column_name": col_name, "default_values": selected_values})


programmatic_search_config['scalar_columns'] = dynamic_programmatic_search_config['scalar_columns']
programmatic_search_config['discrete_columns'] = dynamic_programmatic_search_config['discrete_columns']


# Search button
if st.button('Search'):
    if query:  # Checking if a query was entered
        df_filtered = search_utils.filter_dataframe(df, programmatic_search_config)
        if len(df_filtered) == 0:
            st.write('No results found')
        else:
            index = search_utils.dataframe_to_index(df_filtered)
            if use_cohere == False:
                # Call your Cohere-based search function here
                results_df = search_utils.search(query, df_filtered, model, cross_encoder, index)
                results_df = search_utils.drop_columns(results_df, programmatic_search_config)

            else:
                df_retrieved = search_utils.retrieve(query, df_filtered, model, index)
                df_retrieved = search_utils.drop_columns(df_retrieved, programmatic_search_config)
                df_retrieved.fillna(value="", inplace=True)
                docs = df_retrieved.to_dict('records')
                column_names = semantic_column_names
                docs = [{name: str(doc[name]) for name in column_names} for doc in docs]
                rank_fields = list(docs[0].keys())
                results = co.rerank(query=query, documents=docs, top_n=10, model='rerank-english-v3.0',
                                    rank_fields=rank_fields)
                top_ids = [hit.index for hit in results.results]
                # Create the DataFrame with the rerank results
                results_df = df_retrieved.iloc[top_ids].copy()
                results_df['rank'] = (np.arange(len(results_df)) + 1)

            st.write(results_df)
    else:
        st.write("Please enter a query to search.")