Spaces:
Runtime error
Runtime error
import langchain.document_loaders | |
from langchain.document_loaders import DirectoryLoader, PyPDFLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.schema import Document | |
from langchain.embeddings import OpenAIEmbeddings | |
from langchain.vectorstores.chroma import Chroma | |
import os | |
import shutil | |
from langchain.vectorstores.chroma import Chroma | |
from langchain.embeddings import OpenAIEmbeddings | |
from langchain.chat_models import ChatOpenAI | |
from langchain.prompts import ChatPromptTemplate | |
def get_chunks(file_path): | |
loader = PyPDFLoader(file_path) | |
documents = loader.load() | |
text_splitter = RecursiveCharacterTextSplitter( | |
chunk_size=300, | |
chunk_overlap=100, | |
length_function=len, | |
add_start_index=True, | |
) | |
chunks = text_splitter.split_documents(documents) | |
return chunks | |
def get_vectordb(chunks, CHROMA_PATH): | |
CHROMA_PATH = f"../../chroma/{CHROMA_PATH}" | |
if os.path.exists(CHROMA_PATH): | |
db = Chroma(persist_directory=CHROMA_PATH, embedding_function=OpenAIEmbeddings()) | |
else: | |
db = Chroma.from_documents( | |
chunks, OpenAIEmbeddings(), persist_directory=CHROMA_PATH | |
) | |
db.persist() | |
print(f"Saved {len(chunks)} chunks to {CHROMA_PATH}.") | |
return db | |
def gen_sample(text, decision, db): | |
PROMPT_TEMPLATE = """ | |
Answer the question based only on the following context: | |
{context} | |
--- | |
Answer the question based on the above context: {question} | |
""" | |
query_text = f""" | |
Act as the author of a Choose Your Own Adventure Book. This book is special as it is based on existing material. | |
Now, as with any choose your own adventure book, there are inifinite paths based on the choices a user makes. | |
Given some relevant text and the decision taken with respect to the relevant text, generate the next part of the story. | |
It should be within 6-8 sentences and be coherent as it were actually part of the story. | |
Relevant: {text} | |
Decision: {decision} | |
""" | |
results = db.similarity_search_with_relevance_scores(query_text, k=5) | |
context_text = "\n\n---\n\n".join([doc.page_content for doc, _score in results]) | |
prompt_template = ChatPromptTemplate.from_template(PROMPT_TEMPLATE) | |
prompt = prompt_template.format(context=context_text, question=query_text) | |
model = ChatOpenAI() | |
response_text = model.predict(prompt) | |
return eval(response_text) | |