# utils.py
import streamlit as st
import os
import re
import pandas as pd
from langchain_pinecone import PineconeVectorStore
from langchain_huggingface import HuggingFaceEmbeddings
from langchain.docstore.document import Document
from dotenv import load_dotenv
from pinecone import Pinecone
from openai import OpenAI
# Load environment variables
load_dotenv()
# Initialize OpenAI client
def get_openai_client():
return OpenAI(
organization=os.getenv('OPENAI_ORG_ID'),
project=os.getenv('OPENAI_PROJECT_ID')
)
# Initialize embeddings
@st.cache_resource
def initialize_embeddings(model_name: str = "all-mpnet-base-v2"):
embeddings = HuggingFaceEmbeddings(model_name=model_name)
return embeddings
# Initialize vector store
@st.cache_resource
def initialize_vector_store(pinecone_api_key: str, index_name: str):
pc = Pinecone(api_key=pinecone_api_key)
index = pc.Index(index_name)
embeddings = initialize_embeddings()
vector_store = PineconeVectorStore(index=index, embedding=embeddings, text_key='content')
return vector_store, embeddings
# Fetch documents based on query and filters
def get_docs(vector_store, embeddings, query, country=[], vulnerability_cat=[]):
if not country:
country = "All Countries"
if not vulnerability_cat:
filters = None if country == "All Countries" else {'country': {'$in': country}}
else:
if country == "All Countries":
filters = {'vulnerability_cat': {'$in': vulnerability_cat}}
else:
filters = {
'country': {'$in': country},
'vulnerability_cat': {'$in': vulnerability_cat}
}
docs = vector_store.similarity_search_by_vector_with_score(
embeddings.embed_query(query),
k=20,
filter=filters,
)
docs_dict = [{**x[0].metadata, "score": x[1], "content": x[0].page_content} for x in docs]
df_docs = pd.DataFrame(docs_dict).reset_index()
df_docs['ref_id'] = df_docs.index + 1
ls_dict = [
Document(
page_content=row['content'],
metadata={
'country': row['country'],
'document': row['document'],
'page': row['page'],
'file_name': row['file_name'],
'ref_id': row['ref_id'],
'vulnerability_cat': row['vulnerability_cat'],
'score': row['score']
}
)
for _, row in df_docs.iterrows()
]
return ls_dict
# Extract references from the response
def get_refs(docs, res):
res = res.lower()
pattern = r'ref\. (\d+)'
ref_ids = [int(match) for match in re.findall(pattern, res)]
result_str = ""
for doc in docs:
ref_id = doc.metadata['ref_id']
if ref_id in ref_ids:
metadata = doc.metadata
if metadata['document'] == "Supplementary":
result_str += (
f"**Ref. {ref_id} [{metadata['country']} {metadata['document']}: {metadata['file_name']} p{metadata['page']}; "
f"vulnerabilities: {metadata['vulnerability_cat']}]:** *'{doc.page_content}'*
"
)
else:
result_str += (
f"**Ref. {ref_id} [{metadata['country']} {metadata['document']} p{metadata['page']}; "
f"vulnerabilities: {metadata['vulnerability_cat']}]:** *'{doc.page_content}'*
"
)
return result_str
# Construct the prompt for the model
def get_prompt(prompt_template, docs, input_query):
context = ' - '.join([
f"&&& [ref. {d.metadata['ref_id']}] {d.metadata['document']} &&&: {d.page_content}"
for d in docs
])
prompt = f"{prompt_template}; Context: {context}; Question: {input_query}; Answer:"
return prompt
# Execute the query and generate the response
def run_query(client, prompt, docs, res_box):
stream = client.chat.completions.create(
model="gpt-4o-mini-2024-07-18",
messages=[{"role": "user", "content": prompt}],
stream=True,
)
report = []
for chunk in stream:
if chunk.choices[0].delta.content is not None:
report.append(chunk.choices[0].delta.content)
result = "".join(report).strip()
res_box.success(result)
references = get_refs(docs, result)
return references