|
import numpy as np |
|
import pandas as pd |
|
import time |
|
from sentence_transformers import SentenceTransformer |
|
from redis.commands.search.field import VectorField |
|
from redis.commands.search.field import TextField |
|
from redis.commands.search.field import TagField |
|
from redis.commands.search.query import Query |
|
import redis |
|
from tqdm import tqdm |
|
import google.generativeai as palm |
|
import pandas as pd |
|
from langchain.chains import LLMChain |
|
|
|
|
|
from langchain.prompts import PromptTemplate |
|
|
|
import os |
|
|
|
import gradio as gr |
|
import io |
|
|
|
from langchain.llms import GooglePalm |
|
import pandas as pd |
|
|
|
|
|
from langchain.embeddings import GooglePalmEmbeddings |
|
from langchain.memory import ConversationBufferMemory |
|
from dotenv import load_dotenv |
|
|
|
load_dotenv() |
|
|
|
redis_conn = redis.Redis( |
|
host='redis-15860.c322.us-east-1-2.ec2.cloud.redislabs.com', |
|
port=15860, |
|
password='PVnvSZI5nISPsrxxhCHZF3pfZWI7YAIG') |
|
|
|
''' |
|
df = pd.read_csv("coms3.csv") |
|
|
|
|
|
print(list(df)) |
|
|
|
print(df['item_keywords'].sample(2)) |
|
|
|
company_metadata = df.to_dict(orient='index') |
|
|
|
|
|
model = SentenceTransformer('sentence-transformers/all-distilroberta-v1') |
|
|
|
|
|
item_keywords = [company_metadata[i]['item_keywords'] for i in company_metadata.keys()] |
|
item_keywords_vectors = [] |
|
for sentence in tqdm(item_keywords): |
|
s = model.encode(sentence) |
|
item_keywords_vectors.append(s) |
|
|
|
print(company_metadata[0]) |
|
|
|
def load_vectors(client, company_metadata, vector_dict, vector_field_name): |
|
p = client.pipeline(transaction=False) |
|
for index in company_metadata.keys(): |
|
#hash key |
|
#print(index) |
|
#print(company_metadata[index]['company_l_id']) |
|
try: |
|
key=str('company:'+ str(index)+ ':' + company_metadata[index]['primary_key']) |
|
except: |
|
print(key) |
|
continue |
|
|
|
|
|
#hash values |
|
item_metadata = company_metadata[index] |
|
item_keywords_vector = vector_dict[index].astype(np.float32).tobytes() |
|
item_metadata[vector_field_name]=item_keywords_vector |
|
|
|
# HSET |
|
p.hset(key,mapping=item_metadata) |
|
|
|
p.execute() |
|
|
|
def create_flat_index (redis_conn,vector_field_name,number_of_vectors, vector_dimensions=512, distance_metric='L2'): |
|
redis_conn.ft().create_index([ |
|
VectorField(vector_field_name, "FLAT", {"TYPE": "FLOAT32", "DIM": vector_dimensions, "DISTANCE_METRIC": distance_metric, "INITIAL_CAP": number_of_vectors, "BLOCK_SIZE":number_of_vectors }), |
|
TagField("company_l_id"), |
|
TextField("company_name"), |
|
TextField("item_keywords"), |
|
TagField("industry") |
|
]) |
|
|
|
ITEM_KEYWORD_EMBEDDING_FIELD='item_keyword_vector' |
|
TEXT_EMBEDDING_DIMENSION=768 |
|
NUMBER_COMPANIES=1000 |
|
|
|
print ('Loading and Indexing + ' + str(NUMBER_COMPANIES) + 'companies') |
|
|
|
#flush all data |
|
redis_conn.flushall() |
|
|
|
#create flat index & load vectors |
|
create_flat_index(redis_conn, ITEM_KEYWORD_EMBEDDING_FIELD,NUMBER_COMPANIES,TEXT_EMBEDDING_DIMENSION,'COSINE') |
|
load_vectors(redis_conn,company_metadata,item_keywords_vectors,ITEM_KEYWORD_EMBEDDING_FIELD) |
|
''' |
|
model = SentenceTransformer('sentence-transformers/all-distilroberta-v1') |
|
ITEM_KEYWORD_EMBEDDING_FIELD='item_keyword_vector' |
|
TEXT_EMBEDDING_DIMENSION=768 |
|
NUMBER_PRODUCTS=1000 |
|
|
|
prompt = PromptTemplate( |
|
input_variables=["company_description"], |
|
template='Create comma seperated company keywords to perform a query on a company dataset for this user input' |
|
) |
|
|
|
template = """You are a chatbot. Be kind, detailed and nice. Present the given queried search result in a nice way as answer to the user input. dont ask questions back! just take the given context |
|
|
|
{chat_history} |
|
Human: {user_question} |
|
Chatbot: |
|
""" |
|
|
|
prompt = PromptTemplate( |
|
input_variables=["chat_history", "user_question"], |
|
template=template |
|
) |
|
chat_history= "" |
|
def answer(user_question): |
|
llm = GooglePalm(temperature=0, google_api_key=os.environ['PALM']) |
|
chain = LLMChain(llm=llm, prompt=prompt) |
|
keywords = chain.run({'user_question':user_question, 'chat_history':chat_history}) |
|
|
|
topK=3 |
|
|
|
query_vector = model.encode(keywords).astype(np.float32).tobytes() |
|
|
|
q = Query(f'*=>[KNN {topK} @{ITEM_KEYWORD_EMBEDDING_FIELD} $vec_param AS vector_score]').sort_by('vector_score').paging(0,topK).return_fields('vector_score','item_name','item_id','item_keywords').dialect(2) |
|
params_dict = {"vec_param": query_vector} |
|
|
|
|
|
results = redis_conn.ft().search(q, query_params = params_dict) |
|
|
|
full_result_string = '' |
|
for company in results.docs: |
|
full_result_string += company.company_name + ' ' + company.item_keywords + ' ' + company.company_l_id + "\n\n\n" |
|
|
|
memory = ConversationBufferMemory(memory_key="chat_history") |
|
llm_chain = LLMChain( |
|
llm=llm, |
|
prompt=prompt, |
|
verbose=False, |
|
memory=memory, |
|
) |
|
|
|
|
|
ans = llm_chain.predict(user_msg= f"{full_result_string} ---\n\n {user_question}") |
|
|
|
return ans |
|
|
|
demo = gr.Interface( |
|
|
|
fn=answer, |
|
inputs=["text"], |
|
outputs=["text"], |
|
title="Ask Sonity", |
|
) |
|
demo.launch(share=True) |
|
|