File size: 5,053 Bytes
dc2e0ab 93d8bb2 dc2e0ab 93d8bb2 dc2e0ab 9ce513b bcce814 dc2e0ab 8a8b067 93d8bb2 dc2e0ab bcce814 dc2e0ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import numpy as np
import pandas as pd
import time
from sentence_transformers import SentenceTransformer
from redis.commands.search.field import VectorField
from redis.commands.search.field import TextField
from redis.commands.search.field import TagField
from redis.commands.search.query import Query
import redis
from tqdm import tqdm
import google.generativeai as palm
import pandas as pd
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
import os
import gradio as gr
import io
from langchain.llms import GooglePalm
import pandas as pd
#from yolopandas import pd
from langchain.embeddings import GooglePalmEmbeddings
from langchain.memory import ConversationBufferMemory
from dotenv import load_dotenv
load_dotenv()
redis_conn = redis.Redis(
host='redis-15860.c322.us-east-1-2.ec2.cloud.redislabs.com',
port=15860,
password='PVnvSZI5nISPsrxxhCHZF3pfZWI7YAIG')
'''
df = pd.read_csv("coms3.csv")
print(list(df))
print(df['item_keywords'].sample(2))
company_metadata = df.to_dict(orient='index')
model = SentenceTransformer('sentence-transformers/all-distilroberta-v1')
item_keywords = [company_metadata[i]['item_keywords'] for i in company_metadata.keys()]
item_keywords_vectors = []
for sentence in tqdm(item_keywords):
s = model.encode(sentence)
item_keywords_vectors.append(s)
print(company_metadata[0])
def load_vectors(client, company_metadata, vector_dict, vector_field_name):
p = client.pipeline(transaction=False)
for index in company_metadata.keys():
#hash key
#print(index)
#print(company_metadata[index]['company_l_id'])
try:
key=str('company:'+ str(index)+ ':' + company_metadata[index]['primary_key'])
except:
print(key)
continue
#hash values
item_metadata = company_metadata[index]
item_keywords_vector = vector_dict[index].astype(np.float32).tobytes()
item_metadata[vector_field_name]=item_keywords_vector
# HSET
p.hset(key,mapping=item_metadata)
p.execute()
def create_flat_index (redis_conn,vector_field_name,number_of_vectors, vector_dimensions=512, distance_metric='L2'):
redis_conn.ft().create_index([
VectorField(vector_field_name, "FLAT", {"TYPE": "FLOAT32", "DIM": vector_dimensions, "DISTANCE_METRIC": distance_metric, "INITIAL_CAP": number_of_vectors, "BLOCK_SIZE":number_of_vectors }),
TagField("company_l_id"),
TextField("company_name"),
TextField("item_keywords"),
TagField("industry")
])
ITEM_KEYWORD_EMBEDDING_FIELD='item_keyword_vector'
TEXT_EMBEDDING_DIMENSION=768
NUMBER_COMPANIES=1000
print ('Loading and Indexing + ' + str(NUMBER_COMPANIES) + 'companies')
#flush all data
redis_conn.flushall()
#create flat index & load vectors
create_flat_index(redis_conn, ITEM_KEYWORD_EMBEDDING_FIELD,NUMBER_COMPANIES,TEXT_EMBEDDING_DIMENSION,'COSINE')
load_vectors(redis_conn,company_metadata,item_keywords_vectors,ITEM_KEYWORD_EMBEDDING_FIELD)
'''
model = SentenceTransformer('sentence-transformers/all-distilroberta-v1')
ITEM_KEYWORD_EMBEDDING_FIELD='item_keyword_vector'
TEXT_EMBEDDING_DIMENSION=768
NUMBER_PRODUCTS=1000
prompt = PromptTemplate(
input_variables=["company_description"],
template='Create comma seperated company keywords to perform a query on a company dataset for this user input'
)
template = """You are a chatbot. Be kind, detailed and nice. Present the given queried search result in a nice way as answer to the user input. dont ask questions back! just take the given context
{chat_history}
Human: {user_question}
Chatbot:
"""
prompt = PromptTemplate(
input_variables=["chat_history", "user_question"],
template=template
)
chat_history= ""
def answer(user_input):
llm = GooglePalm(temperature=0, google_api_key=os.environ['PALM'])
chain = LLMChain(llm=llm, prompt=prompt)
keywords = chain.run({'user_question':user_input, 'chat_history':chat_history})
topK=3
#vectorize the query
query_vector = model.encode(keywords).astype(np.float32).tobytes()
q = Query(f'*=>[KNN {topK} @{ITEM_KEYWORD_EMBEDDING_FIELD} $vec_param AS vector_score]').sort_by('vector_score').paging(0,topK).return_fields('vector_score','item_name','item_id','item_keywords').dialect(2)
params_dict = {"vec_param": query_vector}
#Execute the query
results = redis_conn.ft().search(q, query_params = params_dict)
full_result_string = ''
for company in results.docs:
full_result_string += company.id + ' ' + company.item_keywords + "\n\n\n"
memory = ConversationBufferMemory(memory_key="chat_history")
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
verbose=False,
memory=memory,
)
#ans2 = llm_chain.predict(full_result_string)
ans = llm_chain.predict(user_question= f"{full_result_string} ---\n\n {user_input}")
return ans
demo = gr.Interface(
fn=answer,
inputs=["text"],
outputs=["text"],
title="Ask Sonity",
)
demo.launch(share=True)
|