|
import os
|
|
import pinecone
|
|
import openai
|
|
import gradio as gr
|
|
from dotenv import load_dotenv
|
|
from langchain_community.embeddings import HuggingFaceEmbeddings
|
|
from langchain.text_splitter import CharacterTextSplitter
|
|
from langchain.docstore.document import Document
|
|
import boto3
|
|
|
|
|
|
load_dotenv()
|
|
|
|
|
|
openai.api_key = os.getenv("OPENAI_API_KEY")
|
|
pinecone_api_key = os.getenv("PINECONE_API_KEY")
|
|
aws_access_key = os.getenv("AWS_ACCESS_KEY_ID")
|
|
aws_secret_key = os.getenv("AWS_SECRET_ACCESS_KEY")
|
|
bucket_name = 'amtrak-superliner-ai-poc'
|
|
txt_file_name = 'combined_extracted_text.txt'
|
|
index_name = "amtrak-acela-ai-demo"
|
|
|
|
|
|
pc = pinecone.Pinecone(api_key=pinecone_api_key)
|
|
|
|
|
|
s3_client = boto3.client(
|
|
's3',
|
|
aws_access_key_id=aws_access_key,
|
|
aws_secret_access_key=aws_secret_key,
|
|
region_name='us-east-1'
|
|
)
|
|
|
|
|
|
def initialize_pinecone_index(index_name, embedding_dim):
|
|
available_indexes = pc.list_indexes().names()
|
|
if index_name not in available_indexes:
|
|
pc.create_index(
|
|
name=index_name,
|
|
dimension=embedding_dim,
|
|
metric="cosine",
|
|
spec=pinecone.ServerlessSpec(
|
|
cloud="aws",
|
|
region="us-east-1"
|
|
)
|
|
)
|
|
return pc.Index(index_name)
|
|
|
|
embedding_dim = 768
|
|
index = initialize_pinecone_index(index_name, embedding_dim)
|
|
|
|
|
|
embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/msmarco-distilbert-base-v4")
|
|
|
|
|
|
def download_text_from_s3(s3_client, bucket_name, file_name):
|
|
local_txt_path = os.path.join(os.getcwd(), file_name)
|
|
s3_client.download_file(bucket_name, file_name, local_txt_path)
|
|
with open(local_txt_path, 'r', encoding='utf-8') as f:
|
|
return f.read()
|
|
|
|
doc_text = download_text_from_s3(s3_client, bucket_name, txt_file_name)
|
|
|
|
|
|
def process_text_into_embeddings(doc_text):
|
|
text_splitter = CharacterTextSplitter(separator='\n', chunk_size=3000, chunk_overlap=500)
|
|
docs = text_splitter.split_documents([Document(page_content=doc_text)])
|
|
doc_embeddings = embedding_model.embed_documents([doc.page_content for doc in docs])
|
|
return docs, doc_embeddings
|
|
|
|
|
|
def check_embeddings_in_pinecone(index):
|
|
try:
|
|
stats = index.describe_index_stats()
|
|
return stats['total_vector_count'] > 0
|
|
except Exception as e:
|
|
print(f"Error checking Pinecone index: {e}")
|
|
return False
|
|
|
|
|
|
if not check_embeddings_in_pinecone(index):
|
|
split_docs, doc_embeddings = process_text_into_embeddings(doc_text)
|
|
for i, doc in enumerate(split_docs):
|
|
metadata = {'content': doc.page_content}
|
|
index.upsert(vectors=[(str(i), doc_embeddings[i], metadata)])
|
|
else:
|
|
print("Embeddings already exist in Pinecone. Skipping embedding process.")
|
|
|
|
|
|
def get_model_response(human_input, chat_history=None):
|
|
try:
|
|
|
|
query_embedding = embedding_model.embed_query(human_input)
|
|
|
|
|
|
search_results = index.query(vector=query_embedding, top_k=3, include_metadata=True)
|
|
|
|
|
|
context_list = []
|
|
images = []
|
|
|
|
|
|
for ind, result in enumerate(search_results['matches']):
|
|
document_content = result.get('metadata', {}).get('content', 'No content found')
|
|
image_url = result.get('metadata', {}).get('image_path', None)
|
|
figure_desc = result.get('metadata', {}).get('figure_description', '')
|
|
|
|
context_list.append(f"Document {ind+1}: {document_content}")
|
|
|
|
if image_url and figure_desc:
|
|
images.append((figure_desc, image_url))
|
|
|
|
|
|
context_string = '\n\n'.join(context_list)
|
|
|
|
|
|
messages = [
|
|
{"role": "system", "content": "You are a helpful assistant."},
|
|
{"role": "user", "content": f"Here is some context:\n{context_string}\n\nUser's question: {human_input}"}
|
|
]
|
|
|
|
|
|
response = openai.ChatCompletion.create(
|
|
model="gpt-3.5-turbo",
|
|
messages=messages,
|
|
max_tokens=500,
|
|
temperature=0.5
|
|
)
|
|
|
|
|
|
output_text = response['choices'][0]['message']['content'].strip()
|
|
|
|
|
|
return output_text, images
|
|
|
|
except Exception as e:
|
|
return f"Error invoking model: {str(e)}", []
|
|
|
|
|
|
def get_model_response_with_history(human_input, chat_history=None):
|
|
if chat_history is None:
|
|
chat_history = []
|
|
|
|
output_text, chat_history = get_model_response(human_input, chat_history)
|
|
|
|
|
|
def process_image(image_data):
|
|
if isinstance(image_data, list):
|
|
|
|
return " ".join(str(item) for item in image_data)
|
|
return str(image_data)
|
|
|
|
if chat_history:
|
|
|
|
for message in chat_history:
|
|
if "alt_text" in message:
|
|
message["alt_text"] = process_image(message.get("alt_text", ""))
|
|
|
|
return output_text
|
|
|
|
|
|
gr_interface = gr.ChatInterface(
|
|
fn=get_model_response_with_history,
|
|
title="Maintenance Assistant",
|
|
description="Ask questions related to the RMM documents."
|
|
)
|
|
|
|
|
|
gr_interface.launch() |