app1 / model.py
Sandaruth
temp
58df5ed
raw
history blame
2.1 kB
import os
from dotenv import load_dotenv
from prompts import qa_template_V0, qa_template_V1, qa_template_V2
# Load environment variables from .env file
load_dotenv()
# Access the value of OPENAI_API_KEY
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
os.environ["OPENAI_API_KEY"] = OPENAI_API_KEY
from langchain_openai import ChatOpenAI
# llm_OpenAi = ChatOpenAI(model="gpt-3.5-turbo", temperature=0,)
from langchain.chat_models import ChatAnyscale
ANYSCALE_ENDPOINT_TOKEN=os.environ.get("ANYSCALE_ENDPOINT_TOKEN")
anyscale_api_key =ANYSCALE_ENDPOINT_TOKEN
llm=ChatAnyscale(anyscale_api_key=anyscale_api_key,temperature=0, model_name='mistralai/Mistral-7B-Instruct-v0.1', streaming=False)
## Create embeddings and splitter
from langchain.embeddings import HuggingFaceBgeEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
# Create Embeddings
model_name = "BAAI/bge-large-en"
embedding = HuggingFaceBgeEmbeddings(
model_name = model_name,
# model_kwargs = {'device':'cuda'},
encode_kwargs = {'normalize_embeddings': True}
)
# Create Splitter
splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=100,
)
from langchain_community.vectorstores import FAISS
# persits_directory="./faiss_Test02_500_C_BGE_large"
# persits_directory="./faiss_V03_C500_BGE_large-final"
# persits_directory="./faiss_V03_C1000_BGE_large-final"
# persits_directory="./faiss_V04_C500_BGE_large-final"
persits_directory="./faiss_V04_C500_BGE_large_web_doc_with_split-final"
vectorstore= FAISS.load_local(persits_directory, embedding)
# Define a custom prompt for Unser manual
from langchain.prompts import PromptTemplate
QA_PROMPT = PromptTemplate(input_variables=["context", "question"],template=qa_template_V2,)
# Chain for Web
from langchain.chains import RetrievalQA
Web_qa = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever = vectorstore.as_retriever(search_kwargs={"k": 4}),
return_source_documents= True,
input_key="question",
chain_type_kwargs={"prompt": QA_PROMPT},
)