Spaces:
Running
Running
from langchain_openai import OpenAIEmbeddings | |
from langchain_community.vectorstores import FAISS | |
from langchain_core.documents import Document | |
from langchain_openai import ChatOpenAI | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.prompts import PromptTemplate | |
from uuid import uuid4 | |
from prompt import * | |
import random | |
from itext2kg.models import KnowledgeGraph | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
import faiss | |
from langchain_community.docstore.in_memory import InMemoryDocstore | |
from pydantic import BaseModel, Field | |
from dotenv import load_dotenv | |
import os | |
from langchain_core.tools import tool | |
import pickle | |
import unicodedata | |
load_dotenv() | |
index_name = os.environ.get("INDEX_NAME") | |
# Global initialization | |
embedding_model = "text-embedding-3-small" | |
embedding = OpenAIEmbeddings(model=embedding_model) | |
# vector_store = PineconeVectorStore(index=index_name, embedding=embedding) | |
def advanced_graph_to_json(graph:KnowledgeGraph): | |
nodes = [] | |
edges = [] | |
for node in graph.entities: | |
node_id = node.name.replace(" ", "_") | |
label = node.name | |
type = node.label | |
nodes.append({"id": node_id, "label": label, "type": type}) | |
for relationship in graph.relationships: | |
source = relationship.startEntity | |
source_id = source.name.replace(" ", "_") | |
target = relationship.endEntity | |
target_id = target.name.replace(" ", "_") | |
label = relationship.name | |
edges.append({"source": source_id, "label": label, "cible": target_id}) | |
return {"noeuds": nodes, "relations": edges} | |
with open("kg_ia_signature.pkl", "rb") as file: | |
loaded_graph = pickle.load(file) | |
graph = advanced_graph_to_json(loaded_graph) | |
print("Graph loaded") | |
with open("chunks_ia_signature.pkl", "rb") as file: | |
chunks = pickle.load(file) | |
print("Chunks loaded") | |
with open("scenes.pkl", "rb") as file: | |
scenes = pickle.load(file) | |
print("Scenes loaded") | |
class sphinx_output(BaseModel): | |
question: str = Field(description="The question to ask the user to test if they read the entire book") | |
answers: list[str] = Field(description="The possible answers to the question to test if the user read the entire book") | |
class verify_response_model(BaseModel): | |
response: str = Field(description="The response from the user to the question") | |
answers: list[str] = Field(description="The possible answers to the question to test if the user read the entire book") | |
initial_question: str = Field(description="The question asked to the user to test if they read the entire book") | |
class verification_score(BaseModel): | |
score: float = Field(description="The score of the user's response from 0 to 10 to the question") | |
llm = ChatOpenAI(model="gpt-4o", max_tokens=300, temperature=0.5) | |
def split_texts(text : str) -> list[str]: | |
splitter = RecursiveCharacterTextSplitter( | |
chunk_size=1000, | |
chunk_overlap=200, | |
length_function=len, | |
is_separator_regex=False, | |
) | |
return splitter.split_text(text) | |
######################################################################### | |
### PAR ICI , CHOISIR UNE SCENE SPECIFIQUE DANS L'ARGUMENT DE LA FONCTION | |
def get_random_chunk(scene_specific = [4,6]) : # scene_specific = None signifie qu'on considère tout le récit | |
if scene_specific: | |
scene_specific_content = [scenes[i-1] for i in scene_specific] | |
scene_specific_content = " ".join(scene_specific_content) | |
chunks_scene = split_texts(scene_specific_content) | |
print(f"Scene {scene_specific} has {len(chunks_scene)} chunks") | |
print([chunk[0:50] for chunk in chunks_scene]) | |
print('---') | |
chunk_chosen = chunks_scene[random.randint(0, len(chunks_scene) - 1)] | |
print(f"Chosen chunk: {chunk_chosen}") | |
return chunk_chosen, scene_specific | |
return chunks[random.randint(0, len(chunks) - 1)],scene_specific | |
def get_vectorstore() -> FAISS: | |
index = faiss.IndexFlatL2(len(embedding.embed_query("hello world"))) | |
vector_store = FAISS( | |
embedding_function=embedding, | |
index=index, | |
docstore=InMemoryDocstore(), | |
index_to_docstore_id={}, | |
) | |
documents = [Document(page_content=chunk) for chunk in chunks] | |
uuids = [str(uuid4()) for _ in range(len(documents))] | |
vector_store.add_documents(documents=documents, ids=uuids) | |
return vector_store | |
vectore_store = get_vectorstore() | |
def generate_sphinx_response() -> sphinx_output: | |
writer = "Laurent Tripied" | |
book_name = "Limites de l'imaginaire ou limites planétaires" | |
summary = summary_text | |
excerpt , scene_number = get_random_chunk() | |
if scene_number: | |
summary = "scene " + str(scene_number) | |
prompt = PromptTemplate.from_template(template_sphinx) | |
structured_llm = llm.with_structured_output(sphinx_output) | |
# Create an LLM chain with the prompt and the LLM | |
llm_chain = prompt | structured_llm | |
return llm_chain.invoke({"writer":writer,"book_name":book_name,"summary":summary,"excerpt":excerpt}) | |
############################################################# | |
### PAR ICI , CHOISIR LE DEGRE DE SEVERITE DE LA VERIFICATION | |
def verify_response(response:str,answers:list[str],question:str) -> bool: | |
prompt = PromptTemplate.from_template(template_verify) | |
structured_llm = llm.with_structured_output(verification_score) | |
llm_chain = prompt | structured_llm | |
score = llm_chain.invoke({"response":response,"answers":answers,"initial_question":question}) | |
if score.score >= 5: | |
return True | |
def retrieve_context_from_vectorestore(query:str) -> str: | |
retriever = vectore_store.as_retriever(search_type="mmr", search_kwargs={"k": 3}) | |
return retriever.invoke(query) | |
def generate_stream(query:str,messages = [], model = "gpt-4o-mini", max_tokens = 300, temperature = 1,index_name="",stream=True,vector_store=None): | |
try: | |
print("init chat") | |
print("init template") | |
prompt = PromptTemplate.from_template(template) | |
writer = "Laurent Tripied" | |
name_book = "Limites de l'imaginaire ou limites planétaires" | |
name_icon = "Magritte" | |
kg = loaded_graph | |
print("retreiving context") | |
context = retrieve_context_from_vectorestore(query) | |
print(f"Context: {context}") | |
llm_chain = prompt | llm | StrOutputParser() | |
print("streaming") | |
if stream: | |
return llm_chain.stream({"name_book":name_book,"writer":writer,"name_icon":name_icon,"kg":graph,"context":context,"query":query}) | |
else: | |
return llm_chain.invoke({"name_book":name_book,"writer":writer,"name_icon":name_icon,"kg":graph,"context":context,"query":query}) | |
except Exception as e: | |
print(e) | |
return False |