ia_back / rag.py
Ilyas KHIAT
first push
a336311
raw
history blame
2.25 kB
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_openai import OpenAIEmbeddings
from langchain_community.vectorstores import FAISS
from langchain_core.documents import Document
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import PromptTemplate
from uuid import uuid4
from prompt import *
from pydantic import BaseModel, Field
from dotenv import load_dotenv
import os
from langchain_core.tools import tool
import unicodedata
load_dotenv()
index_name = os.environ.get("INDEX_NAME")
# Global initialization
embedding_model = "text-embedding-3-small"
embedding = OpenAIEmbeddings(model=embedding_model)
# vector_store = PineconeVectorStore(index=index_name, embedding=embedding)
class sphinx_output(BaseModel):
question: str = Field(description="The question to ask the user to test if they read the entire book")
answers: list[str] = Field(description="The possible answers to the question to test if the user read the entire book")
llm = ChatOpenAI(model="gpt-4o-mini", max_tokens=300, temperature=0.5)
def get_random_chunk(chunks: list[str]) -> str:
return chunks[tool.random_int(0, len(chunks) - 1)]
def get_vectorstore(chunks: list[str]) -> FAISS:
vector_store = FAISS(index=index_name, embedding=embedding)
for chunk in chunks:
document = Document(text=chunk, id=str(uuid4()))
vector_store.index(document)
return vector_store
def generate_stream(query:str,messages = [], model = "gpt-4o-mini", max_tokens = 300, temperature = 0.5,index_name="",stream=True,vector_store=None):
try:
print("init chat")
print("init template")
prompt = PromptTemplate.from_template(template)
print("retreiving context")
context = retreive_context(query=query,index=index_name,vector_store=vector_store)
print(f"Context: {context}")
llm_chain = prompt | llm | StrOutputParser()
print("streaming")
if stream:
return llm_chain.stream({"context":context,"history":messages,"query":query})
else:
return llm.invoke(query)
except Exception as e:
print(e)
return False